ONNX模型是什么?

好的,我來(lái)詳細(xì)解釋ONNX模型以及如何在Java和C#中與它們交互。這是在不依賴Python的情況下使用開(kāi)源嵌入模型的關(guān)鍵技術(shù)。


一、ONNX模型詳解

什么是ONNX?

ONNX 代表 Open Neural Network Exchange,是一個(gè)開(kāi)放的格式標(biāo)準(zhǔn),用于表示深度學(xué)習(xí)模型。它讓模型可以在不同的框架之間移植。

ONNX的核心價(jià)值:

  1. 框架互操作性

    • 你可以在PyTorch、TensorFlow、Scikit-learn中訓(xùn)練模型
    • 然后導(dǎo)出為ONNX格式
    • 最后在C#、Java、C++等環(huán)境中運(yùn)行
  2. 性能優(yōu)化

    • ONNX Runtime是針對(duì)ONNX模型的高度優(yōu)化推理引擎
    • 支持CPU、GPU(CUDA、DirectML)、TensorRT等后端
    • 提供線程安全、內(nèi)存高效的計(jì)算
  3. 部署友好

    • 單個(gè).onnx文件包含完整的模型架構(gòu)和權(quán)重
    • 無(wú)需原始訓(xùn)練框架的依賴
    • 跨平臺(tái)支持(Windows、Linux、macOS)

ONNX模型的工作流程:

[PyTorch/TensorFlow訓(xùn)練模型] → [導(dǎo)出為ONNX格式] → [ONNX Runtime加載] → [各種語(yǔ)言調(diào)用]

二、使用ONNX模型的準(zhǔn)備工作

1. 獲取ONNX模型

方式一:從Hugging Face直接下載

# 許多流行模型都提供了ONNX版本
# 例如:BAAI/bge-small-zh-v1.5 的ONNX版本

方式二:從PyTorch/TensorFlow模型轉(zhuǎn)換

# Python代碼示例:將PyTorch模型轉(zhuǎn)換為ONNX
import torch
from transformers import AutoModel, AutoTokenizer
import onnxruntime as ort

model_name = "BAAI/bge-small-zh-v1.5"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 示例輸入
dummy_input = tokenizer("hello world", return_tensors="pt")

# 導(dǎo)出為ONNX
torch.onnx.export(
    model,
    tuple(dummy_input.values()),
    "bge-small-zh.onnx",
    input_names=['input_ids', 'attention_mask'],
    output_names=['last_hidden_state'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
    },
    opset_version=14
)

2. 環(huán)境依賴

Java環(huán)境:

<!-- Maven依賴 -->
<dependencies>
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime</artifactId>
        <version>1.17.0</version>
    </dependency>
    <!-- 如果需要GPU支持 -->
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime_gpu</artifactId>
        <version>1.17.0</version>
    </dependency>
</dependencies>

C#環(huán)境:

<!-- 項(xiàng)目文件配置 -->
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.17.1" />
<!-- 如果需要GPU支持 -->
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="1.17.1" />

3. 模型預(yù)處理知識(shí)

需要了解原始模型的:

  • 輸入格式:輸入張量的形狀、數(shù)據(jù)類型
  • 預(yù)處理要求:tokenization、normalization等
  • 輸出格式:如何從輸出中提取嵌入向量

三、Java與ONNX模型交互

完整示例:文本嵌入模型

import ai.onnxruntime.*;
import java.util.*;
import java.nio.*;

public class OnnxEmbeddingModel {
    private OrtEnvironment environment;
    private OrtSession session;
    private Map<String, Integer> vocab;
    
    public OnnxEmbeddingModel(String modelPath) throws OrtException {
        // 1. 初始化ONNX Runtime環(huán)境
        this.environment = OrtEnvironment.getEnvironment();
        
        // 2. 創(chuàng)建會(huì)話選項(xiàng)(可配置CPU/GPU)
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
        // sessionOptions.addCPU(true); // 使用CPU
        // sessionOptions.addCUDA(0);   // 使用GPU
        
        // 3. 加載ONNX模型
        this.session = environment.createSession(modelPath, sessionOptions);
        
        // 4. 加載詞匯表(需要根據(jù)具體模型準(zhǔn)備)
        this.vocab = loadVocabulary();
    }
    
