Skip to content
Snippets Groups Projects
Select Git revision
  • 1b0d6980779c823cb79e95c96fc3a424ac6e77bc
  • main default protected
2 results

regression.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    data.py 2.17 KiB
    import os.path
    from abc import ABCMeta, abstractmethod
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.utils.data as data
    
    
    # Data builder utilities
    def build_train_dataloader(cls, configs, C):
        return data.DataLoader(
            cls(**configs),
            batch_size=C['batch_size'],
            shuffle=True,
            num_workers=C['num_workers'],
            pin_memory=C['device']!='cpu',
            drop_last=True
        )
    
    
    def build_eval_dataloader(cls, configs):
        return data.DataLoader(
            cls(**configs),
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=False,
            drop_last=False
        )
    
    
    def get_common_train_configs(C):
        return dict(phase='train', repeats=C['repeats'])
    
    
    def get_common_eval_configs(C):
        return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
    
    
    # Dataset prototype
    class DatasetBase(data.Dataset, metaclass=ABCMeta):
        def __init__(
            self, 
            root, phase,
            transforms,
            repeats, 
            subset
        ):
            super().__init__()
            self.root = os.path.expanduser(root)
            if not os.path.exists(self.root):
                raise FileNotFoundError
            # phase stands for the working mode,
            # 'train' for training and 'eval' for validating or testing.
            if phase not in ('train', 'eval'):
                raise ValueError("Invalid phase")
            # subset is the sub-dataset to use.
            # For some datasets there are three subsets,
            # while for others there are only train and test(val).
            if subset not in ('train', 'val', 'test'):
                raise ValueError("Invalid subset")
            self.phase = phase
            self.transforms = transforms
            self.repeats = int(repeats)
            # Use 'train' subset during training.
            self.subset = 'train' if self.phase == 'train' else subset
    
        def __len__(self):
            return self.len * self.repeats
    
        def __getitem__(self, index):
            if index >= len(self):
                raise IndexError
            index = index % self.len
    
            item = self.fetch_and_preprocess(index)
    
            return item
    
        @abstractmethod
        def fetch_and_preprocess(self, index):
            return None