#encoding=gbk
import tensorrt as trt
import numpy as np
import os
import cv2
import pycuda.driver as cuda
import pycuda.autoinit
from imutils import paths
from tqdm import tqdm
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
class TrtModel:
def __init__(self, engine_path, max_batch_size=1, dtype=np.float32):
self.engine_path = engine_path
self.dtype = dtype
self.logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.logger)
self.engine = self.load_engine(self.runtime, self.engine_path)
self.max_batch_size = max_batch_size
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
self.context = self.engine.create_execution_context()
@staticmethod
def load_engine(trt_runtime, engine_path):
trt.init_libnvinfer_plugins(None, "")
with open(engine_path, 'rb') as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine
def allocate_buffers(self):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in self.engine:
# size = trt.volume(self.engine.get_binding_shape(binding)) * self.max_batch_size
#*******
ssize = self.engine.get_binding_shape(binding)
ssize[0]=self.max_batch_size
size=trt.volume(ssize)
#*******
host_mem = cuda.pagelocked_empty(size, self.dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))
if self.engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
def __call__(self, x: np.ndarray, batch_size=2):
x = x.astype(self.dtype)
np.copyto(self.inputs[0].host, x.ravel())
for inp in self.inputs:
cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
#**********
origin_inputshape=self.engine.get_binding_shape(0)
origin_inputshape[0]=batch_size
self.context.set_binding_shape(0,(origin_inputshape))
#**********
self.context.execute_async(batch_size=batch_size, bindings=self.bindings, stream_handle=self.stream.handle)
for out in self.outputs:
cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
self.stream.synchronize()
return [out.host.reshape(batch_size, -1) for out in self.outputs]
if __name__ == "__main__":
# 驗(yàn)證模式:fp32,fp16,int8
val_type='fp16'
#---------------------------------
path=r'./imgs/'
trt_engine_path = r'./model/{}.engine'.format(val_type)
out_path=r'./out/{}'.format(val_type)
if not os.path.exists(out_path):
os.makedirs(out_path)
#均值和方差
mean = (120, 114, 104)
std = (70, 69, 73)
#構(gòu)建模型
model = TrtModel(trt_engine_path)
pic_paths = list(paths.list_images(path))
for pic_path in tqdm(pic_paths):
name=os.path.basename(pic_path).split('.')[0]
# 輸入圖像預(yù)處理
img = cv2.imread(pic_path)
imgbak = img.copy()
img = img[:, :, ::-1]
img = np.array(img).astype(np.float32) # 注意輸入type一定要np.float32
img -= mean # 減均值
img /= std # 除方差
img = np.array([np.transpose(img, (2, 0, 1))])
#模型推理
result = model(img, 1)
# 保存圖像
img_out=np.reshape(result[0][0],(512,512))
img_out =img_out.astype('uint8')
# img_out=img_out*25
img_out[img_out>0]=255
cv2.imwrite(os.path.join(out_path,'{}_{}.png'.format(val_type,name)),img_out)
TensorRT Python驗(yàn)證代碼---分割類
最后編輯于 :
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
相關(guān)閱讀更多精彩內(nèi)容
- 實(shí)現(xiàn)功能: python實(shí)現(xiàn)KNN建模,選擇最佳K值,對(duì)數(shù)據(jù)樣本進(jìn)行分類預(yù)測(cè),并驗(yàn)證評(píng)估。 實(shí)現(xiàn)代碼: # 導(dǎo)入需...
- 目前網(wǎng)上關(guān)于滑塊的缺口識(shí)別的方法很多,但是都不極簡(jiǎn),看起來繁雜,各種算法的都有,有遍歷的有二分法的,今天寫個(gè)最簡(jiǎn)單...
- Python中的random模塊 Python中的random模塊用于生成隨機(jī)數(shù)。 random.random r...
- 現(xiàn)在驗(yàn)證碼的種類真的是越來越多,短信驗(yàn)證碼、語音驗(yàn)證碼、圖片驗(yàn)證碼、滑塊驗(yàn)證碼 ... 我們?cè)?PC 的網(wǎng)頁端或者...
- 現(xiàn)在驗(yàn)證碼的種類真的是越來越多,短信驗(yàn)證碼、語音驗(yàn)證碼、圖片驗(yàn)證碼、滑塊驗(yàn)證碼 ... 我們?cè)?PC 的網(wǎng)頁端或者...