    public float[] embedText(String text) throws OrtException {
        // 1. 文本預(yù)處理和tokenization
        int[] tokenIds = tokenizeText(text);
        
        // 2. 創(chuàng)建輸入張量
        long[] shape = {1, tokenIds.length}; // [batch_size, sequence_length]
        
        OnnxTensor inputTensor = OnnxTensor.createTensor(
            environment,
            IntBuffer.wrap(tokenIds),
            shape
        );
        
        // 3. 準(zhǔn)備輸入Map
        Map<String, OnnxTensor> inputs = new HashMap<>();
        inputs.put("input_ids", inputTensor);
        
        // 如果需要attention mask
        long[] attentionShape = {1, tokenIds.length};
        int[] attentionMask = new int[tokenIds.length];
        Arrays.fill(attentionMask, 1);
        
        OnnxTensor attentionTensor = OnnxTensor.createTensor(
            environment,
            IntBuffer.wrap(attentionMask),
            attentionShape
        );
        inputs.put("attention_mask", attentionTensor);
        
        // 4. 運(yùn)行推理
        try (OrtSession.Result results = session.run(inputs)) {
            // 5. 獲取輸出
            OnnxTensor outputTensor = (OnnxTensor) results.get(0);
            float[][] embeddings = (float[][]) outputTensor.getValue();
            
            // 6. 處理輸出(例如:取[CLS] token或平均池化)
            return poolEmbeddings(embeddings[0]);
        } finally {
            // 7. 清理資源
            inputTensor.close();
            attentionTensor.close();
        }
    }
    
    private int[] tokenizeText(String text) {
        // 簡(jiǎn)化的tokenization - 實(shí)際需要根據(jù)模型實(shí)現(xiàn)
        String[] tokens = text.toLowerCase().split(" ");
        int[] tokenIds = new int[tokens.length];
        
        for (int i = 0; i < tokens.length; i++) {
            tokenIds[i] = vocab.getOrDefault(tokens[i], 0); // 0 for UNK
        }
        
        // 截?cái)嗷蛱畛涞侥P推谕拈L(zhǎng)度
        int maxLength = 512;
        if (tokenIds.length > maxLength) {
            tokenIds = Arrays.copyOf(tokenIds, maxLength);
        }
        
        return tokenIds;
    }
    
    private float[] poolEmbeddings(float[] tokenEmbeddings) {
        // 簡(jiǎn)單的平均池化
        int embeddingDim = 384; // 根據(jù)模型輸出維度調(diào)整
        float[] sentenceEmbedding = new float[embeddingDim];
        
        // 這里應(yīng)該是2D數(shù)組的處理,簡(jiǎn)化示例
        // 實(shí)際需要根據(jù)模型輸出結(jié)構(gòu)處理
        
        return sentenceEmbedding;
    }
    
    private Map<String, Integer> loadVocabulary() {
        // 加載詞匯表文件
        Map<String, Integer> vocabMap = new HashMap<>();
        // 這里應(yīng)該從文件加載真實(shí)的詞匯表
        vocabMap.put("hello", 1);
        vocabMap.put("world", 2);
        // ... 加載完整的詞匯表
        return vocabMap;
    }
    
    public void close() throws OrtException {
        if (session != null) {
            session.close();
        }
        if (environment != null) {
            environment.close();
        }
    }
    
    // 使用示例
    public static void main(String[] args) {
        try {
            OnnxEmbeddingModel model = new OnnxEmbeddingModel("path/to/embedding_model.onnx");
            
            float[] embedding = model.embedText("這是一個(gè)測(cè)試句子");
            System.out.println("嵌入向量維度: " + embedding.length);
            System.out.println("前5個(gè)值: " + 
                Arrays.toString(Arrays.copyOf(embedding, 5)));
            
            model.close();
        } catch (OrtException e) {
            e.printStackTrace();
        }
    }
}

四、C#與ONNX模型交互

完整示例:文本嵌入模型

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

public class OnnxEmbeddingModel : IDisposable
{
    private InferenceSession _session;
    private Dictionary<string, int> _vocab;
    
    public OnnxEmbeddingModel(string modelPath)
    {
        // 1. 創(chuàng)建會(huì)話選項(xiàng)
        var sessionOptions = new SessionOptions();
        
        // 配置執(zhí)行提供商
        sessionOptions.AppendExecutionProvider_CPU(); // 使用CPU
        // sessionOptions.AppendExecutionProvider_CUDA(0); // 使用GPU
        // sessionOptions.AppendExecutionProvider_DML(0); // 使用DirectML (Windows)
        
        // 2. 加載ONNX模型
        _session = new InferenceSession(modelPath, sessionOptions);
        
        // 3. 加載詞匯表
        _vocab = LoadVocabulary();
    }
    
    public float[] EmbedText(string text)
    {
        // 1. 文本預(yù)處理和tokenization
        var tokenIds = TokenizeText(text);
        var attentionMask = Enumerable.Repeat(1, tokenIds.Length).ToArray();
        
        // 2. 創(chuàng)建輸入張量
        var inputIdsTensor = new DenseTensor<int>(tokenIds, new[] { 1, tokenIds.Length });
        var attentionMaskTensor = new DenseTensor<int>(attentionMask, new[] { 1, attentionMask.Length });
        
        // 3. 準(zhǔn)備輸入
        var inputs = new List<NamedOnnxValue>
        {
            NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
            NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
        };
        
        // 4. 運(yùn)行推理
        using (var results = _session.Run(inputs))
        {
            // 5. 獲取輸出
            var outputTensor = results.First().AsTensor<float>();
            var embeddings = outputTensor.ToArray();
            
            // 6. 處理輸出(平均池化)
            return PoolEmbeddings(embeddings, tokenIds.Length);
        }
    }
    
