diff --git a/Data_Loader.py b/Data_Loader.py index ac84b4df74e4d99ceb3a30c822c5070f6f5711fd..5a371ff239f30a8a1b0dffe3458cbed386853eda 100644 --- a/Data_Loader.py +++ b/Data_Loader.py @@ -6,6 +6,8 @@ import torch.utils.data import torchvision from skimage import io from torch.utils.data import Dataset +import random +import numpy as np class Images_Dataset(Dataset): @@ -80,6 +82,7 @@ class Images_Dataset_folder(torch.utils.data.Dataset): self.lx = torchvision.transforms.Compose([ # torchvision.transforms.Resize((128,128)), torchvision.transforms.CenterCrop(96), + torchvision.transforms.RandomRotation((-10,10)), torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor(), #torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0)) @@ -93,5 +96,19 @@ class Images_Dataset_folder(torch.utils.data.Dataset): i1 = Image.open(self.images_dir + self.images[i]) l1 = Image.open(self.labels_dir + self.labels[i]) - return self.tx(i1), self.lx(l1) + seed=np.random.randint(0,2**32) # make a seed with numpy generator + + # apply this seed to img tranfsorms + random.seed(seed) + torch.manual_seed(seed) + img = self.tx(i1) + + # apply this seed to target/label tranfsorms + random.seed(seed) + torch.manual_seed(seed) + label = self.lx(l1) + + + + return img, label