LLM 流式通信之 SSE

寫在前面

SSE是LLM進行流式通信常用的技術方案, 下圖是 kimi 的示例

kimi回答時使用SSE

SSE 簡介

Server-Sent Events(SSE)是一種允許服務器向客戶端實時推送數(shù)據(jù)的技術。它基于HTTP協(xié)議,允許服務器通過一個持久的HTTP連接向客戶端發(fā)送事件流。以下是SSE的一些關鍵點:

  1. SSE的本質:SSE利用HTTP協(xié)議的流信息(streaming)特性,實現(xiàn)服務器向客戶端的單向通信。客戶端保持連接打開,等待服務器發(fā)送新的數(shù)據(jù)流。

  2. SSE的特點

    • 使用HTTP協(xié)議,現(xiàn)有的服務器軟件都支持。
    • 輕量級,使用簡單,與WebSocket相比,協(xié)議相對簡單。
    • 默認支持斷線重連,而WebSocket需要自己實現(xiàn)。
    • 一般只用來傳送文本數(shù)據(jù),二進制數(shù)據(jù)需要編碼后傳送。
    • 支持自定義發(fā)送的消息類型。
  3. 客戶端API

    • EventSource對象用于創(chuàng)建與服務器的連接并接收事件。
    • 通過監(jiān)聽message事件接收服務器發(fā)送的消息。
    • 可以監(jiān)聽自定義事件,不僅限于message事件。
  4. 服務器端發(fā)送事件

    • 服務器端腳本需要使用text/event-streamMIME類型響應內容。
    • 每個通知以文本塊形式發(fā)送,并以一對換行符結尾。
    • 消息由字段組成,包括event、dataidretry等。
  5. 事件流格式

    • 事件流是一個簡單的文本數(shù)據(jù)流,使用UTF-8編碼。
    • 消息由一對換行符分開,以冒號開頭的行為注釋行,會被忽略。
    • 每條消息由一行或多行文字組成,列出該消息的字段。
  6. 瀏覽器兼容性

    • SSE在現(xiàn)代瀏覽器中得到了廣泛支持,除了IE/Edge外,其他瀏覽器如Firefox、Chrome、Safari等都支持SSE。

SSE適用于需要服務器向客戶端單向實時推送數(shù)據(jù)的場景,如實時通知、股票行情、新聞推送等。它是一種有效降低服務器負載和網(wǎng)絡資源消耗的技術,通過服務器主動向客戶端發(fā)送更新事件,實現(xiàn)實時通信。

py 中使用 SSE

  • py 中異步: async + await
  • py 中流式接收 SSE: httpx
  • py 中流式返回 SSE: from fastapi.responses import StreamingResponse as FastapiStreamingResponse
  • 路由定義
@router.post("/stream", tags=["chat"])
async def streaming_chat(
    params: QuestionParams, current_user: TokenData = Depends(get_current_user)
):
    if not params.user_id:
        params.user_id = current_user.uid
    async_generator = RetrievalController().stream_answer(params)
    return StreamingResponse(async_generator)
  • 流式輸出定義
from typing import Mapping

from fastapi.responses import StreamingResponse as FastapiStreamingResponse
from starlette.background import BackgroundTask
from starlette.responses import ContentStream


class StreamingResponse(FastapiStreamingResponse):
    def __init__(
        self,
        content: ContentStream,
        status_code: int = 200,
        headers: Mapping[str, str] | None = None,
        media_type: str | None = None,
        background: BackgroundTask | None = None,
    ) -> None:
        default_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
        default_headers.update(headers or {})
        super().__init__(content, status_code, default_headers, media_type, background)
  • 流式接收并流式返回
    @LogDecorate(
        func_name="retrieval_controller::process_stream_answer", raise_exc=True
    )
    async def stream_answer(self, params: QuestionParams, model: int = 1):
        """
        :param model: 1-8B 2-32B
        """
        session_id = params.session_id
        if params.new_session:
            session_id = str(uuid.uuid1()).replace("-", "")
        request_body = dict(
            messages=msgs,
            user_id=params.user_id,
        )
        stream_answer_api = f"{AI_DOMAIN}{STREAM_ANSWER_API}"

        answer = ""
        # 流式接收
        async with httpx.AsyncClient() as client:
            async with client.stream(
                "POST",
                stream_answer_api,
                json=request_body,
                timeout=60,
                headers=dict(trace_id=get_req_ctx("trace_id")),
            ) as response:
                async for chunk in response.aiter_text():
                    answer += chunk
                    yield self.get_yield_data(
                        {"content": chunk, "create_at": int(time.time() * 1000)}
                    )

        yield self.get_yield_data("[DONE]")
        yield self.get_yield_data({"session_id": session_id})
        yield self.get_yield_data("[END]")

        # 落庫
        await user_qa_dao.save_user_qa(params.q, answer, session_id, params.user_id)

Go中使用SSE

使用 https://github.com/hertz-contrib/sse

import (
    "context"
    "encoding/json"
    "fmt"
    "time"

    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/common/hlog"
    "github.com/google/uuid"
    "github.com/hertz-contrib/sse"
    "github.com/spf13/cast"
)

