Select Git revision
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Lebedev.py 1.60 KiB
from glob import glob
from os.path import join, basename
import numpy as np
from . import CDDataset
from .common import default_loader
class LebedevDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
subsets=('real', 'with_shift', 'without_shift')
):
self.subsets = subsets
super().__init__(root, phase, transforms, repeats)
def _read_file_paths(self):
t1_list, t2_list, label_list = [], [], []
for subset in self.subsets:
# Get subset directory
if subset == 'real':
subset_dir = join(self.root, 'Real', 'subset')
elif subset == 'with_shift':
subset_dir = join(self.root, 'Model', 'with_shift')
elif subset == 'without_shift':
subset_dir = join(self.root, 'Model', 'without_shift')
else:
raise RuntimeError('unrecognized key encountered')
pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
label_list.extend(refs)
t1_list.extend(t1s)
t2_list.extend(t2s)
return t1_list, t2_list, label_list
def fetch_label(self, label_path):
# To {0,1}
return (super().fetch_label(label_path) / 255.0).astype(np.uint8)