From eb48f74c5b70deb08e7095dc706ee0441a28eece Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Tue, 8 Dec 2020 22:10:22 +0800
Subject: [PATCH] Update framework

---
 README.md                                   |  80 +---
 configs/config_EF_AC_Szada.yaml             |  52 ---
 configs/config_EF_AC_Tiszadob.yaml          |  52 ---
 configs/config_EF_OSCD.yaml                 |  52 ---
 configs/config_base.yaml                    |  49 +-
 configs/config_siamconc_AC_Szada.yaml       |  52 ---
 configs/config_siamconc_AC_Tiszadob.yaml    |  52 ---
 configs/config_siamconc_OSCD.yaml           |  52 ---
 configs/config_siamdiff_AC_Szada.yaml       |  52 ---
 configs/config_siamdiff_AC_Tiszadob.yaml    |  52 ---
 configs/config_siamdiff_OSCD.yaml           |  52 ---
 src/constants.py                            |  18 +-
 src/core/__init__.py                        |  12 +
 src/core/builders.py                        |  49 ++
 src/core/config.py                          | 126 +++++
 src/core/data.py                            |  81 ++++
 src/core/factories.py                       | 272 ++++-------
 src/{utils => core}/misc.py                 | 307 ++++++------
 src/core/trainer.py                         | 232 ++++++++++
 src/core/trainers.py                        | 303 ------------
 src/data/Lebedev.py                         |  47 --
 src/data/__init__.py                        |  59 ++-
 src/data/{_AirChange.py => _airchange.py}   |  36 +-
 src/data/{AC_Szada.py => ac_szada.py}       |   7 +-
 src/data/{AC_Tiszadob.py => ac_tiszadob.py} |   7 +-
 src/data/augmentation.py                    | 489 --------------------
 src/data/augmentations.py                   | 414 +++++++++++++++++
 src/data/common.py                          |  34 --
 src/data/{OSCD.py => oscd.py}               |  50 +-
 src/impl/builders/__init__.py               |   6 +
 src/impl/builders/critn_builders.py         |   3 +
 src/impl/builders/data_builders.py          | 149 ++++++
 src/impl/builders/model_builders.py         |  39 ++
 src/impl/builders/optim_builders.py         |   3 +
 src/impl/trainers/__init__.py               |   8 +
 src/impl/trainers/cd_trainer.py             | 222 +++++++++
 src/losses.py                               |   5 -
 src/test.py                                 |   0
 src/train.py                                | 173 ++-----
 src/utils/data_utils.py                     |  63 +++
 src/utils/losses.py                         |   0
 src/utils/metrics.py                        |  25 +-
 src/utils/utils.py                          |  63 ++-
 train9.sh                                   |  20 -
 44 files changed, 1885 insertions(+), 2034 deletions(-)
 delete mode 100644 configs/config_EF_AC_Szada.yaml
 delete mode 100644 configs/config_EF_AC_Tiszadob.yaml
 delete mode 100644 configs/config_EF_OSCD.yaml
 delete mode 100644 configs/config_siamconc_AC_Szada.yaml
 delete mode 100644 configs/config_siamconc_AC_Tiszadob.yaml
 delete mode 100644 configs/config_siamconc_OSCD.yaml
 delete mode 100644 configs/config_siamdiff_AC_Szada.yaml
 delete mode 100644 configs/config_siamdiff_AC_Tiszadob.yaml
 delete mode 100644 configs/config_siamdiff_OSCD.yaml
 create mode 100644 src/core/__init__.py
 create mode 100644 src/core/builders.py
 create mode 100644 src/core/config.py
 create mode 100644 src/core/data.py
 rename src/{utils => core}/misc.py (53%)
 create mode 100644 src/core/trainer.py
 delete mode 100644 src/core/trainers.py
 delete mode 100644 src/data/Lebedev.py
 rename src/data/{_AirChange.py => _airchange.py} (63%)
 rename src/data/{AC_Szada.py => ac_szada.py} (69%)
 rename src/data/{AC_Tiszadob.py => ac_tiszadob.py} (69%)
 delete mode 100644 src/data/augmentation.py
 create mode 100644 src/data/augmentations.py
 delete mode 100644 src/data/common.py
 rename src/data/{OSCD.py => oscd.py} (53%)
 create mode 100644 src/impl/builders/__init__.py
 create mode 100644 src/impl/builders/critn_builders.py
 create mode 100644 src/impl/builders/data_builders.py
 create mode 100644 src/impl/builders/model_builders.py
 create mode 100644 src/impl/builders/optim_builders.py
 create mode 100644 src/impl/trainers/__init__.py
 create mode 100644 src/impl/trainers/cd_trainer.py
 delete mode 100644 src/losses.py
 create mode 100644 src/test.py
 create mode 100644 src/utils/data_utils.py
 create mode 100644 src/utils/losses.py
 delete mode 100755 train9.sh

