Skip to content
Snippets Groups Projects
Unverified Commit c050f5ea authored by Malav Bateriwala's avatar Malav Bateriwala Committed by GitHub
Browse files

Merge pull request #30 from ppjerry/FixDataTransformation

fix the random transformation on both image and target
parents 0ba46df9 4b3580f9
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,8 @@ import torch.utils.data ...@@ -6,6 +6,8 @@ import torch.utils.data
import torchvision import torchvision
from skimage import io from skimage import io
from torch.utils.data import Dataset from torch.utils.data import Dataset
import random
import numpy as np
class Images_Dataset(Dataset): class Images_Dataset(Dataset):
...@@ -80,6 +82,7 @@ class Images_Dataset_folder(torch.utils.data.Dataset): ...@@ -80,6 +82,7 @@ class Images_Dataset_folder(torch.utils.data.Dataset):
self.lx = torchvision.transforms.Compose([ self.lx = torchvision.transforms.Compose([
# torchvision.transforms.Resize((128,128)), # torchvision.transforms.Resize((128,128)),
torchvision.transforms.CenterCrop(96), torchvision.transforms.CenterCrop(96),
torchvision.transforms.RandomRotation((-10,10)),
torchvision.transforms.Grayscale(), torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
#torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0)) #torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
...@@ -93,5 +96,19 @@ class Images_Dataset_folder(torch.utils.data.Dataset): ...@@ -93,5 +96,19 @@ class Images_Dataset_folder(torch.utils.data.Dataset):
i1 = Image.open(self.images_dir + self.images[i]) i1 = Image.open(self.images_dir + self.images[i])
l1 = Image.open(self.labels_dir + self.labels[i]) l1 = Image.open(self.labels_dir + self.labels[i])
return self.tx(i1), self.lx(l1) seed=np.random.randint(0,2**32) # make a seed with numpy generator
# apply this seed to img tranfsorms
random.seed(seed)
torch.manual_seed(seed)
img = self.tx(i1)
# apply this seed to target/label tranfsorms
random.seed(seed)
torch.manual_seed(seed)
label = self.lx(l1)
return img, label
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment