PyTorch自定義數(shù)據(jù)集示例(2019-12-19)

文章結(jié)構(gòu)

  • 自定義Dataset的基本結(jié)構(gòu)

  • 使用Torchvisiom進(jìn)行類(lèi)型轉(zhuǎn)換

  • 使用Torchvision的另一種方法

  • Incorporating Pandas

  • Incorporating Pandas with More Logic

  • 使用Data Loader

自定義Dataset的基本結(jié)構(gòu)

  • 首先最重要的是要?jiǎng)?chuàng)建dataset類(lèi)
from torch.utils.data.dataset import Dataset

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        
    def __getitem__(self, index):
        # 填充
        return (img, label)

    def __len__(self):
        return count # 你有多少?gòu)垐D片
  • 這是必須填充用來(lái)獲得自定義數(shù)據(jù)集的框架。數(shù)據(jù)集必須包含以下函數(shù),以便稍后由數(shù)據(jù)加載程序使用。
__init__() #函數(shù)是初始邏輯發(fā)生的地方,比如讀取csv、分配轉(zhuǎn)換等
__getitem__()#函數(shù)返回?cái)?shù)據(jù)和標(biāo)簽。這個(gè)函數(shù)是從dataloader中被調(diào)用的,如下所示:
img, label = MyCustomDataset.__getitem__(99)  # 有99個(gè)數(shù)據(jù)

  • 因此,索引參數(shù)(index)是你要返回的第n個(gè)數(shù)據(jù)/圖像(tensor)。
__len__()#返回你的樣本數(shù)量
  • 注意__getitem__()返回一個(gè)特殊的數(shù)據(jù)類(lèi)型,比如tensor,numpy array等,如果不是這些類(lèi)型,在data loader將會(huì)報(bào)錯(cuò)。
    TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>

使用Torchvisiom進(jìn)行類(lèi)型轉(zhuǎn)換

  • 一般在__init__()里面都會(huì)寫(xiě)成transforms = None,這是為了方便在調(diào)用dataset類(lèi)的時(shí)候傳入自定義的transforms
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # 填充
        #...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 從文件或者圖像中讀取的數(shù)據(jù)
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果轉(zhuǎn)換變量不是空
        # 按照傳入的轉(zhuǎn)換格式來(lái)轉(zhuǎn)換數(shù)據(jù)
        return (img, label)

    def __len__(self):
        return count
        
if __name__ == '__main__':
    # 自定義transforms
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 調(diào)用數(shù)據(jù)集
    custom_dataset = MyCustomDataset(..., transformations)

使用Torchvision的另一種方法

  • 如果不喜歡在外面自定義transforms,可以在dataset類(lèi)里面定義好,不過(guò)這樣降低了程序的可讀性。
from torch.utils.data.dataset import Dataset
from torchvision import transforms

class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # 填充
        #...
        # 單獨(dú)定義轉(zhuǎn)換
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # 也可以組合定義
        self.transformations = transforms.Compose([
                                transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # 填充
        #...
        data = # 從文件或者圖像中讀取的數(shù)據(jù)
        
        #對(duì)應(yīng)了在__init__()中定義的三個(gè)transforms
        data = self.center_crop(data)  
        data = self.to_tensor(data)  
        data = self.trasnformations(data) 
        
        return (img, label)

    def __len__(self):
        return count 
        
if __name__ == '__main__':
    # 調(diào)用dataset
    custom_dataset = MyCustomDataset(...)

Incorporating Pandas

  • 假設(shè),我們想通過(guò)pandas從csv文件中讀取數(shù)據(jù)。第一個(gè)例子如下的csv文件,包含文件名和標(biāo)簽,和一個(gè)額外的操作指示器根據(jù)這個(gè)額外的操作標(biāo)志我們對(duì)圖像做一些操作。
File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_2.png 4 FALSE
  • 如果我們想建立一個(gè)自定義數(shù)據(jù)集,讀取圖像位置從這個(gè)csv文件,然后我們可以做如下操作
class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        '''
        Args:
            csv_path (string): csv文件路徑
            img_path (string): 圖片文件路徑
            transform: pytorch變換用于變換和張量轉(zhuǎn)換
        '''
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # 讀取csv文件
        self.data_info = pd.read_csv(csv_path, header=None)
        # 第一列包含圖像路徑
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # 第二列是標(biāo)簽
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # 第三列是操作指示符
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # 計(jì)算整個(gè)數(shù)據(jù)集的長(zhǎng)度
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # 從pandas df獲取圖片文件名
        single_image_name = self.image_arr[index]
        # 打開(kāi)圖片
        img_as_img = Image.open(single_image_name)

        # 檢查是否有操作
        some_operation = self.operation_arr[index]
        # 如果有操作的話(huà)
        if some_operation:
            # 對(duì)圖像做一些操作
            # ...
            # ...
            pass
        # 把圖像變換成張量
        img_as_tensor = self.to_tensor(img_as_img)

        # 根據(jù)裁剪的panda列獲取圖像的標(biāo)簽
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # 調(diào)用 dataset
    custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv')

Incorporating Pandas with More Logic

  • 另一個(gè)從csv中讀取圖像的例子,其中每個(gè)像素的值都在一個(gè)列中。這時(shí),只需要返回張量以及其標(biāo)簽。數(shù)據(jù)被分成像素。
Lbel pixel_1 pixel_2 ...
1 50 99 ...
0 21 223 ...
9 44 112 ...
... ... ... ...
class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        '''
        Args:
            csv_path (string): csv文件路徑
            height (int): 圖片高度
            width (int): 圖片寬度
            transform: pytorch transforms for transforms and tensor conversion
        '''
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform

    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # Read each 784 pixels and reshape the 1D array ([784]) to 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
    # 將圖像從numpy數(shù)組轉(zhuǎn)換為PIL圖像,模式“L”為灰度
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # 把圖像變換成tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回圖片和標(biāo)簽
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return len(self.data.index)
        

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)

使用Data Loader

  • 在pytorch中DataLoader只需要調(diào)用__getitem__()然后把他們打包成一個(gè)批次。我們也可以不使用Dataloader每調(diào)用__getitem()__一次就把數(shù)據(jù)傳入到模型(遠(yuǎn)沒(méi)有使用DataLoader方便)。從上面的示例繼續(xù),如果我們假設(shè)有一個(gè)名為CustomDatasetFromCSV的自定義數(shù)據(jù)集,那么我們可以像這樣調(diào)用DataLoader
if __name__ == "__main__":
    # 定義 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 定義dataset
    custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv',28, 28,transformations)
    # 定義data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 將數(shù)據(jù)送入模型
  • DataLoader的第一個(gè)參數(shù)是數(shù)據(jù)集,從那里它調(diào)用該數(shù)據(jù)集的__getitem__().batch_size確定一個(gè)批次傳入的數(shù)據(jù)量,如果我們假設(shè)一張圖片的tensor是[1*28*28] ---> [D:1,H:28,W:28]那么用這個(gè)DataLoader返回的tensor是[10*1*28*28]
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容