CountDownLatch
在多線程的情況下,主線程需要等待子線程執(zhí)行完畢之后才能進(jìn)行接下來(lái)的操作,在CountDownLatch出現(xiàn)之前,一般通過(guò)join來(lái)實(shí)現(xiàn),但是join不夠靈活,不能滿足豐富場(chǎng)景下的需求,所以CountDownLatch類誕生了。舉個(gè)例子:
public class ExampleTest {
@Test
public void main() {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
try {
latch.await();
System.out.println("All Thread finish");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
private CountDownLatch latch = new CountDownLatch(3);
private Runnable runnable = new Runnable() {
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName()+" : finish work");
latch.countDown();
}
};
}
/***************************************
輸出:
pool-1-thread-1 : finish work
pool-1-thread-2 : finish work
pool-1-thread-3 : finish work
All Thread finish
***************************************/
接下來(lái)看CountDownLatch的源碼,從構(gòu)造函數(shù)開始:
public class CountDownLatch {
...
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
...
}
只有帶一個(gè)int參數(shù)的構(gòu)造函數(shù),這個(gè)Sync類是它的一個(gè)內(nèi)部類,查看源碼:
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 構(gòu)造函數(shù)調(diào)用了setState,是父類的方法
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// 循環(huán)進(jìn)行CAS,直到當(dāng)前線程成功完成CAS是計(jì)數(shù)器值(state)減1并更新到state
for (;;) {
// 獲取volatile變量的state
int c = getState();
// state為0直接返回
if (c == 0)
return false;
int nextc = c - 1;
// cas讓state-1
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
先看一下AbstractQueuedSynchronizer的setState/getState方法做了些什么:
public abstract class AbstractQueuedSynchronizer
extends AbstractOwnableSynchronizer
implements java.io.Serializable {
...
private volatile int state;
protected final void setState(int newState) {
state = newState;
}
protected final int getState() {
return state;
}
...
}
從上面可以看到CountDownLatch的初始化設(shè)置了一個(gè)volatile的變量state,接下來(lái)看countDown方法做了什么:
// CountDownLatch.java
public void countDown() {
sync.releaseShared(1);
}
// AbstractQueuedSynchronizer.java
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
// 釋放資源
doReleaseShared();
return true;
}
return false;
}
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!h.compareAndSetWaitStatus(0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
可以看到releaseShared是調(diào)用了tryReleaseShared,在循環(huán)進(jìn)行CAS,直到當(dāng)前線程成功完成CAS是計(jì)數(shù)器值(state)減1并更新到state,CAS成功之后調(diào)用doReleaseShared釋放資源。之后看await方法做了些什么:
// CountDownLatch.java
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// AbstractQueuedSynchronizer.java
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// state不為0的時(shí)候進(jìn)入等待隊(duì)列
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
try {
// 阻塞
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
// 當(dāng)state為0時(shí),返回
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
return;
}
}
// 當(dāng)節(jié)點(diǎn)獲取失敗或者中斷的時(shí)候拋出異常
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} catch (Throwable t) {
cancelAcquire(node);
throw t;
}
}
調(diào)用的Sync的acquireSharedInterruptibly方法,當(dāng)state不為0的時(shí)候,進(jìn)入doAcquireSharedInterruptibly阻塞,當(dāng)state為0時(shí)返回,或中斷時(shí)拋出異常。再看帶了超時(shí)參數(shù)的await方法:
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (nanosTimeout <= 0L)
return false;
final long deadline = System.nanoTime() + nanosTimeout;
final Node node = addWaiter(Node.SHARED);
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
return true;
}
}
nanosTimeout = deadline - System.nanoTime();
if (nanosTimeout <= 0L) {
cancelAcquire(node);
return false;
}
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > SPIN_FOR_TIMEOUT_THRESHOLD)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} catch (Throwable t) {
cancelAcquire(node);
throw t;
}
}
與await無(wú)參方法不一樣的是,doAcquireSharedNanos方法多了一個(gè)nanosTimeout參數(shù),當(dāng)nanosTimeout小于0的時(shí)候,釋放資源并返回false,程序主線程將繼續(xù)運(yùn)行。
CyclicBarrier
從上面的分析可以看到,當(dāng)CountDownLatch執(zhí)行countDown到state為0的時(shí)候,就結(jié)束了,沒(méi)有重置的辦法,因此CyclicBarrier來(lái)了,CyclicBarrier是回環(huán)屏障的意思,它可以讓一組線程全部達(dá)到一個(gè)狀態(tài)后再全部同步執(zhí)行。這里之所以叫回環(huán)是因?yàn)楫?dāng)所有等待線程執(zhí)行完畢,重置CyclicBarrier的狀態(tài)后可以被重用。寫一個(gè)測(cè)試用例:
public class ExampleTest {
@Test
public void main() {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
}
private CyclicBarrier barrier = new CyclicBarrier(3, new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + " : finish work");
}
});
private Runnable runnable = new Runnable() {
@Override
public void run() {
try {
System.out.println(Thread.currentThread().getName() + " : start work");
barrier.await();
System.out.println(Thread.currentThread().getName() + " : barrier out 1");
barrier.await();
System.out.println(Thread.currentThread().getName() + " : barrier out 2");
} catch (BrokenBarrierException | InterruptedException e) {
e.printStackTrace();
}
}
};
}
/***************************************
輸出:
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
pool-1-thread-3 : finish work
pool-1-thread-3 : barrier out 1
pool-1-thread-2 : barrier out 1
pool-1-thread-1 : barrier out 1
pool-1-thread-1 : finish work
pool-1-thread-1 : barrier out 2
pool-1-thread-2 : barrier out 2
pool-1-thread-3 : barrier out 2
***************************************/
測(cè)試用例新建了一個(gè)CyclicBarrier對(duì)象,傳遞參數(shù)為計(jì)數(shù)器初始值和當(dāng)計(jì)數(shù)器為0時(shí)執(zhí)行的runnable。一開始計(jì)數(shù)器的值為3,當(dāng)?shù)谝粋€(gè)線程調(diào)用await方法時(shí),計(jì)數(shù)器減1,此時(shí)計(jì)數(shù)器不為0,線程阻塞,直到3個(gè)線程全部執(zhí)行await,計(jì)數(shù)器為0,最后一個(gè)進(jìn)入await的線程執(zhí)行CyclicBarrier中的runnable,執(zhí)行完畢后結(jié)束阻塞,并喚醒其他線程,執(zhí)行完barrier out 1的任務(wù)之后再次阻塞在await方法,這是單個(gè)CountDownLatch無(wú)法完成的。分析源碼實(shí)現(xiàn),首先是構(gòu)造函數(shù):
public class CyclicBarrier {
private final ReentrantLock lock = new ReentrantLock();
private Generation generation = new Generation();
private final Condition trip = lock.newCondition();
private final Runnable barrierCommand;
private final int parties;
private int count;
private static class Generation {
boolean broken; // initially false
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
}
初始化了parties、count和barrierCommand,接著看await方法:
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
// 獲取ReentrantLock
final ReentrantLock lock = this.lock;
// 上鎖
lock.lock();
try {
// 默認(rèn)為false
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
// 中斷
if (Thread.interrupted()) {
// 結(jié)束回環(huán)并拋出異常
breakBarrier();
throw new InterruptedException();
}
// count自減1
int index = --count;
if (index == 0) { // tripped
// 執(zhí)行barrierCommand
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
// 更新狀態(tài)并喚醒所有處于鎖定狀態(tài)的線程
nextGeneration();
return 0;
} finally {
// 如果沒(méi)有正常返回,則結(jié)束回環(huán)
if (!ranAction)
breakBarrier();
}
}
// 循環(huán)直到count=0,屏障破壞,中斷,或者超時(shí)
for (;;) {
try {
if (!timed)
// 阻塞,直到收到通知
trip.await();
else if (nanos > 0L)
// 阻塞,直到超時(shí)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
// 發(fā)生中斷
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
// 解鎖
lock.unlock();
}
}
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
從上面可以看到,當(dāng)進(jìn)入await的時(shí)候,會(huì)通過(guò)ReentrantLock對(duì)代碼段進(jìn)行一個(gè)上鎖,操作count自減1,當(dāng)減為0的時(shí)候,執(zhí)行barrierCommand并調(diào)用trip.signalAll()來(lái)喚醒所有阻塞中的線程,并將count重新初始化為parties,否則在后續(xù)中通過(guò)調(diào)用trip.await()或者trip.awaitNanos(nanos)進(jìn)入阻塞狀態(tài)并釋放鎖,直到收到通知信號(hào)加入鎖的競(jìng)爭(zhēng)中,獲取到鎖之后在finally中釋放鎖,其他線程依次如此,最終所有線程往下繼續(xù)運(yùn)行。
Semaphore
Semaphore也是java的一個(gè)同步器,與CountDownLatch、CyclicBarrier不同的是,它不需要在初始化的時(shí)候指定同步線程的個(gè)數(shù),而是在需要同步的地方調(diào)用acquire方法時(shí),指定需要同步的線程數(shù)。
public class ExampleTest {
private Semaphore semaphore = new Semaphore(0);
@Test
public void main() throws InterruptedException {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
System.out.println(Thread.currentThread().getName() + " : acquire");
semaphore.acquire(3);
System.out.println(Thread.currentThread().getName() + " : release");
}
private Runnable runnable = new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + " : start work");
semaphore.release();
}
};
}
/***************************************
輸出:
main : acquire
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
main : release
***************************************/
上述代碼創(chuàng)建了一個(gè)Semaphore實(shí)例,構(gòu)造函數(shù)傳參為0,說(shuō)明當(dāng)前信號(hào)量的計(jì)數(shù)器的值為0,然后在main中向線程池添加了3個(gè)線程任務(wù),在子線程中調(diào)用release方法,在main的最后調(diào)用acquire方法,傳入?yún)?shù)為線程數(shù)3,之后進(jìn)入阻塞狀態(tài),等待信號(hào)量的計(jì)數(shù)變?yōu)?。接下來(lái)看源碼,從構(gòu)造函數(shù)開始:
public class Semaphore implements java.io.Serializable {
private final Sync sync;
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
}
從上面可以看到Semaphore提供了兩種構(gòu)造方法,其中permits用于初始化信號(hào)量,fair用于確定sync的實(shí)例類型,默認(rèn)是非公平,Sync是內(nèi)部類繼承于AbstractQueuedSynchronizer ,F(xiàn)airSync和NonfairSync是Sync的子類。
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);
}
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
final void reducePermits(int reductions) {
for (;;) {
int current = getState();
int next = current - reductions;
if (next > current) // underflow
throw new Error("Permit count underflow");
if (compareAndSetState(current, next))
return;
}
}
final int drainPermits() {
for (;;) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
/**
* NonFair version
*/
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
/**
* Fair version
*/
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
AbstractQueuedSynchronizer的內(nèi)容之前已經(jīng)分析過(guò)了,先從Semaphore的acquire方法入手,
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 調(diào)用sync子類方法嘗試獲取,默認(rèn)使用非公平策略
if (tryAcquireShared(arg) < 0)
// 如果獲取失敗則添加到阻塞隊(duì)列,然后再次嘗試,繼續(xù)失敗則調(diào)用part方法掛起當(dāng)前線程
doAcquireSharedInterruptibly(arg);
}
// 非公平策略
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
// 當(dāng)前信號(hào)量
int available = getState();
// 剩余值
int remaining = available - acquires;
// 如果當(dāng)前剩余值小于0或者CAS設(shè)置成功則返回
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
// 公平策略
protected int tryAcquireShared(int acquires) {
for (;;) {
// 查看當(dāng)前線程的前驅(qū)節(jié)點(diǎn)是否也在等待獲取資源
// 如果是則放棄獲取并加入AQS阻塞隊(duì)列,否則就去獲取資源
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
public final boolean hasQueuedPredecessors() {
Node t = tail; // Read fields in reverse initialization order
Node h = head;
Node s;
return h != t &&
((s = h.next) == null || s.thread != Thread.currentThread());
}
acquire會(huì)調(diào)用tryAcquireShared,該方法在構(gòu)造函數(shù)中存在公平與非公平兩種策略,其中非公平策略下線程會(huì)直接嘗試獲取資源,而公平策略通過(guò)hasQueuedPredecessors節(jié)點(diǎn)的前節(jié)點(diǎn)是否也在等待獲取資源,如果有前節(jié)點(diǎn)則放棄獲取并加入阻塞隊(duì)列,否則通過(guò)獲取信號(hào)量,并計(jì)算差值,如果差值小于0或者CAS操作state成功時(shí)返回,從上面分析不難發(fā)現(xiàn)Semaphore也是支持回環(huán)的,每次調(diào)用acquire會(huì)更新信號(hào)量,相當(dāng)于CyclicBarrier中將信號(hào)量重新初始化為備份過(guò)的初始值一樣。接下來(lái)看release方法:
public void release() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
// 嘗試釋放資源
if (tryReleaseShared(arg)) {
// 資源釋放成功則調(diào)用part方法喚醒AQS隊(duì)列里面最先掛起的線程
doReleaseShared();
return true;
}
return false;
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
// 當(dāng)前信號(hào)量
int current = getState();
// 信號(hào)量增加release
int next = current + releases;
// releases<0的情況
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
// CAS更新信號(hào)量,保證原子性
if (compareAndSetState(current, next))
return true;
}
}
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!h.compareAndSetWaitStatus(0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
從上面可以看到release方法是調(diào)用releaseShared方法并傳參1,在releaseShared方法中調(diào)用tryReleaseShared方法來(lái)釋放資源,通過(guò)CAS操作更新信號(hào)量,釋放資源成功后,調(diào)用doReleaseShared方法喚醒AQS隊(duì)列中最先掛起的線程,結(jié)束release。