物理の駅 Physics station by 現役研究者

テクノロジーは共有されてこそ栄える

Pythonで引数付きのマルチスレッドとマルチプロセスを簡単に実装する

concurrent.futuresThreadPoolExecutor と、 ProcessPoolExecutor を使うのがセオリーだろう。 スレッド内、プロセス内で例外が出たらprintする機能、スレッドやプロセスを順番に回収して残りスレッド数・プロセス数を表示する機能を実装してある。 使い方は __main__ を参考にして欲しい。

import time
import random
import concurrent.futures
from logging import StreamHandler, Formatter, INFO, getLogger


def test_func(i):
    import random
    time.sleep(3 + random.random() * 3)
    if i % 5 == 4:
        raise Exception("Error: i % 5 == 4")


def init_logger():
    handler = StreamHandler()
    handler.setLevel(INFO)
    handler.setFormatter(Formatter("[%(asctime)s] %(message)s"))
    logger = getLogger()
    if not logger.hasHandlers():
        logger.addHandler(handler)
        logger.setLevel(INFO)


def multi_process(func, max_workers, targets):
    init_logger()
    getLogger().info("multi_process begin")

    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(
            func, *target if type(target) == tuple else (target,)) for target in targets]

        while len(futures) != 0:
            futures_buf = []
            done_flag = False
            for f in futures:
                if f.done():
                    if f.exception() != None:
                        print(f.exception())
                    done_flag = True
                else:
                    futures_buf.append(f)
            if done_flag:
                getLogger().info(
                    f"multi_process {len(futures_buf)} tasks rest")
            futures = futures_buf
            time.sleep(0.01)

    getLogger().info("multi_process end")


def multi_thread(func, max_workers, targets):
    init_logger()
    getLogger().info("multi_thread begin")

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(
            func, *target if type(target) == tuple else (target,)) for target in targets]

        while len(futures) != 0:
            futures_buf = []
            done_flag = False
            for f in futures:
                if f.done():
                    if f.exception() != None:
                        print(f.exception())
                    done_flag = True
                else:
                    futures_buf.append(f)
            if done_flag:
                getLogger().info(f"multi_thread {len(futures_buf)} tasks rest")
            futures = futures_buf
            time.sleep(0.01)

    getLogger().info("multi_thread end")


if __name__ == "__main__":

    multi_process(test_func, max_workers=4, targets=range(10))
    multi_thread(test_func, max_workers=4, targets=range(10))