Java CountDownLatch、CyclicBarrier、Semaphore源碼分析(基于API 29 JDK8)

CountDownLatch

在多線程的情況下,主線程需要等待子線程執(zhí)行完畢之后才能進(jìn)行接下來(lái)的操作,在CountDownLatch出現(xiàn)之前,一般通過(guò)join來(lái)實(shí)現(xiàn),但是join不夠靈活,不能滿足豐富場(chǎng)景下的需求,所以CountDownLatch類誕生了。舉個(gè)例子:

public class ExampleTest {
    @Test
    public void main() {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
        try {
            latch.await();
            System.out.println("All Thread finish");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private CountDownLatch latch = new CountDownLatch(3);
    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName()+" : finish work");
            latch.countDown();
        }
    };
}

/***************************************
輸出:
pool-1-thread-1 : finish work
pool-1-thread-2 : finish work
pool-1-thread-3 : finish work
All Thread finish
***************************************/

接下來(lái)看CountDownLatch的源碼,從構(gòu)造函數(shù)開始:

public class CountDownLatch {
    ...
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    ...
}

只有帶一個(gè)int參數(shù)的構(gòu)造函數(shù),這個(gè)Sync類是它的一個(gè)內(nèi)部類,查看源碼:

private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        // 構(gòu)造函數(shù)調(diào)用了setState,是父類的方法
        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // 循環(huán)進(jìn)行CAS,直到當(dāng)前線程成功完成CAS是計(jì)數(shù)器值(state)減1并更新到state
            for (;;) {
                // 獲取volatile變量的state
                int c = getState();
                // state為0直接返回
                if (c == 0)
                    return false;
                int nextc = c - 1;
                // cas讓state-1
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

先看一下AbstractQueuedSynchronizer的setState/getState方法做了些什么:

public abstract class AbstractQueuedSynchronizer
    extends AbstractOwnableSynchronizer
    implements java.io.Serializable {
    ...
    private volatile int state;
    protected final void setState(int newState) {
        state = newState;
    }
    protected final int getState() {
        return state;
    }
    ...
}

從上面可以看到CountDownLatch的初始化設(shè)置了一個(gè)volatile的變量state,接下來(lái)看countDown方法做了什么:

    // CountDownLatch.java
    public void countDown() {
        sync.releaseShared(1);
    }

    // AbstractQueuedSynchronizer.java
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            // 釋放資源
            doReleaseShared();
            return true;
        }
        return false;
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

可以看到releaseShared是調(diào)用了tryReleaseShared,在循環(huán)進(jìn)行CAS,直到當(dāng)前線程成功完成CAS是計(jì)數(shù)器值(state)減1并更新到state,CAS成功之后調(diào)用doReleaseShared釋放資源。之后看await方法做了些什么:

    // CountDownLatch.java
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    // AbstractQueuedSynchronizer.java
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        // state不為0的時(shí)候進(jìn)入等待隊(duì)列
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        try {
            // 阻塞
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    // 當(dāng)state為0時(shí),返回
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return;
                    }
                }
                // 當(dāng)節(jié)點(diǎn)獲取失敗或者中斷的時(shí)候拋出異常
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

調(diào)用的Sync的acquireSharedInterruptibly方法,當(dāng)state不為0的時(shí)候,進(jìn)入doAcquireSharedInterruptibly阻塞,當(dāng)state為0時(shí)返回,或中斷時(shí)拋出異常。再看帶了超時(shí)參數(shù)的await方法:

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }

    private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (nanosTimeout <= 0L)
            return false;
        final long deadline = System.nanoTime() + nanosTimeout;
        final Node node = addWaiter(Node.SHARED);
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return true;
                    }
                }
                nanosTimeout = deadline - System.nanoTime();
                if (nanosTimeout <= 0L) {
                    cancelAcquire(node);
                    return false;
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    nanosTimeout > SPIN_FOR_TIMEOUT_THRESHOLD)
                    LockSupport.parkNanos(this, nanosTimeout);
                if (Thread.interrupted())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

與await無(wú)參方法不一樣的是,doAcquireSharedNanos方法多了一個(gè)nanosTimeout參數(shù),當(dāng)nanosTimeout小于0的時(shí)候,釋放資源并返回false,程序主線程將繼續(xù)運(yùn)行。

CyclicBarrier

從上面的分析可以看到,當(dāng)CountDownLatch執(zhí)行countDown到state為0的時(shí)候,就結(jié)束了,沒(méi)有重置的辦法,因此CyclicBarrier來(lái)了,CyclicBarrier是回環(huán)屏障的意思,它可以讓一組線程全部達(dá)到一個(gè)狀態(tài)后再全部同步執(zhí)行。這里之所以叫回環(huán)是因?yàn)楫?dāng)所有等待線程執(zhí)行完畢,重置CyclicBarrier的狀態(tài)后可以被重用。寫一個(gè)測(cè)試用例:

public class ExampleTest {
    @Test
    public void main() {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
    }
    
