Skip to content
Snippets Groups Projects
Commit eb48f74c authored by Bobholamovic's avatar Bobholamovic
Browse files

Update framework

parent d6d3faca
No related branches found
No related tags found
1 merge request!2Update outdated code
Showing
with 790 additions and 951 deletions
...@@ -8,17 +8,18 @@ This is an unofficial implementation of the paper ...@@ -8,17 +8,18 @@ This is an unofficial implementation of the paper
[paper link](https://ieeexplore.ieee.org/abstract/document/8451652) [paper link](https://ieeexplore.ieee.org/abstract/document/8451652)
# Prerequisites # Dependencies
> opencv-python==4.1.1 > opencv-python==4.1.1
pytorch==1.2.0 pytorch==1.3.1
torchvision==0.4.2
pyyaml==5.1.2 pyyaml==5.1.2
scikit-image==0.15.0 scikit-image==0.15.0
scikit-learn==0.21.3 scikit-learn==0.21.3
scipy==1.3.1 scipy==1.3.1
tqdm==4.35.0 tqdm==4.35.0
Tested on Python 3.7.4, Ubuntu 16.04 and Python 3.6.8, Windows 10. Tested using Python 3.7.4 on Ubuntu 16.04 and Python 3.6.8 on Windows 10.
# Basic usage # Basic usage
...@@ -30,84 +31,25 @@ mkdir exp ...@@ -30,84 +31,25 @@ mkdir exp
cd src cd src
``` ```
In `src/constants.py`, change the dataset directories to your own. In `config_base.yaml`, feel free to change some configurations. In `src/constants.py`, change the dataset locations to your own. In `config_base.yaml`, set specific configurations.
For training, try For training, try
```bash ```bash
python train.py train --exp-config ../configs/config_base.yaml python train.py train --exp_config ../configs/config_base.yaml
``` ```
For evaluation, try For evaluation, try
```bash ```bash
python train.py val --exp-config ../configs/config_base.yaml --resume path_to_checkpoint --save-on python train.py eval --exp_config ../configs/config_base.yaml --resume path_to_checkpoint --save-on
``` ```
You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/outs`. You can check the model weight files in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/out`.
# Train on Air Change dataset and OSCD dataset
To carry out a full training on these two datasets and with all three architectures, run the `train9.sh` script under the root folder of this repo.
```bash
. ./train9.sh
```
And check the results in different subdirectories of `./exp/`.
# Create your own configuration file
During scientific research, it is common case that we have to do a lot of experiments with different settings, and that's why we need the configuration files to better manage those settings. In this repo, you can create a `yaml` file under the naming convention below:
`config_TAG{_SUFFIX}.yaml`
Those in the curly braces can be omitted. `TAG` usually stands for an experiment group. For example, a set of experiments for an architecture, a dataset, etc. It will be the name of the subdirectory that holds all the checkpoints, log files, and output images. `SUFFIX` can be used to distinguish different experiments in an experiment group. If it is specified, the generated files of this experiment will be tagged with `SUFFIX` in their file names. In plain English, `TAG1` and `TAG2` have major differences, while `SUFFIX1` and `SUFFIX2` of the same `TAG` share most of the configurations. By combining `TAG` and `SUFFIX`, it is convenient for both coarse-grained and find-grained control of experimental configurations.
Here is an example to help you understand. Suppose I'm going to finish my experiments on two datasets, OSCD and Lebedev, and I'm not sure which batch size achieves best performance. So I create these 5 config files.
```
config_OSCD_bs4.yaml
config_OSCD_bs8.yaml
config_OSCD_bs16.yaml
config_Lebedev_bs16.yaml
config_Lebedev_bs32.yaml
```
After training, I get my `exp/` folder like this:
```
-exp/
--OSCD/
---weights/
----model_best_bs4.pth
----model_best_bs8.pth
----model_best_bs16.pth
---outs/
---logs/
---config_OSCD_bs4.yaml
---config_OSCD_bs8.yaml
---config_OSCD_bs16.yaml
--Lebedev/
---weights/
----model_best_bs16.pth
----model_best_bs32.pth
---outs/
---logs/
---config_Lebedev_bs16.yaml
---config_Lebedev_bs32.yaml
```
Now the experiment results are organized in a more structured way, and I think it would be a little bit easier to collect the statistics. Also, since the historical experiments are arranged in neat order, you will soon remember what you'd done when you come back to these results, even after a long time.
Alternatively, you can configure from the command line. This can be useful when there is only minor change between two single runs, because the configuration items from the command line is set to overwrite those from the `yaml` file. That is, the final value of each configuration item is evaluated and applied in the following order:
```
default_value -> value_from_config_file -> value_from_command_line
```
At least one of the above three values should be given. In this way, you don't have to include all of the config items in the `yaml` file or in the command-line input. You can use either of them, or combine them. Make your choice according to preference and circumstances.
--- ---
# Changed # Changed
- 2020.3.14 Add the configuration files of my experiments. - 2020.3.14 Add configuration files.
- 2020.4.14 Detail README.md. - 2020.4.14 Detail README.md.
- 2020.12.8 Update framework.
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 26
\ No newline at end of file
...@@ -2,51 +2,54 @@ ...@@ -2,51 +2,54 @@
# Data # Data
# Common dataset: AC_Szada
dataset: Lebedev num_workers: 0
crop_size: 224 repeats: 3200
num_workers: 1 subset: val
repeats: 1 crop_size: 112
# Optimizer # Optimizer
optimizer: Adam optimizer: SGD
lr: 1e-4 lr: 0.001
lr_mode: step weight_decay: 0.0005
weight_decay: 0.0 load_optim: False
step: 5 save_optim: False
lr_mode: const
step: 2
# Training related # Training related
batch_size: 8 batch_size: 32
num_epochs: 15 num_epochs: 10
resume: '' resume: ''
load_optim: True
save_optim: True
anew: False anew: False
track_intvl: 1
device: cuda device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment # Experiment
exp_dir: ../exp/ exp_dir: ../exp/
out_dir: ''
# tag: '' # tag: ''
# suffix: '' # suffix: ''
# DO NOT specify exp-config term # DO NOT specify exp_config
save_on: False debug_on: False
inherit_off: True
log_off: False log_off: False
track_intvl: 1
tb_on: False
tb_intvl: 100
suffix_off: False suffix_off: False
save_on: False
out_dir: ''
val_iters: 16
# Criterion # Criterion
criterion: NLL criterion: NLL
weights: weights:
- 0.117 # Weight of no-change class - 1.0 # Weight of no-change class
- 0.883 # Weight of change class - 10.0 # Weight of change class
# Model # Model
model: EF model: Unet
num_feats_in: 6 \ No newline at end of file
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 13
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 13
\ No newline at end of file
# Global constants # Global constants
# Dataset directories # Dataset locations
IMDB_OSCD = '~/Datasets/OSCDDataset/' IMDB_OSCD = "~/Datasets/OSCDDataset/"
IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/' IMDB_AIRCHANGE = "~/Datasets/SZTAKI_AirChange_Benchmark/"
IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/'
# Checkpoint templates # Template strings
CKP_LATEST = 'checkpoint_latest.pth' CKP_LATEST = "checkpoint_latest.pth"
CKP_BEST = 'model_best.pth' CKP_BEST = "model_best.pth"
CKP_COUNTED = 'checkpoint_{e:03d}.pth' CKP_COUNTED = "checkpoint_{e:03d}.pth"
# Initialize all
import core.misc
import core.data
import core.config
import core.builders
import core.factories
import core.trainer
import impl.builders
import impl.trainers
\ No newline at end of file
# Built-in builders
import torch
import torch.nn as nn
import torch.nn.functional as F
from .misc import (MODELS, OPTIMS, CRITNS, DATA)
# Optimizer builders
@OPTIMS.register_func('Adam_optim')
def build_Adam_optim(params, C):
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C['lr'],
weight_decay=C['weight_decay']
)
@OPTIMS.register_func('SGD_optim')
def build_SGD_optim(params, C):
return torch.optim.SGD(
params,
lr=C['lr'],
momentum=0.9,
weight_decay=C['weight_decay']
)
# Criterion builders
@CRITNS.register_func('L1_critn')
def build_L1_critn(C):
return nn.L1Loss()
@CRITNS.register_func('MSE_critn')
def build_MSE_critn(C):
return nn.MSELoss()
@CRITNS.register_func('CE_critn')
def build_CE_critn(C):
return nn.CrossEntropyLoss(torch.Tensor(C['weights']))
@CRITNS.register_func('NLL_critn')
def build_NLL_critn(C):
return nn.NLLLoss(torch.Tensor(C['weights']))
import argparse
import os.path as osp
from collections import ChainMap
import yaml
def read_config(config_path):
with open(config_path, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
return cfg or {}
def parse_configs(cfg_path, inherit=True):
# Read and parse config files
cfg_dir = osp.dirname(cfg_path)
cfg_name = osp.basename(cfg_path)
cfg_name, ext = osp.splitext(cfg_name)
parts = cfg_name.split('_')
cfg_path = osp.join(cfg_dir, parts[0])
cfgs = []
for part in parts[1:]:
cfg_path = '_'.join([cfg_path, part])
if osp.exists(cfg_path+ext):
cfgs.append(read_config(cfg_path+ext))
cfgs.reverse()
if len(parts)>=2:
return ChainMap(*cfgs, dict(tag=parts[1], suffix='_'.join(parts[2:])))
else:
return ChainMap(*cfgs)
def parse_args(parser_configurator=None):
# Parse necessary arguments
# Global settings
parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add_argument('cmd', choices=['train', 'eval'])
# Data
group_data = parser.add_argument_group('data')
group_data.add_argument('--dataset', type=str)
group_data.add_argument('--num_workers', type=int, default=4)
group_data.add_argument('--repeats', type=int, default=1)
group_data.add_argument('--subset', type=str, default='val')
# Optimizer
group_optim = parser.add_argument_group('optimizer')
group_optim.add_argument('--optimizer', type=str, default='Adam')
group_optim.add_argument('--lr', type=float, default=1e-4)
group_optim.add_argument('--weight_decay', type=float, default=1e-4)
group_optim.add_argument('--load_optim', action='store_true')
group_optim.add_argument('--save_optim', action='store_true')
# Training related
group_train = parser.add_argument_group('training related')
group_train.add_argument('--batch_size', type=int, default=8)
group_train.add_argument('--num_epochs', type=int)
group_train.add_argument('--resume', type=str, default='')
group_train.add_argument('--anew', action='store_true',
help="clear history and start from epoch 0 with weights updated")
group_train.add_argument('--device', type=str, default='cpu')
# Experiment
group_exp = parser.add_argument_group('experiment related')
group_exp.add_argument('--exp_dir', default='../exp/')
group_exp.add_argument('--tag', type=str, default='')
group_exp.add_argument('--suffix', type=str, default='')
group_exp.add_argument('--exp_config', type=str, default='')
group_exp.add_argument('--debug_on', action='store_true')
group_exp.add_argument('--inherit_off', action='store_true')
group_exp.add_argument('--log_off', action='store_true')
group_exp.add_argument('--track_intvl', type=int, default=1)
# Criterion
group_critn = parser.add_argument_group('criterion related')
group_critn.add_argument('--criterion', type=str, default='NLL')
group_critn.add_argument('--weights', type=float, nargs='+', default=None)
# Model
group_model = parser.add_argument_group('model')
group_model.add_argument('--model', type=str)
if parser_configurator is not None:
parser = parser_configurator(parser)
args, unparsed = parser.parse_known_args()
if osp.exists(args.exp_config):
cfg = parse_configs(args.exp_config, not args.inherit_off)
group_config = parser.add_argument_group('from_file')
def _cfg2arg(cfg, parser, prefix=''):
for k, v in cfg.items():
if isinstance(v, (list, tuple)):
# Only apply to homogeneous lists and tuples
parser.add_argument('--'+prefix+k, type=type(v[0]), nargs='*', default=v)
elif isinstance(v, dict):
# Recursively parse a dict
_cfg2arg(v, parser, prefix+k+'.')
elif isinstance(v, bool):
parser.add_argument('--'+prefix+k, action='store_true', default=v)
else:
parser.add_argument('--'+prefix+k, type=type(v), default=v)
_cfg2arg(cfg, group_config, '')
args = parser.parse_args()
elif len(unparsed)!=0:
raise RuntimeError("Unrecognized arguments")
def _arg2cfg(cfg, args):
args = vars(args)
for k, v in args.items():
pos = k.find('.')
if pos != -1:
# Iteratively parse a dict
dict_ = cfg
while pos != -1:
dict_.setdefault(k[:pos], {})
dict_ = dict_[k[:pos]]
k = k[pos+1:]
pos = k.find('.')
dict_[k] = v
else:
cfg[k] = v
return cfg
return _arg2cfg(dict(), args)
\ No newline at end of file
import os.path
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
# Data builder utilities
def build_train_dataloader(cls, configs, C):
return data.DataLoader(
cls(**configs),
batch_size=C['batch_size'],
shuffle=True,
num_workers=C['num_workers'],
pin_memory=C['device']!='cpu',
drop_last=True
)
def build_eval_dataloader(cls, configs):
return data.DataLoader(
cls(**configs),
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False,
drop_last=False
)
def get_common_train_configs(C):
return dict(phase='train', repeats=C['repeats'])
def get_common_eval_configs(C):
return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
# Dataset prototype
class DatasetBase(data.Dataset, metaclass=ABCMeta):
def __init__(
self,
root, phase,
transforms,
repeats,
subset
):
super().__init__()
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
raise FileNotFoundError
# phase stands for the working mode,
# 'train' for training and 'eval' for validating or testing.
assert phase in ('train', 'eval')
# subset is the sub-dataset to use.
# For some datasets there are three subsets,
# while for others there are only train and test(val).
assert subset in ('train', 'val', 'test')
self.phase = phase
self.transforms = transforms
self.repeats = int(repeats)
# Use 'train' subset during training.
self.subset = 'train' if self.phase == 'train' else subset
def __len__(self):
return self.len * self.repeats
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
item = self.fetch_and_preprocess(index)
return item
@abstractmethod
def fetch_and_preprocess(self, index):
return None
from functools import wraps # from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable from collections.abc import Sequence
from abc import ABC, ABCMeta
from itertools import chain from itertools import chain
from importlib import import_module from importlib import import_module
...@@ -8,73 +9,76 @@ import torch ...@@ -8,73 +9,76 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data as data import torch.utils.data as data
import constants from .misc import (R, MODELS, OPTIMS, CRITNS, DATA)
import utils.metrics as metrics
from utils.misc import R
from data.augmentation import *
class _Desc: class _AttrDesc:
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
def __get__(self, instance, owner): def __get__(self, instance, owner):
return tuple(getattr(instance[_],self.key) for _ in range(len(instance))) return tuple(getattr(ele, self.key) for ele in instance)
def __set__(self, instance, values): def __set__(self, instance, value):
if not (isinstance(values, Iterable) and len(values)==len(instance)): for ele in instance:
raise TypeError("incorrect type or number of values") setattr(ele, self.key, value)
for i, v in zip(range(len(instance)), values):
setattr(instance[i], self.key, v)
def _func_deco(func_name): def _func_deco(func_name):
def _wrapper(self, *args): # FIXME: The signature of the wrapped function will be lost.
return tuple(getattr(ins, func_name)(*args) for ins in self) def _wrapper(self, *args, **kwargs):
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
return _wrapper return _wrapper
def _generator_deco(func_name): def _generator_deco(func_name):
# FIXME: The signature of the wrapped function will be lost.
def _wrapper(self, *args, **kwargs): def _wrapper(self, *args, **kwargs):
for ins in self: for ele in self:
yield from getattr(ins, func_name)(*args, **kwargs) yield from getattr(ele, func_name)(*args, **kwargs)
return _wrapper return _wrapper
# Duck typing # Duck typing
class Duck(tuple): class Duck(Sequence, ABC):
__ducktype__ = object __ducktype__ = object
def __new__(cls, *args): def __init__(self, *args):
if any(not isinstance(a, cls.__ducktype__) for a in args): if any(not isinstance(arg, self.__ducktype__) for arg in args):
raise TypeError("please check the input type") raise TypeError("Please check the input type.")
return tuple.__new__(cls, args) self._seq = tuple(args)
def __getitem__(self, key):
return self._seq[key]
def __add__(self, tup): def __len__(self):
raise NotImplementedError return len(self._seq)
def __mul__(self, tup): def __repr__(self):
raise NotImplementedError return repr(self._seq)
class DuckMeta(type): class DuckMeta(ABCMeta):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
assert len(bases) == 1 assert len(bases) == 1 # Multiple inheritance is not yet supported.
for k, v in getmembers(bases[0]): members = dict(getmembers(bases[0])) # Trade space for time
if k.startswith('__'):
continue for k in attrs['__ava__']:
if k in members:
v = members[k]
if isgeneratorfunction(v): if isgeneratorfunction(v):
attrs.setdefault(k, _generator_deco(k)) attrs.setdefault(k, _generator_deco(k))
elif isfunction(v): elif isfunction(v):
attrs.setdefault(k, _func_deco(k)) attrs.setdefault(k, _func_deco(k))
else: else:
attrs.setdefault(k, _Desc(k)) attrs.setdefault(k, _AttrDesc(k))
attrs['__ducktype__'] = bases[0] attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs) return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module): class DuckModel(nn.Module):
__ava__ = ('state_dict', 'load_state_dict', 'forward', '__call__', 'train', 'eval', 'to', 'training')
def __init__(self, *models): def __init__(self, *models):
super().__init__() super().__init__()
## XXX: The state_dict will be a little larger in size # XXX: The state_dict will be a little larger in size,
# Since some extra bytes are stored in every key # since some extra bytes are stored in every key.
self._m = nn.ModuleList(models) self._m = nn.ModuleList(models)
def __len__(self): def __len__(self):
...@@ -83,27 +87,39 @@ class DuckModel(nn.Module): ...@@ -83,27 +87,39 @@ class DuckModel(nn.Module):
def __getitem__(self, idx): def __getitem__(self, idx):
return self._m[idx] return self._m[idx]
def __contains__(self, m):
return m in self._m
def __repr__(self): def __repr__(self):
return repr(self._m) return repr(self._m)
def forward(self, *args, **kwargs):
return tuple(m(*args, **kwargs) for m in self._m)
Duck.register(DuckModel)
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta): class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
# Cuz this is an instance method __ava__ = ('param_groups', 'state_dict', 'load_state_dict', 'zero_grad', 'step')
# An instance attribute can not be automatically handled by metaclass
@property @property
def param_groups(self): def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self)) return list(chain.from_iterable(ele.param_groups for ele in self))
# This is special in dispatching # Sepcial dispatching rule
def load_state_dict(self, state_dicts): def load_state_dict(self, state_dicts):
for optim, state_dict in zip(self, state_dicts): for optim, state_dict in zip(self, state_dicts):
optim.load_state_dict(state_dict) optim.load_state_dict(state_dict)
class DuckCriterion(nn.Module, metaclass=DuckMeta): class DuckCriterion(nn.Module, metaclass=DuckMeta):
__ava__ = ('forward', '__call__', 'train', 'eval', 'to')
pass pass
class DuckDataset(data.Dataset, metaclass=DuckMeta): class DuckDataLoader(data.DataLoader, metaclass=DuckMeta):
__ava__ = ()
pass pass
...@@ -116,140 +132,45 @@ def _import_module(pkg: str, mod: str, rel=False): ...@@ -116,140 +132,45 @@ def _import_module(pkg: str, mod: str, rel=False):
def single_model_factory(model_name, C): def single_model_factory(model_name, C):
name = model_name.strip().upper() builder_name = '_'.join([model_name, C['model'], C['dataset'], 'model'])
if name == 'SIAMUNET_CONC': if builder_name in MODELS:
from models.siamunet_conc import SiamUnet_conc return MODELS[builder_name](C)
return SiamUnet_conc(C.num_feats_in, 2) builder_name = '_'.join([model_name, C['dataset'], 'model'])
elif name == 'SIAMUNET_DIFF': if builder_name in MODELS:
from models.siamunet_diff import SiamUnet_diff return MODELS[builder_name](C)
return SiamUnet_diff(C.num_feats_in, 2) builder_name = '_'.join([model_name, 'model'])
elif name == 'EF': if builder_name in MODELS:
from models.unet import Unet return MODELS[builder_name](C)
return Unet(C.num_feats_in, 2)
else: else:
raise NotImplementedError("{} is not a supported architecture".format(model_name)) raise NotImplementedError("{} is not a supported architecture.".format(model_name))
def single_optim_factory(optim_name, params, C): def single_optim_factory(optim_name, params, C):
optim_name = optim_name.strip() builder_name = '_'.join([optim_name, 'optim'])
name = optim_name.upper() if builder_name not in OPTIMS:
if name == 'ADAM': raise NotImplementedError("{} is not a supported optimizer type.".format(optim_name))
return torch.optim.Adam( return OPTIMS[builder_name](params, C)
params,
betas=(0.9, 0.999),
lr=C.lr,
weight_decay=C.weight_decay
)
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
def single_critn_factory(critn_name, C): def single_critn_factory(critn_name, C):
import losses builder_name = '_'.join([critn_name, 'critn'])
critn_name = critn_name.strip() if builder_name not in CRITNS:
try: raise NotImplementedError("{} is not a supported criterion type.".format(critn_name))
criterion, params = { return CRITNS[builder_name](C)
'L1': (nn.L1Loss, ()),
'MSE': (nn.MSELoss, ()),
'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)), def single_data_factory(dataset_name, phase, C):
'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),)) builder_name = '_'.join([dataset_name, C['dataset'], C['model'], phase, 'dataset'])
}[critn_name.upper()] if builder_name in DATA:
return criterion(*params) return DATA[builder_name](C)
except KeyError: builder_name = '_'.join([dataset_name, C['model'], phase, 'dataset'])
raise NotImplementedError("{} is not a supported criterion type".format(critn_name)) if builder_name in DATA:
return DATA[builder_name](C)
builder_name = '_'.join([dataset_name, phase, 'dataset'])
def _get_basic_configs(ds_name, C): if builder_name in DATA:
if ds_name == 'OSCD': return DATA[builder_name](C)
return dict(
root = constants.IMDB_OSCD
)
elif ds_name.startswith('AC'):
return dict(
root = constants.IMDB_AIRCHANGE
)
elif ds_name == 'Lebedev':
return dict(
root = constants.IMDB_LEBEDEV
)
else: else:
return dict() raise NotImplementedError("{} is not a supported dataset.".format(dataset_name))
def single_train_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
if ds_name == 'Lebedev':
configs.update(
dict(
subsets = ('real',)
)
)
else:
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
shuffle=True,
num_workers=C.num_workers,
pin_memory=not (C.device == 'cpu'), drop_last=True
)
def single_val_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
transforms=(None, None, None),
repeats=1
)
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
if ds_name == 'Lebedev':
configs.update(
dict(
subsets = ('real',)
)
)
else:
pass
dataset_obj = dataset(**configs)
# Create eval set
return data.DataLoader(
dataset_obj,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False, drop_last=False
)
def _parse_input_names(name_str): def _parse_input_names(name_str):
...@@ -268,7 +189,7 @@ def optim_factory(optim_names, models, C): ...@@ -268,7 +189,7 @@ def optim_factory(optim_names, models, C):
name_list = _parse_input_names(optim_names) name_list = _parse_input_names(optim_names)
num_models = len(models) if isinstance(models, DuckModel) else 1 num_models = len(models) if isinstance(models, DuckModel) else 1
if len(name_list) != num_models: if len(name_list) != num_models:
raise ValueError("the number of optimizers does not match the number of models") raise ValueError("The number of optimizers does not match the number of models.")
if num_models > 1: if num_models > 1:
optims = [] optims = []
...@@ -298,16 +219,7 @@ def critn_factory(critn_names, C): ...@@ -298,16 +219,7 @@ def critn_factory(critn_names, C):
def data_factory(dataset_names, phase, C): def data_factory(dataset_names, phase, C):
name_list = _parse_input_names(dataset_names) name_list = _parse_input_names(dataset_names)
if phase not in ('train', 'val'):
raise ValueError("phase should be either 'train' or 'val'")
fact = globals()['single_'+phase+'_ds_factory']
if len(name_list) > 1: if len(name_list) > 1:
return DuckDataset(*(fact(name, C) for name in name_list)) return DuckDataLoader(*(single_data_factory(name, phase, C) for name in name_list))
else: else:
return fact(dataset_names, C) return single_data_factory(dataset_names, phase, C)
\ No newline at end of file
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name.strip())() for name in name_list]
import logging import logging
import os import os
import os.path as osp
import sys import sys
from time import localtime from time import localtime
from collections import OrderedDict from collections import OrderedDict, deque
from weakref import proxy from weakref import proxy
FORMAT_LONG = "[%(asctime)-15s %(funcName)s] %(message)s" FORMAT_LONG = "[%(asctime)-15s %(funcName)s] %(message)s"
FORMAT_SHORT = "%(message)s" FORMAT_SHORT = "%(message)s"
...@@ -16,6 +18,7 @@ class _LessThanFilter(logging.Filter): ...@@ -16,6 +18,7 @@ class _LessThanFilter(logging.Filter):
def filter(self, record): def filter(self, record):
return record.levelno < self.max_level return record.levelno < self.max_level
class Logger: class Logger:
_count = 0 _count = 0
...@@ -38,11 +41,11 @@ class Logger: ...@@ -38,11 +41,11 @@ class Logger:
self._logger.addHandler(self._scrn_handler) self._logger.addHandler(self._scrn_handler)
if log_dir and phase: if log_dir and phase:
self.log_path = os.path.join(log_dir, self.log_path = osp.join(log_dir,
'{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format( "{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log".format(
phase, *localtime()[:6] phase, *localtime()[:6]
)) ))
self.show_nl("log into {}\n\n".format(self.log_path)) self.show_nl("Log into {}\n\n".format(self.log_path))
self._file_handler = logging.FileHandler(filename=self.log_path) self._file_handler = logging.FileHandler(filename=self.log_path)
self._file_handler.setLevel(logging.DEBUG) self._file_handler.setLevel(logging.DEBUG)
self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG)) self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
...@@ -58,7 +61,7 @@ class Logger: ...@@ -58,7 +61,7 @@ class Logger:
def dump(self, *args, **kwargs): def dump(self, *args, **kwargs):
return self._logger.debug(*args, **kwargs) return self._logger.debug(*args, **kwargs)
def warning(self, *args, **kwargs): def warn(self, *args, **kwargs):
return self._logger.warning(*args, **kwargs) return self._logger.warning(*args, **kwargs)
def error(self, *args, **kwargs): def error(self, *args, **kwargs):
...@@ -67,16 +70,7 @@ class Logger: ...@@ -67,16 +70,7 @@ class Logger:
def fatal(self, *args, **kwargs): def fatal(self, *args, **kwargs):
return self._logger.critical(*args, **kwargs) return self._logger.critical(*args, **kwargs)
@staticmethod _logger = Logger()
def make_desc(counter, total, *triples, opt_str=''):
desc = "[{}/{}] {}".format(counter, total, opt_str)
# The three elements of each triple are
# (name to display, AverageMeter object, formatting string)
for name, obj, fmt in triples:
desc += (" {} {obj.val:"+fmt+"} ({obj.avg:"+fmt+"})").format(name, obj=obj)
return desc
_default_logger = Logger()
class _WeakAttribute: class _WeakAttribute:
...@@ -91,12 +85,12 @@ class _WeakAttribute: ...@@ -91,12 +85,12 @@ class _WeakAttribute:
class _TreeNode: class _TreeNode:
_sep = '/'
_none = None
parent = _WeakAttribute() # To avoid circular reference parent = _WeakAttribute() # To avoid circular reference
def __init__(self, name, value=None, parent=None, children=None): def __init__(
self, name, value=None, parent=None, children=None,
sep='/', none_val=None
):
super().__init__() super().__init__()
self.name = name self.name = name
self.val = value self.val = value
...@@ -106,46 +100,42 @@ class _TreeNode: ...@@ -106,46 +100,42 @@ class _TreeNode:
for child in children: for child in children:
self._add_child(child) self._add_child(child)
self.path = name self.path = name
self._sep = sep
self._none = none_val
def get_child(self, name, def_val=None): def get_child(self, name):
return self.children.get(name, def_val) return self.children.get(name, None)
def set_child(self, name, val=None):
r"""
Set the value of an existing node.
If the node does not exist, return nothing
"""
child = self.get_child(name)
if child is not None:
child.val = val
return child
def add_place_holder(self, name): def add_placeholder(self, name):
return self.add_child(name, val=self._none) return self.add_child(name, value=self._none)
def add_child(self, name, val): def add_child(self, name, value, warning=False):
r""" r"""
If not exists or is a placeholder, create it If node does not exist or is a placeholder, create it,
Otherwise skips and returns the existing node otherwise skip and return the existing node.
""" """
child = self.get_child(name, None) child = self.get_child(name)
if child is None: if child is None:
child = _TreeNode(name, val, parent=self) child = _TreeNode(name, value, parent=self, sep=self._sep, none_val=self._none)
self._add_child(child) self._add_child(child)
elif child.val == self._none: elif child.is_placeholder():
# Retain the links of the placeholder # Retain the links of a placeholder,
# i.e. just fill in it # i.e. just fill in it.
child.val = val child.val = value
else:
if warning:
_logger.warn("Node already exists!")
return child return child
def is_leaf(self): def is_leaf(self):
return len(self.children) == 0 return len(self.children) == 0
def is_placeholder(self):
return self.val == self._none
def __repr__(self): def __repr__(self):
try: try:
repr = self.path + ' ' + str(self.val) repr = self.path + " " + str(self.val)
except TypeError: except TypeError:
repr = self.path repr = self.path
return repr return repr
...@@ -157,7 +147,10 @@ class _TreeNode: ...@@ -157,7 +147,10 @@ class _TreeNode:
return self.get_child(key) return self.get_child(key)
def _add_child(self, node): def _add_child(self, node):
r""" Into children dictionary and set path and parent """ r"""
Add a child node into self.children.
If the node already exists, just update its information.
"""
self.children.update({ self.children.update({
node.name: node node.name: node
}) })
...@@ -166,8 +159,8 @@ class _TreeNode: ...@@ -166,8 +159,8 @@ class _TreeNode:
def apply(self, func): def apply(self, func):
r""" r"""
Apply a callback function on ALL descendants Apply a callback function to ALL descendants.
This is useful for the recursive traversal This is useful for recursive traversal.
""" """
ret = [func(self)] ret = [func(self)]
for _, node in self.children.items(): for _, node in self.children.items():
...@@ -175,69 +168,53 @@ class _TreeNode: ...@@ -175,69 +168,53 @@ class _TreeNode:
return ret return ret
def bfs_tracker(self): def bfs_tracker(self):
queue = [] queue = deque()
queue.insert(0, self) queue.append(self)
while(queue): while(queue):
curr = queue.pop() curr = queue.popleft()
yield curr yield curr
if curr.is_leaf(): if curr.is_leaf():
continue continue
for c in curr.children.values(): for c in curr.children.values():
queue.insert(0, c) queue.append(c)
class _Tree: class _Tree:
def __init__( def __init__(
self, name, value=None, strc_ele=None, self, name, value=None, eles=None,
sep=_TreeNode._sep, def_val=_TreeNode._none sep='/', none_val=None
): ):
super().__init__() super().__init__()
self._sep = sep self._sep = sep
self._def_val = def_val self._none = none_val
self.root = _TreeNode(name, value, parent=None, children={}) self.root = _TreeNode(name, value, parent=None, children={}, sep=self._sep, none_val=self._none)
if strc_ele is not None: if eles is not None:
assert isinstance(strc_ele, dict) assert isinstance(eles, dict)
# This is to avoid mutable parameter default self.build_tree(OrderedDict(eles or {}))
self.build_tree(OrderedDict(strc_ele or {}))
def build_tree(self, elements): def build_tree(self, elements):
# The siblings could be out-of-order # The order of the siblings is not retained
for path, ele in elements.items(): for path, ele in elements.items():
self.add_node(path, ele) self.add_node(path, ele)
def get_root(self):
r""" Get separated root node """
return _TreeNode(
self.root.name, self.root.value,
parent=None, children=None
)
def __repr__(self): def __repr__(self):
return self.__dumps__() _str = ""
def __dumps__(self):
r""" Dump to string """
_str = ''
# DFS # DFS
stack = [] stack = []
stack.append((self.root, 0)) stack.append((self.root, 0))
while(stack): while(stack):
root, layer = stack.pop() root, layer = stack.pop()
_str += ' '*layer + '-' + root.__repr__() + '\n' _str += " "*layer + "-" + root.__repr__() + "\n"
if root.is_leaf(): if root.is_leaf():
continue continue
# Note that the order of the siblings is not retained # Note that the siblings are printed in alphabetical order.
for c in reversed(list(root.children.values())): for c in sorted(list(root.children.values()), key=lambda n: n.name, reverse=True):
stack.append((c, layer+1)) stack.append((c, layer+1))
return _str return _str
def vis(self):
r""" Visualize the structure of the tree """
_default_logger.show(self.__dumps__())
def __contains__(self, obj): def __contains__(self, obj):
return any(self.perform(lambda node: obj in node)) return any(self.perform(lambda node: obj in node))
...@@ -246,14 +223,15 @@ class _Tree: ...@@ -246,14 +223,15 @@ class _Tree:
def get_node(self, tar, mode='name'): def get_node(self, tar, mode='name'):
r""" r"""
This is different from the travasal in that This is different from a travasal in that this search allows early stop.
the search allows early stop
""" """
assert mode in ('name', 'path', 'val')
if mode == 'path': if mode == 'path':
nodes = self.parse_path(tar) nodes = self.parse_path(tar)
root = self.root root = self.root
for r in nodes: for r in nodes:
if root is None: if root is None:
break
root = root.get_child(r) root = root.get_child(r)
return root return root
else: else:
...@@ -264,28 +242,20 @@ class _Tree: ...@@ -264,28 +242,20 @@ class _Tree:
for node in bfs_tracker: for node in bfs_tracker:
if getattr(node, mode) == tar: if getattr(node, mode) == tar:
return node return node
return return None
def set_node(self, path, val):
node = self.get_node(path, mode=path)
if node is not None:
node.val = val
return node
def add_node(self, path, val=None): def add_node(self, path, val):
if not path.strip(): if not path.strip():
raise ValueError("the path is null") raise ValueError("The path is null.")
path = path.strip('/') path = path.rstrip(self._sep)
if val is None:
val = self._def_val
names = self.parse_path(path) names = self.parse_path(path)
root = self.root root = self.root
nodes = [root] nodes = [root]
for name in names[:-1]: for name in names[:-1]:
# Add placeholders # Add a placeholder or skip an existing node
root = root.add_child(name, self._def_val) root = root.add_placeholder(name)
nodes.append(root) nodes.append(root)
root = root.add_child(names[-1], val) root = root.add_child(names[-1], val, True)
return root, nodes return root, nodes
def parse_path(self, path): def parse_path(self, path):
...@@ -296,22 +266,29 @@ class _Tree: ...@@ -296,22 +266,29 @@ class _Tree:
class OutPathGetter: class OutPathGetter:
def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs): def __init__(self, root='', log='logs', out='out', weight='weights', suffix='', **subs):
super().__init__() super().__init__()
self._root = root.rstrip('/') # Work robustly for multiple ending '/'s self._root = root.rstrip(os.sep) # Work robustly on multiple ending '/'s
if len(self._root) == 0 and len(root) > 0: if len(self._root) == 0 and len(root) > 0:
self._root = '/' # In case of the system root dir self._root = os.sep # In case of the system root dir in linux
self._suffix = suffix self._suffix = suffix
self._keys = dict(log=log, out=out, weight=weight, **subs) self._keys = dict(log=log, out=out, weight=weight, **subs)
for k, v in self._keys.items():
v_ = v.rstrip(os.sep)
if len(v_) == 0 or not self.check_path(v_):
_logger.warn("{} is not a valid path.".format(v))
continue
self._keys[k] = v_
self._dir_tree = _Tree( self._dir_tree = _Tree(
self._root, 'root', self._root, 'root',
strc_ele=dict(zip(self._keys.values(), self._keys.keys())), eles=dict(zip(self._keys.values(), self._keys.keys())),
sep='/', sep=os.sep, none_val=''
def_val=''
) )
self.update_keys(False) self.add_keys(False)
self.update_tree(False) self.update_vfs(False)
self.__counter = 0 self.__counter = 0
...@@ -326,89 +303,109 @@ class OutPathGetter: ...@@ -326,89 +303,109 @@ class OutPathGetter:
def root(self): def root(self):
return self._root return self._root
def _update_key(self, key, val, add=False, prefix=False): def _add_key(self, key, val):
if prefix:
val = os.path.join(self._root, val)
if add:
# Do not edit if exists
self._keys.setdefault(key, val) self._keys.setdefault(key, val)
else:
self._keys.__setitem__(key, val)
def _add_node(self, key, val, prefix=False):
if not prefix and key.startswith(self._root):
key = key[len(self._root)+1:]
return self._dir_tree.add_node(key, val)
def update_keys(self, verbose=False): def add_keys(self, verbose=False):
for k, v in self._keys.items(): for k, v in self._keys.items():
self._update_key(k, v, prefix=True) self._add_key(k, v)
if verbose: if verbose:
_default_logger.show(self._keys) _logger.show(self._keys)
def update_tree(self, verbose=False): def update_vfs(self, verbose=False):
self._dir_tree.perform(lambda x: self.make_dir(x.path)) self._dir_tree.perform(lambda x: self.make_dir(x.path))
if verbose: if verbose:
_default_logger.show("\nFolder structure:") _logger.show("\nFolder structure:")
_default_logger.show(self._dir_tree) _logger.show(self._dir_tree)
@staticmethod
def check_path(path):
# This is to prevent stuff like A/../B or A/./.././C.d
# Note that paths like A.B/.C/D are not supported, either.
return osp.dirname(path).find('.') == -1
@staticmethod @staticmethod
def make_dir(path): def make_dir(path):
if not os.path.exists(path): if not osp.exists(path):
os.mkdir(path) os.mkdir(path)
elif not osp.isdir(path):
raise RuntimeError("Cannot create directory.")
def get_dir(self, key): def get_dir(self, key):
return self._keys.get(key, '') if key != 'root' else self.root return osp.join(self.root, self._keys[key])
def get_path( def get_path(
self, key, file, self, key, file,
name='', auto_make=False, name='', auto_make=False,
suffix=True, underline=False suffix=False, underline=True
): ):
folder = self.get_dir(key) if len(file) == 0:
if len(folder) < 1: return self.get_dir(key)
raise KeyError("key not found") if not self.check_path(file):
raise ValueError("{} is not a valid path.".format(file))
folder = self._keys[key]
if suffix: if suffix:
path = os.path.join(folder, self.add_suffix(file, underline=underline)) path = osp.join(folder, self._add_suffix(file, underline=underline))
else: else:
path = os.path.join(folder, file) path = osp.join(folder, file)
if auto_make: if auto_make:
base_dir = os.path.dirname(path) base_dir = osp.dirname(path)
# O(n) search for base_dir
# Never update an existing key!
if base_dir in self: if base_dir in self:
return path _logger.warn("Cannot assign a new key to an existing path!")
if name: return osp.join(self.root, path)
self._update_key(name, base_dir, add=True) node = self._dir_tree.get_node(base_dir, mode='path')
'''
else: # Note that if name is an empty string,
name = 'new_{:03d}'.format(self.__counter) # the directory tree will be updated, but the name will not be added into self._keys.
self._update_key(name, base_dir, add=True) if node is None or node.is_placeholder():
self.__counter += 1 # Update directory tree
''' des, visit = self._dir_tree.add_node(base_dir, name)
des, visit = self._add_node(base_dir, name)
# Create directories along the visiting path # Create directories along the visiting path
for d in visit: self.make_dir(d.path) for d in visit: self.make_dir(d.path)
self.make_dir(des.path) self.make_dir(des.path)
return path else:
node.val = name
if len(name) > 0:
# Add new key
self._add_key(name, base_dir)
return osp.join(self.root, path)
def add_suffix(self, path, suffix='', underline=False): def _add_suffix(self, path, suffix='', underline=False):
pos = path.rfind('.') pos = path.rfind('.')
if pos == -1: if pos == -1:
pos = len(path) pos = len(path)
_suffix = self._suffix if len(suffix) < 1 else suffix _suffix = self._suffix if len(suffix) == 0 else suffix
return path[:pos] + ('_' if underline and _suffix else '') + _suffix + path[pos:] return path[:pos] + ('_' if underline and _suffix else '') + _suffix + path[pos:]
def __contains__(self, value): def __contains__(self, value):
return value in self._keys.values() return value in self._keys.values() or value == self._root
def contains_key(self, key):
return key in self._keys
class Registry(dict): class Registry(dict):
def register(self, key, val): def register(self, key, val):
if key in self: _default_logger.warning("key {} already registered".format(key)) if key in self: _logger.warn("Key {} has already been registered!".format(key))
self[key] = val self[key] = val
def register_func(self, key):
def _wrapper(func):
self.register(key, func)
return func
return _wrapper
# Registry for global objects
R = Registry() R = Registry()
R.register('DEFAULT_LOGGER', _default_logger) R.register('Logger', _logger)
register = R.register register = R.register
# Registries for builders
MODELS = Registry()
OPTIMS = Registry()
CRITNS = Registry()
DATA = Registry()
\ No newline at end of file
...@@ -2,75 +2,81 @@ import shutil ...@@ -2,75 +2,81 @@ import shutil
import os import os
from types import MappingProxyType from types import MappingProxyType
from copy import deepcopy from copy import deepcopy
from abc import ABCMeta, abstractmethod
import torch import torch
from skimage import io
from tqdm import tqdm
import constants import constants
from data.common import to_array from .misc import Logger, OutPathGetter, R
from utils.misc import R from .factories import (model_factory, optim_factory, critn_factory, data_factory)
from utils.metrics import AverageMeter
from utils.utils import mod_crop
from .factories import (model_factory, optim_factory, critn_factory, data_factory, metric_factory)
class Trainer: class Trainer(metaclass=ABCMeta):
def __init__(self, model, dataset, criterion, optimizer, settings): def __init__(self, model, dataset, criterion, optimizer, settings):
super().__init__() super().__init__()
# Make a copy of settings in case of unexpected changes
context = deepcopy(settings) context = deepcopy(settings)
self.ctx = MappingProxyType(vars(context)) # self.ctx is a proxy so that context will be read-only outside __init__
self.mode = ('train', 'val').index(context.cmd) self.ctx = MappingProxyType(context)
self.mode = ('train', 'eval').index(context['cmd'])
self.logger = R['LOGGER'] self.debug = context['debug_on']
self.gpc = R['GPC'] # Global Path Controller self.log = not context['log_off']
self.batch_size = context['batch_size']
self.checkpoint = context['resume']
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context['num_epochs']
self.lr = float(context['lr'])
self.track_intvl = int(context['track_intvl'])
self.device = torch.device(context['device'])
self.gpc = OutPathGetter(
root=os.path.join(context['exp_dir'], context['tag']),
suffix=context['suffix']
) # Global Path Controller
self.logger = Logger(
scrn=True,
log_dir=self.gpc.get_dir('log') if self.log else '',
phase=context['cmd']
)
self.path = self.gpc.get_path self.path = self.gpc.get_path
self.batch_size = context.batch_size for k, v in sorted(context.items()):
self.checkpoint = context.resume
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context.num_epochs
self.lr = float(context.lr)
self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir
self.track_intvl = int(context.track_intvl)
self.device = torch.device(context.device)
self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()):
self.logger.show("{}: {}".format(k,v)) self.logger.show("{}: {}".format(k,v))
self.model = model_factory(model, context) self.model = model_factory(model, context)
self.model.to(self.device) self.model.to(self.device)
self.criterion = critn_factory(criterion, context) self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device) self.criterion.to(self.device)
self.metrics = metric_factory(context.metrics, context)
if self.is_training: if self.is_training:
self.train_loader = data_factory(dataset, 'train', context) self.train_loader = data_factory(dataset, 'train', context)
self.val_loader = data_factory(dataset, 'val', context) self.eval_loader = data_factory(dataset, 'eval', context)
self.optimizer = optim_factory(optimizer, self.model, context) self.optimizer = optim_factory(optimizer, self.model, context)
else: else:
self.val_loader = data_factory(dataset, 'val', context) self.eval_loader = data_factory(dataset, 'eval', context)
self.start_epoch = 0 self.start_epoch = 0
self._init_max_acc_and_epoch = (0.0, 0) self._init_acc_epoch = (0.0, -1)
@property @property
def is_training(self): def is_training(self):
return self.mode == 0 return self.mode == 0
@abstractmethod
def train_epoch(self, epoch): def train_epoch(self, epoch):
raise NotImplementedError pass
def validate_epoch(self, epoch=0, store=False): @abstractmethod
raise NotImplementedError def evaluate_epoch(self, epoch):
return 0.0
def _write_prompt(self): def _write_prompt(self):
self.logger.dump(input("\nWrite some notes: ")) self.logger.dump(input("\nWrite some notes: "))
def run(self): def run(self):
if self.is_training: if self.is_training:
if self.log and not self.debug:
self._write_prompt() self._write_prompt()
self.train() self.train()
else: else:
...@@ -80,23 +86,20 @@ class Trainer: ...@@ -80,23 +86,20 @@ class Trainer:
if self.load_checkpoint: if self.load_checkpoint:
self._resume_from_checkpoint() self._resume_from_checkpoint()
max_acc, best_epoch = self._init_max_acc_and_epoch max_acc, best_epoch = self._init_acc_epoch
lr = self.init_learning_rate()
for epoch in range(self.start_epoch, self.num_epochs): for epoch in range(self.start_epoch, self.num_epochs):
lr = self._adjust_learning_rate(epoch)
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
# Train for one epoch # Train for one epoch
self.model.train()
self.train_epoch(epoch) self.train_epoch(epoch)
# Clear the history of metric objects # Evaluate the model
for m in self.metrics: self.logger.show_nl("Evaluate")
m.reset() self.model.eval()
acc = self.evaluate_epoch(epoch=epoch)
# Evaluate the model on validation set
self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save)
is_best = acc > max_acc is_best = acc > max_acc
if is_best: if is_best:
...@@ -105,77 +108,74 @@ class Trainer: ...@@ -105,77 +108,74 @@ class Trainer:
self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format( self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
acc, epoch, max_acc, best_epoch)) acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch # Do not save checkpoints in debugging mode
if not self.debug:
self._save_checkpoint( self._save_checkpoint(
self.model.state_dict(), self.model.state_dict(),
self.optimizer.state_dict() if self.ctx['save_optim'] else {}, self.optimizer.state_dict() if self.ctx['save_optim'] else {},
(max_acc, best_epoch), epoch+1, is_best (max_acc, best_epoch), epoch, is_best
) )
lr = self.adjust_learning_rate(epoch, acc)
def evaluate(self): def evaluate(self):
if self.checkpoint: if self.checkpoint:
if self._resume_from_checkpoint(): if self._resume_from_checkpoint():
self.validate_epoch(self.ckp_epoch, self.save) self.model.eval()
else: self.evaluate_epoch(self.start_epoch)
self.logger.warning("Warning: no checkpoint assigned!")
def _adjust_learning_rate(self, epoch):
if self.ctx['lr_mode'] == 'step':
lr = self.lr * (0.5 ** (epoch // self.ctx['step']))
elif self.ctx['lr_mode'] == 'poly':
lr = self.lr * (1 - epoch / self.num_epochs) ** 1.1
elif self.ctx['lr_mode'] == 'const':
lr = self.lr
else: else:
raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode'])) self.logger.error("No checkpoint assigned!")
def init_learning_rate(self):
return self.lr
for param_group in self.optimizer.param_groups: def adjust_learning_rate(self, epoch, acc):
param_group['lr'] = lr return self.lr
return lr
def _resume_from_checkpoint(self): def _resume_from_checkpoint(self):
## XXX: This could be slow! # XXX: This could be slow!
if not os.path.isfile(self.checkpoint): if not os.path.isfile(self.checkpoint):
self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint)) self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
return False return False
self.logger.show("=> Loading checkpoint '{}'".format( self.logger.show("=> Loading checkpoint '{}'...".format(self.checkpoint))
self.checkpoint))
checkpoint = torch.load(self.checkpoint, map_location=self.device) checkpoint = torch.load(self.checkpoint, map_location=self.device)
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
ckp_dict = checkpoint.get('state_dict', checkpoint) ckp_dict = checkpoint.get('state_dict', checkpoint)
update_dict = {k:v for k,v in ckp_dict.items() update_dict = {
if k in state_dict and state_dict[k].shape == v.shape} k:v for k,v in ckp_dict.items()
if k in state_dict and state_dict[k].shape == v.shape and state_dict[k].dtype == v.dtype
}
num_to_update = len(update_dict) num_to_update = len(update_dict)
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if not self.is_training and (num_to_update < len(state_dict)): if not self.is_training and (num_to_update < len(state_dict)):
self.logger.error("=> Mismatched checkpoint for evaluation") self.logger.error("=> Mismatched checkpoint for evaluation")
return False return False
self.logger.warning("Warning: trying to load an mismatched checkpoint.") self.logger.warn("Trying to load a mismatched checkpoint.")
if num_to_update == 0: if num_to_update == 0:
self.logger.error("=> No parameter is to be loaded.") self.logger.error("=> No parameter is to be loaded.")
return False return False
else: else:
self.logger.warning("=> {} params are to be loaded.".format(num_to_update)) self.logger.warn("=> {} params are to be loaded.".format(num_to_update))
elif (not self.ctx['anew']) or not self.is_training: elif not self.ctx['anew'] or not self.is_training:
self.start_epoch = checkpoint.get('epoch', 0) ckp_epoch = checkpoint.get('epoch', -1)
max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch)) self.start_epoch = ckp_epoch+1
# For backward compatibility self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch))
if isinstance(max_acc_and_epoch, (float, int)):
self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
else:
self._init_max_acc_and_epoch = max_acc_and_epoch
if self.ctx['load_optim'] and self.is_training: if self.ctx['load_optim'] and self.is_training:
# Note that weight decay might be modified here # XXX: Note that weight decay might be modified here.
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.warn("Weight decay might have been modified.")
state_dict.update(update_dict) state_dict.update(update_dict)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format( if self.start_epoch == 0:
self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint))
else:
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format(
self.checkpoint, self.start_epoch-1, *self._init_acc_epoch
)) ))
return True return True
...@@ -187,117 +187,46 @@ class Trainer: ...@@ -187,117 +187,46 @@ class Trainer:
'max_acc': max_acc 'max_acc': max_acc
} }
# Save history # Save history
history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True) # epoch+1 instead of epoch is contained in the checkpoint name so that it will be easy for
# one to recognize "the next start_epoch".
history_path = self.path(
'weight', constants.CKP_COUNTED.format(e=epoch+1),
suffix=True
)
if epoch % self.track_intvl == 0: if epoch % self.track_intvl == 0:
torch.save(state, history_path) torch.save(state, history_path)
# Save latest # Save latest
latest_path = self.path( latest_path = self.path(
'weight', constants.CKP_LATEST, 'weight', constants.CKP_LATEST,
underline=True suffix=True
) )
torch.save(state, latest_path) torch.save(state, latest_path)
if is_best: if is_best:
shutil.copyfile( shutil.copyfile(
latest_path, self.path( latest_path, self.path(
'weight', constants.CKP_BEST, 'weight', constants.CKP_BEST,
underline=True suffix=True
) )
) )
@property
def ckp_epoch(self):
# Get current epoch of the checkpoint
# For dismatched ckp or no ckp, set to 0
return max(self.start_epoch-1, 0)
def save_image(self, file_name, image, epoch):
file_path = os.path.join(
'epoch_{}/'.format(epoch),
self.out_dir,
file_name
)
out_path = self.path(
'out', file_path,
suffix=not self.suffix_off,
auto_make=True,
underline=True
)
return io.imsave(out_path, image)
class CDTrainer(Trainer):
def __init__(self, arch, dataset, optimizer, settings):
super().__init__(arch, dataset, 'NLL', optimizer, settings)
def train_epoch(self, epoch):
losses = AverageMeter()
len_train = len(self.train_loader)
pb = tqdm(self.train_loader)
self.model.train()
for i, (t1, t2, label) in enumerate(pb):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label) class TrainerSwitcher:
r"""A simple utility class to help dispatch actions to different trainers."""
def __init__(self, *pairs):
self._trainer_list = list(pairs)
losses.update(loss.item(), n=self.batch_size) def __call__(self, args, return_obj=True):
for p, t in self._trainer_list:
if p(args):
return t(args) if return_obj else t
return None
# Compute gradients and do SGD step def add_item(self, predicate, trainer):
self.optimizer.zero_grad() # Newly added items have higher priority
loss.backward() self._trainer_list.insert(0, (predicate, trainer))
self.optimizer.step()
desc = self.logger.make_desc( def add_default(self, trainer):
i+1, len_train, self._trainer_list.append((lambda: True, trainer))
('loss', losses, '.4f')
)
pb.set_description(desc)
self.logger.dump(desc)
def validate_epoch(self, epoch=0, store=False):
self.logger.show_nl("Epoch: [{0}]".format(epoch))
losses = AverageMeter()
len_val = len(self.val_loader)
pb = tqdm(self.val_loader)
self.model.eval()
with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb):
if self.is_training and i >= 16:
# Do not validate all images on training phase
pb.close()
self.logger.warning("validation ends early")
break
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
label = to_array(label[0]).astype('uint8')
for m in self.metrics:
m.update(CM, label)
desc = self.logger.make_desc(
i+1, len_val,
('loss', losses, '.4f'),
*(
(m.__name__, m, '.4f')
for m in self.metrics
)
)
pb.set_description(desc)
self.logger.dump(desc)
if store:
self.save_image(name[0], CM*255, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) R.register('Trainer_switcher', TrainerSwitcher())
\ No newline at end of file \ No newline at end of file
from glob import glob
from os.path import join, basename
import numpy as np
from . import CDDataset
from .common import default_loader
class LebedevDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
subsets=('real', 'with_shift', 'without_shift')
):
self.subsets = subsets
super().__init__(root, phase, transforms, repeats)
def _read_file_paths(self):
t1_list, t2_list, label_list = [], [], []
for subset in self.subsets:
# Get subset directory
if subset == 'real':
subset_dir = join(self.root, 'Real', 'subset')
elif subset == 'with_shift':
subset_dir = join(self.root, 'Model', 'with_shift')
elif subset == 'without_shift':
subset_dir = join(self.root, 'Model', 'without_shift')
else:
raise RuntimeError('unrecognized key encountered')
pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
label_list.extend(refs)
t1_list.extend(t1s)
t2_list.extend(t2s)
return t1_list, t2_list, label_list
def fetch_label(self, label_path):
# To {0,1}
return (super().fetch_label(label_path) > 127).astype(np.uint8)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment