好的,我來(lái)詳細(xì)解釋ONNX模型以及如何在Java和C#中與它們交互。這是在不依賴Python的情況下使用開(kāi)源嵌入模型的關(guān)鍵技術(shù)。
一、ONNX模型詳解
什么是ONNX?
ONNX 代表 Open Neural Network Exchange,是一個(gè)開(kāi)放的格式標(biāo)準(zhǔn),用于表示深度學(xué)習(xí)模型。它讓模型可以在不同的框架之間移植。
ONNX的核心價(jià)值:
-
框架互操作性:
- 你可以在PyTorch、TensorFlow、Scikit-learn中訓(xùn)練模型
- 然后導(dǎo)出為ONNX格式
- 最后在C#、Java、C++等環(huán)境中運(yùn)行
-
性能優(yōu)化:
- ONNX Runtime是針對(duì)ONNX模型的高度優(yōu)化推理引擎
- 支持CPU、GPU(CUDA、DirectML)、TensorRT等后端
- 提供線程安全、內(nèi)存高效的計(jì)算
-
部署友好:
- 單個(gè)
.onnx文件包含完整的模型架構(gòu)和權(quán)重 - 無(wú)需原始訓(xùn)練框架的依賴
- 跨平臺(tái)支持(Windows、Linux、macOS)
- 單個(gè)
ONNX模型的工作流程:
[PyTorch/TensorFlow訓(xùn)練模型] → [導(dǎo)出為ONNX格式] → [ONNX Runtime加載] → [各種語(yǔ)言調(diào)用]
二、使用ONNX模型的準(zhǔn)備工作
1. 獲取ONNX模型
方式一:從Hugging Face直接下載
# 許多流行模型都提供了ONNX版本
# 例如:BAAI/bge-small-zh-v1.5 的ONNX版本
方式二:從PyTorch/TensorFlow模型轉(zhuǎn)換
# Python代碼示例:將PyTorch模型轉(zhuǎn)換為ONNX
import torch
from transformers import AutoModel, AutoTokenizer
import onnxruntime as ort
model_name = "BAAI/bge-small-zh-v1.5"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 示例輸入
dummy_input = tokenizer("hello world", return_tensors="pt")
# 導(dǎo)出為ONNX
torch.onnx.export(
model,
tuple(dummy_input.values()),
"bge-small-zh.onnx",
input_names=['input_ids', 'attention_mask'],
output_names=['last_hidden_state'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
},
opset_version=14
)
2. 環(huán)境依賴
Java環(huán)境:
<!-- Maven依賴 -->
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.0</version>
</dependency>
<!-- 如果需要GPU支持 -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.17.0</version>
</dependency>
</dependencies>
C#環(huán)境:
<!-- 項(xiàng)目文件配置 -->
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.17.1" />
<!-- 如果需要GPU支持 -->
<PackageReference Include="Microsoft.ML.OnnxRuntime.Gpu" Version="1.17.1" />
3. 模型預(yù)處理知識(shí)
需要了解原始模型的:
- 輸入格式:輸入張量的形狀、數(shù)據(jù)類型
- 預(yù)處理要求:tokenization、normalization等
- 輸出格式:如何從輸出中提取嵌入向量
三、Java與ONNX模型交互
完整示例:文本嵌入模型
import ai.onnxruntime.*;
import java.util.*;
import java.nio.*;
public class OnnxEmbeddingModel {
private OrtEnvironment environment;
private OrtSession session;
private Map<String, Integer> vocab;
public OnnxEmbeddingModel(String modelPath) throws OrtException {
// 1. 初始化ONNX Runtime環(huán)境
this.environment = OrtEnvironment.getEnvironment();
// 2. 創(chuàng)建會(huì)話選項(xiàng)(可配置CPU/GPU)
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
// sessionOptions.addCPU(true); // 使用CPU
// sessionOptions.addCUDA(0); // 使用GPU
// 3. 加載ONNX模型
this.session = environment.createSession(modelPath, sessionOptions);
// 4. 加載詞匯表(需要根據(jù)具體模型準(zhǔn)備)
this.vocab = loadVocabulary();
}
public float[] embedText(String text) throws OrtException {
// 1. 文本預(yù)處理和tokenization
int[] tokenIds = tokenizeText(text);
// 2. 創(chuàng)建輸入張量
long[] shape = {1, tokenIds.length}; // [batch_size, sequence_length]
OnnxTensor inputTensor = OnnxTensor.createTensor(
environment,
IntBuffer.wrap(tokenIds),
shape
);
// 3. 準(zhǔn)備輸入Map
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
// 如果需要attention mask
long[] attentionShape = {1, tokenIds.length};
int[] attentionMask = new int[tokenIds.length];
Arrays.fill(attentionMask, 1);
OnnxTensor attentionTensor = OnnxTensor.createTensor(
environment,
IntBuffer.wrap(attentionMask),
attentionShape
);
inputs.put("attention_mask", attentionTensor);
// 4. 運(yùn)行推理
try (OrtSession.Result results = session.run(inputs)) {
// 5. 獲取輸出
OnnxTensor outputTensor = (OnnxTensor) results.get(0);
float[][] embeddings = (float[][]) outputTensor.getValue();
// 6. 處理輸出(例如:取[CLS] token或平均池化)
return poolEmbeddings(embeddings[0]);
} finally {
// 7. 清理資源
inputTensor.close();
attentionTensor.close();
}
}
private int[] tokenizeText(String text) {
// 簡(jiǎn)化的tokenization - 實(shí)際需要根據(jù)模型實(shí)現(xiàn)
String[] tokens = text.toLowerCase().split(" ");
int[] tokenIds = new int[tokens.length];
for (int i = 0; i < tokens.length; i++) {
tokenIds[i] = vocab.getOrDefault(tokens[i], 0); // 0 for UNK
}
// 截?cái)嗷蛱畛涞侥P推谕拈L(zhǎng)度
int maxLength = 512;
if (tokenIds.length > maxLength) {
tokenIds = Arrays.copyOf(tokenIds, maxLength);
}
return tokenIds;
}
private float[] poolEmbeddings(float[] tokenEmbeddings) {
// 簡(jiǎn)單的平均池化
int embeddingDim = 384; // 根據(jù)模型輸出維度調(diào)整
float[] sentenceEmbedding = new float[embeddingDim];
// 這里應(yīng)該是2D數(shù)組的處理,簡(jiǎn)化示例
// 實(shí)際需要根據(jù)模型輸出結(jié)構(gòu)處理
return sentenceEmbedding;
}
private Map<String, Integer> loadVocabulary() {
// 加載詞匯表文件
Map<String, Integer> vocabMap = new HashMap<>();
// 這里應(yīng)該從文件加載真實(shí)的詞匯表
vocabMap.put("hello", 1);
vocabMap.put("world", 2);
// ... 加載完整的詞匯表
return vocabMap;
}
public void close() throws OrtException {
if (session != null) {
session.close();
}
if (environment != null) {
environment.close();
}
}
// 使用示例
public static void main(String[] args) {
try {
OnnxEmbeddingModel model = new OnnxEmbeddingModel("path/to/embedding_model.onnx");
float[] embedding = model.embedText("這是一個(gè)測(cè)試句子");
System.out.println("嵌入向量維度: " + embedding.length);
System.out.println("前5個(gè)值: " +
Arrays.toString(Arrays.copyOf(embedding, 5)));
model.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
}
四、C#與ONNX模型交互
完整示例:文本嵌入模型
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
public class OnnxEmbeddingModel : IDisposable
{
private InferenceSession _session;
private Dictionary<string, int> _vocab;
public OnnxEmbeddingModel(string modelPath)
{
// 1. 創(chuàng)建會(huì)話選項(xiàng)
var sessionOptions = new SessionOptions();
// 配置執(zhí)行提供商
sessionOptions.AppendExecutionProvider_CPU(); // 使用CPU
// sessionOptions.AppendExecutionProvider_CUDA(0); // 使用GPU
// sessionOptions.AppendExecutionProvider_DML(0); // 使用DirectML (Windows)
// 2. 加載ONNX模型
_session = new InferenceSession(modelPath, sessionOptions);
// 3. 加載詞匯表
_vocab = LoadVocabulary();
}
public float[] EmbedText(string text)
{
// 1. 文本預(yù)處理和tokenization
var tokenIds = TokenizeText(text);
var attentionMask = Enumerable.Repeat(1, tokenIds.Length).ToArray();
// 2. 創(chuàng)建輸入張量
var inputIdsTensor = new DenseTensor<int>(tokenIds, new[] { 1, tokenIds.Length });
var attentionMaskTensor = new DenseTensor<int>(attentionMask, new[] { 1, attentionMask.Length });
// 3. 準(zhǔn)備輸入
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor("input_ids", inputIdsTensor),
NamedOnnxValue.CreateFromTensor("attention_mask", attentionMaskTensor)
};
// 4. 運(yùn)行推理
using (var results = _session.Run(inputs))
{
// 5. 獲取輸出
var outputTensor = results.First().AsTensor<float>();
var embeddings = outputTensor.ToArray();
// 6. 處理輸出(平均池化)
return PoolEmbeddings(embeddings, tokenIds.Length);
}
}
private int[] TokenizeText(string text)
{
// 簡(jiǎn)化的tokenization - 實(shí)際需要根據(jù)具體模型實(shí)現(xiàn)
var tokens = text.ToLower().Split(' ', StringSplitOptions.RemoveEmptyEntries);
var tokenIds = new List<int>();
foreach (var token in tokens)
{
if (_vocab.TryGetValue(token, out int tokenId))
{
tokenIds.Add(tokenId);
}
else
{
tokenIds.Add(0); // UNK token
}
}
// 截?cái)嗷蛱畛? int maxLength = 512;
if (tokenIds.Count > maxLength)
{
tokenIds = tokenIds.Take(maxLength).ToList();
}
return tokenIds.ToArray();
}
private float[] PoolEmbeddings(float[] tokenEmbeddings, int sequenceLength)
{
int embeddingDim = tokenEmbeddings.Length / sequenceLength;
var sentenceEmbedding = new float[embeddingDim];
// 平均池化:對(duì)每個(gè)token的對(duì)應(yīng)維度求平均
for (int dim = 0; dim < embeddingDim; dim++)
{
float sum = 0;
for (int tokenIdx = 0; tokenIdx < sequenceLength; tokenIdx++)
{
sum += tokenEmbeddings[tokenIdx * embeddingDim + dim];
}
sentenceEmbedding[dim] = sum / sequenceLength;
}
return sentenceEmbedding;
}
private Dictionary<string, int> LoadVocabulary()
{
// 從文件加載詞匯表
var vocab = new Dictionary<string, int>
{
["hello"] = 1,
["world"] = 2,
// ... 加載完整的詞匯表
};
return vocab;
}
public void Dispose()
{
_session?.Dispose();
}
}
// 使用示例
class Program
{
static void Main()
{
using (var model = new OnnxEmbeddingModel("path/to/embedding_model.onnx"))
{
var embedding = model.EmbedText("這是一個(gè)測(cè)試句子");
Console.WriteLine($"嵌入向量維度: {embedding.Length}");
Console.WriteLine($"前5個(gè)值: [{string.Join(", ", embedding.Take(5))}]");
}
}
}
五、實(shí)際集成到RAGflow的示例
Java + LangChain4j + ONNX
public class OnnxEmbeddingService implements EmbeddingModel {
private final OnnxEmbeddingModel onnxModel;
public OnnxEmbeddingService(String modelPath) throws OrtException {
this.onnxModel = new OnnxEmbeddingModel(modelPath);
}
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
try {
List<Embedding> embeddings = new ArrayList<>();
for (TextSegment segment : textSegments) {
float[] vector = onnxModel.embedText(segment.text());
embeddings.add(Embedding.from(vector));
}
return Response.from(embeddings);
} catch (OrtException e) {
throw new RuntimeException("ONNX推理失敗", e);
}
}
}
// 在RAGflow中使用
EmbeddingModel embeddingModel = new OnnxEmbeddingService("bge-small-zh.onnx");
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
// 存儲(chǔ)文檔
for (TextSegment segment : segments) {
Embedding embedding = embeddingModel.embed(segment.text()).content();
embeddingStore.add(embedding, segment);
}
六、最佳實(shí)踐和注意事項(xiàng)
1. 性能優(yōu)化
- 批處理:一次性處理多個(gè)文本,減少推理調(diào)用次數(shù)
- 內(nèi)存管理:及時(shí)釋放ONNX tensor資源
- 會(huì)話復(fù)用:避免重復(fù)創(chuàng)建InferenceSession
2. 錯(cuò)誤處理
- 檢查模型輸入輸出維度
- 處理詞匯表外的token
- 監(jiān)控內(nèi)存使用情況
3. 模型選擇
- 選擇適合你硬件配置的模型大小
- 考慮量化版本(INT8)以獲得更好性能
- 測(cè)試不同池化策略的效果
通過(guò)ONNX,你可以在Java和C#環(huán)境中獲得接近Python原生的模型性能,同時(shí)享受這些語(yǔ)言在工程化方面的優(yōu)勢(shì)。