浅析 Fork/Join 基本概念和实战

478 阅读5分钟

「这是我参与2022首次更文挑战的第7天,活动详情查看:2022首次更文挑战」。

在 JDK 1.7 版本中提供了 Fork/Join 并行执行任务框架,它主要的作用是把大任务分割成若干个小任务,再对每个小任务得到的结果进行汇总,此种开发方法也叫做分治编程,分治编程可以极大的利用 CPU 资源,提高任务执行效率。

Fork/Join 分治编程

在 JDK 中并行执行框架 Fork-Join 使用了 “工作窃取(work-stealing)”算法,它是指某个线程从其他队列中窃取任务来执行。

比如要完成一个比较大的任务,完全可以把这个大的任务分割为若千互不依赖的子任务/小任务,为了更加方便地管理这些任务,于是把这些子任务分别放到不同的队列里,这时就会处理,完成任务的线程与其等着,不如去帮助其他线程分担要执行的任务,于是它就去其他线程的队列里窃取一一个任务来执行,这就是所谓的“工作窃取(work-stealing)” 算法。

工作窃取

ForkJoinPool与ThreadPoolExecutor有个很大的不同之处在于,ForkJoinPool存在引入了工作窃取设计,它是其性能保证的关键之一。工作窃取,就是允许空闲线程从繁忙线程的双端队列中窃取任务。默认情况下,工作线程从它自己的双端队列的头部获取任务。但是,当自己的任务为空时,线程会从其他繁忙线程双端队列的尾部中获取任务。这种方法,最大限度地减少了线程竞争任务的可能性。

ForkJoinPool的大部分操作都发生在工作窃取队列(work-stealingqueues)中,该队列由内部类WorkQueue实现。它是Deques的特殊形式,但仅支持三种操作方式:push、pop和poll(也称为窃取) 。在ForkJoinPool中,队列的读取有着严格的约束,push和pop仅能从其所属线程调用,而poll则可以从其他线程调用。

工作窃取的运行流程如下图所示:

  • 工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争;
  • 工作窃取算法缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。

为什么工作线程从队列头部获取,工作窃取从尾部窃取?

这样做的主要原因是为了提高性能,通过始终选择最近提交的任务,可以增加资源仍分配在CPU缓存中的机会,这样CPU处理起来要快一些。而窃取者之所以从尾部获取任务,则是为了降低线程之间的竞争可能,毕竟大家都从一个部分拿任务,竞争的可能要大很多。 此外,这样的设计还有一种考虑。由于任务是可分割的,那队列中较旧的任务最有可能粒度较大,因为它们可能还没有被分割,而空闲的线程则相对更有“精力”来完成这些粒度较大的任务。

分治算法

分治算法的基本思想是将一个规模为N的问题分解为K个规模较小的子问题,这些子问题相互独立且与原问题性质相同。求出子问题的解,就可得到原问题的解。即一种分目标完成程序算法,简单问题可用二分法完成。

分治法解题的一般步骤:

(1)分解,将要解决的问题划分成若干规模较小的同类问题;

(2)求解,当子问题划分得足够小时,用较简单的方法解决;

(3)合并,按原问题的要求,将子问题的解逐层合并构成原问题的解。

例子和代码可以参考后面的 ForkJoinPoolTest中累计求和的例子, 注释中有每个步骤的代码开始部分。

ForkJoinPool 和 ForkJoinTask

ForkJoinPool

ForkJoinPool 是用来执行 ForJoinTask 任务的任务池,区别于线程池的 Worker + Queue 的组合,而是维护了一个队列数组 WorkQuque(WorkQuque[]) 在提交任务和线程任务的时候大幅度减少碰撞。

构造方法代码如下:

public ForkJoinPool(int parallelism,
                    ForkJoinWorkerThreadFactory factory,
                    UncaughtExceptionHandler handler,
                    boolean asyncMode) {
    this(checkParallelism(parallelism),
         checkFactory(factory),
         handler,
         asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
         "ForkJoinPool-" + nextPoolId() + "-worker-");
    checkPermission();
}


private ForkJoinPool(int parallelism,
                     ForkJoinWorkerThreadFactory factory,
                     UncaughtExceptionHandler handler,
                     int mode,
                     String workerNamePrefix) {
    this.workerNamePrefix = workerNamePrefix;
    this.factory = factory;
    this.ueh = handler;
    this.config = (parallelism & SMASK) | mode;
    long np = (long)(-parallelism); // offset ctl counts
    this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}

