天天看點

使用pytorch自定義DataSet,以加載圖像資料集為例,實作一些騷操作

使用pytorch自定義DataSet,以加載圖像資料集為例,實作一些騷操作

總共分為四步

  • 構造一個​

    ​my_dataset​

    ​​類,繼承自​

    ​torch.utils.data.Dataset​

  • 重寫​

    ​__getitem__​

    ​​ 和​

    ​__len__​

    ​ 類函數
  • 建立兩個函數​

    ​find_classes​

    ​​、​

    ​has_file_allowed_extension​

    ​,直接從這copy過去
  • 建立​

    ​my_make_dataset​

    ​函數用來構造(path,lable)對

一、構造一個​

​my_dataset​

​​類,繼承自​

​torch.utils.data.Dataset​

二、 重寫​

​__getitem__​

​​ 和​

​__len__​

​ 類函數

要構造Dataset的子類,就必須要實作兩個方法:

  • getitem_(self, index):根據index來傳回資料集中标号為index的元素及其标簽。
  • len_(self):傳回資料集的長度。
class my_dataset(Dataset):
    def __init__(self,root_original, root_cdtfed, transform=None):
        super(my_dataset, self).__init__()
        self.transform = transform
        self.root_original = root_original
        self.root_cdtfed = root_cdtfed
        self.original_imgs = []
        self.cdtfed_imgs = []
        
        #add (img_path, label) to lists
        self.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)
        self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)

        # super(my_dataset, self).__init__()
    def __getitem__(self, index):    #這個方法是必須要有的,用于按照索引讀取每個元素的具體内容
        fn1, label1 = self.original_imgs[index] #fn是圖檔path #fn和label分别獲得imgs[index]也即是剛才每行中word[0]和word[1]的資訊
        fn2, label2 = self.cdtfed_imgs[index]

        img1 = Image.open(fn1).convert('RGB') #按照path讀入圖檔from PIL import Image # 按照路徑讀取圖檔
        img2 = Image.open(fn2).convert('RGB') #按照path讀入圖檔from PIL import Image # 按照路徑讀取圖檔
        
        if self.transform is not None:
            img1 = self.transform(img1) #是否進行transform
            img2 = self.transform(img2) #是否進行transform
        img_list = [img1, img2]
        label = label1
        name = fn1
        return img_list,label,name  #return很關鍵,return回哪些内容,那麼我們在訓練時循環讀取每個batch時,就能獲得哪些内容
 
    def __len__(self): #這個函數也必須要寫,它傳回的是資料集的長度,也就是多少張圖檔,要和loader的長度作區分
        return len(self.original_imgs)      

三、建立兩個函數​

​find_classes​

​​、​

​has_file_allowed_extension​

​,直接從這copy過去

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)      
  • 建立​

    ​my_make_dataset​

    ​函數用來構造(path,lable)對
def my_make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    # item = path, [int(cl) for cl in target_class.split('_')]
                    item = path, target_class
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances #instance:[item:(path, int(class_name)), ]      

附錄:完整代碼

class my_dataset(Dataset):
    def __init__(self,root_original, root_cdtfed, transform=None):
        super(my_dataset, self).__init__()
        self.transform = transform
        self.root_original = root_original
        self.root_cdtfed = root_cdtfed
        self.original_imgs = []
        self.cdtfed_imgs = []
        
        #add (img_path, label) to lists
        self.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)
        self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)

        # super(my_dataset, self).__init__()
    def __getitem__(self, index):    #這個方法是必須要有的,用于按照索引讀取每個元素的具體内容
        fn1, label1 = self.original_imgs[index] #fn是圖檔path #fn和label分别獲得imgs[index]也即是剛才每行中word[0]和word[1]的資訊
        fn2, label2 = self.cdtfed_imgs[index]

        img1 = Image.open(fn1).convert('RGB') #按照path讀入圖檔from PIL import Image # 按照路徑讀取圖檔
        img2 = Image.open(fn2).convert('RGB') #按照path讀入圖檔from PIL import Image # 按照路徑讀取圖檔
        
        if self.transform is not None:
            img1 = self.transform(img1) #是否進行transform
            img2 = self.transform(img2) #是否進行transform
        img_list = [img1, img2]
        label = label1
        name = fn1
        return img_list,label,name  #return很關鍵,return回哪些内容,那麼我們在訓練時循環讀取每個batch時,就能獲得哪些内容
 
    def __len__(self): #這個函數也必須要寫,它傳回的是資料集的長度,也就是多少張圖檔,要和loader的長度作區分
        return len(self.original_imgs)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)

def my_make_dataset(
    directory: str,
    class_to_idx: Optional[Dict[str, int]] = None,
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).

    See :class:`DatasetFolder` for details.

    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                if is_valid_file(fname):
                    path = os.path.join(root, fname)
                    # item = path, [int(cl) for cl in target_class.split('_')]
                    item = path, target_class
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances #instance:[item:(path, int(class_name)), ]      

繼續閱讀