天天看点

封装自己的pytorch数据集

1. pytorch 对于数据的标签要求是长整形,因此要对标签进行转换

train_label = train_label.long()
           

2.对于数据的特征部分,要转换为tensor 形式,可以通过torch.tensor 将数据从numpy 转为tensor

train_fea = torch.tensor(train_fea, dtype=torch.float32)
           

3.封装数据

data = torch.utils.data.TensorDataset(train_fea, train_label)#(特征,标签)
           

4.封装成第三步的形式,就可以采用torch 中的数据加载器为模型提供数据,数据加载器可以自动分批喂给模型数据

train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
           

继续阅读