Skip to content
Snippets Groups Projects
dataset.py 2.33 KiB
Newer Older
  • Learn to ignore specific revisions
  • ZeroAct's avatar
    ZeroAct committed
    import os, cv2
    import numpy as np
    
    import torch
    from torch.utils.data import Dataset
    
    def mat_to_tensor(mat):
        mat = mat.transpose((2, 0, 1))
        tensor = torch.Tensor(mat)
        return tensor
    
    def tensor_to_mat(tensor):
        mat = tensor.detach().cpu().numpy()
        mat = mat.transpose((0, 2, 3, 1))
        return mat
    
    def preprocess_image(img, target_shape: tuple):
        img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
        img = img / 255.
        if len(img.shape) == 2:
            img = img.reshape(*img.shape, 1)
        
        return img
    
    def postprocess_image(img):
    
        img = img * 255
        img = np.clip(img, 0, 255)
    
    ZeroAct's avatar
    ZeroAct committed
        return img.astype(np.uint8)
    
    class CustomDataset(Dataset):
        def __init__(self,
                     data_dir,
                     set_name="train",
                     target_size=(256, 256)):
            
            super().__init__()
            
            self.root_dir = os.path.join(data_dir, set_name)
            self.target_size = target_size
            
            self.I_dir = os.path.join(self.root_dir, "I")
            self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
            self.Mm_dir = os.path.join(self.root_dir, "Mm")
            self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
            
            self.datas = os.listdir(self.I_dir)
            
        def __len__(self):
            return len(self.datas)
        
        def __getitem__(self, idx):
            img_name = self.datas[idx]
            
            I      = cv2.imread(os.path.join(self.I_dir, img_name))
            Itegt  = cv2.imread(os.path.join(self.Itegt_dir, img_name))
            Mm     = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
            Msgt   = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
            
            I      = mat_to_tensor(preprocess_image(I,     self.target_size))
            Itegt  = mat_to_tensor(preprocess_image(Itegt, self.target_size))
            Mm     = mat_to_tensor(preprocess_image(Mm,    self.target_size))
            Msgt   = mat_to_tensor(preprocess_image(Msgt,  self.target_size))
            
            return I, Itegt, Mm, Msgt
            
    
    if __name__ == "__main__":
        ds = CustomDataset('dataset', 'train')
        
        I, Itegt, Mm, Ms = ds.__getitem__(0)
        print(f"Dataset length : {len(ds)}")
        print(f"I shape : {I.shape}")
        print(f"Itegt shape : {Itegt.shape}")
        print(f"Mm shape : {Mm.shape}")
        print(f"Ms shape : {Ms.shape}")