    private int[] TokenizeText(string text)
    {
        // 簡(jiǎn)化的tokenization - 實(shí)際需要根據(jù)具體模型實(shí)現(xiàn)
        var tokens = text.ToLower().Split(' ', StringSplitOptions.RemoveEmptyEntries);
        var tokenIds = new List<int>();
        
        foreach (var token in tokens)
        {
            if (_vocab.TryGetValue(token, out int tokenId))
            {
                tokenIds.Add(tokenId);
            }
            else
            {
                tokenIds.Add(0); // UNK token
            }
        }
        
        // 截?cái)嗷蛱畛?        int maxLength = 512;
        if (tokenIds.Count > maxLength)
        {
            tokenIds = tokenIds.Take(maxLength).ToList();
        }
        
        return tokenIds.ToArray();
    }
    
    private float[] PoolEmbeddings(float[] tokenEmbeddings, int sequenceLength)
    {
        int embeddingDim = tokenEmbeddings.Length / sequenceLength;
        var sentenceEmbedding = new float[embeddingDim];
        
        // 平均池化:對(duì)每個(gè)token的對(duì)應(yīng)維度求平均
        for (int dim = 0; dim < embeddingDim; dim++)
        {
            float sum = 0;
            for (int tokenIdx = 0; tokenIdx < sequenceLength; tokenIdx++)
            {
                sum += tokenEmbeddings[tokenIdx * embeddingDim + dim];
            }
            sentenceEmbedding[dim] = sum / sequenceLength;
        }
        
        return sentenceEmbedding;
    }
    
    private Dictionary<string, int> LoadVocabulary()
    {
        // 從文件加載詞匯表
        var vocab = new Dictionary<string, int>
        {
            ["hello"] = 1,
            ["world"] = 2,
            // ... 加載完整的詞匯表
        };
        
        return vocab;
    }
    
    public void Dispose()
    {
        _session?.Dispose();
    }
}

// 使用示例
class Program
{
    static void Main()
    {
        using (var model = new OnnxEmbeddingModel("path/to/embedding_model.onnx"))
        {
            var embedding = model.EmbedText("這是一個(gè)測(cè)試句子");
            Console.WriteLine($"嵌入向量維度: {embedding.Length}");
            Console.WriteLine($"前5個(gè)值: [{string.Join(", ", embedding.Take(5))}]");
        }
    }
}

五、實(shí)際集成到RAGflow的示例

Java + LangChain4j + ONNX

public class OnnxEmbeddingService implements EmbeddingModel {
    private final OnnxEmbeddingModel onnxModel;
    
    public OnnxEmbeddingService(String modelPath) throws OrtException {
        this.onnxModel = new OnnxEmbeddingModel(modelPath);
    }
    
    @Override
    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        try {
            List<Embedding> embeddings = new ArrayList<>();
            
            for (TextSegment segment : textSegments) {
                float[] vector = onnxModel.embedText(segment.text());
                embeddings.add(Embedding.from(vector));
            }
            
            return Response.from(embeddings);
        } catch (OrtException e) {
            throw new RuntimeException("ONNX推理失敗", e);
        }
    }
}

// 在RAGflow中使用
EmbeddingModel embeddingModel = new OnnxEmbeddingService("bge-small-zh.onnx");
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();

// 存儲(chǔ)文檔
for (TextSegment segment : segments) {
    Embedding embedding = embeddingModel.embed(segment.text()).content();
    embeddingStore.add(embedding, segment);
}

六、最佳實(shí)踐和注意事項(xiàng)

1. 性能優(yōu)化

  • 批處理:一次性處理多個(gè)文本,減少推理調(diào)用次數(shù)
  • 內(nèi)存管理:及時(shí)釋放ONNX tensor資源
  • 會(huì)話復(fù)用:避免重復(fù)創(chuàng)建InferenceSession

2. 錯(cuò)誤處理

  • 檢查模型輸入輸出維度
  • 處理詞匯表外的token
  • 監(jiān)控內(nèi)存使用情況

3. 模型選擇

  • 選擇適合你硬件配置的模型大小
  • 考慮量化版本(INT8)以獲得更好性能
  • 測(cè)試不同池化策略的效果

通過(guò)ONNX,你可以在Java和C#環(huán)境中獲得接近Python原生的模型性能,同時(shí)享受這些語(yǔ)言在工程化方面的優(yōu)勢(shì)。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容