天天看點

封裝自己的pytorch資料集

1. pytorch 對于資料的标簽要求是長整形,是以要對标簽進行轉換

train_label = train_label.long()
           

2.對于資料的特征部分,要轉換為tensor 形式,可以通過torch.tensor 将資料從numpy 轉為tensor

train_fea = torch.tensor(train_fea, dtype=torch.float32)
           

3.封裝資料

data = torch.utils.data.TensorDataset(train_fea, train_label)#(特征,标簽)
           

4.封裝成第三步的形式,就可以采用torch 中的資料加載器為模型提供資料,資料加載器可以自動分批喂給模型資料

train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
           

繼續閱讀