天天看點

pytorch argmax_輕松學Pytorch-使用ResNet50實作圖像分類

pytorch argmax_輕松學Pytorch-使用ResNet50實作圖像分類

來源:磐創AI

本文約1161字,建議閱讀4分鐘。

本文介紹pytorch中最重要的元件torchvision,它包含了常見的資料集、模型架構與預訓練模型權重檔案、常見圖像變換、計算機視覺任務訓練。

Hello大家好,這篇文章給大家詳細介紹一下pytorch中最重要的元件torchvision,它包含了常見的資料集、模型架構與預訓練模型權重檔案、常見圖像變換、計算機視覺任務訓練。可以是說是pytorch中非常有用的模型遷移學習神器。本文将會介紹如何使用torchvison的預訓練模型ResNet50實作圖像分類。

模型

Torchvision.models包裡面包含了常見的各種基礎模型架構,主要包括:

AlexNetVGGResNetSqueezeNetDenseNetInception v3GoogLeNetShuffleNet v2MobileNet v2ResNeXtWide ResNetMNASNet

這裡我選擇了ResNet50,基于ImageNet訓練的基礎網絡來實作圖像分類, 網絡模型下載下傳與加載如下:

model = torchvision.models.resnet50(pretrained=True).eval().cuda()tf = transforms.Compose([            transforms.Resize(256),            transforms.CenterCrop(224),            transforms.ToTensor(),            transforms.Normalize(            mean=[0.485, 0.456, 0.406],            std=[0.229, 0.224, 0.225]        )]
           

使用模型實作圖像分類

這裡首先需要加載ImageNet的分類标簽,目的是最後顯示分類的文本标簽時候使用。然後對輸入圖像完成預處理,使用ResNet50模型實作分類預測,對預測結果解析之後,顯示标簽文本,完整的代碼示範如下:

1with open('imagenet_classes.txt') as f: 2    labels = [line.strip() for line in f.readlines()] 3 4src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg 5image = cv.resize(src, (224, 224)) 6image = np.float32(image) / 255.0 7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406)) 8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225)) 9image = image.transpose((2, 0, 1))10input_x = torch.from_numpy(image).unsqueeze(0)11print(input_x.size())12pred = model(input_x.cuda())13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()14print(pred_index)15print("current predict class name : %s"%labels[pred_index[0]])16cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)17cv.imshow("input", src)18cv.waitKey(0)19cv.destroyAllWindows()
           

運作結果如下:

pytorch argmax_輕松學Pytorch-使用ResNet50實作圖像分類

轉ONNX支援

在torchvision中的模型基本上都可以轉換為ONNX格式,而且被OpenCV DNN子產品所支援,是以,很友善的可以對torchvision自帶的模型轉為ONNX,實作OpenCV DNN的調用,首先轉為ONNX模型,直接使用torch.onnx.export即可轉換(還不知道怎麼轉,快點看前面的例子)。轉換之後使用OpenCV DNN調用的代碼如下:

1with open('imagenet_classes.txt') as f: 2    labels = [line.strip() for line in f.readlines()] 3net = cv.dnn.readNetFromONNX("resnet.onnx") 4src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg 5image = cv.resize(src, (224, 224)) 6image = np.float32(image) / 255.0 7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406)) 8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225)) 9blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)10net.setInput(blob)11probs = net.forward()12index = np.argmax(probs)13cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)14cv.imshow("input", src)15cv.waitKey(0)16cv.destroyAllWindows()
           

運作結果見上圖,這裡就不再貼了。

—完—

想要獲得更多資料科學領域相關動态,誠邀關注清華-青島資料科學研究院官方微信公衆平台“ 資料派THU ”。

繼續閱讀