Skip to content
Snippets Groups Projects
Commit eb48f74c authored by Bobholamovic's avatar Bobholamovic
Browse files

Update framework

parent d6d3faca
No related branches found
No related tags found
1 merge request!2Update outdated code
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
Showing
with 790 additions and 951 deletions
......@@ -8,17 +8,18 @@ This is an unofficial implementation of the paper
[paper link](https://ieeexplore.ieee.org/abstract/document/8451652)
# Prerequisites
# Dependencies
> opencv-python==4.1.1
pytorch==1.2.0
pytorch==1.3.1
torchvision==0.4.2
pyyaml==5.1.2
scikit-image==0.15.0
scikit-learn==0.21.3
scipy==1.3.1
tqdm==4.35.0
Tested on Python 3.7.4, Ubuntu 16.04 and Python 3.6.8, Windows 10.
Tested using Python 3.7.4 on Ubuntu 16.04 and Python 3.6.8 on Windows 10.
# Basic usage
......@@ -30,84 +31,25 @@ mkdir exp
cd src
```
In `src/constants.py`, change the dataset directories to your own. In `config_base.yaml`, feel free to change some configurations.
In `src/constants.py`, change the dataset locations to your own. In `config_base.yaml`, set specific configurations.
For training, try
```bash
python train.py train --exp-config ../configs/config_base.yaml
python train.py train --exp_config ../configs/config_base.yaml
```
For evaluation, try
```bash
python train.py val --exp-config ../configs/config_base.yaml --resume path_to_checkpoint --save-on
python train.py eval --exp_config ../configs/config_base.yaml --resume path_to_checkpoint --save-on
```
You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/outs`.
# Train on Air Change dataset and OSCD dataset
To carry out a full training on these two datasets and with all three architectures, run the `train9.sh` script under the root folder of this repo.
```bash
. ./train9.sh
```
And check the results in different subdirectories of `./exp/`.
# Create your own configuration file
During scientific research, it is common case that we have to do a lot of experiments with different settings, and that's why we need the configuration files to better manage those settings. In this repo, you can create a `yaml` file under the naming convention below:
`config_TAG{_SUFFIX}.yaml`
Those in the curly braces can be omitted. `TAG` usually stands for an experiment group. For example, a set of experiments for an architecture, a dataset, etc. It will be the name of the subdirectory that holds all the checkpoints, log files, and output images. `SUFFIX` can be used to distinguish different experiments in an experiment group. If it is specified, the generated files of this experiment will be tagged with `SUFFIX` in their file names. In plain English, `TAG1` and `TAG2` have major differences, while `SUFFIX1` and `SUFFIX2` of the same `TAG` share most of the configurations. By combining `TAG` and `SUFFIX`, it is convenient for both coarse-grained and find-grained control of experimental configurations.
Here is an example to help you understand. Suppose I'm going to finish my experiments on two datasets, OSCD and Lebedev, and I'm not sure which batch size achieves best performance. So I create these 5 config files.
```
config_OSCD_bs4.yaml
config_OSCD_bs8.yaml
config_OSCD_bs16.yaml
config_Lebedev_bs16.yaml
config_Lebedev_bs32.yaml
```
After training, I get my `exp/` folder like this:
```
-exp/
--OSCD/
---weights/
----model_best_bs4.pth
----model_best_bs8.pth
----model_best_bs16.pth
---outs/
---logs/
---config_OSCD_bs4.yaml
---config_OSCD_bs8.yaml
---config_OSCD_bs16.yaml
--Lebedev/
---weights/
----model_best_bs16.pth
----model_best_bs32.pth
---outs/
---logs/
---config_Lebedev_bs16.yaml
---config_Lebedev_bs32.yaml
```
Now the experiment results are organized in a more structured way, and I think it would be a little bit easier to collect the statistics. Also, since the historical experiments are arranged in neat order, you will soon remember what you'd done when you come back to these results, even after a long time.
Alternatively, you can configure from the command line. This can be useful when there is only minor change between two single runs, because the configuration items from the command line is set to overwrite those from the `yaml` file. That is, the final value of each configuration item is evaluated and applied in the following order:
```
default_value -> value_from_config_file -> value_from_command_line
```
At least one of the above three values should be given. In this way, you don't have to include all of the config items in the `yaml` file or in the command-line input. You can use either of them, or combine them. Make your choice according to preference and circumstances.
You can check the model weight files in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/out`.
---
# Changed
- 2020.3.14 Add the configuration files of my experiments.
- 2020.3.14 Add configuration files.
- 2020.4.14 Detail README.md.
- 2020.12.8 Update framework.
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 26
\ No newline at end of file
......@@ -2,51 +2,54 @@
# Data
# Common
dataset: Lebedev
crop_size: 224
num_workers: 1
repeats: 1
dataset: AC_Szada
num_workers: 0
repeats: 3200
subset: val
crop_size: 112
# Optimizer
optimizer: Adam
lr: 1e-4
lr_mode: step
weight_decay: 0.0
step: 5
optimizer: SGD
lr: 0.001
weight_decay: 0.0005
load_optim: False
save_optim: False
lr_mode: const
step: 2
# Training related
batch_size: 8
num_epochs: 15
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
# DO NOT specify exp_config
debug_on: False
inherit_off: True
log_off: False
track_intvl: 1
tb_on: False
tb_intvl: 100
suffix_off: False
save_on: False
out_dir: ''
val_iters: 16
# Criterion
criterion: NLL
weights:
- 0.117 # Weight of no-change class
- 0.883 # Weight of change class
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
model: Unet
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 13
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
save_optim: True
anew: False
track_intvl: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 13
\ No newline at end of file
# Global constants
# Dataset directories
IMDB_OSCD = '~/Datasets/OSCDDataset/'
IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/'
IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/'
# Checkpoint templates
CKP_LATEST = 'checkpoint_latest.pth'
CKP_BEST = 'model_best.pth'
CKP_COUNTED = 'checkpoint_{e:03d}.pth'
# Dataset locations
IMDB_OSCD = "~/Datasets/OSCDDataset/"
IMDB_AIRCHANGE = "~/Datasets/SZTAKI_AirChange_Benchmark/"
# Template strings
CKP_LATEST = "checkpoint_latest.pth"
CKP_BEST = "model_best.pth"
CKP_COUNTED = "checkpoint_{e:03d}.pth"
# Initialize all
import core.misc
import core.data
import core.config
import core.builders
import core.factories
import core.trainer
import impl.builders
import impl.trainers
\ No newline at end of file
# Built-in builders
import torch
import torch.nn as nn
import torch.nn.functional as F
from .misc import (MODELS, OPTIMS, CRITNS, DATA)
# Optimizer builders
@OPTIMS.register_func('Adam_optim')
def build_Adam_optim(params, C):
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C['lr'],
weight_decay=C['weight_decay']
)
@OPTIMS.register_func('SGD_optim')
def build_SGD_optim(params, C):
return torch.optim.SGD(
params,
lr=C['lr'],
momentum=0.9,
weight_decay=C['weight_decay']
)
# Criterion builders
@CRITNS.register_func('L1_critn')
def build_L1_critn(C):
return nn.L1Loss()
@CRITNS.register_func('MSE_critn')
def build_MSE_critn(C):
return nn.MSELoss()
@CRITNS.register_func('CE_critn')
def build_CE_critn(C):
return nn.CrossEntropyLoss(torch.Tensor(C['weights']))
@CRITNS.register_func('NLL_critn')
def build_NLL_critn(C):
return nn.NLLLoss(torch.Tensor(C['weights']))
import argparse
import os.path as osp
from collections import ChainMap
import yaml
def read_config(config_path):
with open(config_path, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
return cfg or {}
def parse_configs(cfg_path, inherit=True):
# Read and parse config files
cfg_dir = osp.dirname(cfg_path)
cfg_name = osp.basename(cfg_path)
cfg_name, ext = osp.splitext(cfg_name)
parts = cfg_name.split('_')
cfg_path = osp.join(cfg_dir, parts[0])
cfgs = []
for part in parts[1:]:
cfg_path = '_'.join([cfg_path, part])
if osp.exists(cfg_path+ext):
cfgs.append(read_config(cfg_path+ext))
cfgs.reverse()
if len(parts)>=2:
return ChainMap(*cfgs, dict(tag=parts[1], suffix='_'.join(parts[2:])))
else:
return ChainMap(*cfgs)
def parse_args(parser_configurator=None):
# Parse necessary arguments
# Global settings
parser = argparse.ArgumentParser(conflict_handler='resolve')
parser.add_argument('cmd', choices=['train', 'eval'])
# Data
group_data = parser.add_argument_group('data')
group_data.add_argument('--dataset', type=str)
group_data.add_argument('--num_workers', type=int, default=4)
group_data.add_argument('--repeats', type=int, default=1)
group_data.add_argument('--subset', type=str, default='val')
# Optimizer
group_optim = parser.add_argument_group('optimizer')
group_optim.add_argument('--optimizer', type=str, default='Adam')
group_optim.add_argument('--lr', type=float, default=1e-4)
group_optim.add_argument('--weight_decay', type=float, default=1e-4)
group_optim.add_argument('--load_optim', action='store_true')
group_optim.add_argument('--save_optim', action='store_true')
# Training related
group_train = parser.add_argument_group('training related')
group_train.add_argument('--batch_size', type=int, default=8)
group_train.add_argument('--num_epochs', type=int)
group_train.add_argument('--resume', type=str, default='')
group_train.add_argument('--anew', action='store_true',
help="clear history and start from epoch 0 with weights updated")
group_train.add_argument('--device', type=str, default='cpu')
# Experiment
group_exp = parser.add_argument_group('experiment related')
group_exp.add_argument('--exp_dir', default='../exp/')
group_exp.add_argument('--tag', type=str, default='')
group_exp.add_argument('--suffix', type=str, default='')
group_exp.add_argument('--exp_config', type=str, default='')
group_exp.add_argument('--debug_on', action='store_true')
group_exp.add_argument('--inherit_off', action='store_true')
group_exp.add_argument('--log_off', action='store_true')
group_exp.add_argument('--track_intvl', type=int, default=1)
# Criterion
group_critn = parser.add_argument_group('criterion related')
group_critn.add_argument('--criterion', type=str, default='NLL')
group_critn.add_argument('--weights', type=float, nargs='+', default=None)
# Model
group_model = parser.add_argument_group('model')
group_model.add_argument('--model', type=str)
if parser_configurator is not None:
parser = parser_configurator(parser)
args, unparsed = parser.parse_known_args()
if osp.exists(args.exp_config):
cfg = parse_configs(args.exp_config, not args.inherit_off)
group_config = parser.add_argument_group('from_file')
def _cfg2arg(cfg, parser, prefix=''):
for k, v in cfg.items():
if isinstance(v, (list, tuple)):
# Only apply to homogeneous lists and tuples
parser.add_argument('--'+prefix+k, type=type(v[0]), nargs='*', default=v)
elif isinstance(v, dict):
# Recursively parse a dict
_cfg2arg(v, parser, prefix+k+'.')
elif isinstance(v, bool):
parser.add_argument('--'+prefix+k, action='store_true', default=v)
else:
parser.add_argument('--'+prefix+k, type=type(v), default=v)
_cfg2arg(cfg, group_config, '')
args = parser.parse_args()
elif len(unparsed)!=0:
raise RuntimeError("Unrecognized arguments")
def _arg2cfg(cfg, args):
args = vars(args)
for k, v in args.items():
pos = k.find('.')
if pos != -1:
# Iteratively parse a dict
dict_ = cfg
while pos != -1:
dict_.setdefault(k[:pos], {})
dict_ = dict_[k[:pos]]
k = k[pos+1:]
pos = k.find('.')
dict_[k] = v
else:
cfg[k] = v
return cfg
return _arg2cfg(dict(), args)
\ No newline at end of file
import os.path
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
# Data builder utilities
def build_train_dataloader(cls, configs, C):
return data.DataLoader(
cls(**configs),
batch_size=C['batch_size'],
shuffle=True,
num_workers=C['num_workers'],
pin_memory=C['device']!='cpu',
drop_last=True
)
def build_eval_dataloader(cls, configs):
return data.DataLoader(
cls(**configs),
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False,
drop_last=False
)
def get_common_train_configs(C):
return dict(phase='train', repeats=C['repeats'])
def get_common_eval_configs(C):
return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
# Dataset prototype
class DatasetBase(data.Dataset, metaclass=ABCMeta):
def __init__(
self,
root, phase,
transforms,
repeats,
subset
):
super().__init__()
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
raise FileNotFoundError
# phase stands for the working mode,
# 'train' for training and 'eval' for validating or testing.
assert phase in ('train', 'eval')
# subset is the sub-dataset to use.
# For some datasets there are three subsets,
# while for others there are only train and test(val).
assert subset in ('train', 'val', 'test')
self.phase = phase
self.transforms = transforms
self.repeats = int(repeats)
# Use 'train' subset during training.
self.subset = 'train' if self.phase == 'train' else subset
def __len__(self):
return self.len * self.repeats
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
item = self.fetch_and_preprocess(index)
return item
@abstractmethod
def fetch_and_preprocess(self, index):
return None
from functools import wraps
# from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable
from collections.abc import Sequence
from abc import ABC, ABCMeta
from itertools import chain
from importlib import import_module
......@@ -8,73 +9,76 @@ import torch
import torch.nn as nn
import torch.utils.data as data
import constants
import utils.metrics as metrics
from utils.misc import R
from data.augmentation import *
from .misc import (R, MODELS, OPTIMS, CRITNS, DATA)
class _Desc:
class _AttrDesc:
def __init__(self, key):
self.key = key
def __get__(self, instance, owner):
return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
def __set__(self, instance, values):
if not (isinstance(values, Iterable) and len(values)==len(instance)):
raise TypeError("incorrect type or number of values")
for i, v in zip(range(len(instance)), values):
setattr(instance[i], self.key, v)
return tuple(getattr(ele, self.key) for ele in instance)
def __set__(self, instance, value):
for ele in instance:
setattr(ele, self.key, value)
def _func_deco(func_name):
def _wrapper(self, *args):
return tuple(getattr(ins, func_name)(*args) for ins in self)
# FIXME: The signature of the wrapped function will be lost.
def _wrapper(self, *args, **kwargs):
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
return _wrapper
def _generator_deco(func_name):
# FIXME: The signature of the wrapped function will be lost.
def _wrapper(self, *args, **kwargs):
for ins in self:
yield from getattr(ins, func_name)(*args, **kwargs)
for ele in self:
yield from getattr(ele, func_name)(*args, **kwargs)
return _wrapper
# Duck typing
class Duck(tuple):
class Duck(Sequence, ABC):
__ducktype__ = object
def __new__(cls, *args):
if any(not isinstance(a, cls.__ducktype__) for a in args):
raise TypeError("please check the input type")
return tuple.__new__(cls, args)
def __init__(self, *args):
if any(not isinstance(arg, self.__ducktype__) for arg in args):
raise TypeError("Please check the input type.")
self._seq = tuple(args)
def __getitem__(self, key):
return self._seq[key]
def __add__(self, tup):
raise NotImplementedError
def __len__(self):
return len(self._seq)
def __mul__(self, tup):
raise NotImplementedError
def __repr__(self):
return repr(self._seq)
class DuckMeta(type):
class DuckMeta(ABCMeta):
def __new__(cls, name, bases, attrs):
assert len(bases) == 1
for k, v in getmembers(bases[0]):
if k.startswith('__'):
continue
assert len(bases) == 1 # Multiple inheritance is not yet supported.
members = dict(getmembers(bases[0])) # Trade space for time
for k in attrs['__ava__']:
if k in members:
v = members[k]
if isgeneratorfunction(v):
attrs.setdefault(k, _generator_deco(k))
elif isfunction(v):
attrs.setdefault(k, _func_deco(k))
else:
attrs.setdefault(k, _Desc(k))
attrs.setdefault(k, _AttrDesc(k))
attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module):
__ava__ = ('state_dict', 'load_state_dict', 'forward', '__call__', 'train', 'eval', 'to', 'training')
def __init__(self, *models):
super().__init__()
## XXX: The state_dict will be a little larger in size
# Since some extra bytes are stored in every key
# XXX: The state_dict will be a little larger in size,
# since some extra bytes are stored in every key.
self._m = nn.ModuleList(models)
def __len__(self):
......@@ -83,27 +87,39 @@ class DuckModel(nn.Module):
def __getitem__(self, idx):
return self._m[idx]
def __contains__(self, m):
return m in self._m
def __repr__(self):
return repr(self._m)
def forward(self, *args, **kwargs):
return tuple(m(*args, **kwargs) for m in self._m)
Duck.register(DuckModel)
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
# Cuz this is an instance method
__ava__ = ('param_groups', 'state_dict', 'load_state_dict', 'zero_grad', 'step')
# An instance attribute can not be automatically handled by metaclass
@property
def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self))
return list(chain.from_iterable(ele.param_groups for ele in self))
# This is special in dispatching
# Sepcial dispatching rule
def load_state_dict(self, state_dicts):
for optim, state_dict in zip(self, state_dicts):
optim.load_state_dict(state_dict)
class DuckCriterion(nn.Module, metaclass=DuckMeta):
__ava__ = ('forward', '__call__', 'train', 'eval', 'to')
pass
class DuckDataset(data.Dataset, metaclass=DuckMeta):
class DuckDataLoader(data.DataLoader, metaclass=DuckMeta):
__ava__ = ()
pass
......@@ -116,140 +132,45 @@ def _import_module(pkg: str, mod: str, rel=False):
def single_model_factory(model_name, C):
name = model_name.strip().upper()
if name == 'SIAMUNET_CONC':
from models.siamunet_conc import SiamUnet_conc
return SiamUnet_conc(C.num_feats_in, 2)
elif name == 'SIAMUNET_DIFF':
from models.siamunet_diff import SiamUnet_diff
return SiamUnet_diff(C.num_feats_in, 2)
elif name == 'EF':
from models.unet import Unet
return Unet(C.num_feats_in, 2)
builder_name = '_'.join([model_name, C['model'], C['dataset'], 'model'])
if builder_name in MODELS:
return MODELS[builder_name](C)
builder_name = '_'.join([model_name, C['dataset'], 'model'])
if builder_name in MODELS:
return MODELS[builder_name](C)
builder_name = '_'.join([model_name, 'model'])
if builder_name in MODELS:
return MODELS[builder_name](C)
else:
raise NotImplementedError("{} is not a supported architecture".format(model_name))
raise NotImplementedError("{} is not a supported architecture.".format(model_name))
def single_optim_factory(optim_name, params, C):
optim_name = optim_name.strip()
name = optim_name.upper()
if name == 'ADAM':
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C.lr,
weight_decay=C.weight_decay
)
elif name == 'SGD':
return torch.optim.SGD(
params,
lr=C.lr,
momentum=0.9,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
builder_name = '_'.join([optim_name, 'optim'])
if builder_name not in OPTIMS:
raise NotImplementedError("{} is not a supported optimizer type.".format(optim_name))
return OPTIMS[builder_name](params, C)
def single_critn_factory(critn_name, C):
import losses
critn_name = critn_name.strip()
try:
criterion, params = {
'L1': (nn.L1Loss, ()),
'MSE': (nn.MSELoss, ()),
'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
}[critn_name.upper()]
return criterion(*params)
except KeyError:
raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def _get_basic_configs(ds_name, C):
if ds_name == 'OSCD':
return dict(
root = constants.IMDB_OSCD
)
elif ds_name.startswith('AC'):
return dict(
root = constants.IMDB_AIRCHANGE
)
elif ds_name == 'Lebedev':
return dict(
root = constants.IMDB_LEBEDEV
)
builder_name = '_'.join([critn_name, 'critn'])
if builder_name not in CRITNS:
raise NotImplementedError("{} is not a supported criterion type.".format(critn_name))
return CRITNS[builder_name](C)
def single_data_factory(dataset_name, phase, C):
builder_name = '_'.join([dataset_name, C['dataset'], C['model'], phase, 'dataset'])
if builder_name in DATA:
return DATA[builder_name](C)
builder_name = '_'.join([dataset_name, C['model'], phase, 'dataset'])
if builder_name in DATA:
return DATA[builder_name](C)
builder_name = '_'.join([dataset_name, phase, 'dataset'])
if builder_name in DATA:
return DATA[builder_name](C)
else:
return dict()
def single_train_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
if ds_name == 'Lebedev':
configs.update(
dict(
subsets = ('real',)
)
)
else:
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
shuffle=True,
num_workers=C.num_workers,
pin_memory=not (C.device == 'cpu'), drop_last=True
)
def single_val_ds_factory(ds_name, C):
ds_name = ds_name.strip()
module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
transforms=(None, None, None),
repeats=1
)
# Update some common configurations
configs.update(_get_basic_configs(ds_name, C))
# Set phase-specific ones
if ds_name == 'Lebedev':
configs.update(
dict(
subsets = ('real',)
)
)
else:
pass
dataset_obj = dataset(**configs)
# Create eval set
return data.DataLoader(
dataset_obj,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False, drop_last=False
)
raise NotImplementedError("{} is not a supported dataset.".format(dataset_name))
def _parse_input_names(name_str):
......@@ -268,7 +189,7 @@ def optim_factory(optim_names, models, C):
name_list = _parse_input_names(optim_names)
num_models = len(models) if isinstance(models, DuckModel) else 1
if len(name_list) != num_models:
raise ValueError("the number of optimizers does not match the number of models")
raise ValueError("The number of optimizers does not match the number of models.")
if num_models > 1:
optims = []
......@@ -298,16 +219,7 @@ def critn_factory(critn_names, C):
def data_factory(dataset_names, phase, C):
name_list = _parse_input_names(dataset_names)
if phase not in ('train', 'val'):
raise ValueError("phase should be either 'train' or 'val'")
fact = globals()['single_'+phase+'_ds_factory']
if len(name_list) > 1:
return DuckDataset(*(fact(name, C) for name in name_list))
return DuckDataLoader(*(single_data_factory(name, phase, C) for name in name_list))
else:
return fact(dataset_names, C)
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name.strip())() for name in name_list]
return single_data_factory(dataset_names, phase, C)
\ No newline at end of file
import logging
import os
import os.path as osp
import sys
from time import localtime
from collections import OrderedDict
from collections import OrderedDict, deque
from weakref import proxy
FORMAT_LONG = "[%(asctime)-15s %(funcName)s] %(message)s"
FORMAT_SHORT = "%(message)s"
......@@ -16,6 +18,7 @@ class _LessThanFilter(logging.Filter):
def filter(self, record):
return record.levelno < self.max_level
class Logger:
_count = 0
......@@ -38,11 +41,11 @@ class Logger:
self._logger.addHandler(self._scrn_handler)
if log_dir and phase:
self.log_path = os.path.join(log_dir,
'{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
self.log_path = osp.join(log_dir,
"{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log".format(
phase, *localtime()[:6]
))
self.show_nl("log into {}\n\n".format(self.log_path))
self.show_nl("Log into {}\n\n".format(self.log_path))
self._file_handler = logging.FileHandler(filename=self.log_path)
self._file_handler.setLevel(logging.DEBUG)
self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
......@@ -58,7 +61,7 @@ class Logger:
def dump(self, *args, **kwargs):
return self._logger.debug(*args, **kwargs)
def warning(self, *args, **kwargs):
def warn(self, *args, **kwargs):
return self._logger.warning(*args, **kwargs)
def error(self, *args, **kwargs):
......@@ -67,16 +70,7 @@ class Logger:
def fatal(self, *args, **kwargs):
return self._logger.critical(*args, **kwargs)
@staticmethod
def make_desc(counter, total, *triples, opt_str=''):
desc = "[{}/{}] {}".format(counter, total, opt_str)
# The three elements of each triple are
# (name to display, AverageMeter object, formatting string)
for name, obj, fmt in triples:
desc += (" {} {obj.val:"+fmt+"} ({obj.avg:"+fmt+"})").format(name, obj=obj)
return desc
_default_logger = Logger()
_logger = Logger()
class _WeakAttribute:
......@@ -91,12 +85,12 @@ class _WeakAttribute:
class _TreeNode:
_sep = '/'
_none = None
parent = _WeakAttribute() # To avoid circular reference
def __init__(self, name, value=None, parent=None, children=None):
def __init__(
self, name, value=None, parent=None, children=None,
sep='/', none_val=None
):
super().__init__()
self.name = name
self.val = value
......@@ -106,46 +100,42 @@ class _TreeNode:
for child in children:
self._add_child(child)
self.path = name
self._sep = sep
self._none = none_val
def get_child(self, name, def_val=None):
return self.children.get(name, def_val)
def set_child(self, name, val=None):
r"""
Set the value of an existing node.
If the node does not exist, return nothing
"""
child = self.get_child(name)
if child is not None:
child.val = val
return child
def get_child(self, name):
return self.children.get(name, None)
def add_place_holder(self, name):
return self.add_child(name, val=self._none)
def add_placeholder(self, name):
return self.add_child(name, value=self._none)
def add_child(self, name, val):
def add_child(self, name, value, warning=False):
r"""
If not exists or is a placeholder, create it
Otherwise skips and returns the existing node
If node does not exist or is a placeholder, create it,
otherwise skip and return the existing node.
"""
child = self.get_child(name, None)
child = self.get_child(name)
if child is None:
child = _TreeNode(name, val, parent=self)
child = _TreeNode(name, value, parent=self, sep=self._sep, none_val=self._none)
self._add_child(child)
elif child.val == self._none:
# Retain the links of the placeholder
# i.e. just fill in it
child.val = val
elif child.is_placeholder():
# Retain the links of a placeholder,
# i.e. just fill in it.
child.val = value
else:
if warning:
_logger.warn("Node already exists!")
return child
def is_leaf(self):
return len(self.children) == 0
def is_placeholder(self):
return self.val == self._none
def __repr__(self):
try:
repr = self.path + ' ' + str(self.val)
repr = self.path + " " + str(self.val)
except TypeError:
repr = self.path
return repr
......@@ -157,7 +147,10 @@ class _TreeNode:
return self.get_child(key)
def _add_child(self, node):
r""" Into children dictionary and set path and parent """
r"""
Add a child node into self.children.
If the node already exists, just update its information.
"""
self.children.update({
node.name: node
})
......@@ -166,8 +159,8 @@ class _TreeNode:
def apply(self, func):
r"""
Apply a callback function on ALL descendants
This is useful for the recursive traversal
Apply a callback function to ALL descendants.
This is useful for recursive traversal.
"""
ret = [func(self)]
for _, node in self.children.items():
......@@ -175,69 +168,53 @@ class _TreeNode:
return ret
def bfs_tracker(self):
queue = []
queue.insert(0, self)
queue = deque()
queue.append(self)
while(queue):
curr = queue.pop()
curr = queue.popleft()
yield curr
if curr.is_leaf():
continue
for c in curr.children.values():
queue.insert(0, c)
queue.append(c)
class _Tree:
def __init__(
self, name, value=None, strc_ele=None,
sep=_TreeNode._sep, def_val=_TreeNode._none
self, name, value=None, eles=None,
sep='/', none_val=None
):
super().__init__()
self._sep = sep
self._def_val = def_val
self._none = none_val
self.root = _TreeNode(name, value, parent=None, children={})
if strc_ele is not None:
assert isinstance(strc_ele, dict)
# This is to avoid mutable parameter default
self.build_tree(OrderedDict(strc_ele or {}))
self.root = _TreeNode(name, value, parent=None, children={}, sep=self._sep, none_val=self._none)
if eles is not None:
assert isinstance(eles, dict)
self.build_tree(OrderedDict(eles or {}))
def build_tree(self, elements):
# The siblings could be out-of-order
# The order of the siblings is not retained
for path, ele in elements.items():
self.add_node(path, ele)
def get_root(self):
r""" Get separated root node """
return _TreeNode(
self.root.name, self.root.value,
parent=None, children=None
)
def __repr__(self):
return self.__dumps__()
def __dumps__(self):
r""" Dump to string """
_str = ''
_str = ""
# DFS
stack = []
stack.append((self.root, 0))
while(stack):
root, layer = stack.pop()
_str += ' '*layer + '-' + root.__repr__() + '\n'
_str += " "*layer + "-" + root.__repr__() + "\n"
if root.is_leaf():
continue
# Note that the order of the siblings is not retained
for c in reversed(list(root.children.values())):
# Note that the siblings are printed in alphabetical order.
for c in sorted(list(root.children.values()), key=lambda n: n.name, reverse=True):
stack.append((c, layer+1))
return _str
def vis(self):
r""" Visualize the structure of the tree """
_default_logger.show(self.__dumps__())
def __contains__(self, obj):
return any(self.perform(lambda node: obj in node))
......@@ -246,14 +223,15 @@ class _Tree:
def get_node(self, tar, mode='name'):
r"""
This is different from the travasal in that
the search allows early stop
This is different from a travasal in that this search allows early stop.
"""
assert mode in ('name', 'path', 'val')
if mode == 'path':
nodes = self.parse_path(tar)
root = self.root
for r in nodes:
if root is None:
break
root = root.get_child(r)
return root
else:
......@@ -264,28 +242,20 @@ class _Tree:
for node in bfs_tracker:
if getattr(node, mode) == tar:
return node
return
def set_node(self, path, val):
node = self.get_node(path, mode=path)
if node is not None:
node.val = val
return node
return None
def add_node(self, path, val=None):
def add_node(self, path, val):
if not path.strip():
raise ValueError("the path is null")
path = path.strip('/')
if val is None:
val = self._def_val
raise ValueError("The path is null.")
path = path.rstrip(self._sep)
names = self.parse_path(path)
root = self.root
nodes = [root]
for name in names[:-1]:
# Add placeholders
root = root.add_child(name, self._def_val)
# Add a placeholder or skip an existing node
root = root.add_placeholder(name)
nodes.append(root)
root = root.add_child(names[-1], val)
root = root.add_child(names[-1], val, True)
return root, nodes
def parse_path(self, path):
......@@ -296,22 +266,29 @@ class _Tree:
class OutPathGetter:
def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs):
def __init__(self, root='', log='logs', out='out', weight='weights', suffix='', **subs):
super().__init__()
self._root = root.rstrip('/') # Work robustly for multiple ending '/'s
self._root = root.rstrip(os.sep) # Work robustly on multiple ending '/'s
if len(self._root) == 0 and len(root) > 0:
self._root = '/' # In case of the system root dir
self._root = os.sep # In case of the system root dir in linux
self._suffix = suffix
self._keys = dict(log=log, out=out, weight=weight, **subs)
for k, v in self._keys.items():
v_ = v.rstrip(os.sep)
if len(v_) == 0 or not self.check_path(v_):
_logger.warn("{} is not a valid path.".format(v))
continue
self._keys[k] = v_
self._dir_tree = _Tree(
self._root, 'root',
strc_ele=dict(zip(self._keys.values(), self._keys.keys())),
sep='/',
def_val=''
eles=dict(zip(self._keys.values(), self._keys.keys())),
sep=os.sep, none_val=''
)
self.update_keys(False)
self.update_tree(False)
self.add_keys(False)
self.update_vfs(False)
self.__counter = 0
......@@ -326,89 +303,109 @@ class OutPathGetter:
def root(self):
return self._root
def _update_key(self, key, val, add=False, prefix=False):
if prefix:
val = os.path.join(self._root, val)
if add:
# Do not edit if exists
def _add_key(self, key, val):
self._keys.setdefault(key, val)
else:
self._keys.__setitem__(key, val)
def _add_node(self, key, val, prefix=False):
if not prefix and key.startswith(self._root):
key = key[len(self._root)+1:]
return self._dir_tree.add_node(key, val)
def update_keys(self, verbose=False):
def add_keys(self, verbose=False):
for k, v in self._keys.items():
self._update_key(k, v, prefix=True)
self._add_key(k, v)
if verbose:
_default_logger.show(self._keys)
_logger.show(self._keys)
def update_tree(self, verbose=False):
def update_vfs(self, verbose=False):
self._dir_tree.perform(lambda x: self.make_dir(x.path))
if verbose:
_default_logger.show("\nFolder structure:")
_default_logger.show(self._dir_tree)
_logger.show("\nFolder structure:")
_logger.show(self._dir_tree)
@staticmethod
def check_path(path):
# This is to prevent stuff like A/../B or A/./.././C.d
# Note that paths like A.B/.C/D are not supported, either.
return osp.dirname(path).find('.') == -1
@staticmethod
def make_dir(path):
if not os.path.exists(path):
if not osp.exists(path):
os.mkdir(path)
elif not osp.isdir(path):
raise RuntimeError("Cannot create directory.")
def get_dir(self, key):
return self._keys.get(key, '') if key != 'root' else self.root
return osp.join(self.root, self._keys[key])
def get_path(
self, key, file,
name='', auto_make=False,
suffix=True, underline=False
suffix=False, underline=True
):
folder = self.get_dir(key)
if len(folder) < 1:
raise KeyError("key not found")
if len(file) == 0:
return self.get_dir(key)
if not self.check_path(file):
raise ValueError("{} is not a valid path.".format(file))
folder = self._keys[key]
if suffix:
path = os.path.join(folder, self.add_suffix(file, underline=underline))
path = osp.join(folder, self._add_suffix(file, underline=underline))
else:
path = os.path.join(folder, file)
path = osp.join(folder, file)
if auto_make:
base_dir = os.path.dirname(path)
base_dir = osp.dirname(path)
# O(n) search for base_dir
# Never update an existing key!
if base_dir in self:
return path
if name:
self._update_key(name, base_dir, add=True)
'''
else:
name = 'new_{:03d}'.format(self.__counter)
self._update_key(name, base_dir, add=True)
self.__counter += 1
'''
des, visit = self._add_node(base_dir, name)
_logger.warn("Cannot assign a new key to an existing path!")
return osp.join(self.root, path)
node = self._dir_tree.get_node(base_dir, mode='path')
# Note that if name is an empty string,
# the directory tree will be updated, but the name will not be added into self._keys.
if node is None or node.is_placeholder():
# Update directory tree
des, visit = self._dir_tree.add_node(base_dir, name)
# Create directories along the visiting path
for d in visit: self.make_dir(d.path)
self.make_dir(des.path)
return path
else:
node.val = name
if len(name) > 0:
# Add new key
self._add_key(name, base_dir)
return osp.join(self.root, path)
def add_suffix(self, path, suffix='', underline=False):
def _add_suffix(self, path, suffix='', underline=False):
pos = path.rfind('.')
if pos == -1:
pos = len(path)
_suffix = self._suffix if len(suffix) < 1 else suffix
_suffix = self._suffix if len(suffix) == 0 else suffix
return path[:pos] + ('_' if underline and _suffix else '') + _suffix + path[pos:]
def __contains__(self, value):
return value in self._keys.values()
return value in self._keys.values() or value == self._root
def contains_key(self, key):
return key in self._keys
class Registry(dict):
def register(self, key, val):
if key in self: _default_logger.warning("key {} already registered".format(key))
if key in self: _logger.warn("Key {} has already been registered!".format(key))
self[key] = val
def register_func(self, key):
def _wrapper(func):
self.register(key, func)
return func
return _wrapper
# Registry for global objects
R = Registry()
R.register('DEFAULT_LOGGER', _default_logger)
R.register('Logger', _logger)
register = R.register
# Registries for builders
MODELS = Registry()
OPTIMS = Registry()
CRITNS = Registry()
DATA = Registry()
\ No newline at end of file
......@@ -2,75 +2,81 @@ import shutil
import os
from types import MappingProxyType
from copy import deepcopy
from abc import ABCMeta, abstractmethod
import torch
from skimage import io
from tqdm import tqdm
import constants
from data.common import to_array
from utils.misc import R
from utils.metrics import AverageMeter
from utils.utils import mod_crop
from .factories import (model_factory, optim_factory, critn_factory, data_factory, metric_factory)
from .misc import Logger, OutPathGetter, R
from .factories import (model_factory, optim_factory, critn_factory, data_factory)
class Trainer:
class Trainer(metaclass=ABCMeta):
def __init__(self, model, dataset, criterion, optimizer, settings):
super().__init__()
# Make a copy of settings in case of unexpected changes
context = deepcopy(settings)
self.ctx = MappingProxyType(vars(context))
self.mode = ('train', 'val').index(context.cmd)
self.logger = R['LOGGER']
self.gpc = R['GPC'] # Global Path Controller
# self.ctx is a proxy so that context will be read-only outside __init__
self.ctx = MappingProxyType(context)
self.mode = ('train', 'eval').index(context['cmd'])
self.debug = context['debug_on']
self.log = not context['log_off']
self.batch_size = context['batch_size']
self.checkpoint = context['resume']
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context['num_epochs']
self.lr = float(context['lr'])
self.track_intvl = int(context['track_intvl'])
self.device = torch.device(context['device'])
self.gpc = OutPathGetter(
root=os.path.join(context['exp_dir'], context['tag']),
suffix=context['suffix']
) # Global Path Controller
self.logger = Logger(
scrn=True,
log_dir=self.gpc.get_dir('log') if self.log else '',
phase=context['cmd']
)
self.path = self.gpc.get_path
self.batch_size = context.batch_size
self.checkpoint = context.resume
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context.num_epochs
self.lr = float(context.lr)
self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir
self.track_intvl = int(context.track_intvl)
self.device = torch.device(context.device)
self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()):
for k, v in sorted(context.items()):
self.logger.show("{}: {}".format(k,v))
self.model = model_factory(model, context)
self.model.to(self.device)
self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device)
self.metrics = metric_factory(context.metrics, context)
if self.is_training:
self.train_loader = data_factory(dataset, 'train', context)
self.val_loader = data_factory(dataset, 'val', context)
self.eval_loader = data_factory(dataset, 'eval', context)
self.optimizer = optim_factory(optimizer, self.model, context)
else:
self.val_loader = data_factory(dataset, 'val', context)
self.eval_loader = data_factory(dataset, 'eval', context)
self.start_epoch = 0
self._init_max_acc_and_epoch = (0.0, 0)
self._init_acc_epoch = (0.0, -1)
@property
def is_training(self):
return self.mode == 0
@abstractmethod
def train_epoch(self, epoch):
raise NotImplementedError
pass
def validate_epoch(self, epoch=0, store=False):
raise NotImplementedError
@abstractmethod
def evaluate_epoch(self, epoch):
return 0.0
def _write_prompt(self):
self.logger.dump(input("\nWrite some notes: "))
def run(self):
if self.is_training:
if self.log and not self.debug:
self._write_prompt()
self.train()
else:
......@@ -80,23 +86,20 @@ class Trainer:
if self.load_checkpoint:
self._resume_from_checkpoint()
max_acc, best_epoch = self._init_max_acc_and_epoch
max_acc, best_epoch = self._init_acc_epoch
lr = self.init_learning_rate()
for epoch in range(self.start_epoch, self.num_epochs):
lr = self._adjust_learning_rate(epoch)
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
# Train for one epoch
self.model.train()
self.train_epoch(epoch)
# Clear the history of metric objects
for m in self.metrics:
m.reset()
# Evaluate the model on validation set
self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save)
# Evaluate the model
self.logger.show_nl("Evaluate")
self.model.eval()
acc = self.evaluate_epoch(epoch=epoch)
is_best = acc > max_acc
if is_best:
......@@ -105,77 +108,74 @@ class Trainer:
self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch
# Do not save checkpoints in debugging mode
if not self.debug:
self._save_checkpoint(
self.model.state_dict(),
self.optimizer.state_dict() if self.ctx['save_optim'] else {},
(max_acc, best_epoch), epoch+1, is_best
(max_acc, best_epoch), epoch, is_best
)
lr = self.adjust_learning_rate(epoch, acc)
def evaluate(self):
if self.checkpoint:
if self._resume_from_checkpoint():
self.validate_epoch(self.ckp_epoch, self.save)
else:
self.logger.warning("Warning: no checkpoint assigned!")
def _adjust_learning_rate(self, epoch):
if self.ctx['lr_mode'] == 'step':
lr = self.lr * (0.5 ** (epoch // self.ctx['step']))
elif self.ctx['lr_mode'] == 'poly':
lr = self.lr * (1 - epoch / self.num_epochs) ** 1.1
elif self.ctx['lr_mode'] == 'const':
lr = self.lr
self.model.eval()
self.evaluate_epoch(self.start_epoch)
else:
raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode']))
self.logger.error("No checkpoint assigned!")
def init_learning_rate(self):
return self.lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def adjust_learning_rate(self, epoch, acc):
return self.lr
def _resume_from_checkpoint(self):
## XXX: This could be slow!
# XXX: This could be slow!
if not os.path.isfile(self.checkpoint):
self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
return False
self.logger.show("=> Loading checkpoint '{}'".format(
self.checkpoint))
self.logger.show("=> Loading checkpoint '{}'...".format(self.checkpoint))
checkpoint = torch.load(self.checkpoint, map_location=self.device)
state_dict = self.model.state_dict()
ckp_dict = checkpoint.get('state_dict', checkpoint)
update_dict = {k:v for k,v in ckp_dict.items()
if k in state_dict and state_dict[k].shape == v.shape}
update_dict = {
k:v for k,v in ckp_dict.items()
if k in state_dict and state_dict[k].shape == v.shape and state_dict[k].dtype == v.dtype
}
num_to_update = len(update_dict)
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if not self.is_training and (num_to_update < len(state_dict)):
self.logger.error("=> Mismatched checkpoint for evaluation")
return False
self.logger.warning("Warning: trying to load an mismatched checkpoint.")
self.logger.warn("Trying to load a mismatched checkpoint.")
if num_to_update == 0:
self.logger.error("=> No parameter is to be loaded.")
return False
else:
self.logger.warning("=> {} params are to be loaded.".format(num_to_update))
elif (not self.ctx['anew']) or not self.is_training:
self.start_epoch = checkpoint.get('epoch', 0)
max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch))
# For backward compatibility
if isinstance(max_acc_and_epoch, (float, int)):
self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
else:
self._init_max_acc_and_epoch = max_acc_and_epoch
self.logger.warn("=> {} params are to be loaded.".format(num_to_update))
elif not self.ctx['anew'] or not self.is_training:
ckp_epoch = checkpoint.get('epoch', -1)
self.start_epoch = ckp_epoch+1
self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch))
if self.ctx['load_optim'] and self.is_training:
# Note that weight decay might be modified here
# XXX: Note that weight decay might be modified here.
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.warn("Weight decay might have been modified.")
state_dict.update(update_dict)
self.model.load_state_dict(state_dict)
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format(
self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch
if self.start_epoch == 0:
self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint))
else:
self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format(
self.checkpoint, self.start_epoch-1, *self._init_acc_epoch
))
return True
......@@ -187,117 +187,46 @@ class Trainer:
'max_acc': max_acc
}
# Save history
history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True)
# epoch+1 instead of epoch is contained in the checkpoint name so that it will be easy for
# one to recognize "the next start_epoch".
history_path = self.path(
'weight', constants.CKP_COUNTED.format(e=epoch+1),
suffix=True
)
if epoch % self.track_intvl == 0:
torch.save(state, history_path)
# Save latest
latest_path = self.path(
'weight', constants.CKP_LATEST,
underline=True
suffix=True
)
torch.save(state, latest_path)
if is_best:
shutil.copyfile(
latest_path, self.path(
'weight', constants.CKP_BEST,
underline=True
suffix=True
)
)
@property
def ckp_epoch(self):
# Get current epoch of the checkpoint
# For dismatched ckp or no ckp, set to 0
return max(self.start_epoch-1, 0)
def save_image(self, file_name, image, epoch):
file_path = os.path.join(
'epoch_{}/'.format(epoch),
self.out_dir,
file_name
)
out_path = self.path(
'out', file_path,
suffix=not self.suffix_off,
auto_make=True,
underline=True
)
return io.imsave(out_path, image)
class CDTrainer(Trainer):
def __init__(self, arch, dataset, optimizer, settings):
super().__init__(arch, dataset, 'NLL', optimizer, settings)
def train_epoch(self, epoch):
losses = AverageMeter()
len_train = len(self.train_loader)
pb = tqdm(self.train_loader)
self.model.train()
for i, (t1, t2, label) in enumerate(pb):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
class TrainerSwitcher:
r"""A simple utility class to help dispatch actions to different trainers."""
def __init__(self, *pairs):
self._trainer_list = list(pairs)
losses.update(loss.item(), n=self.batch_size)
def __call__(self, args, return_obj=True):
for p, t in self._trainer_list:
if p(args):
return t(args) if return_obj else t
return None
# Compute gradients and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def add_item(self, predicate, trainer):
# Newly added items have higher priority
self._trainer_list.insert(0, (predicate, trainer))
desc = self.logger.make_desc(
i+1, len_train,
('loss', losses, '.4f')
)
pb.set_description(desc)
self.logger.dump(desc)
def validate_epoch(self, epoch=0, store=False):
self.logger.show_nl("Epoch: [{0}]".format(epoch))
losses = AverageMeter()
len_val = len(self.val_loader)
pb = tqdm(self.val_loader)
self.model.eval()
with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb):
if self.is_training and i >= 16:
# Do not validate all images on training phase
pb.close()
self.logger.warning("validation ends early")
break
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
label = to_array(label[0]).astype('uint8')
for m in self.metrics:
m.update(CM, label)
desc = self.logger.make_desc(
i+1, len_val,
('loss', losses, '.4f'),
*(
(m.__name__, m, '.4f')
for m in self.metrics
)
)
pb.set_description(desc)
self.logger.dump(desc)
def add_default(self, trainer):
self._trainer_list.append((lambda: True, trainer))
if store:
self.save_image(name[0], CM*255, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
R.register('Trainer_switcher', TrainerSwitcher())
\ No newline at end of file
from glob import glob
from os.path import join, basename
import numpy as np
from . import CDDataset
from .common import default_loader
class LebedevDataset(CDDataset):
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
subsets=('real', 'with_shift', 'without_shift')
):
self.subsets = subsets
super().__init__(root, phase, transforms, repeats)
def _read_file_paths(self):
t1_list, t2_list, label_list = [], [], []
for subset in self.subsets:
# Get subset directory
if subset == 'real':
subset_dir = join(self.root, 'Real', 'subset')
elif subset == 'with_shift':
subset_dir = join(self.root, 'Model', 'with_shift')
elif subset == 'without_shift':
subset_dir = join(self.root, 'Model', 'without_shift')
else:
raise RuntimeError('unrecognized key encountered')
pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
label_list.extend(refs)
t1_list.extend(t1s)
t2_list.extend(t2s)
return t1_list, t2_list, label_list
def fetch_label(self, label_path):
# To {0,1}
return (super().fetch_label(label_path) > 127).astype(np.uint8)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment