Skip to content
Snippets Groups Projects
Commit a4d70bc2 authored by Bobholamovic's avatar Bobholamovic
Browse files

Fix mp error on Windows

parent 909b6859
No related branches found
No related tags found
No related merge requests found
......@@ -175,7 +175,7 @@ def _get_basic_configs(ds_name, C):
return dict(
root = constants.IMDB_AIRCHANGE
)
elif ds_name.startswith('Lebedev'):
elif ds_name == 'Lebedev':
return dict(
root = constants.IMDB_LEBEDEV
)
......
import abc
from os.path import join, basename
from multiprocessing import Manager
from functools import lru_cache
import numpy as np
......@@ -19,11 +19,6 @@ class _AirChangeDataset(CDDataset):
super().__init__(root, phase, transforms, repeats)
self.cropper = Crop(bounds=(0, 0, 748, 448))
self._manager = Manager()
sync_list = self._manager.list
self.images = sync_list([sync_list([None]*self.N_PAIRS), sync_list([None]*self.N_PAIRS)])
self.labels = sync_list([None]*self.N_PAIRS)
@property
@abc.abstractmethod
def LOCATION(self):
......@@ -41,32 +36,28 @@ class _AirChangeDataset(CDDataset):
def _read_file_paths(self):
if self.phase == 'train':
sample_ids = range(self.N_PAIRS)
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in sample_ids if i not in self.TEST_SAMPLE_IDS]
sample_ids = [i for i in range(self.N_PAIRS) if i not in self.TEST_SAMPLE_IDS]
t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in sample_ids]
t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in sample_ids]
label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids]
else:
t1_list = ['-'.join([self.LOCATION,str(i),'0.bmp']) for i in self.TEST_SAMPLE_IDS]
t2_list = ['-'.join([self.LOCATION,str(i),'1.bmp']) for i in self.TEST_SAMPLE_IDS]
label_list = ['-'.join([self.LOCATION,str(i),'cm.bmp']) for i in self.TEST_SAMPLE_IDS]
t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in self.TEST_SAMPLE_IDS]
t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in self.TEST_SAMPLE_IDS]
label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS]
return t1_list, t2_list, label_list
@lru_cache(maxsize=8)
def fetch_image(self, image_name):
_, i, t = image_name.split('-')
i, t = int(i), int(t[:-4])
if self.images[t][i] is None:
image = self._bmp_loader(join(self.root, self.LOCATION, str(i+1), 'im'+str(t+1)))
self.images[t][i] = image if self.phase == 'train' else self.cropper(image)
return self.images[t][i]
image = self._bmp_loader(image_name)
return image if self.phase == 'train' else self.cropper(image)
@lru_cache(maxsize=8)
def fetch_label(self, label_name):
index = int(label_name.split('-')[1])
if self.labels[index] is None:
label = self._bmp_loader(join(self.root, self.LOCATION, str(index+1), 'gt'))
label = (label / 255.0).astype(np.uint8) # To 0,1
self.labels[index] = label if self.phase == 'train' else self.cropper(label)
return self.labels[index]
label = self._bmp_loader(label_name)
label = (label / 255.0).astype(np.uint8) # To 0,1
return label if self.phase == 'train' else self.cropper(label)
@staticmethod
def _bmp_loader(bmp_path_wo_ext):
......@@ -74,6 +65,4 @@ class _AirChangeDataset(CDDataset):
try:
return default_loader(bmp_path_wo_ext+'.bmp')
except FileNotFoundError:
return default_loader(bmp_path_wo_ext+'.BMP')
\ No newline at end of file
return default_loader(bmp_path_wo_ext+'.BMP')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment