example.data_augmentation_mp のソースコード

"""データ拡張をマルチプロセスで行うコード例.

Summary:
    このエグザンプルコードでは、時間のかかるデータ拡張においてマルチプロセス処理を行います。

Example:
    .. code-block:: bash

        python3 example/data_augmentation.py "/path/to/dir" "" "lens" 224 10 -0.1 0.1 0.01 4

Source code:
    `data_augmentation_mp.py <https://github.com/yKesamaru/FACE01_DEV/blob/master/example/data_augmentation_mp.py>`_
"""
import concurrent.futures
# Operate directory: Common to all examples
import os.path
import sys
from concurrent.futures import ProcessPoolExecutor
from glob import glob

dir: str = os.path.dirname(__file__)
parent_dir, _ = os.path.split(dir)
sys.path.append(parent_dir)

from typing import Dict, Optional

from face01lib.Initialize import Initialize
from face01lib.logger import Logger
from face01lib.utils import Utils

# Initialize
CONFIG: Dict = Initialize('DEFAULT', 'info').initialize()
# Set up logger
logger = Logger(CONFIG['log_level']).logger(__file__, CONFIG['RootDir'])
"""Initialize and Setup logger.
"""

utils = Utils(CONFIG['log_level'])


def _process_data(
    dir: str,
    size: int,
    initial_value: float,
    closing_value: float,
    step_value: float,
    num_jitters: int
):
    """_process_data 樽型歪みとジッター処理を行います

    Args:
        dir (str): ディレクトリパス
        size (int): 画像サイズ
        initial_value (float): 初期値
        closing_value (float): 最終値
        step_value (float): ステップ値
        num_jitters (int): ジッター回数
    """
    os.chdir(dir)
    utils.distort_barrel(
        dir,
        size=size,
        initial_value=initial_value,
        closing_value=closing_value,
        step_value=step_value,
    )
    utils.get_jitter_image(
        dir,
        num_jitters=num_jitters,
        size=size,
        disturb_color=True,
    )


[ドキュメント] def main( dir_path: str, size: int = 224, num_jitters: int = 10, initial_value: float = -0.1, closing_value: float = 0.1, step_value: float = 0.01, max_workers: Optional[int] = None, ): """main メインメソッド Args: dir_path (str): ディレクトリパス size (int, optional): 画像サイズ. Defaults to 224. num_jitters (int, optional): ジッター処理回数. Defaults to 10. initial_value (float, optional): 初期値. Defaults to -0.1. closing_value (float, optional): 最終値. Defaults to 0.1. step_value (float, optional): ステップ値. Defaults to 0.01. max_workers (Optional[int], optional): 並行処理値. Defaults to None. """ data_dir = dir_path directory_list = [] # Search recursively under the data directory for root, dirs, _ in os.walk(data_dir): for directory in dirs: directory_path = os.path.join(root, directory) # Get all files in directory_path and store them in files files = glob(os.path.join(directory_path, "*")) # skip the for statement if any element of the list files contains the string jitter if any('jitter' in file for file in files): continue else: directory_list.append(directory_path) # Execute data processing in parallel for each directory with ProcessPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit( _process_data, dir, size, initial_value, closing_value, step_value, num_jitters, ) for dir in directory_list ] # Wait until all processing is complete for future in concurrent.futures.as_completed(futures): future.result()
if __name__ == '__main__': args: list = sys.argv main( args[1], int(args[2]), int(args[3]), float(args[4]), float(args[5]), float(args[6]), max_workers=4 )