Skip to content
Snippets Groups Projects
Select Git revision
  • 7f0846c13cd76c56fc4c60280e3875a16822056b
  • master default protected
  • github/fork/Bobholamovic/master
3 results

Lebedev.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    Lebedev.py 1.60 KiB
    from glob import glob
    from os.path import join, basename
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    
    class LebedevDataset(CDDataset):
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1,
            subsets=('real', 'with_shift', 'without_shift')
        ):
            self.subsets = subsets
            super().__init__(root, phase, transforms, repeats)
    
        def _read_file_paths(self):
            t1_list, t2_list, label_list = [], [], []
    
            for subset in self.subsets:
                # Get subset directory
                if subset == 'real':
                    subset_dir = join(self.root, 'Real', 'subset')
                elif subset == 'with_shift':
                    subset_dir = join(self.root, 'Model', 'with_shift')
                elif subset == 'without_shift':
                    subset_dir = join(self.root, 'Model', 'without_shift')
                else:
                    raise RuntimeError('unrecognized key encountered')
    
                pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
                refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
                t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
                t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
    
                label_list.extend(refs)
                t1_list.extend(t1s)
                t2_list.extend(t2s)
    
            return t1_list, t2_list, label_list
    
        def fetch_label(self, label_path):
            # To {0,1}
            return (super().fetch_label(label_path) / 255.0).astype(np.uint8)