Pytorch導(dǎo)出ONNX踩坑指南

相對與ONNX模型,Pytorch模型經(jīng)常較為松散,API的限制也往往較為寬松。因此,在導(dǎo)出的過程中,不可避免地會遇到導(dǎo)出失敗的問題??梢灶A(yù)見到,這塊API可能在不久的將來會發(fā)生變化。

ONNX導(dǎo)出

ONNX導(dǎo)出的基本操作比較簡單。官網(wǎng)上的例子是:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

可惜真要這么容易就好了

ONNX導(dǎo)出驗證腳本

import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession('./model.onnx', None)

# 以圖像分類為例,batchsize設(shè)為2測試導(dǎo)出模型支持batching。
sess.run(None, {'input_1': np.random.rand(2, 3, img_height, img_width).astype('float32')})

讓導(dǎo)出模型支持同時處理多個數(shù)據(jù)(Batching)

支持Batching需要制定Dynamic Axes,即可變的維度。

案例:

torch.export(...,
  input_names=['input_1'],
  output_names=['output_1'],
  dynamic_axes={
    'input_1': [0],  # 第0維是batch dimension
    'output_1': [0],
  })

解決Caffe2運行報錯

keep_initializers_as_inputs 這個參數(shù)是False的情況下,在Caffe2中報錯:IndexError: _Map_base::at. 參考https://github.com/onnx/onnx/issues/2458

opset 11在onnxruntime中運行時沒使用GPU

問題比較復(fù)雜。貌似tensorflow也有類似問題。導(dǎo)出時添加參數(shù)do_constant_folding=True或許可以解決。
參考https://github.com/NVIDIA/triton-inference-server/issues/1080

List of tensor的導(dǎo)出

定長list

定長list會導(dǎo)出為一個tuple

變長list

Pytorch 1.4,ONNX 9不支持變長List的導(dǎo)出。之后的Pytorch版本有支持,需要更高版本的ONNX

不支持的操作

  • Tensor in-place indexed assignment like data[index] = new_data is currently not supported in exporting. One way to resolve this kind of issue is to use operator scatter, explicitly updating the original tensor.

  • There is no concept of tensor list in ONNX. Without this concept, it is very hard to export operators that consume or produce tensor list, especially when the length of the tensor list is not known at export time.

  • Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. Users need to verify their dict inputs carefully, and keep in mind that dynamic lookups are not available.

  • PyTorch and ONNX backends(Caffe2, ONNX Runtime, etc) often have implementations of operators with some numeric differences. Depending on model structure, these differences may be negligible, but they can also cause major divergences in behavior (especially on untrained models.) We allow Caffe2 to call directly to Torch implementations of operators, to help you smooth over these differences when precision is important, and to also document these differences.

不一致的Operator

Expand

Pytorch中,Expand未改動的dim可以指定為-1,導(dǎo)出到ONNX中時,需要手動指定每個dim的值。如:

Pytorch:
a = a.expand(10, -1, -1)
ONNX:
a = a.expand(10, a.size(1), a.size(2))

Squeeze

Pytorch中,Squeeze一個不為1維的dim不會有任何效果。ONNX會報錯

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • tensorflow開發(fā)API 架構(gòu) Modules app module: Generic entry poin...
    sennchi閱讀 1,443評論 0 2
  • NAME dnsmasq - A lightweight DHCP and caching DNS server....
    ximitc閱讀 2,993評論 0 0
  • This is a pre-print version. Official version: http://rsi...
    hydro閱讀 689評論 0 0
  • 啥都入門一下,啥都不精。所有的都是按著廖雪峰大神的腳步走。從Node.js官網(wǎng)下載對應(yīng)平臺的安裝程序,網(wǎng)速慢的童鞋...
    mild_chen閱讀 289評論 0 0
  • 格式 顯示普通字符(雙引號可以忽略) 顯示轉(zhuǎn)移字符 結(jié)果: 顯示變量reed命令從標準輸入中讀取一行,并把輸入行的...
    AsaGuo閱讀 619評論 0 1

友情鏈接更多精彩內(nèi)容