func ChatStream(ctx context.Context, c *app.RequestContext) {
    u := ctl.CtxUser(c)

    var req struct {
        Query string `form:"query" json:"query"`
        Model int    `form:"model" json:"model"`
        Sid   string `form:"sid" json:"sid"` // session id
    }
    if err := c.BindAndValidate(&req); err != nil {
        utils.RespErr(c, err)
        return
    }

    // 聊天消息支持多輪對話
    var sid string
    if req.Sid != "" {
        sid = req.Sid
    } else {
        sid = uuid.New().String()
    }
    msg := chat.SaveUserMsg(ctx, sid, req.Query)
    content := &chat.Content{
        Messages: msg,
        UserId:   cast.ToString(u.ID),
        UserName: u.Name,
    }
    b, _ := json.Marshal(content)

    // https://github.com/hertz-contrib/sse/blob/main/examples/client/quickstart/main.go
    cli := sse.NewClient(conf.GetConf().Dev.AIDomain + "xxx")
    cli.SetMethod("POST")
    cli.SetHeaders(map[string]string{"Content-Type": "application/json", "trace_id": httpx.TraceId()})
    cli.SetBody(b)

    var ans, allAns string // AI 返回內容
    var flag bool          // reply正文標識
    events := make(chan *sse.Event)
    errChan := make(chan error)
    s := sse.NewStream(c)
    go func() {
        cErr := cli.Subscribe(func(msg *sse.Event) {
            if msg != nil && msg.Data != nil {
                events <- msg
                return
            }
        })
        errChan <- cErr
    }()
    for {
        select {
        case e := <-events:
            m := map[string]any{}
            _ = json.Unmarshal(e.Data, &m)
            if v, ok := m["content"]; ok {
                allAns += v.(string)
                if flag {
                    ans += v.(string)
                }
                if v == "__REPLY_START__" {
                    flag = true
                }
                da := map[string]any{
                    "content":   v,
                    "create_at": time.Now().Unix(),
                }
                jsonData, _ := json.Marshal(da)
                hlog.Info("publish event data = %s", string(jsonData))
                _ = s.Publish(&sse.Event{Data: jsonData})
            } else {
                hlog.Info("invalid event data = %s", string(e.Data))
            }
        case err := <-errChan:
            if err != nil {
                hlog.CtxErrorf(context.Background(), "err = %s", err.Error())
            }
            chat.SaveAssistantMsg(ctx, sid, ans, msg)
            chat.SaveQA(u.ID, sid, req.Query, allAns)
            _ = s.Publish(&sse.Event{Data: []byte("[DONE]")})
            _ = s.Publish(&sse.Event{Data: []byte(fmt.Sprintf(`{"session_id": "%s"}`, sid))})
            _ = s.Publish(&sse.Event{Data: []byte("[END]")})
            hlog.Info("cli get all event")
            return
        }
    }
}

寫在最后

需要注意的點

  • py 使用 httpx 接收 SSE 流式數(shù)據(jù), 對數(shù)據(jù)結構沒有要求, 比如 SSE event 常見的 data: xxx, 可以不帶 data 標識返回
  • go 中使用 https://github.com/hertz-contrib/sse 接收 SSE 流式數(shù)據(jù)
    • 底層會解析 SSE 數(shù)據(jù)格式, 需要判斷 data 標識, 如果沒有, 會導致解析失敗
    • 如果數(shù)據(jù)包含 \n換行, 也會導致數(shù)據(jù)解析失敗, 比較簡單的做法 data: json 格式數(shù)據(jù)
// go 中對應 SSE 庫數(shù)據(jù)解析源碼
func (c *Client) processEvent(msg []byte) (event *Event, err error) {
    var e Event

    if len(msg) < 1 {
        return nil, fmt.Errorf("event message was empty")
    }

    // Normalize the crlf to lf to make it easier to split the lines.
    // Split the line by "\n" or "\r", per the spec.
    for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
        switch {
        case bytes.HasPrefix(line, headerID):
            e.ID = string(append([]byte(nil), trimHeader(len(headerID), line)...))
        case bytes.HasPrefix(line, headerData):
            // The spec allows for multiple data fields per event, concatenated them with "\n".
            e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
        // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
        case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
            e.Data = append(e.Data, byte('\n'))
        case bytes.HasPrefix(line, headerEvent):
            e.Event = string(append([]byte(nil), trimHeader(len(headerEvent), line)...))
        case bytes.HasPrefix(line, headerRetry):
            e.Retry, err = strconv.ParseUint(b2s(append([]byte(nil), trimHeader(len(headerRetry), line)...)), 10, 64)
            if err != nil {
                return nil, fmt.Errorf("process message `retry` failed, err is %s", err)
            }
        default:
            // Ignore any garbage that doesn't match what we're looking for.
        }
    }

    // Trim the last "\n" per the spec.
    e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))

    if c.encodingBase64 {
        buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))

        n, err := base64.StdEncoding.Decode(buf, e.Data)
        if err != nil {
            err = fmt.Errorf("failed to decode event message: %s", err)
            return &e, err
        }
        e.Data = buf[:n]
    }
    return &e, err
}
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容