from concurrent.futures import ThreadPoolExecutor
from tornado.concurrent import run_on_executor
import asyncio
class ChatHandler(tornado_factory("rest")):
'''
流式問答
'''
executor = ThreadPoolExecutor(max_workers=4) # 創(chuàng)建線程池
async def post(self):
data = self.get_json()
self.set_header("Content-Type", "text/event-stream")
self.set_header("Cache-Control", "no-cache")
queue = asyncio.Queue() # 創(chuàng)建一個隊列用于數(shù)據(jù)傳遞
asyncio.create_task(self.run_prediction(data, queue))
# 在主線程中不斷從隊列中取數(shù)據(jù)并發(fā)送
while not self.request.connection.stream.closed():
result = await queue.get()
if result is None:
break
self.write(result + '\n')
await self.flush()
async def run_prediction(self, data, queue):
# 將同步生成器放入線程池執(zhí)行,逐步將生成結(jié)果放入隊列
loop = asyncio.get_running_loop()
def blocking_predict():
# predict 是同步生成器,逐步生成預(yù)測結(jié)果
for result in predict(data):
# 使用 run_coroutine_threadsafe 將結(jié)果安全地放入隊列
asyncio.run_coroutine_threadsafe(queue.put(result), loop)
await loop.run_in_executor(self.executor, blocking_predict)
# 放入 None 表示生成器結(jié)束
await queue.put(None)
class ChatNostreamHandler(tornado_factory("rest")):
'''
非流式問答
'''
executor = ThreadPoolExecutor(max_workers=4) # 創(chuàng)建線程池
async def post(self):
data = self.get_json()
await self.run_prediction(data)
@run_on_executor
def run_prediction(self, data):
try:
result = ''
for char in predict(data):
result += char
return self.json_response(status="OK", result=result)
except:
return self.json_response(status="Failed", code=500)
最后編輯于 :
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。