diff --git a/README.md b/README.md
index 722b479..385c929 100644
--- a/README.md
+++ b/README.md
@@ -8,17 +8,18 @@ This is an unofficial implementation of the paper
 
 [paper link](https://ieeexplore.ieee.org/abstract/document/8451652)
 
-# Prerequisites
+# Dependencies
 
 > opencv-python==4.1.1  
-  pytorch==1.2.0  
+  pytorch==1.3.1  
+  torchvision==0.4.2  
   pyyaml==5.1.2  
   scikit-image==0.15.0  
   scikit-learn==0.21.3  
   scipy==1.3.1  
-  tqdm==4.35.0  
+  tqdm==4.35.0
 
-Tested on Python 3.7.4, Ubuntu 16.04 and Python 3.6.8, Windows 10.
+Tested using Python 3.7.4 on Ubuntu 16.04 and Python 3.6.8 on Windows 10.
 
 # Basic usage
 
@@ -30,84 +31,25 @@ mkdir exp
 cd src
 ```
 
-In `src/constants.py`, change the dataset directories to your own. In `config_base.yaml`, feel free to change some configurations.
+In `src/constants.py`, change the dataset locations to your own. In `config_base.yaml`, set specific configurations.
 
 For training, try
 
 ```bash
-python train.py train --exp-config ../configs/config_base.yaml
+python train.py train --exp_config ../configs/config_base.yaml
 ```
 
 For evaluation, try
 
 ```bash
-python train.py val --exp-config ../configs/config_base.yaml --resume path_to_checkpoint --save-on
+python train.py eval --exp_config ../configs/config_base.yaml --resume path_to_checkpoint --save-on
 ```
 
-You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/outs`.
-
-# Train on Air Change dataset and OSCD dataset
-
-To carry out a full training on these two datasets and with all three architectures, run the `train9.sh` script under the root folder of this repo.
-```bash
-. ./train9.sh
-```
-
-And check the results in different subdirectories of `./exp/`. 
-
-# Create your own configuration file
-
-During scientific research, it is common case that we have to do a lot of experiments with different settings, and that's why we need the configuration files to better manage those settings. In this repo, you can create a `yaml` file under the naming convention below:
-
-`config_TAG{_SUFFIX}.yaml`
-
-Those in the curly braces can be omitted. `TAG` usually stands for an experiment group. For example, a set of experiments for an architecture, a dataset, etc. It will be the name of the subdirectory that holds all the checkpoints, log files, and output images. `SUFFIX` can be used to distinguish different experiments in an experiment group. If it is specified, the generated files of this experiment will be tagged with `SUFFIX` in their file names. In plain English, `TAG1` and `TAG2` have major differences, while `SUFFIX1` and `SUFFIX2` of the same `TAG` share most of the configurations. By combining `TAG` and `SUFFIX`, it is convenient for both coarse-grained and find-grained control of experimental configurations.
-
-Here is an example to help you understand. Suppose I'm going to finish my experiments on two datasets, OSCD and Lebedev, and I'm not sure which batch size achieves best performance. So I create these 5 config files.
-```
-config_OSCD_bs4.yaml
-config_OSCD_bs8.yaml
-config_OSCD_bs16.yaml
-config_Lebedev_bs16.yaml
-config_Lebedev_bs32.yaml
-```
-
-After training, I get my `exp/` folder like this:
-
-```
--exp/
---OSCD/
----weights/
-----model_best_bs4.pth
-----model_best_bs8.pth
-----model_best_bs16.pth
----outs/
----logs/
----config_OSCD_bs4.yaml
----config_OSCD_bs8.yaml
----config_OSCD_bs16.yaml
---Lebedev/
----weights/
-----model_best_bs16.pth
-----model_best_bs32.pth
----outs/
----logs/
----config_Lebedev_bs16.yaml
----config_Lebedev_bs32.yaml
-```
-
-Now the experiment results are organized in a more structured way, and I think it would be a little bit easier to collect the statistics. Also, since the historical experiments are arranged in neat order, you will soon remember what you'd done when you come back to these results, even after a long time.
-
-Alternatively, you can configure from the command line. This can be useful when there is only minor change between two single runs, because the configuration items from the command line is set to overwrite those from the `yaml` file. That is, the final value of each configuration item is evaluated and applied in the following order:
-
-```
-default_value -> value_from_config_file -> value_from_command_line
-```
-
-At least one of the above three values should be given. In this way, you don't have to include all of the config items in the `yaml` file or in the command-line input. You can use either of them, or combine them. Make your choice according to preference and circumstances.
+You can check the model weight files in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/base/out`.
 
 ---
 # Changed
 
-- 2020.3.14 Add the configuration files of my experiments. 
+- 2020.3.14 Add configuration files.
 - 2020.4.14 Detail README.md.
+- 2020.12.8 Update framework.
\ No newline at end of file
diff --git a/configs/config_EF_AC_Szada.yaml b/configs/config_EF_AC_Szada.yaml
deleted file mode 100644
index 8579c59..0000000
--- a/configs/config_EF_AC_Szada.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Szada
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: EF
-num_feats_in: 6
\ No newline at end of file
diff --git a/configs/config_EF_AC_Tiszadob.yaml b/configs/config_EF_AC_Tiszadob.yaml
deleted file mode 100644
index 52a103a..0000000
--- a/configs/config_EF_AC_Tiszadob.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Tiszadob
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: EF
-num_feats_in: 6
\ No newline at end of file
diff --git a/configs/config_EF_OSCD.yaml b/configs/config_EF_OSCD.yaml
deleted file mode 100644
index bb1005e..0000000
--- a/configs/config_EF_OSCD.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: OSCD
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: EF
-num_feats_in: 26
\ No newline at end of file
diff --git a/configs/config_base.yaml b/configs/config_base.yaml
index 2e1119b..0f30ec5 100644
--- a/configs/config_base.yaml
+++ b/configs/config_base.yaml
@@ -2,51 +2,54 @@
 
 
 # Data
-# Common
-dataset: Lebedev
-crop_size: 224
-num_workers: 1
-repeats: 1
+dataset: AC_Szada
+num_workers: 0
+repeats: 3200
+subset: val
+crop_size: 112
 
 
 # Optimizer
-optimizer: Adam
-lr: 1e-4
-lr_mode: step
-weight_decay: 0.0
-step: 5
+optimizer: SGD
+lr: 0.001
+weight_decay: 0.0005
+load_optim: False
+save_optim: False
+lr_mode: const
+step: 2
 
 
 # Training related
-batch_size: 8 
-num_epochs: 15
+batch_size: 32
+num_epochs: 10
 resume: ''
-load_optim: True
-save_optim: True
 anew: False
-track_intvl: 1
 device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
 
 
 # Experiment
 exp_dir: ../exp/
-out_dir: ''
 # tag: ''
 # suffix: ''
-# DO NOT specify exp-config term
-save_on: False
+# DO NOT specify exp_config
+debug_on: False
+inherit_off: True
 log_off: False
+track_intvl: 1
+tb_on: False
+tb_intvl: 100
 suffix_off: False
+save_on: False
+out_dir: ''
+val_iters: 16
 
 
 # Criterion
 criterion: NLL
 weights: 
-  - 0.117   # Weight of no-change class
-  - 0.883   # Weight of change class
+  - 1.0   # Weight of no-change class
+  - 10.0   # Weight of change class
 
 
 # Model
-model: EF
-num_feats_in: 6
\ No newline at end of file
+model: Unet
\ No newline at end of file
diff --git a/configs/config_siamconc_AC_Szada.yaml b/configs/config_siamconc_AC_Szada.yaml
deleted file mode 100644
index d1972fb..0000000
--- a/configs/config_siamconc_AC_Szada.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Szada
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_conc
-num_feats_in: 3
\ No newline at end of file
diff --git a/configs/config_siamconc_AC_Tiszadob.yaml b/configs/config_siamconc_AC_Tiszadob.yaml
deleted file mode 100644
index 1ab90f9..0000000
--- a/configs/config_siamconc_AC_Tiszadob.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Tiszadob
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_conc
-num_feats_in: 3
\ No newline at end of file
diff --git a/configs/config_siamconc_OSCD.yaml b/configs/config_siamconc_OSCD.yaml
deleted file mode 100644
index a007e1e..0000000
--- a/configs/config_siamconc_OSCD.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: OSCD
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_conc
-num_feats_in: 13
\ No newline at end of file
diff --git a/configs/config_siamdiff_AC_Szada.yaml b/configs/config_siamdiff_AC_Szada.yaml
deleted file mode 100644
index e569a2b..0000000
--- a/configs/config_siamdiff_AC_Szada.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Szada
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_diff
-num_feats_in: 3
\ No newline at end of file
diff --git a/configs/config_siamdiff_AC_Tiszadob.yaml b/configs/config_siamdiff_AC_Tiszadob.yaml
deleted file mode 100644
index 5d6a35d..0000000
--- a/configs/config_siamdiff_AC_Tiszadob.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: AC_Tiszadob
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_diff
-num_feats_in: 3
\ No newline at end of file
diff --git a/configs/config_siamdiff_OSCD.yaml b/configs/config_siamdiff_OSCD.yaml
deleted file mode 100644
index c871805..0000000
--- a/configs/config_siamdiff_OSCD.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-# Basic configurations
-
-
-# Data
-# Common
-dataset: OSCD
-crop_size: 112
-num_workers: 1
-repeats: 3200
-
-
-# Optimizer
-optimizer: SGD
-lr: 0.001
-lr_mode: const
-weight_decay: 0.0005
-step: 2
-
-
-# Training related
-batch_size: 32
-num_epochs: 10
-resume: ''
-load_optim: True
-save_optim: True
-anew: False
-track_intvl: 1
-device: cuda
-metrics: 'F1Score+Accuracy+Recall+Precision'
-
-
-# Experiment
-exp_dir: ../exp/
-out_dir: ''
-# tag: ''
-# suffix: ''
-# DO NOT specify exp-config term
-save_on: False
-log_off: False
-suffix_off: False
-
-
-# Criterion
-criterion: NLL
-weights: 
-  - 1.0   # Weight of no-change class
-  - 10.0   # Weight of change class
-
-
-# Model
-model: siamunet_diff
-num_feats_in: 13
\ No newline at end of file
diff --git a/src/constants.py b/src/constants.py
index c9bfdd0..58c8a20 100644
--- a/src/constants.py
+++ b/src/constants.py
@@ -1,12 +1,12 @@
 # Global constants
 
 
-# Dataset directories
-IMDB_OSCD = '~/Datasets/OSCDDataset/'
-IMDB_AIRCHANGE = '~/Datasets/SZTAKI_AirChange_Benchmark/'
-IMDB_LEBEDEV = '~/Datasets/HR/ChangeDetectionDataset/'
-
-# Checkpoint templates
-CKP_LATEST = 'checkpoint_latest.pth'
-CKP_BEST = 'model_best.pth'
-CKP_COUNTED = 'checkpoint_{e:03d}.pth'
+# Dataset locations
+IMDB_OSCD = "~/Datasets/OSCDDataset/"
+IMDB_AIRCHANGE = "~/Datasets/SZTAKI_AirChange_Benchmark/"
+
+
+# Template strings
+CKP_LATEST = "checkpoint_latest.pth"
+CKP_BEST = "model_best.pth"
+CKP_COUNTED = "checkpoint_{e:03d}.pth"
diff --git a/src/core/__init__.py b/src/core/__init__.py
new file mode 100644
index 0000000..0adee14
--- /dev/null
+++ b/src/core/__init__.py
@@ -0,0 +1,12 @@
+# Initialize all
+
+import core.misc
+import core.data
+import core.config
+
+import core.builders
+import core.factories
+import core.trainer
+
+import impl.builders
+import impl.trainers
\ No newline at end of file
diff --git a/src/core/builders.py b/src/core/builders.py
new file mode 100644
index 0000000..345f0ea
--- /dev/null
+++ b/src/core/builders.py
@@ -0,0 +1,49 @@
+# Built-in builders
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .misc import (MODELS, OPTIMS, CRITNS, DATA)
+
+
+# Optimizer builders
+@OPTIMS.register_func('Adam_optim')
+def build_Adam_optim(params, C):
+    return torch.optim.Adam(
+        params, 
+        betas=(0.9, 0.999),
+        lr=C['lr'],
+        weight_decay=C['weight_decay']
+    )
+
+
+@OPTIMS.register_func('SGD_optim')
+def build_SGD_optim(params, C):
+    return torch.optim.SGD(
+        params, 
+        lr=C['lr'],
+        momentum=0.9,
+        weight_decay=C['weight_decay']
+    )
+
+
+# Criterion builders
+@CRITNS.register_func('L1_critn')
+def build_L1_critn(C):
+    return nn.L1Loss()
+
+
+@CRITNS.register_func('MSE_critn')
+def build_MSE_critn(C):
+    return nn.MSELoss()
+
+
+@CRITNS.register_func('CE_critn')
+def build_CE_critn(C):
+    return nn.CrossEntropyLoss(torch.Tensor(C['weights']))
+
+
+@CRITNS.register_func('NLL_critn')
+def build_NLL_critn(C):
+    return nn.NLLLoss(torch.Tensor(C['weights']))
diff --git a/src/core/config.py b/src/core/config.py
new file mode 100644
index 0000000..949f029
--- /dev/null
+++ b/src/core/config.py
@@ -0,0 +1,126 @@
+import argparse
+import os.path as osp
+from collections import ChainMap
+
+import yaml
+
+
+def read_config(config_path):
+    with open(config_path, 'r') as f:
+        cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
+    return cfg or {}
+
+
+def parse_configs(cfg_path, inherit=True):
+    # Read and parse config files
+    cfg_dir = osp.dirname(cfg_path)
+    cfg_name = osp.basename(cfg_path)
+    cfg_name, ext = osp.splitext(cfg_name)
+    parts = cfg_name.split('_')
+    cfg_path = osp.join(cfg_dir, parts[0])
+    cfgs = []
+    for part in parts[1:]:
+        cfg_path = '_'.join([cfg_path, part])
+        if osp.exists(cfg_path+ext):
+            cfgs.append(read_config(cfg_path+ext))
+    cfgs.reverse()
+    if len(parts)>=2:
+        return ChainMap(*cfgs, dict(tag=parts[1], suffix='_'.join(parts[2:])))
+    else:
+        return ChainMap(*cfgs)
+
+
+def parse_args(parser_configurator=None):
+    # Parse necessary arguments
+    # Global settings
+    parser = argparse.ArgumentParser(conflict_handler='resolve')
+    parser.add_argument('cmd', choices=['train', 'eval'])
+
+    # Data
+    group_data = parser.add_argument_group('data')
+    group_data.add_argument('--dataset', type=str)
+    group_data.add_argument('--num_workers', type=int, default=4)
+    group_data.add_argument('--repeats', type=int, default=1)
+    group_data.add_argument('--subset', type=str, default='val')
+
+    # Optimizer
+    group_optim = parser.add_argument_group('optimizer')
+    group_optim.add_argument('--optimizer', type=str, default='Adam')
+    group_optim.add_argument('--lr', type=float, default=1e-4)
+    group_optim.add_argument('--weight_decay', type=float, default=1e-4)
+    group_optim.add_argument('--load_optim', action='store_true')
+    group_optim.add_argument('--save_optim', action='store_true')
+
+    # Training related
+    group_train = parser.add_argument_group('training related')
+    group_train.add_argument('--batch_size', type=int, default=8)
+    group_train.add_argument('--num_epochs', type=int)
+    group_train.add_argument('--resume', type=str, default='')
+    group_train.add_argument('--anew', action='store_true',
+                        help="clear history and start from epoch 0 with weights updated")
+    group_train.add_argument('--device', type=str, default='cpu')
+
+    # Experiment
+    group_exp = parser.add_argument_group('experiment related')
+    group_exp.add_argument('--exp_dir', default='../exp/')
+    group_exp.add_argument('--tag', type=str, default='')
+    group_exp.add_argument('--suffix', type=str, default='')
+    group_exp.add_argument('--exp_config', type=str, default='')
+    group_exp.add_argument('--debug_on', action='store_true')
+    group_exp.add_argument('--inherit_off', action='store_true')
+    group_exp.add_argument('--log_off', action='store_true')
+    group_exp.add_argument('--track_intvl', type=int, default=1)
+
+    # Criterion
+    group_critn = parser.add_argument_group('criterion related')
+    group_critn.add_argument('--criterion', type=str, default='NLL')
+    group_critn.add_argument('--weights', type=float, nargs='+', default=None)
+
+    # Model
+    group_model = parser.add_argument_group('model')
+    group_model.add_argument('--model', type=str)
+
+    if parser_configurator is not None:
+        parser = parser_configurator(parser)
+
+    args, unparsed = parser.parse_known_args()
+    
+    if osp.exists(args.exp_config):
+        cfg = parse_configs(args.exp_config, not args.inherit_off)
+        group_config = parser.add_argument_group('from_file')
+        
+        def _cfg2arg(cfg, parser, prefix=''):
+            for k, v in cfg.items():
+                if isinstance(v, (list, tuple)):
+                    # Only apply to homogeneous lists and tuples
+                    parser.add_argument('--'+prefix+k, type=type(v[0]), nargs='*', default=v)
+                elif isinstance(v, dict):
+                    # Recursively parse a dict
+                    _cfg2arg(v, parser, prefix+k+'.')
+                elif isinstance(v, bool):
+                    parser.add_argument('--'+prefix+k, action='store_true', default=v)
+                else:
+                    parser.add_argument('--'+prefix+k, type=type(v), default=v)
+        _cfg2arg(cfg, group_config, '')
+        args = parser.parse_args()
+    elif len(unparsed)!=0:
+        raise RuntimeError("Unrecognized arguments")
+
+    def _arg2cfg(cfg, args):
+        args = vars(args)
+        for k, v in args.items():
+            pos = k.find('.')
+            if pos != -1:
+                # Iteratively parse a dict
+                dict_ = cfg
+                while pos != -1:
+                    dict_.setdefault(k[:pos], {})
+                    dict_ = dict_[k[:pos]]
+                    k = k[pos+1:]
+                    pos = k.find('.')
+                dict_[k] = v
+            else:
+                cfg[k] = v
+        return cfg
+
+    return _arg2cfg(dict(), args)
\ No newline at end of file
diff --git a/src/core/data.py b/src/core/data.py
new file mode 100644
index 0000000..b1fcbcd
--- /dev/null
+++ b/src/core/data.py
@@ -0,0 +1,81 @@
+import os.path
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data as data
+
+
+# Data builder utilities
+def build_train_dataloader(cls, configs, C):
+    return data.DataLoader(
+        cls(**configs),
+        batch_size=C['batch_size'],
+        shuffle=True,
+        num_workers=C['num_workers'],
+        pin_memory=C['device']!='cpu',
+        drop_last=True
+    )
+
+
+def build_eval_dataloader(cls, configs):
+    return data.DataLoader(
+        cls(**configs),
+        batch_size=1,
+        shuffle=False,
+        num_workers=1,
+        pin_memory=False,
+        drop_last=False
+    )
+
+
+def get_common_train_configs(C):
+    return dict(phase='train', repeats=C['repeats'])
+
+
+def get_common_eval_configs(C):
+    return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
+
+
+# Dataset prototype
+class DatasetBase(data.Dataset, metaclass=ABCMeta):
+    def __init__(
+        self, 
+        root, phase,
+        transforms,
+        repeats, 
+        subset
+    ):
+        super().__init__()
+        self.root = os.path.expanduser(root)
+        if not os.path.exists(self.root):
+            raise FileNotFoundError
+        # phase stands for the working mode,
+        # 'train' for training and 'eval' for validating or testing.
+        assert phase in ('train', 'eval')
+        # subset is the sub-dataset to use.
+        # For some datasets there are three subsets,
+        # while for others there are only train and test(val).
+        assert subset in ('train', 'val', 'test')
+        self.phase = phase
+        self.transforms = transforms
+        self.repeats = int(repeats)
+        # Use 'train' subset during training.
+        self.subset = 'train' if self.phase == 'train' else subset
+
+    def __len__(self):
+        return self.len * self.repeats
+
+    def __getitem__(self, index):
+        if index >= len(self):
+            raise IndexError
+        index = index % self.len
+
+        item = self.fetch_and_preprocess(index)
+
+        return item
+
+    @abstractmethod
+    def fetch_and_preprocess(self, index):
+        return None
diff --git a/src/core/factories.py b/src/core/factories.py
index 2e3cc1c..c39e426 100644
--- a/src/core/factories.py
+++ b/src/core/factories.py
@@ -1,6 +1,7 @@
-from functools import wraps
+# from functools import wraps
 from inspect import isfunction, isgeneratorfunction, getmembers
-from collections.abc import Iterable
+from collections.abc import Sequence
+from abc import ABC, ABCMeta
 from itertools import chain
 from importlib import import_module
 
@@ -8,73 +9,76 @@ import torch
 import torch.nn as nn
 import torch.utils.data as data
 
-import constants
-import utils.metrics as metrics
-from utils.misc import R
-from data.augmentation import *
+from .misc import (R, MODELS, OPTIMS, CRITNS, DATA)
 
 
-class _Desc:
+class _AttrDesc:
     def __init__(self, key):
         self.key = key
     def __get__(self, instance, owner):
-        return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
-    def __set__(self, instance, values):
-        if not (isinstance(values, Iterable) and len(values)==len(instance)):
-            raise TypeError("incorrect type or number of values")
-        for i, v in zip(range(len(instance)), values):
-            setattr(instance[i], self.key, v)
+        return tuple(getattr(ele, self.key) for ele in instance)
+    def __set__(self, instance, value):
+        for ele in instance:
+            setattr(ele, self.key, value)
 
 
 def _func_deco(func_name):
-    def _wrapper(self, *args):
-        return tuple(getattr(ins, func_name)(*args) for ins in self)
+    # FIXME: The signature of the wrapped function will be lost.
+    def _wrapper(self, *args, **kwargs):
+        return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self)
     return _wrapper
 
 
 def _generator_deco(func_name):
+    # FIXME: The signature of the wrapped function will be lost.
     def _wrapper(self, *args, **kwargs):
-        for ins in self:
-            yield from getattr(ins, func_name)(*args, **kwargs)
+        for ele in self:
+            yield from getattr(ele, func_name)(*args, **kwargs)
     return _wrapper
 
 
 # Duck typing
-class Duck(tuple):
+class Duck(Sequence, ABC):
     __ducktype__ = object
-    def __new__(cls, *args):
-        if any(not isinstance(a, cls.__ducktype__) for a in args):
-            raise TypeError("please check the input type")
-        return tuple.__new__(cls, args)
+    def __init__(self, *args):
+        if any(not isinstance(arg, self.__ducktype__) for arg in args):
+            raise TypeError("Please check the input type.")
+        self._seq = tuple(args)
+
+    def __getitem__(self, key):
+        return self._seq[key]
 
-    def __add__(self, tup):
-        raise NotImplementedError
+    def __len__(self):
+        return len(self._seq)
 
-    def __mul__(self, tup):
-        raise NotImplementedError
+    def __repr__(self):
+        return repr(self._seq)
 
 
-class DuckMeta(type):
+class DuckMeta(ABCMeta):
     def __new__(cls, name, bases, attrs):
-        assert len(bases) == 1
-        for k, v in getmembers(bases[0]):
-            if k.startswith('__'):
-                continue
-            if isgeneratorfunction(v):
-                attrs.setdefault(k, _generator_deco(k))
-            elif isfunction(v):
-                attrs.setdefault(k, _func_deco(k))
-            else:
-                attrs.setdefault(k, _Desc(k))
+        assert len(bases) == 1  # Multiple inheritance is not yet supported.
+        members = dict(getmembers(bases[0]))  # Trade space for time
+
+        for k in attrs['__ava__']:
+            if k in members:
+                v = members[k]
+                if isgeneratorfunction(v):
+                    attrs.setdefault(k, _generator_deco(k))
+                elif isfunction(v):
+                    attrs.setdefault(k, _func_deco(k))
+                else:
+                    attrs.setdefault(k, _AttrDesc(k))
         attrs['__ducktype__'] = bases[0]
         return super().__new__(cls, name, (Duck,), attrs)
 
 
 class DuckModel(nn.Module):
+    __ava__ = ('state_dict', 'load_state_dict', 'forward', '__call__', 'train', 'eval', 'to', 'training')
     def __init__(self, *models):
         super().__init__()
-        ## XXX: The state_dict will be a little larger in size
-        # Since some extra bytes are stored in every key
+        # XXX: The state_dict will be a little larger in size,
+        # since some extra bytes are stored in every key.
         self._m = nn.ModuleList(models)
 
     def __len__(self):
@@ -83,27 +87,39 @@ class DuckModel(nn.Module):
     def __getitem__(self, idx):
         return self._m[idx]
 
+    def __contains__(self, m):
+        return m in self._m
+
     def __repr__(self):
         return repr(self._m)
 
+    def forward(self, *args, **kwargs):
+        return tuple(m(*args, **kwargs) for m in self._m)
+
+
+Duck.register(DuckModel)
+
 
 class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
-    # Cuz this is an instance method
+    __ava__ = ('param_groups', 'state_dict', 'load_state_dict', 'zero_grad', 'step')
+    # An instance attribute can not be automatically handled by metaclass
     @property
     def param_groups(self):
-        return list(chain.from_iterable(ins.param_groups for ins in self))
+        return list(chain.from_iterable(ele.param_groups for ele in self))
 
-    # This is special in dispatching
+    # Sepcial dispatching rule
     def load_state_dict(self, state_dicts):
         for optim, state_dict in zip(self, state_dicts):
             optim.load_state_dict(state_dict)
 
 
 class DuckCriterion(nn.Module, metaclass=DuckMeta):
+    __ava__ = ('forward', '__call__', 'train', 'eval', 'to')
     pass
 
 
-class DuckDataset(data.Dataset, metaclass=DuckMeta):
+class DuckDataLoader(data.DataLoader, metaclass=DuckMeta):
+    __ava__ = ()
     pass
 
 
@@ -116,140 +132,45 @@ def _import_module(pkg: str, mod: str, rel=False):
 
 
 def single_model_factory(model_name, C):
-    name = model_name.strip().upper()
-    if name == 'SIAMUNET_CONC':
-        from models.siamunet_conc import SiamUnet_conc
-        return SiamUnet_conc(C.num_feats_in, 2)
-    elif name == 'SIAMUNET_DIFF':
-        from models.siamunet_diff import SiamUnet_diff
-        return SiamUnet_diff(C.num_feats_in, 2)
-    elif name == 'EF':
-        from models.unet import Unet
-        return Unet(C.num_feats_in, 2)
+    builder_name = '_'.join([model_name, C['model'], C['dataset'], 'model'])
+    if builder_name in MODELS:
+        return MODELS[builder_name](C)
+    builder_name = '_'.join([model_name, C['dataset'], 'model'])
+    if builder_name in MODELS:
+        return MODELS[builder_name](C)
+    builder_name = '_'.join([model_name, 'model'])
+    if builder_name in MODELS:
+        return MODELS[builder_name](C)
     else:
-        raise NotImplementedError("{} is not a supported architecture".format(model_name))
+        raise NotImplementedError("{} is not a supported architecture.".format(model_name))
 
 
 def single_optim_factory(optim_name, params, C):
-    optim_name = optim_name.strip()
-    name = optim_name.upper()
-    if name == 'ADAM':
-        return torch.optim.Adam(
-            params, 
-            betas=(0.9, 0.999),
-            lr=C.lr,
-            weight_decay=C.weight_decay
-        )
-    elif name == 'SGD':
-        return torch.optim.SGD(
-            params, 
-            lr=C.lr,
-            momentum=0.9,
-            weight_decay=C.weight_decay
-        )
-    else:
-        raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
-
+    builder_name = '_'.join([optim_name, 'optim'])
+    if builder_name not in OPTIMS:
+        raise NotImplementedError("{} is not a supported optimizer type.".format(optim_name))
+    return OPTIMS[builder_name](params, C)
+        
 
 def single_critn_factory(critn_name, C):
-    import losses
-    critn_name = critn_name.strip()
-    try:
-        criterion, params = {
-            'L1': (nn.L1Loss, ()),
-            'MSE': (nn.MSELoss, ()),
-            'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
-            'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
-        }[critn_name.upper()]
-        return criterion(*params)
-    except KeyError:
-        raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
-
-
-def _get_basic_configs(ds_name, C):
-    if ds_name == 'OSCD':
-        return dict(
-            root = constants.IMDB_OSCD
-        )
-    elif ds_name.startswith('AC'):
-        return dict(
-            root = constants.IMDB_AIRCHANGE
-        )
-    elif ds_name == 'Lebedev':
-        return dict(
-            root = constants.IMDB_LEBEDEV
-        )
-    else:
-        return dict()
+    builder_name = '_'.join([critn_name, 'critn'])
+    if builder_name not in CRITNS:
+        raise NotImplementedError("{} is not a supported criterion type.".format(critn_name))
+    return CRITNS[builder_name](C)
         
 
-def single_train_ds_factory(ds_name, C):
-    ds_name = ds_name.strip()
-    module = _import_module('data', ds_name)
-    dataset = getattr(module, ds_name+'Dataset')
-    configs = dict(
-        phase='train', 
-        transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
-        repeats=C.repeats
-    )
-    
-    # Update some common configurations
-    configs.update(_get_basic_configs(ds_name, C))
-
-    # Set phase-specific ones
-    if ds_name == 'Lebedev':
-        configs.update(
-            dict(
-                subsets = ('real',)
-            )
-        )
+def single_data_factory(dataset_name, phase, C):
+    builder_name = '_'.join([dataset_name, C['dataset'], C['model'], phase, 'dataset'])
+    if builder_name in DATA:
+        return DATA[builder_name](C)
+    builder_name = '_'.join([dataset_name, C['model'], phase, 'dataset'])
+    if builder_name in DATA:
+        return DATA[builder_name](C)
+    builder_name = '_'.join([dataset_name, phase, 'dataset'])
+    if builder_name in DATA:
+        return DATA[builder_name](C)
     else:
-        pass
-    
-    dataset_obj = dataset(**configs)
-    
-    return data.DataLoader(
-        dataset_obj,
-        batch_size=C.batch_size,
-        shuffle=True,
-        num_workers=C.num_workers,
-        pin_memory=not (C.device == 'cpu'), drop_last=True
-    )
-
-
-def single_val_ds_factory(ds_name, C):
-    ds_name = ds_name.strip()
-    module = _import_module('data', ds_name)
-    dataset = getattr(module, ds_name+'Dataset')
-    configs = dict(
-        phase='val', 
-        transforms=(None, None, None),
-        repeats=1
-    )
-
-    # Update some common configurations
-    configs.update(_get_basic_configs(ds_name, C))
-
-    # Set phase-specific ones
-    if ds_name == 'Lebedev':
-        configs.update(
-            dict(
-                subsets = ('real',)
-            )
-        )
-    else:
-        pass
-    
-    dataset_obj = dataset(**configs)  
-
-    # Create eval set
-    return data.DataLoader(
-        dataset_obj,
-        batch_size=1,
-        shuffle=False,
-        num_workers=1,
-        pin_memory=False, drop_last=False
-    )
+        raise NotImplementedError("{} is not a supported dataset.".format(dataset_name))
 
 
 def _parse_input_names(name_str):
@@ -268,7 +189,7 @@ def optim_factory(optim_names, models, C):
     name_list = _parse_input_names(optim_names)
     num_models = len(models) if isinstance(models, DuckModel) else 1
     if len(name_list) != num_models:
-        raise ValueError("the number of optimizers does not match the number of models")
+        raise ValueError("The number of optimizers does not match the number of models.")
     
     if num_models > 1:
         optims = []
@@ -298,16 +219,7 @@ def critn_factory(critn_names, C):
 
 def data_factory(dataset_names, phase, C):
     name_list = _parse_input_names(dataset_names)
-    if phase not in ('train', 'val'):
-        raise ValueError("phase should be either 'train' or 'val'")
-    fact = globals()['single_'+phase+'_ds_factory']
     if len(name_list) > 1:
-        return DuckDataset(*(fact(name, C) for name in name_list))
+        return DuckDataLoader(*(single_data_factory(name, phase, C) for name in name_list))
     else:
-        return fact(dataset_names, C)
-
-
-def metric_factory(metric_names, C):
-    from utils import metrics
-    name_list = _parse_input_names(metric_names)
-    return [getattr(metrics, name.strip())() for name in name_list]
+        return single_data_factory(dataset_names, phase, C)
\ No newline at end of file
diff --git a/src/utils/misc.py b/src/core/misc.py
similarity index 53%
rename from src/utils/misc.py
rename to src/core/misc.py
index 7b194b5..20122ae 100644
--- a/src/utils/misc.py
+++ b/src/core/misc.py
@@ -1,10 +1,12 @@
 import logging
 import os
+import os.path as osp
 import sys
 from time import localtime
-from collections import OrderedDict
+from collections import OrderedDict, deque
 from weakref import proxy
 
+
 FORMAT_LONG = "[%(asctime)-15s %(funcName)s] %(message)s"
 FORMAT_SHORT = "%(message)s"
 
@@ -16,6 +18,7 @@ class _LessThanFilter(logging.Filter):
     def filter(self, record):
         return record.levelno < self.max_level
 
+
 class Logger:
     _count = 0
 
@@ -38,11 +41,11 @@ class Logger:
             self._logger.addHandler(self._scrn_handler)
             
         if log_dir and phase:
-            self.log_path = os.path.join(log_dir,
-                    '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
+            self.log_path = osp.join(log_dir,
+                    "{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log".format(
                         phase, *localtime()[:6]
                       ))
-            self.show_nl("log into {}\n\n".format(self.log_path))
+            self.show_nl("Log into {}\n\n".format(self.log_path))
             self._file_handler = logging.FileHandler(filename=self.log_path)
             self._file_handler.setLevel(logging.DEBUG)
             self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
@@ -58,7 +61,7 @@ class Logger:
     def dump(self, *args, **kwargs):
         return self._logger.debug(*args, **kwargs)
 
-    def warning(self, *args, **kwargs):
+    def warn(self, *args, **kwargs):
         return self._logger.warning(*args, **kwargs)
 
     def error(self, *args, **kwargs):
@@ -67,16 +70,7 @@ class Logger:
     def fatal(self, *args, **kwargs):
         return self._logger.critical(*args, **kwargs)
 
-    @staticmethod
-    def make_desc(counter, total, *triples, opt_str=''):
-        desc = "[{}/{}] {}".format(counter, total, opt_str)
-        # The three elements of each triple are
-        # (name to display, AverageMeter object, formatting string)
-        for name, obj, fmt in triples:
-            desc += (" {} {obj.val:"+fmt+"} ({obj.avg:"+fmt+"})").format(name, obj=obj)
-        return desc
-
-_default_logger = Logger()
+_logger = Logger()
 
 
 class _WeakAttribute:
@@ -91,12 +85,12 @@ class _WeakAttribute:
 
 
 class _TreeNode:
-    _sep = '/'
-    _none = None
-
     parent = _WeakAttribute()   # To avoid circular reference
 
-    def __init__(self, name, value=None, parent=None, children=None):
+    def __init__(
+        self, name, value=None, parent=None, children=None,
+        sep='/', none_val=None
+    ):
         super().__init__()
         self.name = name
         self.val = value
@@ -106,46 +100,42 @@ class _TreeNode:
             for child in children:
                 self._add_child(child)
         self.path = name
+        self._sep = sep
+        self._none = none_val
     
-    def get_child(self, name, def_val=None):
-        return self.children.get(name, def_val)
+    def get_child(self, name):
+        return self.children.get(name, None)
 
-    def set_child(self, name, val=None):
-        r"""
-            Set the value of an existing node. 
-            If the node does not exist, return nothing
-        """
-        child = self.get_child(name)
-        if child is not None:
-            child.val = val
-
-        return child
+    def add_placeholder(self, name):
+        return self.add_child(name, value=self._none)
 
-    def add_place_holder(self, name):
-        return self.add_child(name, val=self._none)
-
-    def add_child(self, name, val):
+    def add_child(self, name, value, warning=False):
         r"""
-            If not exists or is a placeholder, create it
-            Otherwise skips and returns the existing node
+        If node does not exist or is a placeholder, create it,
+        otherwise skip and return the existing node.
         """
-        child = self.get_child(name, None)
+        child = self.get_child(name)
         if child is None:
-            child = _TreeNode(name, val, parent=self)
+            child = _TreeNode(name, value, parent=self, sep=self._sep, none_val=self._none)
             self._add_child(child)
-        elif child.val == self._none:
-            # Retain the links of the placeholder
-            # i.e. just fill in it
-            child.val = val
-
+        elif child.is_placeholder():
+            # Retain the links of a placeholder,
+            # i.e. just fill in it.
+            child.val = value
+        else:
+            if warning: 
+                _logger.warn("Node already exists!")
         return child
 
     def is_leaf(self):
         return len(self.children) == 0
 
+    def is_placeholder(self):
+        return self.val == self._none
+
     def __repr__(self):
         try:
-            repr = self.path + ' ' + str(self.val)
+            repr = self.path + " " + str(self.val)
         except TypeError:
             repr = self.path
         return repr
@@ -157,7 +147,10 @@ class _TreeNode:
         return self.get_child(key)
 
     def _add_child(self, node):
-        r""" Into children dictionary and set path and parent """
+        r"""
+        Add a child node into self.children.
+        If the node already exists, just update its information.
+        """
         self.children.update({
             node.name: node
         })
@@ -166,8 +159,8 @@ class _TreeNode:
 
     def apply(self, func):
         r"""
-            Apply a callback function on ALL descendants
-            This is useful for the recursive traversal
+        Apply a callback function to ALL descendants.
+        This is useful for recursive traversal.
         """
         ret = [func(self)]
         for _, node in self.children.items():
@@ -175,69 +168,53 @@ class _TreeNode:
         return ret
 
     def bfs_tracker(self):
-        queue = []
-        queue.insert(0, self)
+        queue = deque()
+        queue.append(self)
         while(queue):
-            curr = queue.pop()
+            curr = queue.popleft()
             yield curr
             if curr.is_leaf():
                 continue
             for c in curr.children.values():
-                queue.insert(0, c)
+                queue.append(c)
 
 
 class _Tree:
     def __init__(
-        self, name, value=None, strc_ele=None, 
-        sep=_TreeNode._sep, def_val=_TreeNode._none
+        self, name, value=None, eles=None, 
+        sep='/', none_val=None
     ):
         super().__init__()
         self._sep = sep
-        self._def_val = def_val
+        self._none = none_val
         
-        self.root = _TreeNode(name, value, parent=None, children={})
-        if strc_ele is not None:
-            assert isinstance(strc_ele, dict)
-            # This is to avoid mutable parameter default
-            self.build_tree(OrderedDict(strc_ele or {}))
+        self.root = _TreeNode(name, value, parent=None, children={}, sep=self._sep, none_val=self._none)
+        if eles is not None:
+            assert isinstance(eles, dict)
+            self.build_tree(OrderedDict(eles or {}))
 
     def build_tree(self, elements):
-        # The siblings could be out-of-order
+        # The order of the siblings is not retained
         for path, ele in elements.items():
             self.add_node(path, ele)
 
-    def get_root(self):
-        r""" Get separated root node """
-        return _TreeNode(
-            self.root.name, self.root.value, 
-            parent=None, children=None
-        )
-
     def __repr__(self):
-        return self.__dumps__()
-        
-    def __dumps__(self):
-        r""" Dump to string """
-        _str = ''
+        _str = ""
         # DFS
         stack = []
         stack.append((self.root, 0))
         while(stack):
             root, layer = stack.pop()
-            _str += ' '*layer + '-' + root.__repr__() + '\n'
+            _str += " "*layer + "-" + root.__repr__() + "\n"
 
             if root.is_leaf():
                 continue
-            # Note that the order of the siblings is not retained
-            for c in reversed(list(root.children.values())):
+            # Note that the siblings are printed in alphabetical order.
+            for c in sorted(list(root.children.values()), key=lambda n: n.name, reverse=True):
                 stack.append((c, layer+1))
 
         return _str
 
-    def vis(self):
-        r""" Visualize the structure of the tree """
-        _default_logger.show(self.__dumps__())
-
     def __contains__(self, obj):
         return any(self.perform(lambda node: obj in node))
 
@@ -246,15 +223,16 @@ class _Tree:
 
     def get_node(self, tar, mode='name'):
         r"""
-            This is different from the travasal in that
-            the search allows early stop
+        This is different from a travasal in that this search allows early stop.
         """
+        assert mode in ('name', 'path', 'val')
         if mode == 'path':
             nodes = self.parse_path(tar)
             root = self.root
             for r in nodes:
                 if root is None:
-                    root = root.get_child(r)
+                    break
+                root = root.get_child(r)
             return root
         else:
             # BFS
@@ -264,28 +242,20 @@ class _Tree:
             for node in bfs_tracker:
                 if getattr(node, mode) == tar:
                     return node
-        return
+            return None
 
-    def set_node(self, path, val):
-        node = self.get_node(path, mode=path)
-        if node is not None:
-            node.val = val
-        return node
-
-    def add_node(self, path, val=None):
+    def add_node(self, path, val):
         if not path.strip():
-            raise ValueError("the path is null")
-        path = path.strip('/')
-        if val is None:
-            val = self._def_val
+            raise ValueError("The path is null.")
+        path = path.rstrip(self._sep)
         names = self.parse_path(path)
         root = self.root
         nodes = [root]
         for name in names[:-1]:
-            # Add placeholders
-            root = root.add_child(name, self._def_val)
+            # Add a placeholder or skip an existing node
+            root = root.add_placeholder(name)
             nodes.append(root)
-        root = root.add_child(names[-1], val)
+        root = root.add_child(names[-1], val, True)
         return root, nodes
 
     def parse_path(self, path):
@@ -296,22 +266,29 @@ class _Tree:
         
         
 class OutPathGetter:
-    def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs):
+    def __init__(self, root='', log='logs', out='out', weight='weights', suffix='', **subs):
         super().__init__()
-        self._root = root.rstrip('/')    # Work robustly for multiple ending '/'s
+        self._root = root.rstrip(os.sep)    # Work robustly on multiple ending '/'s
         if len(self._root) == 0 and len(root) > 0:
-            self._root = '/'    # In case of the system root dir
+            self._root = os.sep    # In case of the system root dir in linux
         self._suffix = suffix
+
         self._keys = dict(log=log, out=out, weight=weight, **subs)
+        for k, v in self._keys.items():
+            v_ = v.rstrip(os.sep)
+            if len(v_) == 0 or not self.check_path(v_):
+                _logger.warn("{} is not a valid path.".format(v))
+                continue
+            self._keys[k] = v_
+
         self._dir_tree = _Tree(
             self._root, 'root',
-            strc_ele=dict(zip(self._keys.values(), self._keys.keys())),
-            sep='/', 
-            def_val=''
+            eles=dict(zip(self._keys.values(), self._keys.keys())),
+            sep=os.sep, none_val=''
         )
 
-        self.update_keys(False)
-        self.update_tree(False)
+        self.add_keys(False)
+        self.update_vfs(False)
 
         self.__counter = 0
 
@@ -326,89 +303,109 @@ class OutPathGetter:
     def root(self):
         return self._root
 
-    def _update_key(self, key, val, add=False, prefix=False):
-        if prefix:
-            val = os.path.join(self._root, val)
-        if add:
-            # Do not edit if exists
-            self._keys.setdefault(key, val)
-        else:
-            self._keys.__setitem__(key, val)
-
-    def _add_node(self, key, val, prefix=False):
-        if not prefix and key.startswith(self._root):
-            key = key[len(self._root)+1:]
-        return self._dir_tree.add_node(key, val)
+    def _add_key(self, key, val):
+        self._keys.setdefault(key, val)
 
-    def update_keys(self, verbose=False):
+    def add_keys(self, verbose=False):
         for k, v in self._keys.items():
-            self._update_key(k, v, prefix=True)
+            self._add_key(k, v)
         if verbose:
-            _default_logger.show(self._keys)
+            _logger.show(self._keys)
         
-    def update_tree(self, verbose=False):
+    def update_vfs(self, verbose=False):
         self._dir_tree.perform(lambda x: self.make_dir(x.path))
         if verbose:
-            _default_logger.show("\nFolder structure:")
-            _default_logger.show(self._dir_tree)
+            _logger.show("\nFolder structure:")
+            _logger.show(self._dir_tree)
+
+    @staticmethod
+    def check_path(path):
+        # This is to prevent stuff like A/../B or A/./.././C.d
+        # Note that paths like A.B/.C/D are not supported, either.
+        return osp.dirname(path).find('.') == -1
 
     @staticmethod
     def make_dir(path):
-        if not os.path.exists(path):
+        if not osp.exists(path):
             os.mkdir(path)
+        elif not osp.isdir(path):
+            raise RuntimeError("Cannot create directory.")
 
     def get_dir(self, key):
-        return self._keys.get(key, '') if key != 'root' else self.root
+        return osp.join(self.root, self._keys[key])
 
     def get_path(
         self, key, file, 
         name='', auto_make=False, 
-        suffix=True, underline=False
+        suffix=False, underline=True
     ):
-        folder = self.get_dir(key)
-        if len(folder) < 1:
-            raise KeyError("key not found") 
+        if len(file) == 0:
+            return self.get_dir(key)
+        if not self.check_path(file):
+            raise ValueError("{} is not a valid path.".format(file))
+        folder = self._keys[key]
         if suffix:
-            path = os.path.join(folder, self.add_suffix(file, underline=underline))
+            path = osp.join(folder, self._add_suffix(file, underline=underline))
         else:
-            path = os.path.join(folder, file)
+            path = osp.join(folder, file)
 
         if auto_make:
-            base_dir = os.path.dirname(path)
-
+            base_dir = osp.dirname(path)
+            # O(n) search for base_dir
+            # Never update an existing key!
             if base_dir in self:
-                return path
-            if name:
-                self._update_key(name, base_dir, add=True)
-            '''
+                _logger.warn("Cannot assign a new key to an existing path!")
+                return osp.join(self.root, path)
+            node = self._dir_tree.get_node(base_dir, mode='path')
+            
+            # Note that if name is an empty string,
+            # the directory tree will be updated, but the name will not be added into self._keys.
+            if node is None or node.is_placeholder():
+                # Update directory tree
+                des, visit = self._dir_tree.add_node(base_dir, name)
+                # Create directories along the visiting path
+                for d in visit: self.make_dir(d.path)
+                self.make_dir(des.path)
             else:
-                name = 'new_{:03d}'.format(self.__counter)
-                self._update_key(name, base_dir, add=True)
-                self.__counter += 1
-            '''
-            des, visit = self._add_node(base_dir, name)
-            # Create directories along the visiting path
-            for d in visit: self.make_dir(d.path)
-            self.make_dir(des.path)
-        return path
-
-    def add_suffix(self, path, suffix='', underline=False):
+                node.val = name
+            if len(name) > 0:
+                # Add new key
+                self._add_key(name, base_dir)
+        return osp.join(self.root, path)
+
+    def _add_suffix(self, path, suffix='', underline=False):
         pos = path.rfind('.')
         if pos == -1:
             pos = len(path)
-        _suffix = self._suffix if len(suffix) < 1 else suffix
+        _suffix = self._suffix if len(suffix) == 0 else suffix
         return path[:pos] + ('_' if underline and _suffix else '') + _suffix + path[pos:]
 
     def __contains__(self, value):
-        return value in self._keys.values()
+        return value in self._keys.values() or value == self._root
+
+    def contains_key(self, key):
+        return key in self._keys
 
 
 class Registry(dict):
     def register(self, key, val):
-        if key in self: _default_logger.warning("key {} already registered".format(key))
+        if key in self: _logger.warn("Key {} has already been registered!".format(key))
         self[key] = val
+    
+    def register_func(self, key):
+        def _wrapper(func):
+            self.register(key, func)
+            return func
+        return _wrapper
 
 
+# Registry for global objects
 R = Registry()
-R.register('DEFAULT_LOGGER', _default_logger)
-register = R.register
\ No newline at end of file
+R.register('Logger', _logger)
+register = R.register
+
+# Registries for builders
+MODELS = Registry()
+OPTIMS = Registry()
+CRITNS = Registry()
+DATA = Registry()
\ No newline at end of file
diff --git a/src/core/trainer.py b/src/core/trainer.py
new file mode 100644
index 0000000..ae87413
--- /dev/null
+++ b/src/core/trainer.py
@@ -0,0 +1,232 @@
+import shutil
+import os
+from types import MappingProxyType
+from copy import deepcopy
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+import constants
+from .misc import Logger, OutPathGetter, R
+from .factories import (model_factory, optim_factory, critn_factory, data_factory)
+
+
+class Trainer(metaclass=ABCMeta):
+    def __init__(self, model, dataset, criterion, optimizer, settings):
+        super().__init__()
+        # Make a copy of settings in case of unexpected changes
+        context = deepcopy(settings)
+        # self.ctx is a proxy so that context will be read-only outside __init__
+        self.ctx = MappingProxyType(context)
+        self.mode = ('train', 'eval').index(context['cmd'])
+        self.debug = context['debug_on']
+        self.log = not context['log_off']
+        self.batch_size = context['batch_size']
+        self.checkpoint = context['resume']
+        self.load_checkpoint = (len(self.checkpoint)>0)
+        self.num_epochs = context['num_epochs']
+        self.lr = float(context['lr'])
+        self.track_intvl = int(context['track_intvl'])
+        self.device = torch.device(context['device'])
+
+        self.gpc = OutPathGetter(
+            root=os.path.join(context['exp_dir'], context['tag']), 
+            suffix=context['suffix']
+        )   # Global Path Controller
+        
+        self.logger = Logger(
+            scrn=True,
+            log_dir=self.gpc.get_dir('log') if self.log else '',
+            phase=context['cmd']
+        )
+        self.path = self.gpc.get_path
+
+        for k, v in sorted(context.items()):
+            self.logger.show("{}: {}".format(k,v))
+
+        self.model = model_factory(model, context)
+        self.model.to(self.device)
+        self.criterion = critn_factory(criterion, context)
+        self.criterion.to(self.device)
+
+        if self.is_training:
+            self.train_loader = data_factory(dataset, 'train', context)
+            self.eval_loader = data_factory(dataset, 'eval', context)
+            self.optimizer = optim_factory(optimizer, self.model, context)
+        else:
+            self.eval_loader = data_factory(dataset, 'eval', context)
+        
+        self.start_epoch = 0
+        self._init_acc_epoch = (0.0, -1)
+
+    @property
+    def is_training(self):
+        return self.mode == 0
+
+    @abstractmethod
+    def train_epoch(self, epoch):
+        pass
+
+    @abstractmethod
+    def evaluate_epoch(self, epoch):
+        return 0.0
+
+    def _write_prompt(self):
+        self.logger.dump(input("\nWrite some notes: "))
+
+    def run(self):
+        if self.is_training:
+            if self.log and not self.debug:
+                self._write_prompt()
+            self.train()
+        else:
+            self.evaluate()
+
+    def train(self):
+        if self.load_checkpoint:
+            self._resume_from_checkpoint()
+
+        max_acc, best_epoch = self._init_acc_epoch
+        lr = self.init_learning_rate()
+
+        for epoch in range(self.start_epoch, self.num_epochs):
+            self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
+
+            # Train for one epoch
+            self.model.train()
+            self.train_epoch(epoch)
+            
+            # Evaluate the model
+            self.logger.show_nl("Evaluate")
+            self.model.eval()
+            acc = self.evaluate_epoch(epoch=epoch)
+            
+            is_best = acc > max_acc
+            if is_best:
+                max_acc = acc
+                best_epoch = epoch
+            self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
+                                acc, epoch, max_acc, best_epoch))
+
+            # Do not save checkpoints in debugging mode
+            if not self.debug:
+                self._save_checkpoint(
+                    self.model.state_dict(), 
+                    self.optimizer.state_dict() if self.ctx['save_optim'] else {}, 
+                    (max_acc, best_epoch), epoch, is_best
+                )
+
+            lr = self.adjust_learning_rate(epoch, acc)
+        
+    def evaluate(self):
+        if self.checkpoint: 
+            if self._resume_from_checkpoint():
+                self.model.eval()
+                self.evaluate_epoch(self.start_epoch)
+        else:
+            self.logger.error("No checkpoint assigned!")
+
+    def init_learning_rate(self):
+        return self.lr
+
+    def adjust_learning_rate(self, epoch, acc):
+        return self.lr
+
+    def _resume_from_checkpoint(self):
+        # XXX: This could be slow!
+        if not os.path.isfile(self.checkpoint):
+            self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
+            return False
+
+        self.logger.show("=> Loading checkpoint '{}'...".format(self.checkpoint))
+        checkpoint = torch.load(self.checkpoint, map_location=self.device)
+
+        state_dict = self.model.state_dict()
+        ckp_dict = checkpoint.get('state_dict', checkpoint)
+        update_dict = {
+            k:v for k,v in ckp_dict.items() 
+            if k in state_dict and state_dict[k].shape == v.shape and state_dict[k].dtype == v.dtype
+        }
+        
+        num_to_update = len(update_dict)
+        if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
+            if not self.is_training and (num_to_update < len(state_dict)):
+                self.logger.error("=> Mismatched checkpoint for evaluation")
+                return False
+            self.logger.warn("Trying to load a mismatched checkpoint.")
+            if num_to_update == 0:
+                self.logger.error("=> No parameter is to be loaded.")
+                return False
+            else:
+                self.logger.warn("=> {} params are to be loaded.".format(num_to_update))
+        elif not self.ctx['anew'] or not self.is_training:
+            ckp_epoch = checkpoint.get('epoch', -1)
+            self.start_epoch = ckp_epoch+1
+            self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch))
+            if self.ctx['load_optim'] and self.is_training:
+                # XXX: Note that weight decay might be modified here.
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
+                self.logger.warn("Weight decay might have been modified.")
+
+        state_dict.update(update_dict)
+        self.model.load_state_dict(state_dict)
+
+        if self.start_epoch == 0:
+            self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint))
+        else:
+            self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format(
+                self.checkpoint, self.start_epoch-1, *self._init_acc_epoch
+                ))
+        return True
+        
+    def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best):
+        state = {
+            'epoch': epoch,
+            'state_dict': state_dict,
+            'optimizer': optim_state, 
+            'max_acc': max_acc
+        } 
+        # Save history
+        # epoch+1 instead of epoch is contained in the checkpoint name so that it will be easy for 
+        # one to recognize "the next start_epoch". 
+        history_path = self.path(
+            'weight', constants.CKP_COUNTED.format(e=epoch+1), 
+            suffix=True
+        )
+        if epoch % self.track_intvl == 0:
+            torch.save(state, history_path)
+        # Save latest
+        latest_path = self.path(
+            'weight', constants.CKP_LATEST, 
+            suffix=True
+        )
+        torch.save(state, latest_path)
+        if is_best:
+            shutil.copyfile(
+                latest_path, self.path(
+                    'weight', constants.CKP_BEST, 
+                    suffix=True
+                )
+            )
+
+
+class TrainerSwitcher:
+    r"""A simple utility class to help dispatch actions to different trainers."""
+    def __init__(self, *pairs):
+        self._trainer_list = list(pairs)
+
+    def __call__(self, args, return_obj=True):
+        for p, t in self._trainer_list:
+            if p(args):
+                return t(args) if return_obj else t
+        return None
+
+    def add_item(self, predicate, trainer):
+        # Newly added items have higher priority
+        self._trainer_list.insert(0, (predicate, trainer))
+
+    def add_default(self, trainer):
+        self._trainer_list.append((lambda: True, trainer))
+
+
+R.register('Trainer_switcher', TrainerSwitcher())
\ No newline at end of file
diff --git a/src/core/trainers.py b/src/core/trainers.py
deleted file mode 100644
index 68cde97..0000000
--- a/src/core/trainers.py
+++ /dev/null
@@ -1,303 +0,0 @@
-import shutil
-import os
-from types import MappingProxyType
-from copy import deepcopy
-
-import torch
-from skimage import io
-from tqdm import tqdm
-
-import constants
-from data.common import to_array
-from utils.misc import R
-from utils.metrics import AverageMeter
-from utils.utils import mod_crop
-from .factories import (model_factory, optim_factory, critn_factory, data_factory, metric_factory)
-
-
-class Trainer:
-    def __init__(self, model, dataset, criterion, optimizer, settings):
-        super().__init__()
-        context = deepcopy(settings)
-        self.ctx = MappingProxyType(vars(context))
-        self.mode = ('train', 'val').index(context.cmd)
-
-        self.logger = R['LOGGER']
-        self.gpc = R['GPC']     # Global Path Controller
-        self.path = self.gpc.get_path
-
-        self.batch_size = context.batch_size
-        self.checkpoint = context.resume
-        self.load_checkpoint = (len(self.checkpoint)>0)
-        self.num_epochs = context.num_epochs
-        self.lr = float(context.lr)
-        self.save = context.save_on or context.out_dir
-        self.out_dir = context.out_dir
-        self.track_intvl = int(context.track_intvl)
-        self.device = torch.device(context.device)
-        self.suffix_off = context.suffix_off
-
-        for k, v in sorted(self.ctx.items()):
-            self.logger.show("{}: {}".format(k,v))
-
-        self.model = model_factory(model, context)
-        self.model.to(self.device)
-        self.criterion = critn_factory(criterion, context)
-        self.criterion.to(self.device)
-        self.metrics = metric_factory(context.metrics, context)
-
-        if self.is_training:
-            self.train_loader = data_factory(dataset, 'train', context)
-            self.val_loader = data_factory(dataset, 'val', context)
-            self.optimizer = optim_factory(optimizer, self.model, context)
-        else:
-            self.val_loader = data_factory(dataset, 'val', context)
-        
-        self.start_epoch = 0
-        self._init_max_acc_and_epoch = (0.0, 0)
-
-    @property
-    def is_training(self):
-        return self.mode == 0
-
-    def train_epoch(self, epoch):
-        raise NotImplementedError
-
-    def validate_epoch(self, epoch=0, store=False):
-        raise NotImplementedError
-
-    def _write_prompt(self):
-        self.logger.dump(input("\nWrite some notes: "))
-
-    def run(self):
-        if self.is_training:
-            self._write_prompt()
-            self.train()
-        else:
-            self.evaluate()
-
-    def train(self):
-        if self.load_checkpoint:
-            self._resume_from_checkpoint()
-
-        max_acc, best_epoch = self._init_max_acc_and_epoch
-
-        for epoch in range(self.start_epoch, self.num_epochs):
-            lr = self._adjust_learning_rate(epoch)
-
-            self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
-
-            # Train for one epoch
-            self.train_epoch(epoch)
-            
-            # Clear the history of metric objects
-            for m in self.metrics:
-                m.reset()
-                
-            # Evaluate the model on validation set
-            self.logger.show_nl("Validate")
-            acc = self.validate_epoch(epoch=epoch, store=self.save)
-            
-            is_best = acc > max_acc
-            if is_best:
-                max_acc = acc
-                best_epoch = epoch
-            self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
-                                acc, epoch, max_acc, best_epoch))
-
-            # The checkpoint saves next epoch
-            self._save_checkpoint(
-                self.model.state_dict(), 
-                self.optimizer.state_dict() if self.ctx['save_optim'] else {}, 
-                (max_acc, best_epoch), epoch+1, is_best
-            )
-        
-    def evaluate(self):
-        if self.checkpoint: 
-            if self._resume_from_checkpoint():
-                self.validate_epoch(self.ckp_epoch, self.save)
-        else:
-            self.logger.warning("Warning: no checkpoint assigned!")
-
-    def _adjust_learning_rate(self, epoch):
-        if self.ctx['lr_mode'] == 'step':
-            lr = self.lr * (0.5 ** (epoch // self.ctx['step']))
-        elif self.ctx['lr_mode'] == 'poly':
-            lr = self.lr * (1 - epoch / self.num_epochs) ** 1.1
-        elif self.ctx['lr_mode'] == 'const':
-            lr = self.lr
-        else:
-            raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode']))
-
-        for param_group in self.optimizer.param_groups:
-            param_group['lr'] = lr
-        return lr
-
-    def _resume_from_checkpoint(self):
-        ## XXX: This could be slow!
-        if not os.path.isfile(self.checkpoint):
-            self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
-            return False
-
-        self.logger.show("=> Loading checkpoint '{}'".format(
-                        self.checkpoint))
-        checkpoint = torch.load(self.checkpoint, map_location=self.device)
-
-        state_dict = self.model.state_dict()
-        ckp_dict = checkpoint.get('state_dict', checkpoint)
-        update_dict = {k:v for k,v in ckp_dict.items() 
-            if k in state_dict and state_dict[k].shape == v.shape}
-        
-        num_to_update = len(update_dict)
-        if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
-            if not self.is_training and (num_to_update < len(state_dict)):
-                self.logger.error("=> Mismatched checkpoint for evaluation")
-                return False
-            self.logger.warning("Warning: trying to load an mismatched checkpoint.")
-            if num_to_update == 0:
-                self.logger.error("=> No parameter is to be loaded.")
-                return False
-            else:
-                self.logger.warning("=> {} params are to be loaded.".format(num_to_update))
-        elif (not self.ctx['anew']) or not self.is_training:
-            self.start_epoch = checkpoint.get('epoch', 0)
-            max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch))
-            # For backward compatibility
-            if isinstance(max_acc_and_epoch, (float, int)):
-                self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
-            else:
-                self._init_max_acc_and_epoch = max_acc_and_epoch
-            if self.ctx['load_optim'] and self.is_training:
-                # Note that weight decay might be modified here
-                self.optimizer.load_state_dict(checkpoint['optimizer'])
-
-        state_dict.update(update_dict)
-        self.model.load_state_dict(state_dict)
-
-        self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format(
-            self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch
-            ))
-        return True
-        
-    def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best):
-        state = {
-            'epoch': epoch,
-            'state_dict': state_dict,
-            'optimizer': optim_state, 
-            'max_acc': max_acc
-        } 
-        # Save history
-        history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True)
-        if epoch % self.track_intvl == 0:
-            torch.save(state, history_path)
-        # Save latest
-        latest_path = self.path(
-            'weight', constants.CKP_LATEST, 
-            underline=True
-        )
-        torch.save(state, latest_path)
-        if is_best:
-            shutil.copyfile(
-                latest_path, self.path(
-                    'weight', constants.CKP_BEST, 
-                    underline=True
-                )
-            )
-    
-    @property
-    def ckp_epoch(self):
-        # Get current epoch of the checkpoint
-        # For dismatched ckp or no ckp, set to 0
-        return max(self.start_epoch-1, 0)
-
-    def save_image(self, file_name, image, epoch):
-        file_path = os.path.join(
-            'epoch_{}/'.format(epoch),
-            self.out_dir,
-            file_name
-        )
-        out_path = self.path(
-            'out', file_path,
-            suffix=not self.suffix_off,
-            auto_make=True,
-            underline=True
-        )
-        return io.imsave(out_path, image)
-
-
-class CDTrainer(Trainer):
-    def __init__(self, arch, dataset, optimizer, settings):
-        super().__init__(arch, dataset, 'NLL', optimizer, settings)
-
-    def train_epoch(self, epoch):
-        losses = AverageMeter()
-        len_train = len(self.train_loader)
-        pb = tqdm(self.train_loader)
-        
-        self.model.train()
-
-        for i, (t1, t2, label) in enumerate(pb):
-            t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
-            
-            prob = self.model(t1, t2)
-            
-            loss = self.criterion(prob, label)
-            
-            losses.update(loss.item(), n=self.batch_size)
-
-            # Compute gradients and do SGD step
-            self.optimizer.zero_grad()
-            loss.backward()
-            self.optimizer.step()
-
-            desc = self.logger.make_desc(
-                i+1, len_train,
-                ('loss', losses, '.4f')
-            )
-
-            pb.set_description(desc)
-            self.logger.dump(desc)
-
-    def validate_epoch(self, epoch=0, store=False):
-        self.logger.show_nl("Epoch: [{0}]".format(epoch))
-        losses = AverageMeter()
-        len_val = len(self.val_loader)
-        pb = tqdm(self.val_loader)
-
-        self.model.eval()
-
-        with torch.no_grad():
-            for i, (name, t1, t2, label) in enumerate(pb):
-                if self.is_training and i >= 16: 
-                    # Do not validate all images on training phase
-                    pb.close()
-                    self.logger.warning("validation ends early")
-                    break
-                t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
-
-                prob = self.model(t1, t2)
-
-                loss = self.criterion(prob, label)
-                losses.update(loss.item(), n=self.batch_size)
-
-                # Convert to numpy arrays
-                CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
-                label = to_array(label[0]).astype('uint8')
-                for m in self.metrics:
-                    m.update(CM, label)
-
-                desc = self.logger.make_desc(
-                    i+1, len_val,
-                    ('loss', losses, '.4f'),
-                    *(
-                        (m.__name__, m, '.4f')
-                        for m in self.metrics
-                    )
-                )
-                pb.set_description(desc)
-                self.logger.dump(desc)
-                    
-                if store:
-                    self.save_image(name[0], CM*255, epoch)
-
-        return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
diff --git a/src/data/Lebedev.py b/src/data/Lebedev.py
deleted file mode 100644
index ecdf5ba..0000000
--- a/src/data/Lebedev.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from glob import glob
-from os.path import join, basename
-
-import numpy as np
-
-from . import CDDataset
-from .common import default_loader
-
-class LebedevDataset(CDDataset):
-    def __init__(
-        self, 
-        root, phase='train', 
-        transforms=(None, None, None), 
-        repeats=1,
-        subsets=('real', 'with_shift', 'without_shift')
-    ):
-        self.subsets = subsets
-        super().__init__(root, phase, transforms, repeats)
-
-    def _read_file_paths(self):
-        t1_list, t2_list, label_list = [], [], []
-
-        for subset in self.subsets:
-            # Get subset directory
-            if subset == 'real':
-                subset_dir = join(self.root, 'Real', 'subset')
-            elif subset == 'with_shift':
-                subset_dir = join(self.root, 'Model', 'with_shift')
-            elif subset == 'without_shift':
-                subset_dir = join(self.root, 'Model', 'without_shift')
-            else:
-                raise RuntimeError('unrecognized key encountered')
-
-            pattern = '*.bmp' if (subset == 'with_shift' and self.phase in ('test', 'val')) else '*.jpg'
-            refs = sorted(glob(join(subset_dir, self.phase, 'OUT', pattern)))
-            t1s = (join(subset_dir, self.phase, 'A', basename(ref)) for ref in refs)
-            t2s = (join(subset_dir, self.phase, 'B', basename(ref)) for ref in refs)
-
-            label_list.extend(refs)
-            t1_list.extend(t1s)
-            t2_list.extend(t2s)
-
-        return t1_list, t2_list, label_list
-
-    def fetch_label(self, label_path):
-        # To {0,1}
-        return (super().fetch_label(label_path) > 127).astype(np.uint8)  
\ No newline at end of file
diff --git a/src/data/__init__.py b/src/data/__init__.py
index 1f81254..bb29ebc 100644
--- a/src/data/__init__.py
+++ b/src/data/__init__.py
@@ -1,69 +1,62 @@
-from os.path import join, expanduser, basename, exists, splitext
+from os.path import basename, splitext
 
 import torch
 import torch.utils.data as data
 import numpy as np
 
-from .common import (default_loader, to_tensor)
+from core.data import DatasetBase
+from utils.data_utils import (default_loader, to_tensor)
 
 
-class CDDataset(data.Dataset):
+class CDDataset(DatasetBase):
     def __init__(
         self, 
         root, phase,
         transforms,
-        repeats
+        repeats, 
+        subset
     ):
-        super().__init__()
-        self.root = expanduser(root)
-        if not exists(self.root):
-            raise FileNotFoundError
-        self.phase = phase
-        self.transforms = list(transforms)
+        super().__init__(root, phase, transforms, repeats, subset)
+        self.transforms = list(self.transforms)
         self.transforms += [None]*(3-len(self.transforms))
-        self.repeats = int(repeats)
-
-        self.t1_list, self.t2_list, self.label_list = self._read_file_paths()
-        self.len = len(self.label_list)
+        self.t1_list, self.t2_list, self.tar_list = self._read_file_paths()
+        self.len = len(self.tar_list)
 
     def __len__(self):
         return self.len * self.repeats
 
-    def __getitem__(self, index):
-        if index >= len(self):
-            raise IndexError
-        index = index % self.len
-        
+    def fetch_and_preprocess(self, index):
         t1 = self.fetch_image(self.t1_list[index])
         t2 = self.fetch_image(self.t2_list[index])
-        label = self.fetch_label(self.label_list[index])
-        t1, t2, label = self.preprocess(t1, t2, label)
+        tar = self.fetch_target(self.tar_list[index])
+        t1, t2, tar = self.preprocess(t1, t2, tar)
+        
         if self.phase == 'train':
-            return t1, t2, label
+            return t1, t2, tar
         else:
-            return self.get_name(index), t1, t2, label
+            return self.get_name(index), t1, t2, tar
 
     def _read_file_paths(self):
         raise NotImplementedError
         
-    def fetch_label(self, label_path):
-        return default_loader(label_path)
+    def fetch_target(self, target_path):
+        return default_loader(target_path)
 
     def fetch_image(self, image_path):
         return default_loader(image_path)
 
     def get_name(self, index):
-        return splitext(basename(self.label_list[index]))[0]+'.bmp'
+        return splitext(basename(self.tar_list[index]))[0]+'.bmp'
 
-    def preprocess(self, t1, t2, label):
+    def preprocess(self, t1, t2, tar):
         if self.transforms[0] is not None:
-            # Applied on all
-            t1, t2, label = self.transforms[0](t1, t2, label)
+            # Applied to all
+            t1, t2, tar = self.transforms[0](t1, t2, tar)
         if self.transforms[1] is not None:
-            # For images solely
+            # Solely for images
             t1, t2 = self.transforms[1](t1, t2)
         if self.transforms[2] is not None:
-            # For labels solely
-            label = self.transforms[2](label)
+            # Solely for labels
+            tar = self.transforms[2](tar)
         
-        return to_tensor(t1).float(), to_tensor(t2).float(), to_tensor(label).long()
\ No newline at end of file
+        return to_tensor(t1).float(), to_tensor(t2).float(), to_tensor(tar).long()
\ No newline at end of file
diff --git a/src/data/_AirChange.py b/src/data/_airchange.py
similarity index 63%
rename from src/data/_AirChange.py
rename to src/data/_airchange.py
index 00e17a1..d6b406f 100644
--- a/src/data/_AirChange.py
+++ b/src/data/_airchange.py
@@ -1,12 +1,11 @@
-import abc
 from os.path import join, basename
 from functools import lru_cache
 
 import numpy as np
 
+from utils.data_utils import default_loader
 from . import CDDataset
-from .common import default_loader
-from .augmentation import Crop
+from .augmentations import Crop
 
 
 class _AirChangeDataset(CDDataset):
@@ -14,57 +13,56 @@ class _AirChangeDataset(CDDataset):
         self, 
         root, phase='train', 
         transforms=(None, None, None), 
-        repeats=1
+        repeats=1,
+        subset='val'
     ):
-        super().__init__(root, phase, transforms, repeats)
+        super().__init__(root, phase, transforms, repeats, subset)
         self.cropper = Crop(bounds=(0, 0, 748, 448))
 
     @property
-    @abc.abstractmethod
     def LOCATION(self):
         return ''
 
     @property
-    @abc.abstractmethod
     def TEST_SAMPLE_IDS(self):
         return ()
 
     @property
-    @abc.abstractmethod
     def N_PAIRS(self):
         return 0
 
     def _read_file_paths(self):
-        if self.phase == 'train':
+        if self.subset == 'train':
             sample_ids = [i for i in range(self.N_PAIRS) if i not in self.TEST_SAMPLE_IDS]
             t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in sample_ids]
             t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in sample_ids]
-            label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids]
+            tar_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in sample_ids]
         else:
+            # val and test subsets are equal
             t1_list = [join(self.root, self.LOCATION, str(i+1), 'im1') for i in self.TEST_SAMPLE_IDS]
             t2_list = [join(self.root, self.LOCATION, str(i+1), 'im2') for i in self.TEST_SAMPLE_IDS]
-            label_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS]
+            tar_list = [join(self.root, self.LOCATION, str(i+1), 'gt') for i in self.TEST_SAMPLE_IDS]
 
-        return t1_list, t2_list, label_list
+        return t1_list, t2_list, tar_list
 
-
-    @lru_cache(maxsize=8)
+    # XXX: In a multi-process environment, there might be multiple caches in memory, each for one process.
+    @lru_cache(maxsize=16)
     def fetch_image(self, image_name):
         image = self._bmp_loader(image_name)
         return image if self.phase == 'train' else self.cropper(image)
 
     @lru_cache(maxsize=8)
-    def fetch_label(self, label_name):
-        label = self._bmp_loader(label_name)
-        label = (label / 255.0).astype(np.uint8)    # To 0,1
-        return label if self.phase == 'train' else self.cropper(label)
+    def fetch_target(self, target_name):
+        tar = self._bmp_loader(target_name)
+        tar = (tar > 0).astype(np.bool)    # To 0,1
+        return tar if self.phase == 'train' else self.cropper(tar)
 
     def get_name(self, index):
         return '{loc}-{id}-cm.bmp'.format(loc=self.LOCATION, id=index)
 
     @staticmethod
     def _bmp_loader(bmp_path_wo_ext):
-        # Case insensitive .bmp loader
+        # Case-insensitive .bmp loader
         try:
             return default_loader(bmp_path_wo_ext+'.bmp')
         except FileNotFoundError:
diff --git a/src/data/AC_Szada.py b/src/data/ac_szada.py
similarity index 69%
rename from src/data/AC_Szada.py
rename to src/data/ac_szada.py
index a9b9619..478f093 100644
--- a/src/data/AC_Szada.py
+++ b/src/data/ac_szada.py
@@ -1,4 +1,4 @@
-from ._AirChange import _AirChangeDataset
+from ._airchange import _AirChangeDataset
 
 
 class AC_SzadaDataset(_AirChangeDataset):
@@ -6,9 +6,10 @@ class AC_SzadaDataset(_AirChangeDataset):
         self, 
         root, phase='train', 
         transforms=(None, None, None), 
-        repeats=1
+        repeats=1,
+        subset='val'
     ):
-        super().__init__(root, phase, transforms, repeats)
+        super().__init__(root, phase, transforms, repeats, subset)
 
     @property
     def LOCATION(self):
diff --git a/src/data/AC_Tiszadob.py b/src/data/ac_tiszadob.py
similarity index 69%
rename from src/data/AC_Tiszadob.py
rename to src/data/ac_tiszadob.py
index 830450d..e931038 100644
--- a/src/data/AC_Tiszadob.py
+++ b/src/data/ac_tiszadob.py
@@ -1,4 +1,4 @@
-from ._AirChange import _AirChangeDataset
+from ._airchange import _AirChangeDataset
 
 
 class AC_TiszadobDataset(_AirChangeDataset):
@@ -6,9 +6,10 @@ class AC_TiszadobDataset(_AirChangeDataset):
         self, 
         root, phase='train', 
         transforms=(None, None, None), 
-        repeats=1
+        repeats=1,
+        subset='val'
     ):
-        super().__init__(root, phase, transforms, repeats)
+        super().__init__(root, phase, transforms, repeats, subset)
 
     @property
     def LOCATION(self):
diff --git a/src/data/augmentation.py b/src/data/augmentation.py
deleted file mode 100644
index 0c5c02b..0000000
--- a/src/data/augmentation.py
+++ /dev/null
@@ -1,489 +0,0 @@
-import random
-import math
-from functools import partial, wraps
-
-import numpy as np
-import cv2
-
-
-__all__ = [
-    'Compose', 'Choose', 
-    'Scale', 'DiscreteScale', 
-    'Flip', 'HorizontalFlip', 'VerticalFlip', 'Rotate', 
-    'Crop', 'MSCrop',
-    'Shift', 'XShift', 'YShift',
-    'HueShift', 'SaturationShift', 'RGBShift', 'RShift', 'GShift', 'BShift',
-    'PCAJitter', 
-    'ContraBrightScale', 'ContrastScale', 'BrightnessScale',
-    'AddGaussNoise'
-]
-
-
-rand = random.random
-randi = random.randint
-choice = random.choice
-uniform = random.uniform
-# gauss = random.gauss
-gauss = random.normalvariate    # This one is thread-safe
-
-# The transformations treat 2-D or 3-D numpy ndarrays only, with the optional 3rd dim as the channel dim
-
-def _istuple(x): return isinstance(x, (tuple, list))
-
-class Transform:
-    def __init__(self, random_state=False):
-        self.random_state = random_state
-    def _transform(self, x):
-        raise NotImplementedError
-    def __call__(self, *args):
-        if self.random_state: self._set_rand_param()
-        assert len(args) > 0
-        return self._transform(args[0]) if len(args) == 1 else tuple(map(self._transform, args))
-    def _set_rand_param(self):
-        raise NotImplementedError
-
-
-class Compose:
-    def __init__(self, *tf):
-        assert len(tf) > 0
-        self.tfs = tf
-    def __call__(self, *x):
-        if len(x) > 1:
-            for tf in self.tfs: x = tf(*x)
-        else:
-            x = x[0]
-            for tf in self.tfs: x = tf(x)
-        return x
-
-
-class Choose:
-    def __init__(self, *tf):
-        assert len(tf) > 1
-        self.tfs = tf
-    def __call__(self, *x):
-        idx = randi(0, len(self.tfs)-1)
-        return self.tfs[idx](*x)
-
-
-class Scale(Transform):
-    def __init__(self, scale=(0.5,1.0)):
-        if _istuple(scale):
-            assert len(scale) == 2
-            self.scale_range = tuple(scale) #sorted(scale)
-            self.scale = float(scale[0])
-            super(Scale, self).__init__(random_state=True)
-        else:
-            super(Scale, self).__init__(random_state=False)
-            self.scale = float(scale)
-    def _transform(self, x):
-        # assert x.ndim == 3
-        h, w = x.shape[:2]
-        size = (int(h*self.scale), int(w*self.scale))
-        if size == (h,w):
-            return x
-        interp = cv2.INTER_LINEAR if np.issubdtype(x.dtype, np.floating) else cv2.INTER_NEAREST
-        return cv2.resize(x, size, interpolation=interp)
-    def _set_rand_param(self):
-        self.scale = uniform(*self.scale_range)
-        
-
-class DiscreteScale(Scale):
-    def __init__(self, bins=(0.5, 0.75), keep_prob=0.5):
-        super(DiscreteScale, self).__init__(scale=(min(bins), 1.0))
-        self.bins = tuple(bins)
-        self.keep_prob = float(keep_prob)
-    def _set_rand_param(self):
-        self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins)
-
-
-class Flip(Transform):
-    # Flip or rotate
-    _directions = ('ud', 'lr', 'no', '90', '180', '270')
-    def __init__(self, direction=None):
-        super(Flip, self).__init__(random_state=(direction is None))
-        self.direction = direction
-        if direction is not None: assert direction in self._directions
-    def _transform(self, x):
-        if self.direction == 'ud':
-            ## Current torch version doesn't support negative stride of numpy arrays
-            return np.ascontiguousarray(x[::-1])
-        elif self.direction == 'lr':
-            return np.ascontiguousarray(x[:,::-1])
-        elif self.direction == 'no':
-            return x
-        elif self.direction == '90':
-            # Clockwise
-            return np.ascontiguousarray(self._T(x)[:,::-1])
-        elif self.direction == '180':
-            return np.ascontiguousarray(x[::-1,::-1])
-        elif self.direction == '270':
-            return np.ascontiguousarray(self._T(x)[::-1])
-        else:
-            raise ValueError('invalid flipping direction')
-
-    def _set_rand_param(self):
-        self.direction = choice(self._directions)
-
-    @staticmethod
-    def _T(x):
-        return np.swapaxes(x, 0, 1)
-        
-
-class HorizontalFlip(Flip):
-    _directions = ('lr', 'no')
-    def __init__(self, flip=None):
-        if flip is not None: flip = self._directions[~flip]
-        super(HorizontalFlip, self).__init__(direction=flip)
-    
-
-class VerticalFlip(Flip):
-    _directions = ('ud', 'no')
-    def __init__(self, flip=None):
-        if flip is not None: flip = self._directions[~flip]
-        super(VerticalFlip, self).__init__(direction=flip)
-
-
-class Rotate(Flip):
-    _directions = ('90', '180', '270', 'no')
-
-
-class Crop(Transform):
-    _inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r')
-    def __init__(self, crop_size=None, bounds=None):
-        __no_bounds = (bounds is None)
-        super(Crop, self).__init__(random_state=__no_bounds)
-        if __no_bounds:
-            assert crop_size is not None
-        else:
-            if not((_istuple(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._inner_bounds)):
-                raise ValueError('invalid bounds')
-        self.bounds = bounds
-        self.crop_size = crop_size if _istuple(crop_size) else (crop_size, crop_size)
-    def _transform(self, x):
-        h, w = x.shape[:2]
-        if self.bounds == 'bl':
-            return x[h//2:,:w//2]
-        elif self.bounds == 'br':
-            return x[h//2:,w//2:]
-        elif self.bounds == 'tl':
-            return x[:h//2,:w//2]
-        elif self.bounds == 'tr':
-            return x[:h//2,w//2:]
-        elif self.bounds == 't':
-            return x[:h//2]
-        elif self.bounds == 'b':
-            return x[h//2:]
-        elif self.bounds == 'l':
-            return x[:,:w//2]
-        elif self.bounds == 'r':
-            return x[:,w//2:]
-        elif len(self.bounds) == 2:
-            assert self.crop_size <= (h, w)
-            ch, cw = self.crop_size
-            if (ch,cw) == (h,w):
-                return x
-            cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
-            return x[cy:cy+ch, cx:cx+cw]
-        else:
-            left, top, right, lower = self.bounds
-            return x[top:lower, left:right]
-    def _set_rand_param(self):
-        self.bounds = (rand(), rand())
-   
-
-class MSCrop(Crop):
-    def __init__(self, scale, crop_size=None):
-        super(MSCrop, self).__init__(crop_size)
-        self.scale = scale  # Scale factor
-
-    def __call__(self, lr, hr):
-        if self.random_state:
-            self._set_rand_param()
-        # I've noticed that random scaling bounds may cause pixel misalignment
-        # between the lr-hr pair, which significantly damages the training
-        # effect, thus the quadruple mode is desired
-        left, top, cw, ch = self._get_quad(*lr.shape[:2])
-        self._set_quad(left, top, cw, ch)
-        lr_crop = self._transform(lr)
-        left, top, cw, ch = [int(it*self.scale) for it in (left, top, cw, ch)]
-        self._set_quad(left, top, cw, ch)
-        hr_crop = self._transform(hr)
-
-        return lr_crop, hr_crop
-
-    def _get_quad(self, h, w):
-        ch, cw = self.crop_size
-        cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
-        return cx, cy, cw, ch
-
-    def _set_quad(self, left, top, cw, ch):
-        self.bounds = (left, top, left+cw, top+ch)
-
-
-class Shift(Transform):
-    def __init__(self, x_shift=(-0.0625, 0.0625), y_shift=(-0.0625, 0.0625), circular=True):
-        super(Shift, self).__init__(random_state=_istuple(x_shift) or _istuple(y_shift))
-
-        if _istuple(x_shift):
-            self.xshift_range = tuple(x_shift)
-            self.xshift = float(x_shift[0])
-        else:
-            self.xshift = float(x_shift)
-            self.xshift_range = (self.xshift, self.xshift)
-
-        if _istuple(y_shift):
-            self.yshift_range = tuple(y_shift)
-            self.yshift = float(y_shift[0])
-        else:
-            self.yshift = float(y_shift)
-            self.yshift_range = (self.yshift, self.yshift)
-
-        self.circular = circular
-
-    def _transform(self, im):
-        h, w = im.shape[:2]
-        xsh = -int(self.xshift*w)
-        ysh = -int(self.yshift*h)
-        if self.circular:
-            # Shift along the x-axis
-            im_shifted = np.concatenate((im[:, xsh:], im[:, :xsh]), axis=1)
-            # Shift along the y-axis
-            im_shifted = np.concatenate((im_shifted[ysh:], im_shifted[:ysh]), axis=0)
-        else:
-            zeros = np.zeros(im.shape)
-            im1, im2 = (zeros, im) if xsh < 0 else (im, zeros)
-            im_shifted = np.concatenate((im1[:, xsh:], im2[:, :xsh]), axis=1)
-            im1, im2 = (zeros, im_shifted) if ysh < 0 else (im_shifted, zeros)
-            im_shifted = np.concatenate((im1[ysh:], im2[:ysh]), axis=0)
-
-        return im_shifted
-        
-    def _set_rand_param(self):
-        self.xshift = uniform(*self.xshift_range)
-        self.yshift = uniform(*self.yshift_range)
-
-
-class XShift(Shift):
-    def __init__(self, x_shift=(-0.0625, 0.0625), circular=True):
-        super(XShift, self).__init__(x_shift, 0.0, circular)
-
-
-class YShift(Shift):
-    def __init__(self, y_shift=(-0.0625, 0.0625), circular=True):
-        super(YShift, self).__init__(0.0, y_shift, circular)
-
-
-# Color jittering and transformation
-# The followings partially refer to https://github.com/albu/albumentations/
-class _ValueTransform(Transform):
-    def __init__(self, rs, limit=(0, 255)):
-        super().__init__(rs)
-        self.limit = limit
-        self.limit_range = limit[1] - limit[0]
-    @staticmethod
-    def keep_range(tf):
-        @wraps(tf)
-        def wrapper(obj, x):
-            # # Make a copy
-            # x = x.copy()
-            dtype = x.dtype
-            # The calculations are done with floating type in case of overflow
-            # This is a stupid yet simple way
-            x = tf(obj, np.clip(x.astype(np.float32), *obj.limit))
-            # Convert back to the original type
-            return np.clip(x, *obj.limit).astype(dtype)
-        return wrapper
-        
-
-class ColorJitter(_ValueTransform):
-    _channel = (0,1,2)
-    def __init__(self, shift=((-20,20), (-20,20), (-20,20)), limit=(0,255)):
-        super().__init__(False, limit)
-        _nc = len(self._channel)
-        if _nc == 1:
-            if _istuple(shift):
-                rs = True
-                self.shift = self.range = shift
-            else:
-                rs = False
-                self.shift = (shift,)
-                self.range = (shift, shift)
-        else:
-            if _istuple(shift):
-                if len(shift) != _nc:
-                    raise ValueError("please specify the shift value (or range) for every channel.")
-                rs = all(_istuple(s) for s in shift)
-                self.shift = self.range = shift
-            else:
-                rs = False
-                self.shift = [shift for _ in range(_nc)]
-                self.range = [(shift, shift) for _ in range(_nc)]
-                
-        self.random_state = rs
-        
-        def _(x):
-            return x
-        self.convert_to = _
-        self.convert_back = _
-    
-    @_ValueTransform.keep_range
-    def _transform(self, x):
-        x = self.convert_to(x)
-        for i, c in enumerate(self._channel):
-            x[...,c] = self._clip(x[...,c]+float(self.shift[i]))
-        x = self.convert_back(x)
-        return x
-        
-    def _clip(self, x):
-        return x
-        
-    def _set_rand_param(self):
-        if len(self._channel) == 1:
-            self.shift = [uniform(*self.range)]
-        else:
-            self.shift = [uniform(*r) for r in self.range]
-
-
-class HSVShift(ColorJitter):
-    def __init__(self, shift, limit):
-        super().__init__(shift, limit)
-        def _convert_to(x):
-            x = x.astype(np.float32)
-            # Normalize to [0,1]
-            x -= self.limit[0]
-            x /= self.limit_range
-            x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV)
-            return x
-        def _convert_back(x):
-            x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB)
-            return x * self.limit_range + self.limit[0]
-        # Pack conversion methods
-        self.convert_to = _convert_to
-        self.convert_back = _convert_back
-
-        def _clip(self, x):
-            raise NotImplementedError
-        
-
-class HueShift(HSVShift):
-    _channel = (0,)
-    def __init__(self, shift=(-20, 20), limit=(0, 255)):
-        super().__init__(shift, limit)
-    def _clip(self, x):
-        # Circular
-        # Note that this works in Opencv 3.4.3, not yet tested under other versions
-        x[x<0] += 360
-        x[x>360] -= 360
-        return x
-        
-
-class SaturationShift(HSVShift):    
-    _channel = (1,)
-    def __init__(self, shift=(-30, 30), limit=(0, 255)):
-        super().__init__(shift, limit)
-        self.range = tuple(r / self.limit_range for r in self.range)
-    def _clip(self, x):
-        return np.clip(x, 0, 1.0)
-        
-
-class RGBShift(ColorJitter):
-    def __init__(self, shift=((-20,20), (-20,20), (-20,20)), limit=(0, 255)):
-        super().__init__(shift, limit)        
-
-
-class RShift(RGBShift):
-    _channel = (0,)
-    def __init__(self, shift=(-20,20), limit=(0, 255)):
-        super().__init__(shift, limit)
-
-
-class GShift(RGBShift):
-    _channel = (1,)
-    def __init__(self, shift=(-20,20), limit=(0, 255)):
-        super().__init__(shift, limit)
-
-
-class BShift(RGBShift):
-    _channel = (2,)
-    def __init__(self, shift=(-20,20), limit=(0, 255)):
-        super().__init__(shift, limit)
-
-
-class PCAJitter(_ValueTransform):
-    def __init__(self, sigma=0.3, limit=(0, 255)):
-        # For RGB only
-        super().__init__(True, limit)
-        self.sigma = sigma
-        
-    @_ValueTransform.keep_range
-    def _transform(self, x):
-        old_shape = x.shape
-        x = np.reshape(x, (-1,3), order='F')   # For RGB
-        x_mean = np.mean(x, 0)
-        x = x - x_mean
-        cov_x = np.cov(x, rowvar=False)
-        eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x))
-        # The eigen vectors are already unit "length"
-        noise = (eig_vals * self.alpha) * eig_vecs
-        x += np.asarray(noise)
-        return np.reshape(x+x_mean, old_shape, order='F')
-    
-    def _set_rand_param(self):
-        self.alpha = [gauss(0, self.sigma) for _ in range(3)]
-        
-
-class ContraBrightScale(_ValueTransform):
-    def __init__(self, alpha=(-0.2, 0.2), beta=(-0.2, 0.2), limit=(0, 255)):
-        super().__init__(_istuple(alpha) or _istuple(beta), limit)
-        self.alpha = alpha
-        self.alpha_range = alpha if _istuple(alpha) else (alpha, alpha)
-        self.beta = beta
-        self.beta_range = beta if _istuple(beta) else (beta, beta)
-    
-    @_ValueTransform.keep_range
-    def _transform(self, x):
-        if not math.isclose(self.alpha, 1.0):
-            x *= self.alpha
-        if not math.isclose(self.beta, 0.0):
-            x += self.beta*np.mean(x)
-        return x
-    
-    def _set_rand_param(self):
-        self.alpha = uniform(*self.alpha_range)
-        self.beta = uniform(*self.beta_range)
-
-
-class ContrastScale(ContraBrightScale):
-    def __init__(self, alpha=(0.2, 0.8), limit=(0,255)):
-        super().__init__(alpha=alpha, beta=0, limit=limit)
-        
-
-class BrightnessScale(ContraBrightScale):
-    def __init__(self, beta=(-0.2, 0.2), limit=(0,255)):
-        super().__init__(alpha=1, beta=beta, limit=limit)
-
-
-class _AddNoise(_ValueTransform):
-    def __init__(self, limit):
-        super().__init__(True, limit)
-        self._im_shape = (0, 0)
-        
-    @_ValueTransform.keep_range
-    def _transform(self, x):
-        return x + self.noise_map
-        
-    def __call__(self, *args):
-        shape = args[0].shape
-        if any(im.shape != shape for im in args):
-            raise ValueError("the input images should be of same size.")
-        self._im_shape = shape
-        return super().__call__(*args)
-        
-
-class AddGaussNoise(_AddNoise):
-    def __init__(self, mu=0.0, sigma=0.1, limit=(0, 255)):
-        super().__init__(limit)
-        self.mu = mu
-        self.sigma = sigma
-    def _set_rand_param(self):
-        self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu
\ No newline at end of file
diff --git a/src/data/augmentations.py b/src/data/augmentations.py
new file mode 100644
index 0000000..4733723
--- /dev/null
+++ b/src/data/augmentations.py
@@ -0,0 +1,414 @@
+import random
+import math
+from functools import partial, wraps
+from copy import deepcopy
+
+import numpy as np
+import skimage.transform
+
+
+__all__ = [
+    'Compose', 'Choose', 
+    'Scale', 'DiscreteScale', 
+    'FlipRotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Rotate', 
+    'Crop', 'CenterCrop',
+    'Shift', 'XShift', 'YShift',
+    'ContrastBrightScale', 'ContrastScale', 'BrightnessScale',
+    'AddGaussNoise'
+]
+
+
+rand = random.random
+randi = random.randint
+choice = random.choice
+uniform = random.uniform
+# gauss = random.gauss
+gauss = random.normalvariate    # This one is thread-safe
+
+
+def _isseq(x): return isinstance(x, (tuple, list))
+
+
+class Transform:
+    def __init__(self, rand_state=False, prob_apply=1.0):
+        self._rand_state = bool(rand_state)
+        self.prob_apply = float(prob_apply)
+
+    def _transform(self, x, params):
+        raise NotImplementedError
+
+    def __call__(self, *args, copy=False):
+        # NOTE: A Transform object deals with 2-D or 3-D numpy ndarrays only, with an optional third dim as the channel dim.
+        if copy:
+            args = deepcopy(args)
+        if rand() > self.prob_apply:
+            return args
+        if self._rand_state:
+            params = self._get_rand_params()
+        else:
+            params = None
+        return self._transform(args[0], params) if len(args) == 1 else tuple(self._transform(x, params) for x in args)
+
+    def _get_rand_params(self):
+        raise NotImplementedError
+
+    def info(self):
+        return ""
+
+    def __repr__(self):
+        return self.info()+"\nrand_state={}\nprob_apply={}\n".format(self._rand_state, self.prob_apply)
+
+
+class Compose:
+    def __init__(self, *tfs):
+        assert len(tfs) > 0
+        self.tfs = tfs
+
+    def __call__(self, *x):
+        if len(x) == 1:
+            x = x[0]
+            for tf in self.tfs: 
+                x = tf(x)
+        else:
+            for tf in self.tfs:
+                x = tf(*x)
+        return x
+
+    def __repr__(self):
+        return "Compose [ "+", ".join(tf.__repr__() for tf in self.tfs)+"]\n"
+
+
+class Choose:
+    def __init__(self, *tfs):
+        assert len(tfs) > 1
+        self.tfs = tfs
+
+    def __call__(self, *x):
+        return choice(self.tfs)(*x)
+
+    def __repr__(self):
+        return "Choose [ "+", ".join(tf.__repr__() for tf in self.tfs)+"]\n"
+
+
+class Scale(Transform):
+    def __init__(self, scale=(0.5, 1.0), prob_apply=1.0):
+        super(Scale, self).__init__(rand_state=_isseq(scale), prob_apply=prob_apply)
+        if _isseq(scale):
+            assert len(scale) == 2
+            self.scale = tuple(scale)
+        else:
+            self.scale = float(scale)
+
+    def _transform(self, x, params):
+        if self._rand_state:
+            scale = params['scale']
+        else:
+            scale = self.scale
+        h, w = x.shape[:2]
+        size = (int(h*scale), int(w*scale))
+        if size == (h,w):
+            return x
+        order = 0 if x.dtype == np.bool else 1
+        return skimage.transform.resize(x, size, order=order, preserve_range=True).astype(x.dtype)
+
+    def _get_rand_params(self):
+        return {'scale': uniform(*self.scale)}
+
+    def info(self):
+        return "Scale\nscaling_factor={}".format(self.scale)
+        
+
+class DiscreteScale(Scale):
+    def __init__(self, bins=(0.5, 0.75), prob_apply=1.0):
+        super(DiscreteScale, self).__init__(scale=(min(bins), max(bins)), prob_apply=prob_apply)
+        self.bins = tuple(bins)
+
+    def _get_rand_params(self):
+        return {'scale': choice(self.bins)}
+
+    def info(self):
+        return "DiscreteScale\nscaling_factors={}".format(self.bins)
+
+
+class FlipRotate(Transform):
+    # Flip or rotate
+    _DIRECTIONS = ('ud', 'lr', '90', '180', '270')
+    def __init__(self, direction=None, prob_apply=1.0):
+        super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply)
+        if direction is not None: 
+            assert direction in self._DIRECTIONS
+            self.direction = direction
+
+    def _transform(self, x, params):
+        if self._rand_state:
+            direction = params['direction']
+        else:
+            direction = self.direction
+
+        if direction == 'ud':
+            return np.flip(x, 0)
+        elif direction == 'lr':
+            return np.flip(x, 1)
+        elif direction == '90':
+            # Clockwise
+            return np.flip(self._T(x), 1)
+        elif direction == '180':
+            return np.flip(np.flip(x, 0), 1)
+        elif direction == '270':
+            return np.flip(self._T(x), 0)
+        else:
+            raise ValueError("Invalid direction")
+
+    def _get_rand_params(self):
+        return {'direction': choice(self._DIRECTIONS)}
+
+    @staticmethod
+    def _T(x):
+        return np.swapaxes(x, 0, 1)
+
+    def info(self):
+        return "FlipRotate"
+
+
+class Flip(FlipRotate):
+    _DIRECTIONS = ('ud', 'lr')
+
+    def info(self):
+        return "Flip"
+
+
+class HorizontalFlip(Flip):
+    def __init__(self, prob_apply=1.0):
+        super(HorizontalFlip, self).__init__(direction='lr', prob_apply=prob_apply)
+
+    def info(self):
+        return "HorizontalFlip"
+
+
+class VerticalFlip(Flip):
+    def __init__(self, prob_apply=1.0):
+        super(VerticalFlip, self).__init__(direction='ud', prob_apply=prob_apply)
+    
+    def info(self):
+        return "VerticalFlip"
+
+
+class Rotate(FlipRotate):
+    _DIRECTIONS = ('90', '180', '270')
+
+    def info(self):
+        return "Rotate"
+
+
+class Crop(Transform):
+    _INNER_BOUNDS = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r')
+    def __init__(self, crop_size=None, bounds=None, prob_apply=1.0):
+        _no_bounds = (bounds is None)
+        super(Crop, self).__init__(rand_state=_no_bounds, prob_apply=prob_apply)
+        if _no_bounds:
+            assert crop_size is not None
+        else:
+            if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)):
+                raise ValueError("Invalid bounds")
+        self.bounds = bounds
+        self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size)
+
+    def _transform(self, x, params):
+        h, w = x.shape[:2]
+        if not self._rand_state:
+            bounds = self.bounds
+            if bounds == 'bl':
+                return x[h//2:,:w//2]
+            elif bounds == 'br':
+                return x[h//2:,w//2:]
+            elif bounds == 'tl':
+                return x[:h//2,:w//2]
+            elif bounds == 'tr':
+                return x[:h//2,w//2:]
+            elif bounds == 't':
+                return x[:h//2]
+            elif bounds == 'b':
+                return x[h//2:]
+            elif bounds == 'l':
+                return x[:,:w//2]
+            elif bounds == 'r':
+                return x[:,w//2:]
+            else:
+                left, top, right, lower = bounds
+                return x[top:lower, left:right]
+        else:
+            assert self.crop_size <= (h, w)
+            ch, cw = self.crop_size
+            if (ch,cw) == (h,w):
+                return x
+            cx, cy = int((w-cw+1)*params['rel_pos_x']), int((h-ch+1)*params['rel_pos_y'])
+            return x[cy:cy+ch, cx:cx+cw]
+
+    def _get_rand_params(self):
+        return {'rel_pos_x': rand(),
+                'rel_pos_y': rand()}
+
+    def info(self):
+        return "Crop\ncrop_size={}\nbounds={}".format(self.crop_size, self.bounds)
+
+
+class CenterCrop(Transform):
+    def __init__(self, crop_size, prob_apply=1.0):
+        super(CenterCrop, self).__init__(False, prob_apply=prob_apply)
+        self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size)
+
+    def _transform(self, x, params):
+        h, w = x.shape[:2]
+
+        ch, cw = self.crop_size
+
+        assert ch<=h and cw<=w
+        
+        offset_up = (h-ch)//2
+        offset_left = (w-cw)//2
+
+        return x[offset_up:offset_up+ch, offset_left:offset_left+cw]
+
+    def info(self):
+        return "CenterCrop\ncrop_size={}".format(self.crop_size)
+
+
+class Shift(Transform):
+    def __init__(self, xshift=(-0.0625, 0.0625), yshift=(-0.0625, 0.0625), circular=False, prob_apply=1.0):
+        super(Shift, self).__init__(rand_state=_isseq(xshift) or _isseq(yshift), prob_apply=prob_apply)
+
+        if _isseq(xshift):
+            self.xshift = tuple(xshift)
+        else:
+            self.xshift = float(xshift)
+
+        if _isseq(yshift):
+            self.yshift = tuple(yshift)
+        else:
+            self.yshift = float(yshift)
+
+        self.circular = circular
+
+    def _transform(self, x, params):
+        h, w = x.shape[:2]
+        if self._rand_state:
+            xshift = params['xshift']
+            yshift = params['yshift']
+        else:
+            xshift = self.xshift
+            yshift = self.yshift
+        xsh = -int(xshift*w)
+        ysh = -int(yshift*h)
+        if self.circular:
+            # Shift along the x-axis
+            x_shifted = np.concatenate((x[:, xsh:], x[:, :xsh]), axis=1)
+            # Shift along the y-axis
+            x_shifted = np.concatenate((x_shifted[ysh:], x_shifted[:ysh]), axis=0)
+        else:
+            zeros = np.zeros(x.shape, dtype=x.dtype)
+            x1, x2 = (zeros, x) if xsh < 0 else (x, zeros)
+            x_shifted = np.concatenate((x1[:, xsh:], x2[:, :xsh]), axis=1)
+            x1, x2 = (zeros, x_shifted) if ysh < 0 else (x_shifted, zeros)
+            x_shifted = np.concatenate((x1[ysh:], x2[:ysh]), axis=0)
+
+        return x_shifted
+        
+    def _get_rand_params(self):
+        return {'xshift': uniform(*self.xshift) if isinstance(self.xshift, tuple) else self.xshift,
+                'yshift': uniform(*self.yshift) if isinstance(self.yshift, tuple) else self.yshift}
+
+    def info(self):
+        return "Shift\nxshift={}\nyshift={}".format(self.xshift, self.yshift)
+
+
+class XShift(Shift):
+    def __init__(self, shift=(-0.0625, 0.0625), circular=False, prob_apply=1.0):
+        super(XShift, self).__init__(shift, 0.0, circular, prob_apply)
+
+
+class YShift(Shift):
+    def __init__(self, shift=(-0.0625, 0.0625), circular=False, prob_apply=1.0):
+        super(YShift, self).__init__(0.0, shift, circular, prob_apply)
+
+
+# Color jittering and transformation
+# Partially refer to https://github.com/albu/albumentations/
+class _ValueTransform(Transform):
+    def __init__(self, rand_state, prob_apply, limit):
+        super(_ValueTransform, self).__init__(rand_state, prob_apply)
+        self.limit = limit
+        self.limit_range = limit[1] - limit[0]
+
+    @staticmethod
+    def keep_range(tf):
+        @wraps(tf)
+        def wrapper(obj, x, params):
+            dtype = x.dtype
+            # NOTE: The calculations are done with floating type to prevent overflow.
+            # This is a simple yet stupid way.
+            # FIXME: Current implementation always makes a copy.
+            x = tf(obj, np.clip(x.astype(np.float32), *obj.limit), params)
+            # Convert back to the original type
+            # TODO: Round instead of truncate if dtype is integer
+            return np.clip(x, *obj.limit).astype(dtype)
+        return wrapper
+        
+
+class ContrastBrightScale(_ValueTransform):
+    def __init__(self, alpha=(0.2, 0.8), beta=(-0.2, 0.2), prob_apply=1.0, limit=(0, 255)):
+        super(ContrastBrightScale, self).__init__(_isseq(alpha) or _isseq(beta), prob_apply, limit)
+
+        if _isseq(alpha):
+            self.alpha = tuple(alpha)
+        else:
+            self.alpha = float(alpha)
+
+        if _isseq(beta):
+            self.beta = tuple(beta)
+        else:
+            self.beta = float(beta)
+    
+    @_ValueTransform.keep_range
+    def _transform(self, x, params):
+        alpha = params['alpha'] if self._rand_state else self.alpha
+        beta = params['beta'] if self._rand_state else self.beta
+        if not math.isclose(alpha, 1.0):
+            x *= alpha
+        if not math.isclose(beta, 0.0):
+            x += beta*np.mean(x)
+        return x
+    
+    def _get_rand_params(self):
+        return {'alpha': uniform(*self.alpha) if isinstance(self.alpha, tuple) else self.alpha,
+                'beta': uniform(*self.beta) if isinstance(self.beta, tuple) else self.beta}
+
+    def info(self):
+        return "ContrastBrightScale\nalpha={}\nbeta={}\nlimit={}".format(self.alpha, self.beta, self.limit)
+
+
+class ContrastScale(ContrastBrightScale):
+    def __init__(self, alpha=(0.2, 0.8), prob_apply=1.0, limit=(0, 255)):
+        super(ContrastScale, self).__init__(alpha=alpha, beta=0.0, prob_apply=prob_apply, limit=limit)
+        
+
+class BrightnessScale(ContrastBrightScale):
+    def __init__(self, beta=(-0.2, 0.2), prob_apply=1.0, limit=(0, 255)):
+        super(BrightnessScale, self).__init__(alpha=1.0, beta=beta, prob_apply=prob_apply, limit=limit)
+        
+
+class AddGaussNoise(_ValueTransform):
+    def __init__(self, mu=0.0, sigma=0.1, prob_apply=1.0, limit=(0, 255)):
+        super().__init__(True, prob_apply, limit)
+        self.mu = float(mu)
+        self.sigma = float(sigma)
+
+    @_ValueTransform.keep_range
+    def _transform(self, x, params):
+        x += np.random.randn(*x.shape)*self.sigma + self.mu
+        return x
+
+    def _get_rand_params(self):
+        return {}
+
+    def info(self):
+        return "AddGaussNoise\nmu={}\nsigma={}\nlimit={}".format(self.mu, self.sigma, self.limit)
\ No newline at end of file
diff --git a/src/data/common.py b/src/data/common.py
deleted file mode 100644
index f79c58d..0000000
--- a/src/data/common.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-import numpy as np
-
-from scipy.io import loadmat
-from skimage.io import imread
-
-def default_loader(path_):
-    return imread(path_)
-
-def mat_loader(path_):
-    return loadmat(path_)
-
-def make_onehot(index_map, n):
-    # Only deals with tensors with no batch dim
-    old_size = index_map.size()
-    z = torch.zeros(n, *old_size[-2:]).type_as(index_map)
-    z.scatter_(0, index_map, 1)
-    return z
-    
-def to_tensor(arr):
-    if arr.ndim < 3:
-        return torch.from_numpy(arr)
-    elif arr.ndim == 3:
-        return torch.from_numpy(np.ascontiguousarray(np.transpose(arr, (2,0,1))))
-    else:
-        raise NotImplementedError
-
-def to_array(tensor):
-    if tensor.ndimension() < 3:
-        return tensor.data.cpu().numpy()
-    elif tensor.ndimension() in (3, 4):
-        return np.ascontiguousarray(np.moveaxis(tensor.data.cpu().numpy(), -3, -1))
-    else:
-        raise NotImplementedError
\ No newline at end of file
diff --git a/src/data/OSCD.py b/src/data/oscd.py
similarity index 53%
rename from src/data/OSCD.py
rename to src/data/oscd.py
index a471ae7..6f7c6ce 100644
--- a/src/data/OSCD.py
+++ b/src/data/oscd.py
@@ -1,12 +1,12 @@
 import os
 from glob import glob
 from os.path import join, basename
-from multiprocessing import Manager
 
 import numpy as np
 
+from utils.data_utils import default_loader
 from . import CDDataset
-from .common import default_loader
+
 
 class OSCDDataset(CDDataset):
     __BAND_NAMES = (
@@ -18,35 +18,35 @@ class OSCDDataset(CDDataset):
         root, phase='train', 
         transforms=(None, None, None), 
         repeats=1,
+        subset='val',
         cache_level=1
     ):
-        super().__init__(root, phase, transforms, repeats)
-        # 0 for no cache, 1 for caching labels only, 2 and higher for caching all
+        super().__init__(root, phase, transforms, repeats, subset)
+        # cache_level=0 for no cache, 1 to cache labels, 2 and higher to cache all.
         self.cache_level = int(cache_level)
         if self.cache_level > 0:
-            self._manager = Manager()
-            self._pool = self._manager.dict()
+            self._pool = dict()
 
     def _read_file_paths(self):
-        image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
-        label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
-        txt_file = join(image_dir, 'train.txt')
+        image_dir = join(self.root, "Onera Satellite Change Detection dataset - Images")
+        target_dir = join(self.root, "Onera Satellite Change Detection dataset - Train Labels")
+        txt_file = join(image_dir, "train.txt")
         # Read cities
         with open(txt_file, 'r') as f:
             cities = [city.strip() for city in f.read().strip().split(',')]
-        if self.phase == 'train':
+        if self.subset == 'train':
             # For training, use the first 11 pairs
             cities = cities[:-3]
         else:
-            # For validation, use the remaining 3 pairs
+            # For validation and test, use the remaining 3 pairs
             cities = cities[-3:]
             
         # Use resampled images
-        t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
-        t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
-        label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
+        t1_list = [[join(image_dir, city, "imgs_1_rect", band+'.tif') for band in self.__BAND_NAMES] for city in cities]
+        t2_list = [[join(image_dir, city, "imgs_2_rect", band+'.tif') for band in self.__BAND_NAMES] for city in cities]
+        tar_list = [join(target_dir, city, 'cm', city+'-cm.tif') for city in cities]
 
-        return t1_list, t2_list, label_list
+        return t1_list, t2_list, tar_list
 
     def fetch_image(self, image_paths):
         key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
@@ -59,15 +59,15 @@ class OSCDDataset(CDDataset):
             self._pool[key] = image
         return image
 
-    def fetch_label(self, label_path):
-        key = basename(label_path)
+    def fetch_target(self, target_path):
+        key = basename(target_path)
         if self.cache_level >= 1:
-            label = self._pool.get(key, None)
-            if label is not None:
-                return label
-        # In the tif labels, 1 for NC and 2 for C
-        # Thus a -1 offset is needed
-        label = default_loader(label_path) - 1
+            tar = self._pool.get(key, None)
+            if tar is not None:
+                return tar
+        # In the tif labels, 1 stands for NC and 2 for C,
+        # thus a -1 offset is added.
+        tar = (default_loader(target_path) - 1).astype(np.bool)
         if self.cache_level >= 1:
-            self._pool[key] = label
-        return label
+            self._pool[key] = tar
+        return tar
diff --git a/src/impl/builders/__init__.py b/src/impl/builders/__init__.py
new file mode 100644
index 0000000..db65b3a
--- /dev/null
+++ b/src/impl/builders/__init__.py
@@ -0,0 +1,6 @@
+from .critn_builders import *
+from .data_builders import *
+from .model_builders import *
+from .optim_builders import *
+
+__all__ = []
\ No newline at end of file
diff --git a/src/impl/builders/critn_builders.py b/src/impl/builders/critn_builders.py
new file mode 100644
index 0000000..2c39ffc
--- /dev/null
+++ b/src/impl/builders/critn_builders.py
@@ -0,0 +1,3 @@
+# Custom criterion builders
+
+from core.misc import CRITNS
\ No newline at end of file
diff --git a/src/impl/builders/data_builders.py b/src/impl/builders/data_builders.py
new file mode 100644
index 0000000..97f1976
--- /dev/null
+++ b/src/impl/builders/data_builders.py
@@ -0,0 +1,149 @@
+# Custom data builders
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data as data
+
+import constants
+from data.augmentations import *
+from core.misc import DATA, R
+from core.data import (
+    build_train_dataloader, build_eval_dataloader, get_common_train_configs, get_common_eval_configs
+)
+
+
+@DATA.register_func('AC_Szada_train_dataset')
+def build_AC_Szada_train_dataset(C):
+    configs = get_common_train_configs(C)
+    configs.update(dict(
+        transforms=(Compose(
+            Crop(C['crop_size']),
+            Choose(                
+                HorizontalFlip(), VerticalFlip(), 
+                Rotate('90'), Rotate('180'), Rotate('270'),
+                Shift(),
+            )
+        ), None, None),
+        root=constants.IMDB_AIRCHANGE
+    ))
+
+    from data.ac_szada import AC_SzadaDataset
+    if C['num_workers'] != 0:
+        R['Logger'].warn("Will use num_workers=0.")
+    return data.DataLoader(
+        AC_SzadaDataset(**configs),
+        batch_size=C['batch_size'],
+        shuffle=True,
+        num_workers=0,  # No need to use multiprocessing
+        pin_memory=C['device']!='cpu',
+        drop_last=True
+    )
+
+
+@DATA.register_func('AC_Szada_eval_dataset')
+def build_AC_Szada_eval_dataset(C):
+    configs = get_common_eval_configs(C)
+    configs.update(dict(
+        root=constants.IMDB_AIRCHANGE
+    ))
+
+    from data.ac_szada import AC_SzadaDataset
+    return data.DataLoader(
+        AC_SzadaDataset(**configs),
+        batch_size=1,
+        shuffle=False,
+        num_workers=0,
+        pin_memory=False,
+        drop_last=False
+    )
+
+
+@DATA.register_func('AC_Tiszadob_train_dataset')
+def build_AC_Tiszadob_train_dataset(C):
+    configs = get_common_train_configs(C)
+    configs.update(dict(
+        transforms=(Compose(
+            Crop(C['crop_size']),
+            Choose(                
+                HorizontalFlip(), VerticalFlip(), 
+                Rotate('90'), Rotate('180'), Rotate('270'),
+                Shift(),
+            )
+        ), None, None),
+        root=constants.IMDB_AIRCHANGE
+    ))
+
+    from data.ac_tiszadob import AC_TiszadobDataset
+    if C['num_workers'] != 0:
+        R['Logger'].warn("Will use num_workers=0.")
+    return data.DataLoader(
+        AC_TiszadobDataset(**configs),
+        batch_size=C['batch_size'],
+        shuffle=True,
+        num_workers=0,  # No need to use multiprocessing
+        pin_memory=C['device']!='cpu',
+        drop_last=True
+    )
+
+
+@DATA.register_func('AC_Tiszadob_eval_dataset')
+def build_AC_Tiszadob_eval_dataset(C):
+    configs = get_common_eval_configs(C)
+    configs.update(dict(
+        root=constants.IMDB_AIRCHANGE
+    ))
+
+    from data.ac_tiszadob import AC_TiszadobDataset
+    return data.DataLoader(
+        AC_TiszadobDataset(**configs),
+        batch_size=1,
+        shuffle=False,
+        num_workers=0,
+        pin_memory=False,
+        drop_last=False
+    )
+
+
+@DATA.register_func('OSCD_train_dataset')
+def build_OSCD_train_dataset(C):
+    configs = get_common_train_configs(C)
+    configs.update(dict(
+        transforms=(Compose(
+            Crop(C['crop_size']),
+            FlipRotate()
+        ), None, None),
+        root=constants.IMDB_OSCD,
+        cache_level=2,
+    ))
+
+    from data.oscd import OSCDDataset
+    if C['num_workers'] != 0:
+        R['Logger'].warn("Will use num_workers=0.")
+    return data.DataLoader(
+        OSCDDataset(**configs),
+        batch_size=C['batch_size'],
+        shuffle=True,
+        num_workers=0,  # Disable multiprocessing
+        pin_memory=C['device']!='cpu',
+        drop_last=True
+    )
+
+
+@DATA.register_func('OSCD_eval_dataset')
+def build_OSCD_eval_dataset(C):
+    configs = get_common_eval_configs(C)
+    configs.update(dict(
+        root=constants.IMDB_OSCD,
+        cache_level=2
+    ))
+
+    from data.oscd import OSCDDataset
+    return data.DataLoader(
+        OSCDDataset(**configs),
+        batch_size=1,
+        shuffle=False,
+        num_workers=0,
+        pin_memory=False,
+        drop_last=False
+    )
diff --git a/src/impl/builders/model_builders.py b/src/impl/builders/model_builders.py
new file mode 100644
index 0000000..e972da2
--- /dev/null
+++ b/src/impl/builders/model_builders.py
@@ -0,0 +1,39 @@
+# Custom model builders
+
+from core.misc import MODELS
+
+
+@MODELS.register_func('Unet_model')
+def build_Unet_model(C):
+    from models.unet import Unet
+    return Unet(6, 2)
+
+
+@MODELS.register_func('Unet_OSCD_model')
+def build_Unet_OSCD_model(C):
+    from models.unet import Unet
+    return Unet(26, 2)
+
+
+@MODELS.register_func('SiamUnet_diff_model')
+def build_SiamUnet_diff_model(C):
+    from models.siamunet_diff import SiamUnet_diff
+    return SiamUnet_diff(3, 2)
+
+
+@MODELS.register_func('SiamUnet_diff_OSCD_model')
+def build_SiamUnet_diff_OSCD_model(C):
+    from models.siamunet_diff import SiamUnet_diff
+    return SiamUnet_diff(13, 2)
+
+
+@MODELS.register_func('SiamUnet_conc_model')
+def build_SiamUnet_conc_model(C):
+    from models.siamunet_conc import SiamUnet_conc
+    return SiamUnet_conc(3, 2)
+
+
+@MODELS.register_func('SiamUnet_conc_OSCD_model')
+def build_SiamUnet_conc_OSCD_model(C):
+    from models.siamunet_conc import SiamUnet_conc
+    return SiamUnet_conc(13, 2)
\ No newline at end of file
diff --git a/src/impl/builders/optim_builders.py b/src/impl/builders/optim_builders.py
new file mode 100644
index 0000000..8e0f330
--- /dev/null
+++ b/src/impl/builders/optim_builders.py
@@ -0,0 +1,3 @@
+# Custom optimizer builders
+
+from core.misc import OPTIMS
\ No newline at end of file
diff --git a/src/impl/trainers/__init__.py b/src/impl/trainers/__init__.py
new file mode 100644
index 0000000..281e1d4
--- /dev/null
+++ b/src/impl/trainers/__init__.py
@@ -0,0 +1,8 @@
+from core.misc import R
+from .cd_trainer import CDTrainer
+
+__all__ = []
+
+trainer_switcher = R['Trainer_switcher']
+# Append the (pred, trainer) pairs to trainer_switcher
+trainer_switcher.add_item(lambda C: not C['tb_on'] or C['dataset'] != 'OSCD', CDTrainer)
\ No newline at end of file
diff --git a/src/impl/trainers/cd_trainer.py b/src/impl/trainers/cd_trainer.py
new file mode 100644
index 0000000..e071086
--- /dev/null
+++ b/src/impl/trainers/cd_trainer.py
@@ -0,0 +1,222 @@
+import os
+import os.path as osp
+from random import randint
+from functools import partial
+
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from torch.optim import lr_scheduler
+from skimage import io
+from tqdm import tqdm
+
+from core.trainer import Trainer
+from utils.data_utils import (
+    to_array, to_pseudo_color, 
+    normalize_8bit,
+    quantize_8bit as quantize
+)
+from utils.utils import mod_crop, HookHelper
+from utils.metrics import (AverageMeter, Precision, Recall, Accuracy, F1Score)
+
+
+class CDTrainer(Trainer):
+    def __init__(self, settings):
+        super().__init__(settings['model'], settings['dataset'], 'NLL', settings['optimizer'], settings)
+        self.tb_on = (hasattr(self.logger, 'log_path') or self.debug) and self.ctx['tb_on']
+        if self.tb_on:
+            # Initialize tensorboard
+            if hasattr(self.logger, 'log_path'):
+                tb_dir = self.path(
+                    'log', 
+                    osp.join('tb', osp.splitext(osp.basename(self.logger.log_path))[0], '.'), 
+                    name='tb', 
+                    auto_make=True, 
+                    suffix=False
+                )
+            else:
+                tb_dir = self.path(
+                    'log', 
+                    osp.join('tb', 'debug', '.'), 
+                    name='tb', 
+                    auto_make=True, 
+                    suffix=False
+                )
+                for root, dirs, files in os.walk(self.gpc.get_dir('tb'), False):
+                    for f in files:
+                        os.remove(osp.join(root, f))
+                    for d in dirs:
+                        os.rmdir(osp.join(root, d))
+            self.tb_writer = SummaryWriter(tb_dir)
+            self.logger.show_nl("\nTensorboard logdir: {}".format(osp.abspath(self.gpc.get_dir('tb'))))
+            self.tb_intvl = int(self.ctx['tb_intvl'])
+            
+            # Global steps
+            self.train_step = 0
+            self.eval_step = 0
+
+        # Whether to save network output
+        self.out_dir = self.ctx['out_dir']
+        self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug
+
+        self.val_iters = float(self.ctx['val_iters'])
+            
+    def init_learning_rate(self):
+        # Set learning rate adjustment strategy
+        if self.ctx['lr_mode'] == 'const':
+            return self.lr
+        else:
+            def _simple_scheduler_step(self, epoch, acc):
+                self.scheduler.step()
+                return self.scheduler.get_lr()[0]
+            def _scheduler_step_with_acc(self, epoch, acc):
+                self.scheduler.step(acc)
+                # Only return the lr of the first param group
+                return self.optimizer.param_groups[0]['lr']
+            lr_mode = self.ctx['lr_mode']
+            if lr_mode == 'step':
+                self.scheduler = lr_scheduler.StepLR( 
+                    self.optimizer, self.ctx['step'], gamma=0.5
+                )
+                self.adjust_learning_rate = partial(_simple_scheduler_step, self)
+            elif lr_mode == 'exp':
+                self.scheduler = lr_scheduler.ExponentialLR(
+                    self.optimizer, gamma=0.9
+                )
+                self.adjust_learning_rate = partial(_simple_scheduler_step, self)
+            elif lr_mode == 'plateau':
+                if self.load_checkpoint:
+                    self.logger.warn("The old state of the lr scheduler will not be restored.")
+                self.scheduler = lr_scheduler.ReduceLROnPlateau(
+                    self.optimizer, mode='max', factor=0.5, threshold=1e-4
+                )
+                self.adjust_learning_rate = partial(_scheduler_step_with_acc, self)
+                return self.optimizer.param_groups[0]['lr']
+            else:
+                raise NotImplementedError
+
+            if self.start_epoch > 0:
+                # Restore previous state
+                # FIXME: This will trigger pytorch warning "Detected call of `lr_scheduler.step()` 
+                # before `optimizer.step()`" in pytorch 1.1.0 and later.
+                # Perhaps I should store the state of scheduler to a checkpoint file and restore it from disk.
+                last_epoch = self.start_epoch
+                while self.scheduler.last_epoch < last_epoch:
+                    self.scheduler.step()
+            return self.scheduler.get_lr()[0]
+
+    def train_epoch(self, epoch):
+        losses = AverageMeter()
+        len_train = len(self.train_loader)
+        width = len(str(len_train))
+        start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
+        pb = tqdm(self.train_loader)
+        
+        self.model.train()
+        
+        for i, (t1, t2, tar) in enumerate(pb):
+            t1, t2, tar = t1.to(self.device), t2.to(self.device), tar.to(self.device)
+            
+            show_imgs_on_tb = self.tb_on and (i%self.tb_intvl == 0)
+            
+            prob = self.model(t1, t2)
+            
+            loss = self.criterion(prob, tar)
+            
+            losses.update(loss.item(), n=self.batch_size)
+
+            self.optimizer.zero_grad()
+            loss.backward()
+            self.optimizer.step()
+
+            desc = (start_pattern+" Loss: {:.4f} ({:.4f})").format(i+1, len_train, losses.val, losses.avg)
+
+            pb.set_description(desc)
+            if i % max(1, len_train//10) == 0:
+                self.logger.dump(desc)
+
+            if self.tb_on:
+                # Write to tensorboard
+                self.tb_writer.add_scalar("Train/loss", losses.val, self.train_step)
+                if show_imgs_on_tb:
+                    self.tb_writer.add_image("Train/t1_picked", normalize_8bit(t1.detach()[0]), self.train_step)
+                    self.tb_writer.add_image("Train/t2_picked", normalize_8bit(t2.detach()[0]), self.train_step)
+                    self.tb_writer.add_image("Train/labels_picked", tar[0].unsqueeze(0), self.train_step)
+                    self.tb_writer.flush()
+                self.train_step += 1
+
+    def evaluate_epoch(self, epoch):
+        self.logger.show_nl("Epoch: [{0}]".format(epoch))
+        losses = AverageMeter()
+        len_eval = len(self.eval_loader)
+        width = len(str(len_eval))
+        start_pattern = "[{{:>{0}}}/{{:>{0}}}]".format(width)
+        pb = tqdm(self.eval_loader)
+
+        # Construct metrics
+        metrics = (Precision(), Recall(), F1Score(), Accuracy())
+
+        self.model.eval()
+
+        with torch.no_grad():
+            for i, (name, t1, t2, tar) in enumerate(pb):
+                if self.is_training and i >= self.val_iters:
+                    # This saves time
+                    pb.close()
+                    self.logger.warn("Evaluation ends early.")
+                    break
+                t1, t2, tar = t1.to(self.device), t2.to(self.device), tar.to(self.device)
+
+                prob = self.model(t1, t2)
+
+                loss = self.criterion(prob, tar)
+                losses.update(loss.item(), n=self.batch_size)
+
+                # Convert to numpy arrays
+                cm = to_array(torch.argmax(prob[0], 0)).astype('uint8')
+                tar = to_array(tar[0]).astype('uint8')
+
+                for m in metrics:
+                    m.update(cm, tar)
+
+                desc = (start_pattern+" Loss: {:.4f} ({:.4f})").format(i+1, len_eval, losses.val, losses.avg)
+                for m in metrics:
+                    desc += " {} {:.4f} ({:.4f})".format(m.__name__, m.val, m.avg)
+
+                pb.set_description(desc)
+                self.logger.dump(desc)
+
+                if self.tb_on:
+                    self.tb_writer.add_image("Eval/t1", normalize_8bit(t1[0]), self.eval_step)
+                    self.tb_writer.add_image("Eval/t2", normalize_8bit(t2[0]), self.eval_step)
+                    self.tb_writer.add_image("Eval/labels", quantize(tar), self.eval_step, dataformats='HW')
+                    prob = quantize(to_array(torch.exp(prob[0,1])))
+                    self.tb_writer.add_image("Eval/prob", to_pseudo_color(prob), self.eval_step, dataformats='HWC')
+                    self.tb_writer.add_image("Eval/cm", quantize(cm), self.eval_step, dataformats='HW')
+                    self.eval_step += 1
+                
+                if self.save:
+                    self.save_image(name[0], quantize(cm), epoch)
+
+        if self.tb_on:
+            self.tb_writer.add_scalar("Eval/loss", losses.avg, self.eval_step)
+            self.tb_writer.add_scalars("Eval/metrics", {m.__name__.lower(): m.avg for m in metrics}, self.eval_step)
+
+        return metrics[2].avg   # F1-score
+
+    def save_image(self, file_name, image, epoch):
+        file_path = osp.join(
+            'epoch_{}'.format(epoch),
+            self.out_dir,
+            file_name
+        )
+        out_path = self.path(
+            'out', file_path,
+            suffix=not self.ctx['suffix_off'],
+            auto_make=True,
+            underline=True
+        )
+        return io.imsave(out_path, image)
+
+    # def __del__(self):
+    #     if self.tb_on:
+    #         self.tb_writer.close()
\ No newline at end of file
diff --git a/src/losses.py b/src/losses.py
deleted file mode 100644
index 0ab648b..0000000
--- a/src/losses.py
+++ /dev/null
@@ -1,5 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
diff --git a/src/test.py b/src/test.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/train.py b/src/train.py
index 84b7769..125c803 100644
--- a/src/train.py
+++ b/src/train.py
@@ -1,142 +1,18 @@
 #!/usr/bin/env python3
-import argparse
 import os
 import shutil
 import random
-import ast
-from os.path import basename, exists, splitext
+import os.path as osp
 
 import torch
 import torch.backends.cudnn as cudnn
 import numpy as np
-import yaml
 
-from core.trainers import CDTrainer
-from utils.misc import OutPathGetter, Logger, register
-
-
-def read_config(config_path):
-    with open(config_path, 'r') as f:
-        cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
-    return cfg or {}
-
-
-def parse_config(cfg_name, cfg):
-    # Parse the name of config file
-    sp = splitext(cfg_name)[0].split('_')
-    if len(sp) >= 2:
-        cfg.setdefault('tag', sp[1])
-        cfg.setdefault('suffix', '_'.join(sp[2:]))
-    
-    return cfg
-
-
-def parse_args():
-    # Training settings
-    parser = argparse.ArgumentParser()
-    parser.add_argument('cmd', choices=['train', 'val'])
-
-    # Data
-    # Common
-    group_data = parser.add_argument_group('data')
-    group_data.add_argument('-d', '--dataset', type=str, default='OSCD')
-    group_data.add_argument('-p', '--crop-size', type=int, default=256, metavar='P', 
-                        help='patch size (default: %(default)s)')
-    group_data.add_argument('--num-workers', type=int, default=8)
-    group_data.add_argument('--repeats', type=int, default=100)
-
-    # Optimizer
-    group_optim = parser.add_argument_group('optimizer')
-    group_optim.add_argument('--optimizer', type=str, default='Adam')
-    group_optim.add_argument('--lr', type=float, default=1e-4, metavar='LR',
-                        help='learning rate (default: %(default)s)')
-    group_optim.add_argument('--lr-mode', type=str, default='const')
-    group_optim.add_argument('--weight-decay', default=1e-4, type=float,
-                        metavar='W', help='weight decay (default: %(default)s)')
-    group_optim.add_argument('--step', type=int, default=200)
-
-    # Training related
-    group_train = parser.add_argument_group('training related')
-    group_train.add_argument('--batch-size', type=int, default=8, metavar='B',
-                        help='input batch size for training (default: %(default)s)')
-    group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE',
-                        help='number of epochs to train (default: %(default)s)')
-    group_train.add_argument('--load-optim', action='store_true')
-    group_train.add_argument('--save-optim', action='store_true')
-    group_train.add_argument('--resume', default='', type=str, metavar='PATH',
-                        help='path to latest checkpoint')
-    group_train.add_argument('--anew', action='store_true',
-                        help='clear history and start from epoch 0 with the checkpoint loaded')
-    group_train.add_argument('--track_intvl', type=int, default=50)
-    group_train.add_argument('--device', type=str, default='cpu')
-    group_train.add_argument('--metrics', type=str, default='F1Score+Accuracy+Recall+Precision')
-
-    # Experiment
-    group_exp = parser.add_argument_group('experiment related')
-    group_exp.add_argument('--exp-dir', default='../exp/')
-    group_exp.add_argument('-o', '--out-dir', default='')
-    group_exp.add_argument('--tag', type=str, default='')
-    group_exp.add_argument('--suffix', type=str, default='')
-    group_exp.add_argument('--exp-config', type=str, default='')
-    group_exp.add_argument('--save-on', action='store_true')
-    group_exp.add_argument('--log-off', action='store_true')
-    group_exp.add_argument('--suffix-off', action='store_true')
-
-    # Criterion
-    group_critn = parser.add_argument_group('criterion related')
-    group_critn.add_argument('--criterion', type=str, default='NLL')
-    group_critn.add_argument('--weights', type=str, default=(1.0, 1.0))
-
-    # Model
-    group_model = parser.add_argument_group('model')
-    group_model.add_argument('--model', type=str, default='siamunet_conc')
-    group_model.add_argument('--num-feats-in', type=int, default=13)
-
-    args = parser.parse_args()
-
-    if exists(args.exp_config):
-        cfg = read_config(args.exp_config)
-        cfg = parse_config(basename(args.exp_config), cfg)
-        # Settings from cfg file overwrite those in args
-        # Note that the non-default values will not be affected
-        parser.set_defaults(**cfg)  # Reset part of the default values
-        args = parser.parse_args()  # Parse again
-
-    # Handle args.weights
-    if isinstance(args.weights, str):
-        args.weights = ast.literal_eval(args.weights)
-    args.weights = tuple(args.weights)
-
-    return args
-
-
-def set_gpc_and_logger(args):
-    gpc = OutPathGetter(
-            root=os.path.join(args.exp_dir, args.tag), 
-            suffix=args.suffix)
-
-    log_dir = '' if args.log_off else gpc.get_dir('log')
-    logger = Logger(
-        scrn=True,
-        log_dir=log_dir,
-        phase=args.cmd
-    )
-
-    register('GPC', gpc)
-    register('LOGGER', logger)
-
-    return gpc, logger
+from core.misc import R
+from core.config import parse_args
     
 
 def main():
-    args = parse_args()
-    gpc, logger = set_gpc_and_logger(args)
-
-    if args.exp_config:
-        # Make a copy of the config file
-        cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False)
-        shutil.copy(args.exp_config, cfg_path)
-
     # Set random seed
     RNG_SEED = 1
     random.seed(RNG_SEED)
@@ -147,14 +23,41 @@ def main():
     cudnn.deterministic = True
     cudnn.benchmark = False
 
-    try:
-        trainer = CDTrainer(args.model, args.dataset, args.optimizer, args)
-        trainer.run()
-    except BaseException as e:
-        import traceback
-        # Catch ALL kinds of exceptions
-        logger.fatal(traceback.format_exc())
-        exit(1)
+    # Parse commandline arguments
+    def parser_configurator(parser):
+        parser.add_argument('--crop_size', type=int, default=256, metavar='P', 
+                            help="patch size (default: %(default)s)")
+        parser.add_argument('--tb_on', action='store_true')
+        parser.add_argument('--tb_intvl', type=int, default=100)
+        parser.add_argument('--suffix_off', action='store_true')
+        parser.add_argument('--lr_mode', type=str, default='const')
+        parser.add_argument('--step', type=int, default=200)
+        parser.add_argument('--save_on', action='store_true')
+        parser.add_argument('--out_dir', default='')
+        parser.add_argument('--val_iters', type=int, default=16)
+
+        return parser
+    args = parse_args(parser_configurator)
+
+    trainer = R['Trainer_switcher'](args)
+
+    if trainer is not None:
+        if args['exp_config']:
+            # Make a copy of the config file
+            cfg_path = osp.join(trainer.gpc.root, osp.basename(args['exp_config']))
+            shutil.copy(args['exp_config'], cfg_path)
+        try:
+            trainer.run()
+        except BaseException as e:
+            import traceback
+            # Catch ALL kinds of exceptions
+            trainer.logger.fatal(traceback.format_exc())
+            if args['debug_on']:
+                breakpoint()
+            exit(1)
+    else:
+        raise NotImplementedError("Cannot find an appropriate trainer!")
+
 
 if __name__ == '__main__':
     main()
\ No newline at end of file
diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py
new file mode 100644
index 0000000..3c41389
--- /dev/null
+++ b/src/utils/data_utils.py
@@ -0,0 +1,63 @@
+import torch
+import numpy as np
+import cv2
+from scipy.io import loadmat
+from skimage.io import imread
+
+
+def default_loader(path_):
+    return imread(path_)
+
+
+def mat_loader(path_):
+    return loadmat(path_)
+
+
+def make_onehot(index_map, n):
+    # Only deals with tensors with no batch dim
+    old_size = index_map.size()
+    z = torch.zeros(n, *old_size[-2:]).type_as(index_map)
+    z.scatter_(0, index_map, 1)
+    return z
+    
+
+def to_tensor(arr):
+    if any(s < 0 for s in arr.strides):
+        # Enforce contiguousness since currently torch.from_numpy doesn't support negative strides.
+        arr = np.ascontiguousarray(arr)
+    if arr.ndim < 3:
+        return torch.from_numpy(arr)
+    elif arr.ndim == 3:
+        return torch.from_numpy(np.transpose(arr, (2,0,1)))
+    else:
+        raise NotImplementedError
+
+
+def to_array(tensor):
+    if tensor.ndimension() <= 4:
+        arr = tensor.data.cpu().numpy()
+        if tensor.ndimension() in (3, 4):
+            arr = np.moveaxis(arr, -3, -1)
+        return arr
+    else:
+        raise NotImplementedError
+
+
+def normalize_minmax(x):
+    EPS = 1e-32
+    return (x-x.min()) / (x.max()-x.min()+EPS)
+
+
+def normalize_8bit(x):
+    return x/255.0
+
+
+def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET):
+    # Reverse channels to convert BGR to RGB
+    return cv2.applyColorMap(gray, color_map)[...,::-1]
+
+
+def quantize_8bit(x):
+    # [0.0,1.0] float => [0,255] uint8
+    # or [0,1] int => [0,255] uint8
+    return (x*255).astype('uint8')
\ No newline at end of file
diff --git a/src/utils/losses.py b/src/utils/losses.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/metrics.py b/src/utils/metrics.py
index 9c71f97..f4668b2 100644
--- a/src/utils/metrics.py
+++ b/src/utils/metrics.py
@@ -5,10 +5,11 @@ from sklearn import metrics
 
 
 class AverageMeter:
-    def __init__(self, callback=None):
+    def __init__(self, callback=None, calc_avg=True):
         super().__init__()
         if callback is not None:
             self.compute = callback
+        self.calc_avg = calc_avg
         self.reset()
 
     def compute(self, *args):
@@ -19,9 +20,10 @@ class AverageMeter:
 
     def reset(self):
         self.val = 0
-        self.avg = 0
         self.sum = 0
         self.count = 0
+        if self.calc_avg:
+            self.avg = 0
 
         for attr in filter(lambda a: not a.startswith('__'), dir(self)):
             obj = getattr(self, attr)
@@ -32,23 +34,24 @@ class AverageMeter:
         self.val = self.compute(*args)
         self.sum += self.val * n
         self.count += n
-        self.avg = self.sum / self.count
+        if self.calc_avg:
+            self.avg = self.sum / self.count
 
     def __repr__(self):
-        return 'val: {} avg: {} cnt: {}'.format(self.val, self.avg, self.count)
+        return "val: {} avg: {} cnt: {}".format(self.val, self.avg, self.count)
 
 
 # These metrics only for numpy arrays
 class Metric(AverageMeter):
     __name__ = 'Metric'
     def __init__(self, n_classes=2, mode='separ', reduction='binary'):
-        super().__init__(None)
-        self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)))
         assert mode in ('accum', 'separ')
-        self.mode = mode
         assert reduction in ('mean', 'none', 'binary')
+        super().__init__(None, mode!='accum')
+        self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False)
+        self.mode = mode
         if reduction == 'binary' and n_classes != 2:
-            raise ValueError("binary reduction only works in 2-class cases")
+            raise ValueError("Binary reduction only works in 2-class cases.")
         self.reduction = reduction
     
     def _compute(self, cm):
@@ -68,10 +71,6 @@ class Metric(AverageMeter):
     def update(self, pred, true, n=1):
         self._cm.update(true.ravel(), pred.ravel())
         if self.mode == 'accum':
-            # Note that accumulation mode is special in that metric.val saves historical information.
-            # Therefore, metric.avg IS USUALLY NOT THE "AVERAGE" VALUE YOU WANT!!! 
-            # Instead, metric.val is the averaged result in the sense of metric.avg in separ mode, 
-            # while metric.avg can be considered as some average of average.
             cm = self._cm.sum
         elif self.mode == 'separ':
             cm = self._cm.val
@@ -80,7 +79,7 @@ class Metric(AverageMeter):
         super().update(cm, n=n)
 
     def __repr__(self):
-        return self.__name__+' '+super().__repr__()
+        return self.__name__+" "+super().__repr__()
 
 
 class Precision(Metric):
diff --git a/src/utils/utils.py b/src/utils/utils.py
index 6d7167f..eb08225 100644
--- a/src/utils/utils.py
+++ b/src/utils/utils.py
@@ -1,4 +1,5 @@
 import math
+import weakref
 
 import torch
 import numpy as np
@@ -17,4 +18,64 @@ def mod_crop(blob, N):
             h, w = blob.shape[-2:]
             nh = h - h % N
             nw = w - w % N
-            return blob[..., :nh, :nw]
\ No newline at end of file
+            return blob[..., :nh, :nw]
+            
+
+class HookHelper:
+    def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
+        self.model = weakref.proxy(model)
+        self.fetch_dict = fetch_dict
+        # Subclass the built-in list to make it weak referenceable
+        class _list(list):
+            pass
+        for entry in self.fetch_dict.values():
+            # entry is expected to be a string or a non-nested tuple
+            if isinstance(entry, tuple):
+                for key in entry:
+                    out_dict[key] = _list()
+            else:
+                out_dict[entry] = _list()
+        self.out_dict = weakref.WeakValueDictionary(out_dict)
+        self._handles = []
+
+        assert hook_type in ('forward_in', 'forward_out', 'backward_out')
+
+        def _proto_hook(x, entry):
+            # x should be a tensor or a tuple
+            if isinstance(entry, tuple):
+                for key, f in zip(entry, x):
+                    self.out_dict[key].append(f.detach().clone())
+            else:
+                self.out_dict[entry].append(x.detach().clone())
+
+        def _forward_in_hook(m, x, y, entry):
+            # x is a tuple
+            return _proto_hook(x[0] if len(x)==1 else x, entry)
+
+        def _forward_out_hook(m, x, y, entry):
+            # y is a tensor or a tuple
+            return _proto_hook(y, entry)
+
+        def _backward_out_hook(m, grad_in, grad_out, entry):
+            # grad_out is a tuple
+            return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry)
+
+        self._hook_func, self._reg_func_name = {
+            'forward_in': (_forward_in_hook, 'register_forward_hook'),
+            'forward_out': (_forward_out_hook, 'register_forward_hook'),
+            'backward_out': (_backward_out_hook, 'register_backward_hook'),
+        }[hook_type]
+
+    def __enter__(self):
+        for name, module in self.model.named_modules():
+            if name in self.fetch_dict:
+                entry = self.fetch_dict[name]
+                self._handles.append(
+                    getattr(module, self._reg_func_name)(
+                        lambda *args, entry=entry: self._hook_func(*args, entry=entry)
+                    )
+                )
+
+    def __exit__(self, exc_type, exc_val, ext_tb):
+        for handle in self._handles:
+            handle.remove()
\ No newline at end of file
diff --git a/train9.sh b/train9.sh
deleted file mode 100755
index 6edaf03..0000000
--- a/train9.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# # Activate conda environment
-# source activate $ME
-
-# Change directory
-cd src
-
-# Define constants
-ARCHS=("siamdiff" "siamconc" "EF")
-DATASETS=("AC_Szada" "AC_Tiszadob" "OSCD")
-
-# LOOP
-for arch in ${ARCHS[@]}
-do
-    for dataset in ${DATASETS[@]}
-    do
-        python train.py train --exp-config ../configs/config_${arch}_${dataset}.yaml
-    done
-done
\ No newline at end of file
-- 
GitLab