yolov5預(yù)測(cè)

half = False
import sys
sys.path.insert(0, '/kaggle/input/yolov5/yolov5/')
import torch
device = torch.device('cuda:0')
model = torch.load('/kaggle/input/wheat-submit/last_wheat.pt', map_location=device)['model'].to(device).float().eval()
if half:
    model.half()
def inference_detector(model, img_path):
    from utils.datasets import LoadImages
    dataset = LoadImages(img_path, img_size=640)
    path, img, im0, vid_cap = next(iter(dataset))
    img = torch.from_numpy(img).to(device)
    img = img.half() if half else img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    pred = model(img, augment=False)[0]
    from utils.utils import non_max_suppression
    pred = non_max_suppression(pred, conf_thres=0.1, iou_thres=0.5, classes=None, agnostic=True)
    from utils.utils import scale_coords
    bboxes = []
    scores = []
    clses = []
    for i, det in enumerate(pred):  # detections per image
        if det is not None and len(det):
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
            for *xyxy, conf, cls in det:
                xyxy = torch.tensor(xyxy).view(-1).numpy()
                bboxes.append([*xyxy, conf.item()])
    return np.array(bboxes)
# test
import numpy as np
import cv2
def vis(image_path, det):
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    size = 300
    idx = -1
    font = cv2.FONT_HERSHEY_SIMPLEX 
    # fontScale 
    fontScale = 1
    # Blue color in BGR 
    color = (255, 0, 0) 
    bboxes = det[:,:4].astype(np.int32)
    scores = det[:,4]
    # Line thickness of 2 px 
    thickness = 2
    for b,s in zip(bboxes,scores):
        if s > 0.1:
            image = cv2.rectangle(image, (b[0],b[1]), (b[2],b[3]), (255,0,0), 1) 
            image = cv2.putText(image, '{:.2}'.format(s), (b[0]+np.random.randint(20),b[1]), font,  
                           fontScale, color, thickness, cv2.LINE_AA)
    import matplotlib.pyplot as plt
    plt.figure(figsize=[6, 6])
    plt.imshow(image[:,:,::-1])
    plt.show()
import glob
img_paths = glob.glob('/kaggle/input/global-wheat-detection/test/*.jpg')
img_path = img_paths[0]
det = inference_detector(model, img_path)
vis(img_path, det)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容