本文主要学习了多线程的分支合并框架。
部分内容来自以下博客:
https://segmentfault.com/a/1190000016781127
https://segmentfault.com/a/1190000016877931
1 简介
JDK1.7版本引入了一套Fork/Join框架。Fork/Join框架的基本思想就是将一个大任务分解(Fork)成一系列子任务,子任务可以继续往下分解,当多个不同的子任务都执行完成后,可以将它们各自的结果合并(Join)成一个大结果,最终合并成大任务的结果。
Fork/Join 框架要完成两件事情:
1)Fork:把一个复杂任务进行分拆
2)Join:把分拆任务的结果进行合并
Fork/Join框架的实现非常复杂,内部大量运用了位操作和无锁算法。
Fork/Join框架内部还涉及到三大核心组件:ForkJoinPool(线程池)、ForkJoinTask(任务)、ForkJoinWorkerThread(工作线程),外加WorkQueue(任务队列)。
2 类和接口
2.1 ForkJoinPool
ForkJoinPool是分支合并池,类似于线程池ThreadPoolExecutor,同样是ExecutorService接口的一个实现类。
ForkJoinPool类的实现:
1 public class ForkJoinPool extends AbstractExecutorService {
在ForkJoinPool类中提供了三个构造方法:
1 public ForkJoinPool();
2 public ForkJoinPool(int parallelism);
3 public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode);
最终调用的是下面这个私有构造器:
1 private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix);
其参数含义如下:
parallelism:并行级别,默认值为CPU核心数,ForkJoinPool里工作线程数量与该参数有关,但它不表示最大线程数。
factory:工作线程工厂,默认是DefaultForkJoinWorkerThreadFactory,其实就是用来创建ForkJoinWorkerThread工作线程对象。
handler:异常处理器。
mode:调度模式,true表示FIFO_QUEUE,false表示LIFO_QUEUE。
workerNamePrefix:工作线程的名称前缀。
2.2 ForkJoinTask
ForkJoinTask是Future接口的抽象实现类,提供了用于分解任务的fork()方法和用于合并任务的join()方法。
在ThreadPoolExecutor类中,使用线程池执行任务调用的execute()方法中要求传入Runnable接口的实例。但是在ForkJoinPool类中,除了可以传入Runnable接口的实例外,还可以传入ForkJoinTask抽象类的实例,并且传入Runnable接口的实例也会被适配为ForkJoinTask抽象类的实例。
2.3 RecursiveTask
通常情况下使用ForkJoinTask抽象类的实例,并不需要直接继承ForkJoinTask类,只需要继承其子类:
1)RecursiveAction:用于没有返回结果的任务
2)RecursiveTask:用于有返回结果的任务
其中,最常用的还是RecursiveTask类。
2.4 ForkJoinWorkerThread
ForkJoinWorkerThread类是Thread的子类,作为线程池中的工作线程执行任务,其内部维护了一个WorkerQueue类型的双向任务队列。
工作线程在执行任务时,优先处理自身任务队列中的任务(FIFO或者LIFO),当自身队列中的任务为空时,会窃取其他任务队列中的任务(FIFO)。
2.5 WorkerQueue
WorkerQueue类是ForkJoinPool类的一个内部类,代表存储ForkJoinTask实例的双端队列。
在ForkJoinPool类的私有构造方法中,有一个int类型的mode参数,其取值如下:
1 static final int LIFO_QUEUE = 0;
2 static final int FIFO_QUEUE = 1 << 16;
当入参为LIFO_QUEUE时,表示同步,对于工作线程(Worker)自身队列中的任务,采用后进先出(LIFO)的方式执行。
当入参为FIFO_QUEUE时,表示异步,对于工作线程(Worker)自身队列中的任务,采用先进先出(FIFO)的方式执行。
3 实现原理
3.1 提交任务
使用ForkJoinPool的submit方法提交任务得到ForkJoinTask对象:
1 public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
2 if (task == null)
3 throw new NullPointerException();
4 externalPush(task);
5 return task;
6 }
继续查看externalPush方法:
1 final void externalPush(ForkJoinTask<?> task) {
2 WorkQueue[] ws; WorkQueue q; int m;
3 int r = ThreadLocalRandom.getProbe();
4 int rs = runState;
5 if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
6 (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
7 U.compareAndSwapInt(q, QLOCK, 0, 1)) {
8 ForkJoinTask<?>[] a; int am, n, s;
9 if ((a = q.array) != null &&
10 (am = a.length - 1) > (n = (s = q.top) - q.base)) {
11 int j = ((am & s) << ASHIFT) + ABASE;
12 U.putOrderedObject(a, j, task);
13 U.putOrderedInt(q, QTOP, s + 1);
14 U.putIntVolatile(q, QLOCK, 0);
15 if (n <= 1)
16 signalWork(ws, q);
17 return;
18 }
19 U.compareAndSwapInt(q, QLOCK, 1, 0);
20 }
21 externalSubmit(task);
22 }
该方法包含两个部分:
1)尝试将任务添加到任务队列,添加后则创建或激活一个工作线程,在此过程中使用了CAS保证线程安全。
2)添加队列失败,则调用externalSubmit方法初始化队列,并将任务加入到队列。
3.2 分解任务
3.2.1 创建或唤醒工作线程
调用ForkJoinTask的fork方法完成任务分解:
1 public final ForkJoinTask<V> fork() {
2 Thread t;
3 if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)// 调用线程为工作线程
4 ((ForkJoinWorkerThread)t).workQueue.push(this);// 将任务添加到自身队列
5 else
6 ForkJoinPool.common.externalPush(this);// 调用ForkJoinPool的externalPush方法
7 return this;
8 }
1)调用线程为工作线程,将任务添加到自身队列。
2)调用线程为其他外部线程,继续调用ForkJoinPool的externalPush方法,尝试将任务添加到任务队列并激活工作线程。
继续查看push方法,添加任务到自身队列:
1 final void push(ForkJoinTask<?> task) {
2 ForkJoinTask<?>[] a; ForkJoinPool p;
3 int b = base, s = top, n;
4 if ((a = array) != null) { // ignore if queue removed
5 int m = a.length - 1; // fenced write for task visibility
6 U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
7 U.putOrderedInt(this, QTOP, s + 1);
8 if ((n = s - b) <= 1) {
9 if ((p = pool) != null)
10 p.signalWork(p.workQueues, this);// 唤醒或创建工作线程
11 }
12 else if (n >= m)
13 growArray();// 扩容
14 }
15 }
1)判断是否需要扩容,不需要扩容则唤醒或创建工作线程。
2)需要扩容,则进行扩容操作。
继续查看signalWork方法,创建或唤醒工作线程:
1 final void signalWork(WorkQueue[] ws, WorkQueue q) {
2 long c; int sp, i; WorkQueue v; Thread p;
3 while ((c = ctl) < 0L) { // too few active
4 if ((sp = (int)c) == 0) { // 没有空闲工作进程
5 if ((c & ADD_WORKER) != 0L) // 工作进程太少
6 tryAddWorker(c);// 增加工作进程
7 break;
8 }
9 // 有工作进程,唤醒
10 if (ws == null) // unstarted/terminated
11 break;
12 if (ws.length <= (i = sp & SMASK)) // terminated
13 break;
14 if ((v = ws[i]) == null) // terminating
15 break;
16 int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
17 int d = sp - v.scanState; // screen CAS
18 long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
19 if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
20 v.scanState = vs; // activate v
21 if ((p = v.parker) != null)
22 U.unpark(p);
23 break;
24 }
25 if (q != null && q.base == q.top) // no more work
26 break;
27 }
28 }
继续查看tryAddWorker方法:
1 private void tryAddWorker(long c) {
2 boolean add = false;
3 do {
4 // 设置活跃工作线程数和总工作线程数
5 long nc = ((AC_MASK & (c + AC_UNIT)) |
6 (TC_MASK & (c + TC_UNIT)));
7 if (ctl == c) {
8 int rs, stop; // check if terminating
9 if ((stop = (rs = lockRunState()) & STOP) == 0)
10 add = U.compareAndSwapLong(this, CTL, c, nc);
11 unlockRunState(rs, rs & ~RSLOCK);
12 if (stop != 0)
13 break;
14 if (add) {
15 // 创建工作线程
16 createWorker();
17 break;
18 }
19 }
20 } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
21 }
继续查看createWorker方法:
1 private boolean createWorker() {
2 ForkJoinWorkerThreadFactory fac = factory;
3 Throwable ex = null;
4 ForkJoinWorkerThread wt = null;
5 try {
6 // 使用线程池工厂创建线程
7 if (fac != null && (wt = fac.newThread(this)) != null) {
8 // 启动线程
9 wt.start();
10 return true;
11 }
12 } catch (Throwable rex) {
13 ex = rex;
14 }
15 // 出现异常,注销该工作线程
16 deregisterWorker(wt, ex);
17 return false;
18 }
3.2.2 启动任务
ForkJoinWorkerThread在执行start方法后,会执行run方法:
1 public void run() {
2 if (workQueue.array == null) { // only run once
3 Throwable exception = null;
4 try {
5 onStart();
6 pool.runWorker(workQueue);
7 } catch (Throwable ex) {
8 exception = ex;
9 } finally {
10 try {
11 onTermination(exception);
12 } catch (Throwable ex) {
13 if (exception == null)
14 exception = ex;
15 } finally {
16 pool.deregisterWorker(this, exception);
17 }
18 }
19 }
20 }
在run方法内部调用了ForkJoinPool对象的runWorker方法:
1 final void runWorker(WorkQueue w) {
2 w.growArray(); // 初始化任务队列
3 int seed = w.hint; // initially holds randomization hint
4 int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
5 for (ForkJoinTask<?> t;;) {
6 if ((t = scan(w, r)) != null)// 尝试获取任务
7 w.runTask(t);// 执行任务
8 else if (!awaitWork(w, r))// 获取失败,加入等待任务队列
9 break;// 等待失败,跳出方法并注销工作线程
10 r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
11 }
12 }
3.2.3 窃取任务
使用scan方法窃取任务:
1 private ForkJoinTask<?> scan(WorkQueue w, int r) {
2 WorkQueue[] ws; int m;
3 if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
4 int ss = w.scanState; // initially non-negative
5 for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
6 WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
7 int b, n; long c;
8 if ((q = ws[k]) != null) {// 定位任务队列
9 if ((n = (b = q.base) - q.top) < 0 &&
10 (a = q.array) != null) { // non-empty
11 long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
12 if ((t = ((ForkJoinTask<?>)
13 U.getObjectVolatile(a, i))) != null &&
14 q.base == b) {
15 if (ss >= 0) {
16 if (U.compareAndSwapObject(a, i, t, null)) {
17 q.base = b + 1;
18 if (n < -1) // signal others
19 signalWork(ws, q);// 创建获唤醒工作线程执行任务
20 return t;
21 }
22 }
23 else if (oldSum == 0 && // try to activate
24 w.scanState < 0)
25 tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);// 唤醒栈顶工作线程
26 }
27 if (ss < 0) // refresh
28 ss = w.scanState;
29 r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
30 origin = k = r & m; // move and rescan
31 oldSum = checkSum = 0;
32 continue;
33 }
34 checkSum += b;
35 }
36 // 已扫描全部工作线程,但并未找到任务
37 if ((k = (k + 1) & m) == origin) { // continue until stable
38 if ((ss >= 0 || (ss == (ss = w.scanState))) &&
39 oldSum == (oldSum = checkSum)) {
40 if (ss < 0 || w.qlock < 0) // already inactive
41 break;
42 int ns = ss | INACTIVE; // 尝试对当前工作线程灭活
43 long nc = ((SP_MASK & ns) |
44 (UC_MASK & ((c = ctl) - AC_UNIT)));
45 w.stackPred = (int)c; // hold prev stack top
46 U.putInt(w, QSCANSTATE, ns);
47 if (U.compareAndSwapLong(this, CTL, c, nc))
48 ss = ns;
49 else
50 w.scanState = ss; // back out
51 }
52 checkSum = 0;
53 }
54 }
55 }
56 return null;
57 }
3.2.4 执行任务
窃取到任务后,调用runTask方法执行任务:
1 final void runTask(ForkJoinTask<?> task) {
2 if (task != null) {
3 scanState &= ~SCANNING; // mark as busy
4 (currentSteal = task).doExec();// 执行任务
5 U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
6 execLocalTasks();// 执行本地任务
7 ForkJoinWorkerThread thread = owner;
8 if (++nsteals < 0) // collect on overflow
9 transferStealCount(pool);// 增加窃取任务数
10 scanState |= SCANNING;
11 if (thread != null)
12 thread.afterTopLevelExec();// 执行钩子函数
13 }
14 }
3.2.5 阻塞等待
如何未窃取到任务,会调用awaitWork方法等待获取任务:
1 private boolean awaitWork(WorkQueue w, int r) {
2 if (w == null || w.qlock < 0) // w is terminating
3 return false;
4 for (int pred = w.stackPred, spins = SPINS, ss;;) {
5 if ((ss = w.scanState) >= 0)
6 break;
7 else if (spins > 0) {
8 r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
9 if (r >= 0 && --spins == 0) { // randomize spins
10 WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
11 if (pred != 0 && (ws = workQueues) != null &&
12 (j = pred & SMASK) < ws.length &&
13 (v = ws[j]) != null && // see if pred parking
14 (v.parker == null || v.scanState >= 0))
15 spins = SPINS; // continue spinning
16 }
17 }
18 else if (w.qlock < 0) // recheck after spins
19 return false;
20 else if (!Thread.interrupted()) {
21 long c, prevctl, parkTime, deadline;
22 int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
23 if ((ac <= 0 && tryTerminate(false, false)) ||
24 (runState & STOP) != 0) // pool terminating
25 return false;
26 if (ac <= 0 && ss == (int)c) { // is last waiter
27 prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
28 int t = (short)(c >>> TC_SHIFT); // shrink excess spares
29 if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
30 return false; // else use timed wait
31 parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
32 deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
33 }
34 else
35 prevctl = parkTime = deadline = 0L;
36 Thread wt = Thread.currentThread();
37 U.putObject(wt, PARKBLOCKER, this); // emulate LockSupport
38 w.parker = wt;
39 if (w.scanState < 0 && ctl == c) // recheck before park
40 U.park(false, parkTime);
41 U.putOrderedObject(w, QPARKER, null);
42 U.putObject(wt, PARKBLOCKER, null);
43 if (w.scanState >= 0)
44 break;
45 if (parkTime != 0L && ctl == c &&
46 deadline - System.nanoTime() <= 0L &&
47 U.compareAndSwapLong(this, CTL, c, prevctl))
48 return false; // shrink pool
49 }
50 }
51 return true;
52 }
3.3 合并任务
使用ForkJoinTask的join方法可以获取任务的执行结果:
1 public final V join() {
2 int s;
3 if ((s = doJoin() & DONE_MASK) != NORMAL)
4 reportException(s);
5 return getRawResult();
6 }
查看doJoin方法:
1 private int doJoin() {
2 int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
3 return (s = status) < 0 ? s :
4 ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
5 (w = (wt = (ForkJoinWorkerThread)t).workQueue).
6 tryUnpush(this) && (s = doExec()) < 0 ? s :
7 wt.pool.awaitJoin(w, this, 0L) :
8 externalAwaitDone();
9 }
4 使用
4.1 计算多个整数的和
任务类定义,因为需要返回结果,所以继承RecursiveTask,并覆写compute方法。
1 class SumTask extends RecursiveTask<Integer> {
2 private static final int THRESHOLD = 10;// 拆分阈值
3 private int begin;// 拆分开始值
4 private int end;// 拆分结束值
5 public SumTask(int begin, int end) {
6 this.begin = begin;
7 this.end = end;
8 }
9 @Override
10 protected Integer compute() {
11 Integer value = 0;
12 if (end - begin <= THRESHOLD) {// 小于阈值,直接计算
13 for (int i = begin; i <= end; i++) {
14 value += i;
15 }
16 } else {// 大于阈值,递归计算
17 int middle = (begin + end) / 2;
18 SumTask beginTask = new SumTask(begin, middle);
19 SumTask endTask = new SumTask(middle + 1, end);
20 beginTask.fork();
21 endTask.fork();
22 value = beginTask.join() + endTask.join();
23 }
24 return value;
25 }
26 }
27 public class DemoTest {
28 public static void main(String[] args) {
29 SumTask sumTask = new SumTask(1, 100);
30 ForkJoinPool pool = new ForkJoinPool();
31 try {
32 ForkJoinTask<Integer> task = pool.submit(sumTask);
33 System.out.println(task.get());
34 } catch (Exception e) {
35 e.printStackTrace();
36 } finally {
37 pool.shutdown();
38 }
39 }
40 }