    private CyclicBarrier barrier = new CyclicBarrier(3, new Runnable() {
        @Override
        public void run() {
            System.out.println(Thread.currentThread().getName() + " : finish work");
        }
    });

    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            try {
                System.out.println(Thread.currentThread().getName() + " : start work");
                barrier.await();
                System.out.println(Thread.currentThread().getName() + " : barrier out 1");
                barrier.await();
                System.out.println(Thread.currentThread().getName() + " : barrier out 2");
            } catch (BrokenBarrierException | InterruptedException e) {
                e.printStackTrace();
            }
        }
    };
}

/***************************************
輸出:
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
pool-1-thread-3 : finish work
pool-1-thread-3 : barrier out 1
pool-1-thread-2 : barrier out 1
pool-1-thread-1 : barrier out 1
pool-1-thread-1 : finish work
pool-1-thread-1 : barrier out 2
pool-1-thread-2 : barrier out 2
pool-1-thread-3 : barrier out 2
***************************************/

測(cè)試用例新建了一個(gè)CyclicBarrier對(duì)象,傳遞參數(shù)為計(jì)數(shù)器初始值和當(dāng)計(jì)數(shù)器為0時(shí)執(zhí)行的runnable。一開始計(jì)數(shù)器的值為3,當(dāng)?shù)谝粋€(gè)線程調(diào)用await方法時(shí),計(jì)數(shù)器減1,此時(shí)計(jì)數(shù)器不為0,線程阻塞,直到3個(gè)線程全部執(zhí)行await,計(jì)數(shù)器為0,最后一個(gè)進(jìn)入await的線程執(zhí)行CyclicBarrier中的runnable,執(zhí)行完畢后結(jié)束阻塞,并喚醒其他線程,執(zhí)行完barrier out 1的任務(wù)之后再次阻塞在await方法,這是單個(gè)CountDownLatch無(wú)法完成的。分析源碼實(shí)現(xiàn),首先是構(gòu)造函數(shù):

public class CyclicBarrier {
    private final ReentrantLock lock = new ReentrantLock();
    private Generation generation = new Generation();
    private final Condition trip = lock.newCondition();

    private final Runnable barrierCommand;
    private final int parties;
    private int count;

    private static class Generation {
        boolean broken;         // initially false
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }
}

初始化了parties、count和barrierCommand,接著看await方法:

    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }
 
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        // 獲取ReentrantLock
        final ReentrantLock lock = this.lock;
        // 上鎖
        lock.lock();
        try {
            // 默認(rèn)為false
            final Generation g = generation;

            if (g.broken)
                throw new BrokenBarrierException();
            // 中斷
            if (Thread.interrupted()) {
                // 結(jié)束回環(huán)并拋出異常
                breakBarrier();
                throw new InterruptedException();
            }
            // count自減1
            int index = --count;
            if (index == 0) {  // tripped
                // 執(zhí)行barrierCommand
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    // 更新狀態(tài)并喚醒所有處于鎖定狀態(tài)的線程
                    nextGeneration();
                    return 0;
                } finally {
                    // 如果沒(méi)有正常返回,則結(jié)束回環(huán)
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // 循環(huán)直到count=0,屏障破壞,中斷,或者超時(shí)
            for (;;) {
                try {
                    if (!timed)
                        // 阻塞,直到收到通知
                        trip.await();
                    else if (nanos > 0L)
                        // 阻塞,直到超時(shí)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    // 發(fā)生中斷
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();

                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            // 解鎖
            lock.unlock();
        }
    }

    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

從上面可以看到,當(dāng)進(jìn)入await的時(shí)候,會(huì)通過(guò)ReentrantLock對(duì)代碼段進(jìn)行一個(gè)上鎖,操作count自減1,當(dāng)減為0的時(shí)候,執(zhí)行barrierCommand并調(diào)用trip.signalAll()來(lái)喚醒所有阻塞中的線程,并將count重新初始化為parties,否則在后續(xù)中通過(guò)調(diào)用trip.await()或者trip.awaitNanos(nanos)進(jìn)入阻塞狀態(tài)并釋放鎖,直到收到通知信號(hào)加入鎖的競(jìng)爭(zhēng)中,獲取到鎖之后在finally中釋放鎖,其他線程依次如此,最終所有線程往下繼續(xù)運(yùn)行。

Semaphore

Semaphore也是java的一個(gè)同步器,與CountDownLatch、CyclicBarrier不同的是,它不需要在初始化的時(shí)候指定同步線程的個(gè)數(shù),而是在需要同步的地方調(diào)用acquire方法時(shí),指定需要同步的線程數(shù)。

public class ExampleTest {
    private Semaphore semaphore = new Semaphore(0);

    @Test
    public void main() throws InterruptedException {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
        System.out.println(Thread.currentThread().getName() + " : acquire");
        semaphore.acquire(3);
        System.out.println(Thread.currentThread().getName() + " : release");
    }
    
    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            System.out.println(Thread.currentThread().getName() + " : start work");
            semaphore.release();
        }
    };
}
/***************************************
輸出:
main : acquire
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
main : release
***************************************/

