import os, cv2
import numpy as np

import torch
from torch.utils.data import Dataset

def mat_to_tensor(mat):
    mat = mat.transpose((2, 0, 1))
    tensor = torch.Tensor(mat)
    return tensor

def tensor_to_mat(tensor):
    mat = tensor.detach().cpu().numpy()
    mat = mat.transpose((0, 2, 3, 1))
    return mat

def preprocess_image(img, target_shape: tuple):
    img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
    img = img / 255.
    if len(img.shape) == 2:
        img = img.reshape(*img.shape, 1)
    
    return img

def postprocess_image(img):
    img = img * 255
    img = np.clip(img, 0, 255)
    return img.astype(np.uint8)

class CustomDataset(Dataset):
    def __init__(self,
                 data_dir,
                 set_name="train",
                 target_size=(256, 256)):
        
        super().__init__()
        
        self.root_dir = os.path.join(data_dir, set_name)
        self.target_size = target_size
        
        self.I_dir = os.path.join(self.root_dir, "I")
        self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
        self.Mm_dir = os.path.join(self.root_dir, "Mm")
        self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
        
        self.datas = os.listdir(self.I_dir)
        
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, idx):
        img_name = self.datas[idx]
        
        I      = cv2.imread(os.path.join(self.I_dir, img_name))
        Itegt  = cv2.imread(os.path.join(self.Itegt_dir, img_name))
        Mm     = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
        Msgt   = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
        
        I      = mat_to_tensor(preprocess_image(I,     self.target_size))
        Itegt  = mat_to_tensor(preprocess_image(Itegt, self.target_size))
        Mm     = mat_to_tensor(preprocess_image(Mm,    self.target_size))
        Msgt   = mat_to_tensor(preprocess_image(Msgt,  self.target_size))
        
        return I, Itegt, Mm, Msgt
        

if __name__ == "__main__":
    ds = CustomDataset('dataset', 'train')
    
    I, Itegt, Mm, Ms = ds.__getitem__(0)
    print(f"Dataset length : {len(ds)}")
    print(f"I shape : {I.shape}")
    print(f"Itegt shape : {Itegt.shape}")
    print(f"Mm shape : {Mm.shape}")
    print(f"Ms shape : {Ms.shape}")