天天看点

pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案

多线程加载

  • 在 datalaoder中指定​

    ​num_works > 0​

    ​,多线程加载数据集,最大可设置为 cpu 核数
  • 设置 ​

    ​pin_memory = True​

    ​, 固定内存访问单元,节约内存调度时间
  • 示例如下:
loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
        num_workers=2,
        pin_memory=True,
    )      

预加载数据集

  • 原理:将整个数据集预先 load 到内存单元中,读取则直接访问内存,不存在与磁盘的I/O问题
  • 构建自己的dataset类
  • 示例如下:
class My_Dataset(Dataset):
    def __init__(
        self, filename, preprocess_config, train_config, sort=False, drop_last=False
    ):
        self.dataset_name = preprocess_config["dataset"]
        self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
        self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
        self.batch_size = train_config["optimizer"]["batch_size"]

        self.basename, self.speaker, self.text, self.raw_text = self.process_meta(
            filename
        )
        with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
            self.speaker_map = json.load(f)
        self.sort = sort
        self.drop_last = drop_last
        # add
        self.mel_list = []
        self.pitch_list = []
        self.energy_list = []
        self.duration_list = []
        for idx in range(len(self.text)):
            basename = self.basename[idx]
            speaker = self.speaker[idx]
            mel_path = os.path.join(
            self.preprocessed_path,
            "mel",
            "{}-mel-{}.npy".format(speaker, basename),
            )
            mel = np.load(mel_path)
            pitch_path = os.path.join(
                self.preprocessed_path,
                "pitch",
                "{}-pitch-{}.npy".format(speaker, basename),
            )
            pitch = np.load(pitch_path)
            energy_path = os.path.join(
                self.preprocessed_path,
                "energy",
                "{}-energy-{}.npy".format(speaker, basename),
            )
            energy = np.load(energy_path)
            duration_path = os.path.join(
                self.preprocessed_path,
                "duration",
                "{}-duration-{}.npy".format(speaker, basename),
            )
            duration = np.load(duration_path)
            self.mel_list.append(mel)
            self.pitch_list.append(pitch)
            self.energy_list.append(energy)
            self.duration_list.append(duration)

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        basename = self.basename[idx]
        speaker = self.speaker[idx]
        speaker_id = self.speaker_map[speaker]
        raw_text = self.raw_text[idx]
        phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
        
        mel = self.mel_list[idx]
        pitch = self.pitch_list[idx]       
        energy = self.energy_list[idx]        
        duration = self.duration_list[idx]

        sample = {
            "id": basename,
            "speaker": speaker_id,
            "text": phone,
            "raw_text": raw_text,
            "mel": mel,
            "pitch": pitch,
            "energy": energy,
            "duration": duration,
        }

        return sample      
  • 在 ​

    ​__init__​

    ​函数里,即将所有数据load进内存
  • ​__getitem__(self, idx):​

    ​函数,则直接通过列表idx访问每一条数据

继续阅读