今天看小明大神的博客:深入理解asyncio(三) 里面有段將同步函數(shù)改為協(xié)程使用的代碼。其中提到了run_in_executor,主要使用這個(gè)方法將同步變?yōu)楫惒健?/p>
我們先看下如何將一個(gè)同步函數(shù)變?yōu)楫惒降?/p>
In [35]: import time
In [36]: import asyncio
In [37]: def a():
...: time.sleep(1)
...: return 'A'
...:
In [38]: async def c():
...: loop = asyncio.get_running_loop()
...: return await loop.run_in_executor(None, a)
...:
In [39]: asyncio.run(c())
Out[39]: 'A'
上面使用run_in_executor可以將同步函數(shù)a以協(xié)程的方式執(zhí)行,我們看下源碼
def run_in_executor(self, executor, func, *args):
self._check_closed()
if self._debug:
self._check_callback(func, 'run_in_executor')
if executor is None:
executor = self._default_executor
if executor is None:
executor = concurrent.futures.ThreadPoolExecutor()
self._default_executor = executor
return futures.wrap_future(
executor.submit(func, *args), loop=self)
我們看到 當(dāng)沒(méi)有設(shè)置 executor的時(shí)候 會(huì)默認(rèn)使用concurrent.futures.ThreadPoolExecutor() 那我們自己設(shè)置一下試試看。
In [40]: from concurrent.futures import ThreadPoolExecutor
In [41]: thread_executor = ThreadPoolExecutor(5)
In [44]: async def c():
...: loop = asyncio.get_running_loop()
...:
...: return await loop.run_in_executor(thread_executor, a)
...:
In [45]: asyncio.run(c())
Out[45]: 'A'
正確輸出
還有其他方法實(shí)現(xiàn)嗎?
我們看到 run_in_executor的源碼進(jìn)行了loop是否關(guān)閉的校驗(yàn)和是否是debug的判斷以及executor驗(yàn)空和賦值。假設(shè)我們不去做這些操作的話,直接使用ThreadPoolExecutor是否可以呢?
In [1]: from concurrent.futures import ThreadPoolExecutor
In [2]: import time
In [3]: import asyncio
In [4]: def a():
...: time.sleep(1)
...: return 'A'
...:
In [5]: thread_executor = ThreadPoolExecutor(5)
In [6]: async def c():
...: future = thread_executor.submit(a)
...: return await asyncio.wrap_future(future)
...:
In [7]: asyncio.run(c())
Out[7]: 'A'
正確輸出
上面的thread_executor.submit返回的是一個(gè)future對(duì)象,但是并不是一個(gè)符合asyncio模塊的future是不可等待的,即無(wú)法調(diào)用await去等待該對(duì)象。
代碼中的wrap_future是一個(gè)比較關(guān)鍵的函數(shù),看下源碼
def wrap_future(future, *, loop=None):
"""Wrap concurrent.futures.Future object."""
if isfuture(future):
return future
assert isinstance(future, concurrent.futures.Future), \
f'concurrent.futures.Future is expected, got {future!r}'
if loop is None:
loop = events.get_event_loop()
new_future = loop.create_future()
_chain_future(future, new_future)
return new_future
結(jié)合著使用到的_chain_future源碼一起看
def _chain_future(source, destination):
"""Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination.
If destination is cancelled, source gets cancelled too.
Compatible with both asyncio.Future and concurrent.futures.Future.
"""
if not isfuture(source) and not isinstance(source,
concurrent.futures.Future):
raise TypeError('A future is required for source argument')
if not isfuture(destination) and not isinstance(destination,
concurrent.futures.Future):
raise TypeError('A future is required for destination argument')
source_loop = _get_loop(source) if isfuture(source) else None
dest_loop = _get_loop(destination) if isfuture(destination) else None
def _set_state(future, other):
if isfuture(future):
_copy_future_state(other, future)
else:
_set_concurrent_future_state(future, other)
def _call_check_cancel(destination):
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)
def _call_set_state(source):
if (destination.cancelled() and
dest_loop is not None and dest_loop.is_closed()):
return
if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source)
else:
dest_loop.call_soon_threadsafe(_set_state, destination, source)
destination.add_done_callback(_call_check_cancel)
source.add_done_callback(_call_set_state)
通過(guò)wrap_future函數(shù)可以將concurrent.futures.Future變成asyncio.Future實(shí)現(xiàn)可等待。
這樣我們可以不用顯示的獲取當(dāng)前的loop也可以直接去將同步函數(shù)變成協(xié)程去執(zhí)行了。
裝飾器模式 將同步函數(shù)變?yōu)楫惒椒绞?/h4>
import asyncio
import functools
from concurrent.futures import ThreadPoolExecutor
class ThreadPool():
def __init__(self, max_workers):
self._thread_pool = ThreadPoolExecutor(max_workers)
async def run(self, _callable, *args, **kwargs):
future = self._thread_pool.submit(_callable, *args, **kwargs)
return await asyncio.wrap_future(future)
class ThreadWorker:
def __init__(self, max_workers):
self._thread_pool = ThreadPool(max_workers)
def __call__(self, func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
return self._thread_pool.run(func, *args, **kwargs)
return _wrapper
thread_worker = ThreadWorker(32)
@thread_worker
def some_io_block():
return 1
asyncio.run(some_io_block())
# 輸出 1
import asyncio
import functools
from concurrent.futures import ThreadPoolExecutor
class ThreadPool():
def __init__(self, max_workers):
self._thread_pool = ThreadPoolExecutor(max_workers)
async def run(self, _callable, *args, **kwargs):
future = self._thread_pool.submit(_callable, *args, **kwargs)
return await asyncio.wrap_future(future)
class ThreadWorker:
def __init__(self, max_workers):
self._thread_pool = ThreadPool(max_workers)
def __call__(self, func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
return self._thread_pool.run(func, *args, **kwargs)
return _wrapper
thread_worker = ThreadWorker(32)
@thread_worker
def some_io_block():
return 1
asyncio.run(some_io_block())
# 輸出 1
其中的 max_workers 參數(shù)就是能夠執(zhí)行的最大線程數(shù)。
再深入看下源碼
上面我們知道可直接使用ThreadPoolExecutor的submit方法獲取到future對(duì)象。通過(guò)源碼來(lái)分析下具體的流程。
class ThreadPoolExecutor(_base.Executor):
# Used to assign unique thread names when thread_name_prefix is not supplied.
_counter = itertools.count().__next__
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=()):
if max_workers is None:
# Use this number because ThreadPoolExecutor is often
# used to overlap I/O instead of CPU work.
max_workers = (os.cpu_count() or 1) * 5
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
if initializer is not None and not callable(initializer):
raise TypeError("initializer must be a callable")
self._max_workers = max_workers
self._work_queue = queue.SimpleQueue()
self._threads = set()
self._broken = False
self._shutdown = False
self._shutdown_lock = threading.Lock()
self._thread_name_prefix = (thread_name_prefix or
("ThreadPoolExecutor-%d" % self._counter()))
self._initializer = initializer
self._initargs = initargs
def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._broken:
raise BrokenThreadPool(self._broken)
if self._shutdown:
raise RuntimeError('cannot schedule new futures after shutdown')
if _shutdown:
raise RuntimeError('cannot schedule new futures after '
'interpreter shutdown')
f = _base.Future() # 這個(gè) Future 就是 concurrent.futures 里的 Future
w = _WorkItem(f, fn, args, kwargs)
self._work_queue.put(w)
self._adjust_thread_count()
return f
submit.__doc__ = _base.Executor.submit.__doc__
我們看到在submit函數(shù)里面主要是生成了一個(gè)concurrent.futures 里的 Future對(duì)象。然后一個(gè)_WorkItem實(shí)例。接著將生成的_WorkItem實(shí)例放到了一個(gè)隊(duì)列里面,然后執(zhí)行self._adjust_thread_count() 函數(shù)。這個(gè)_WorkItem實(shí)例是什么呢?self._adjust_thread_count()函數(shù)里面是什么呢?我們看下源碼。
class _WorkItem(object):
def __init__(self, future, fn, args, kwargs):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs
def run(self):
if not self.future.set_running_or_notify_cancel():
return
try:
result = self.fn(*self.args, **self.kwargs)
except BaseException as exc:
self.future.set_exception(exc)
# Break a reference cycle with the exception 'exc'
self = None
else:
self.future.set_result(result)
我們發(fā)現(xiàn)在_WorkItem類中的run方法執(zhí)行了真正的同步函數(shù) 并將執(zhí)行結(jié)果或者異常放到了之前生成的future對(duì)象中。
那這個(gè)run方法什么時(shí)候真正執(zhí)行呢?我們返回看self._adjust_thread_count()的源碼:
def _adjust_thread_count(self):
# When the executor gets lost, the weakref callback will wake up
# the worker threads.
def weakref_cb(_, q=self._work_queue):
q.put(None)
# TODO(bquinlan): Should avoid creating new threads if there are more
# idle threads than items in the work queue.
num_threads = len(self._threads)
if num_threads < self._max_workers:
thread_name = '%s_%d' % (self._thread_name_prefix or self,
num_threads)
t = threading.Thread(name=thread_name, target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue,
self._initializer,
self._initargs))
t.daemon = True
t.start()
self._threads.add(t)
_threads_queues[t] = self._work_queue
我們看到在_adjust_thread_count方法中生成了一個(gè)線程并且去執(zhí)行了,執(zhí)行的函數(shù)是_worker。這個(gè)_worker是什么呢?
def _worker(executor_reference, work_queue, initializer, initargs):
if initializer is not None:
try:
initializer(*initargs)
except BaseException:
_base.LOGGER.critical('Exception in initializer:', exc_info=True)
executor = executor_reference()
if executor is not None:
executor._initializer_failed()
return
try:
while True:
work_item = work_queue.get(block=True)
if work_item is not None:
work_item.run()
# Delete references to object. See issue16284
del work_item
continue
executor = executor_reference()
# Exit if:
# - The interpreter is shutting down OR
# - The executor that owns the worker has been collected OR
# - The executor that owns the worker has been shutdown.
if _shutdown or executor is None or executor._shutdown:
# Flag the executor as shutting down as early as possible if it
# is not gc-ed yet.
if executor is not None:
executor._shutdown = True
# Notice other workers
work_queue.put(None)
return
del executor
except BaseException:
_base.LOGGER.critical('Exception in worker', exc_info=True)
我們看到在這個(gè)函數(shù)中會(huì)獲取到之前放到隊(duì)列里面的_WorkItem的實(shí)例。然后執(zhí)行_WorkItem里面的run方法。
這樣我們整個(gè)過(guò)程就完整了,成功將一個(gè)同步函數(shù)變成異步方式執(zhí)行。