1. pytorch 對于資料的标簽要求是長整形,是以要對标簽進行轉換
train_label = train_label.long()
2.對于資料的特征部分,要轉換為tensor 形式,可以通過torch.tensor 将資料從numpy 轉為tensor
train_fea = torch.tensor(train_fea, dtype=torch.float32)
3.封裝資料
data = torch.utils.data.TensorDataset(train_fea, train_label)#(特征,标簽)
4.封裝成第三步的形式,就可以采用torch 中的資料加載器為模型提供資料,資料加載器可以自動分批喂給模型資料
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)