参数的含义:

  • parallelism 指定并行级别(parallelism level)。ForkJoin 将根据这个设定来决定工作线程的数量。如果没有设置将使用 Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors())其实也就是 cpu 核心线程数。
  • factory ForkJoinPool 创建线程时,会通过factory 来创建,自定义需要实现 ForkJoinWorkerThreadFactory接口,默认使用 DefaultForkJoinWorkerThreadFactory
  • handler 指定异常处理器,当任务在运行中出错,将由设定的 handler 进行处理
  • mode 模式,设置队列工作模式:两种 FIFO_QUEUE, LIFO_QUEUE
  • workerNamePrefix 线程名的前缀

ForkJoinTask

ForkJoinTask 是 ForkJoinPool 的核心之一,它是任务的实际载体,定义了执行时间的具体逻辑和拆分逻辑。

ForkJoinTask 继承了 Future 接口,也可以当作是一个轻量级的 Future.

ForkJoinTask 是一个抽象类,它的方法有很多,最核心的方法是 fork() 方法和 join 方法, 承载了主要的任务协调作用,一个用于任务提交,一个用于获取结果

  • fork() 提交任务 : fork()方法用于向当前任务所运行的线程池中提交任务。如果当前线程是ForkJoinWorkerThread类型,将会放入该线程的工作队列,否则放入common线程池的工作队列中。
  • join() 获取任务结果: join()方法用于获取任务的执行结果。调用join()时,将阻塞当前线程直到对应的子任务完成运行并返回结果。

通常情况下,我们不需要直接继承 ForkJoinTask 类, 而只需要继承它的子类, Fork/Join 框架提供了ForkJoinTask 的三个子类:

  • RecursiveAction 用于递归执行且不需要返回结果的任务
  • RecursiveTask 用于递归执行且返回结果的任务
  • CountedCompleter:在执行完成任务后会触发一个自定义的钩子

ForkJoin 最适合纯粹的计算任务,也就是纯粹的函数计算,计算过程中都是独立运行的,没有外部数据/逻辑依赖。提交 ForkJoinPool 中的任务应该避免执行阻塞 I/O。

执行例子

通过实现 RecursiveTask实现 int a -> b 的累加,具体的代码如下:

public class ForkJoinPoolTest {

    static class MyForkJoinTask extends RecursiveTask<Integer> {

        // 每个任务的任务量
        private static final Integer MAX = 200;

        // 子任务开始计算的值
        private Integer startValue;
        // 子任务结束计算的值
        private Integer endValue;

        public MyForkJoinTask(Integer startValue, Integer endValue) {
            this.startValue = startValue;
            this.endValue = endValue;
        }

        @Override
        protected Integer compute() {
            //(2)求解,当子问题划分得足够小时,用较简单的方法解决;
            if (endValue - startValue < MAX) {
                System.out.println(Thread.currentThread().getName()+": 开始计算的部分:startValue = " + startValue + ";endValue = " + endValue);
                Integer totalValue = 0;
                for (int index = this.startValue; index <= this.endValue; index++) {
                    totalValue += index;
                }
                return totalValue;
            }
            //(1)分解,将要解决的问题划分成若干规模较小的同类问题;
            else {
                MyForkJoinTask subTask1 = new MyForkJoinTask(startValue, (startValue + endValue) / 2);
                subTask1.fork();
                MyForkJoinTask subTask2 = new MyForkJoinTask((startValue + endValue) / 2 + 1, endValue);
                subTask2.fork();
                // (3)合并,按原问题的要求,将子问题的解逐层合并构成原问题的解。
                return subTask1.join() + subTask2.join();
            }
        }
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        // 0 - 1000 累加
        ForkJoinTask<Integer> task = forkJoinPool.submit(new MyForkJoinTask(0, 1000));
        // 获取结果
        Integer result = task.get();
        // 结果打印
        System.out.println(result);
    }
}

输出结果如下:

ForkJoinPool-1-worker-3: 开始计算的部分:startValue = 501;endValue = 625
ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 0;endValue = 125
ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 751;endValue = 875
ForkJoinPool-1-worker-3: 开始计算的部分:startValue = 626;endValue = 750
ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 876;endValue = 1000
ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 126;endValue = 250
ForkJoinPool-1-worker-0: 开始计算的部分:startValue = 251;endValue = 375
ForkJoinPool-1-worker-2: 开始计算的部分:startValue = 376;endValue = 500
500500

结论:我们从线程名称就可以看到,我全部都是使用的默认参数,一共启动了 4 个线程。每次大概分到 125 次计算

参考资料