上述代碼創(chuàng)建了一個(gè)Semaphore實(shí)例,構(gòu)造函數(shù)傳參為0,說(shuō)明當(dāng)前信號(hào)量的計(jì)數(shù)器的值為0,然后在main中向線程池添加了3個(gè)線程任務(wù),在子線程中調(diào)用release方法,在main的最后調(diào)用acquire方法,傳入?yún)?shù)為線程數(shù)3,之后進(jìn)入阻塞狀態(tài),等待信號(hào)量的計(jì)數(shù)變?yōu)?。接下來(lái)看源碼,從構(gòu)造函數(shù)開始:

public class Semaphore implements java.io.Serializable {
    private final Sync sync;
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
}

從上面可以看到Semaphore提供了兩種構(gòu)造方法,其中permits用于初始化信號(hào)量,fair用于確定sync的實(shí)例類型,默認(rèn)是非公平,Sync是內(nèi)部類繼承于AbstractQueuedSynchronizer ,F(xiàn)airSync和NonfairSync是Sync的子類。

    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;

        Sync(int permits) {
            setState(permits);
        }

        final int getPermits() {
            return getState();
        }

        final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }

        protected final boolean tryReleaseShared(int releases) {
            for (;;) {
                int current = getState();
                int next = current + releases;
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                if (compareAndSetState(current, next))
                    return true;
            }
        }

        final void reducePermits(int reductions) {
            for (;;) {
                int current = getState();
                int next = current - reductions;
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                if (compareAndSetState(current, next))
                    return;
            }
        }

        final int drainPermits() {
            for (;;) {
                int current = getState();
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }

    /**
     * NonFair version
     */
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        NonfairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    /**
     * Fair version
     */
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;

        FairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            for (;;) {
                if (hasQueuedPredecessors())
                    return -1;
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }

AbstractQueuedSynchronizer的內(nèi)容之前已經(jīng)分析過(guò)了,先從Semaphore的acquire方法入手,

    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        // 調(diào)用sync子類方法嘗試獲取,默認(rèn)使用非公平策略
        if (tryAcquireShared(arg) < 0)
            // 如果獲取失敗則添加到阻塞隊(duì)列,然后再次嘗試,繼續(xù)失敗則調(diào)用part方法掛起當(dāng)前線程
            doAcquireSharedInterruptibly(arg);
    }
    // 非公平策略
    protected int tryAcquireShared(int acquires) {
        return nonfairTryAcquireShared(acquires);
    }

    final int nonfairTryAcquireShared(int acquires) {
        for (;;) {
            // 當(dāng)前信號(hào)量
            int available = getState();
            // 剩余值
            int remaining = available - acquires;
            // 如果當(dāng)前剩余值小于0或者CAS設(shè)置成功則返回
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }

    // 公平策略
    protected int tryAcquireShared(int acquires) {
        for (;;) {
            // 查看當(dāng)前線程的前驅(qū)節(jié)點(diǎn)是否也在等待獲取資源
            // 如果是則放棄獲取并加入AQS阻塞隊(duì)列,否則就去獲取資源
            if (hasQueuedPredecessors())
                return -1;
            int available = getState();
            int remaining = available - acquires;
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }

    public final boolean hasQueuedPredecessors() {
        Node t = tail; // Read fields in reverse initialization order
        Node h = head;
        Node s;
        return h != t &&
            ((s = h.next) == null || s.thread != Thread.currentThread());
    }

acquire會(huì)調(diào)用tryAcquireShared,該方法在構(gòu)造函數(shù)中存在公平與非公平兩種策略,其中非公平策略下線程會(huì)直接嘗試獲取資源,而公平策略通過(guò)hasQueuedPredecessors節(jié)點(diǎn)的前節(jié)點(diǎn)是否也在等待獲取資源,如果有前節(jié)點(diǎn)則放棄獲取并加入阻塞隊(duì)列,否則通過(guò)獲取信號(hào)量,并計(jì)算差值,如果差值小于0或者CAS操作state成功時(shí)返回,從上面分析不難發(fā)現(xiàn)Semaphore也是支持回環(huán)的,每次調(diào)用acquire會(huì)更新信號(hào)量,相當(dāng)于CyclicBarrier中將信號(hào)量重新初始化為備份過(guò)的初始值一樣。接下來(lái)看release方法:

    public void release() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        // 嘗試釋放資源
        if (tryReleaseShared(arg)) {
            // 資源釋放成功則調(diào)用part方法喚醒AQS隊(duì)列里面最先掛起的線程
            doReleaseShared();
            return true;
        }
        return false;
    }

    protected final boolean tryReleaseShared(int releases) {
        for (;;) {
            // 當(dāng)前信號(hào)量
            int current = getState();
            // 信號(hào)量增加release
            int next = current + releases;
            // releases<0的情況
            if (next < current) // overflow
                throw new Error("Maximum permit count exceeded");
            // CAS更新信號(hào)量,保證原子性
            if (compareAndSetState(current, next))
                return true;
        }
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

從上面可以看到release方法是調(diào)用releaseShared方法并傳參1,在releaseShared方法中調(diào)用tryReleaseShared方法來(lái)釋放資源,通過(guò)CAS操作更新信號(hào)量,釋放資源成功后,調(diào)用doReleaseShared方法喚醒AQS隊(duì)列中最先掛起的線程,結(jié)束release。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容