物理の駅 Physics station by 現役研究者

テクノロジーは共有されてこそ栄える

C++でstd::asyncを使ったマルチスレッド処理

#include <future>
#include <thread>
#include <vector>
#include <iostream>
#include <mutex>

// 標準出力のmutex
std::mutex mtx_;

int long_calc(int i) {

    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    mtx_.lock();
    std::cout << "a" << i << std::endl;
    mtx_.unlock();

    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    mtx_.lock();
    std::cout << "b" << i << std::endl;
    mtx_.unlock();

    std::this_thread::sleep_for(std::chrono::milliseconds(1000));
    mtx_.lock();
    std::cout << "c" << i << std::endl;
    mtx_.unlock();
    if (i % 10 == 5) {
        // たまに例外を飛ばす
        throw std::exception("exception in long_func");
    }
    return 0;
}
class MyClass {
    // 最大スレッド数
    const int max_num_threads = 5;

    std::vector<std::future<int>> v_status;

    // max_num_threads未満になるまで待つ関数
    void wait_threads() {
        if (v_status.size() < max_num_threads)return;
        else {
            while (true) {
                for (int i = 0; i < v_status.size(); i++) {
                    auto status = v_status[i].wait_for(std::chrono::milliseconds(100));
                    if (status != std::future_status::timeout) {
                        try {
                            // timeout以外はスレッドが終わったことを意味するのでgetして戻り値と例外を取得
                            int j = v_status[i].get();
                        }
                        catch (std::exception & ex) {
                            std::cout << ex.what() << std::endl;
                        }
                        catch (...) {
                            std::cout << "Unknown exception" << std::endl;
                        }
                        // 終わったスレッドを削除
                        v_status.erase(v_status.begin() + i, v_status.begin() + i + 1);
                    }
                }
                if (v_status.size() < max_num_threads)return;
            }
        }
    }
public:
    // マルチスレッド可能な関数
    void my_func(int i) {
        wait_threads();
        v_status.emplace_back(std::async(std::launch::async, long_calc, i));
    }
    // 最後に全てのスレッドを同期する関数
    void wait_for_completion() {
        for (int i = 0; i < v_status.size(); i++) {
            v_status[i].wait();
            try {
                // getして戻り値と例外を取得
                int j = v_status[i].get();
            }
            catch (std::exception & ex) {
                std::cout << ex.what() << std::endl;
            }
            catch (...) {
                std::cout << "Unknown exception" << std::endl;
            }
        }
    }
};

int main() {
    MyClass a;
    for (int i = 0; i < 20; i++) {
        a.my_func(i);
        std::this_thread::sleep_for(std::chrono::milliseconds(200));
    }
    a.wait_for_completion();
    return 0;
}