Skip to content
Snippets Groups Projects
Commit 9fb61787 authored by Bobholamovic's avatar Bobholamovic
Browse files

Update custom framework

parent c8bfd29b
No related branches found
No related tags found
No related merge requests found
...@@ -130,9 +130,9 @@ dmypy.json ...@@ -130,9 +130,9 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# Config files # # Config files
config*.yaml # config*.yaml
!/config_base.yaml # !/config_base.yaml
# Experiment folder # Experiment folder
/exp/ /exp/
\ No newline at end of file
...@@ -45,3 +45,8 @@ python train.py val --exp-config ../config_base.yaml --resume path_to_checkpoint ...@@ -45,3 +45,8 @@ python train.py val --exp-config ../config_base.yaml --resume path_to_checkpoint
``` ```
You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`. You can find the checkpoints in `exp/base/weights/`, the log files in `exp/base/logs`, and the output change maps in `exp/outs`.
---
# Changed
2020.3.14 Add the configuration files of my experiments.
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 6
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: EF
num_feats_in: 26
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 13
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Szada
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: AC_Tiszadob
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 3
\ No newline at end of file
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 112
num_workers: 1
repeats: 3200
# Optimizer
optimizer: SGD
lr: 0.001
lr_mode: const
weight_decay: 0.0005
step: 2
# Training related
batch_size: 32
num_epochs: 10
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_diff
num_feats_in: 13
\ No newline at end of file
...@@ -11,6 +11,7 @@ import torch.utils.data as data ...@@ -11,6 +11,7 @@ import torch.utils.data as data
import constants import constants
import utils.metrics as metrics import utils.metrics as metrics
from utils.misc import R from utils.misc import R
from data.augmentation import *
class _Desc: class _Desc:
...@@ -38,16 +39,6 @@ def _generator_deco(func_name): ...@@ -38,16 +39,6 @@ def _generator_deco(func_name):
return _wrapper return _wrapper
def _mark(func):
func.__marked__ = True
return func
def _unmark(func):
func.__marked__ = False
return func
# Duck typing # Duck typing
class Duck(tuple): class Duck(tuple):
__ducktype__ = object __ducktype__ = object
...@@ -56,6 +47,12 @@ class Duck(tuple): ...@@ -56,6 +47,12 @@ class Duck(tuple):
raise TypeError("please check the input type") raise TypeError("please check the input type")
return tuple.__new__(cls, args) return tuple.__new__(cls, args)
def __add__(self, tup):
raise NotImplementedError
def __mul__(self, tup):
raise NotImplementedError
class DuckMeta(type): class DuckMeta(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
...@@ -63,61 +60,43 @@ class DuckMeta(type): ...@@ -63,61 +60,43 @@ class DuckMeta(type):
for k, v in getmembers(bases[0]): for k, v in getmembers(bases[0]):
if k.startswith('__'): if k.startswith('__'):
continue continue
if k in attrs and hasattr(attrs[k], '__marked__'):
if attrs[k].__marked__:
continue
if isgeneratorfunction(v): if isgeneratorfunction(v):
attrs[k] = _generator_deco(k) attrs.setdefault(k, _generator_deco(k))
elif isfunction(v): elif isfunction(v):
attrs[k] = _func_deco(k) attrs.setdefault(k, _func_deco(k))
else: else:
attrs[k] = _Desc(k) attrs.setdefault(k, _Desc(k))
attrs['__ducktype__'] = bases[0] attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs) return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module, metaclass=DuckMeta): class DuckModel(nn.Module):
DELIM = ':' def __init__(self, *models):
@_mark super().__init__()
def load_state_dict(self, state_dict): ## XXX: The state_dict will be a little larger in size
dicts = [dict() for _ in range(len(self))] # Since some extra bytes are stored in every key
for k, v in state_dict.items(): self._m = nn.ModuleList(models)
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k) def __len__(self):
i = int(i) return len(self._m)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
@_mark def __getitem__(self, idx):
def state_dict(self): return self._m[idx]
dict_ = dict()
for i, ins in enumerate(self): def __repr__(self):
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()}) return repr(self._m)
return dict_
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta): class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
DELIM = ':' # Cuz this is an instance method
@property @property
def param_groups(self): def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self)) return list(chain.from_iterable(ins.param_groups for ins in self))
@_mark # This is special in dispatching
def state_dict(self): def load_state_dict(self, state_dicts):
dict_ = dict() for optim, state_dict in zip(self, state_dicts):
for i, ins in enumerate(self): optim.load_state_dict(state_dict)
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
return dict_
@_mark
def load_state_dict(self, state_dict):
dicts = [dict() for _ in range(len(self))]
for k, v in state_dict.items():
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k)
i = int(i)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
class DuckCriterion(nn.Module, metaclass=DuckMeta): class DuckCriterion(nn.Module, metaclass=DuckMeta):
...@@ -205,7 +184,6 @@ def _get_basic_configs(ds_name, C): ...@@ -205,7 +184,6 @@ def _get_basic_configs(ds_name, C):
def single_train_ds_factory(ds_name, C): def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip
ds_name = ds_name.strip() ds_name = ds_name.strip()
module = _import_module('data', ds_name) module = _import_module('data', ds_name)
dataset = getattr(module, ds_name+'Dataset') dataset = getattr(module, ds_name+'Dataset')
......
...@@ -20,7 +20,7 @@ class Trainer: ...@@ -20,7 +20,7 @@ class Trainer:
super().__init__() super().__init__()
context = deepcopy(settings) context = deepcopy(settings)
self.ctx = MappingProxyType(vars(context)) self.ctx = MappingProxyType(vars(context))
self.phase = context.cmd self.mode = ('train', 'val').index(context.cmd)
self.logger = R['LOGGER'] self.logger = R['LOGGER']
self.gpc = R['GPC'] # Global Path Controller self.gpc = R['GPC'] # Global Path Controller
...@@ -44,27 +44,43 @@ class Trainer: ...@@ -44,27 +44,43 @@ class Trainer:
self.model.to(self.device) self.model.to(self.device)
self.criterion = critn_factory(criterion, context) self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device) self.criterion.to(self.device)
self.optimizer = optim_factory(optimizer, self.model, context)
self.metrics = metric_factory(context.metrics, context) self.metrics = metric_factory(context.metrics, context)
if self.is_training:
self.train_loader = data_factory(dataset, 'train', context) self.train_loader = data_factory(dataset, 'train', context)
self.val_loader = data_factory(dataset, 'val', context) self.val_loader = data_factory(dataset, 'val', context)
self.optimizer = optim_factory(optimizer, self.model, context)
else:
self.val_loader = data_factory(dataset, 'val', context)
self.start_epoch = 0 self.start_epoch = 0
self._init_max_acc = 0.0 self._init_max_acc_and_epoch = (0.0, 0)
@property
def is_training(self):
return self.mode == 0
def train_epoch(self): def train_epoch(self, epoch):
raise NotImplementedError raise NotImplementedError
def validate_epoch(self, epoch=0, store=False): def validate_epoch(self, epoch=0, store=False):
raise NotImplementedError raise NotImplementedError
def _write_prompt(self):
self.logger.dump(input("\nWrite some notes: "))
def run(self):
if self.is_training:
self._write_prompt()
self.train()
else:
self.evaluate()
def train(self): def train(self):
if self.load_checkpoint: if self.load_checkpoint:
self._resume_from_checkpoint() self._resume_from_checkpoint()
max_acc = self._init_max_acc max_acc, best_epoch = self._init_max_acc_and_epoch
best_epoch = self.get_ckp_epoch()
for epoch in range(self.start_epoch, self.num_epochs): for epoch in range(self.start_epoch, self.num_epochs):
lr = self._adjust_learning_rate(epoch) lr = self._adjust_learning_rate(epoch)
...@@ -72,7 +88,7 @@ class Trainer: ...@@ -72,7 +88,7 @@ class Trainer:
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr)) self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
# Train for one epoch # Train for one epoch
self.train_epoch() self.train_epoch(epoch)
# Clear the history of metric objects # Clear the history of metric objects
for m in self.metrics: for m in self.metrics:
...@@ -90,14 +106,14 @@ class Trainer: ...@@ -90,14 +106,14 @@ class Trainer:
acc, epoch, max_acc, best_epoch)) acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch # The checkpoint saves next epoch
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), max_acc, epoch+1, is_best) self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best)
def validate(self): def evaluate(self):
if self.checkpoint: if self.checkpoint:
if self._resume_from_checkpoint(): if self._resume_from_checkpoint():
self.validate_epoch(self.get_ckp_epoch(), self.save) self.validate_epoch(self.ckp_epoch, self.save)
else: else:
self.logger.warning("no checkpoint assigned!") self.logger.warning("Warning: no checkpoint assigned!")
def _adjust_learning_rate(self, epoch): def _adjust_learning_rate(self, epoch):
if self.ctx['lr_mode'] == 'step': if self.ctx['lr_mode'] == 'step':
...@@ -114,13 +130,14 @@ class Trainer: ...@@ -114,13 +130,14 @@ class Trainer:
return lr return lr
def _resume_from_checkpoint(self): def _resume_from_checkpoint(self):
## XXX: This could be slow!
if not os.path.isfile(self.checkpoint): if not os.path.isfile(self.checkpoint):
self.logger.error("=> no checkpoint found at '{}'".format(self.checkpoint)) self.logger.error("=> No checkpoint was found at '{}'.".format(self.checkpoint))
return False return False
self.logger.show("=> loading checkpoint '{}'".format( self.logger.show("=> Loading checkpoint '{}'".format(
self.checkpoint)) self.checkpoint))
checkpoint = torch.load(self.checkpoint) checkpoint = torch.load(self.checkpoint, map_location=self.device)
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
ckp_dict = checkpoint.get('state_dict', checkpoint) ckp_dict = checkpoint.get('state_dict', checkpoint)
...@@ -129,32 +146,35 @@ class Trainer: ...@@ -129,32 +146,35 @@ class Trainer:
num_to_update = len(update_dict) num_to_update = len(update_dict)
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)): if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if self.phase == 'val' and (num_to_update < len(state_dict)): if not self.is_training and (num_to_update < len(state_dict)):
self.logger.error("=> mismatched checkpoint for validation") self.logger.error("=> Mismatched checkpoint for evaluation")
return False return False
self.logger.warning("warning: trying to load an mismatched checkpoint") self.logger.warning("Warning: trying to load an mismatched checkpoint.")
if num_to_update == 0: if num_to_update == 0:
self.logger.error("=> no parameter is to be loaded") self.logger.error("=> No parameter is to be loaded.")
return False return False
else: else:
self.logger.warning("=> {} params are to be loaded".format(num_to_update)) self.logger.warning("=> {} params are to be loaded.".format(num_to_update))
elif (not self.ctx['anew']) or (self.phase != 'train'): elif (not self.ctx['anew']) or not self.is_training:
# Note in the non-anew mode, it is not guaranteed that the contained field self.start_epoch = checkpoint.get('epoch', 0)
# max_acc be the corresponding one of the loaded checkpoint. max_acc_and_epoch = checkpoint.get('max_acc', (0.0, self.ckp_epoch))
self.start_epoch = checkpoint.get('epoch', self.start_epoch) # For backward compatibility
self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc) if isinstance(max_acc_and_epoch, (float, int)):
if self.ctx['load_optim']: self._init_max_acc_and_epoch = (max_acc_and_epoch, self.ckp_epoch)
else:
self._init_max_acc_and_epoch = max_acc_and_epoch
if self.ctx['load_optim'] and self.is_training:
try: try:
# Note that weight decay might be modified here # Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
except KeyError: except KeyError:
self.logger.warning("warning: failed to load optimizer parameters") self.logger.warning("Warning: failed to load optimizer parameters.")
state_dict.update(update_dict) state_dict.update(update_dict)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
self.logger.show("=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format( self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {})".format(
self.checkpoint, self.get_ckp_epoch(), self._init_max_acc self.checkpoint, self.ckp_epoch, *self._init_max_acc_and_epoch
)) ))
return True return True
...@@ -183,7 +203,8 @@ class Trainer: ...@@ -183,7 +203,8 @@ class Trainer:
) )
) )
def get_ckp_epoch(self): @property
def ckp_epoch(self):
# Get current epoch of the checkpoint # Get current epoch of the checkpoint
# For dismatched ckp or no ckp, set to 0 # For dismatched ckp or no ckp, set to 0
return max(self.start_epoch-1, 0) return max(self.start_epoch-1, 0)
...@@ -207,7 +228,7 @@ class CDTrainer(Trainer): ...@@ -207,7 +228,7 @@ class CDTrainer(Trainer):
def __init__(self, arch, dataset, optimizer, settings): def __init__(self, arch, dataset, optimizer, settings):
super().__init__(arch, dataset, 'NLL', optimizer, settings) super().__init__(arch, dataset, 'NLL', optimizer, settings)
def train_epoch(self): def train_epoch(self, epoch):
losses = AverageMeter() losses = AverageMeter()
len_train = len(self.train_loader) len_train = len(self.train_loader)
pb = tqdm(self.train_loader) pb = tqdm(self.train_loader)
...@@ -246,7 +267,7 @@ class CDTrainer(Trainer): ...@@ -246,7 +267,7 @@ class CDTrainer(Trainer):
with torch.no_grad(): with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb): for i, (name, t1, t2, label) in enumerate(pb):
if self.phase == 'train' and i >= 16: if self.is_training and i >= 16:
# Do not validate all images on training phase # Do not validate all images on training phase
pb.close() pb.close()
self.logger.warning("validation ends early") self.logger.warning("validation ends early")
......
from os.path import join, expanduser, basename from os.path import join, expanduser, basename, exists
import torch import torch
import torch.utils.data as data import torch.utils.data as data
...@@ -16,9 +16,12 @@ class CDDataset(data.Dataset): ...@@ -16,9 +16,12 @@ class CDDataset(data.Dataset):
): ):
super().__init__() super().__init__()
self.root = expanduser(root) self.root = expanduser(root)
if not exists(self.root):
raise FileNotFoundError
self.phase = phase self.phase = phase
self.transforms = transforms self.transforms = list(transforms)
self.repeats = repeats self.transforms += [None]*(3-len(self.transforms))
self.repeats = int(repeats)
self.t1_list, self.t2_list, self.label_list = self._read_file_paths() self.t1_list, self.t2_list, self.label_list = self._read_file_paths()
self.len = len(self.label_list) self.len = len(self.label_list)
......
import random import random
import math
from functools import partial, wraps from functools import partial, wraps
import numpy as np import numpy as np
import cv2 import cv2
__all__ = [
'Compose', 'Choose',
'Scale', 'DiscreteScale',
'Flip', 'HorizontalFlip', 'VerticalFlip', 'Rotate',
'Crop', 'MSCrop',
'Shift', 'XShift', 'YShift',
'HueShift', 'SaturationShift', 'RGBShift', 'RShift', 'GShift', 'BShift',
'PCAJitter',
'ContraBrightScale', 'ContrastScale', 'BrightnessScale',
'AddGaussNoise'
]
rand = random.random rand = random.random
randi = random.randint randi = random.randint
choice = random.choice choice = random.choice
...@@ -11,11 +26,10 @@ uniform = random.uniform ...@@ -11,11 +26,10 @@ uniform = random.uniform
# gauss = random.gauss # gauss = random.gauss
gauss = random.normalvariate # This one is thread-safe gauss = random.normalvariate # This one is thread-safe
# The transformations treat numpy ndarrays only # The transformations treat 2-D or 3-D numpy ndarrays only, with the optional 3rd dim as the channel dim
def _istuple(x): return isinstance(x, (tuple, list)) def _istuple(x): return isinstance(x, (tuple, list))
class Transform: class Transform:
def __init__(self, random_state=False): def __init__(self, random_state=False):
self.random_state = random_state self.random_state = random_state
...@@ -28,6 +42,7 @@ class Transform: ...@@ -28,6 +42,7 @@ class Transform:
def _set_rand_param(self): def _set_rand_param(self):
raise NotImplementedError raise NotImplementedError
class Compose: class Compose:
def __init__(self, *tf): def __init__(self, *tf):
assert len(tf) > 0 assert len(tf) > 0
...@@ -40,16 +55,26 @@ class Compose: ...@@ -40,16 +55,26 @@ class Compose:
for tf in self.tfs: x = tf(x) for tf in self.tfs: x = tf(x)
return x return x
class Choose:
def __init__(self, *tf):
assert len(tf) > 1
self.tfs = tf
def __call__(self, *x):
idx = randi(0, len(self.tfs)-1)
return self.tfs[idx](*x)
class Scale(Transform): class Scale(Transform):
def __init__(self, scale=(0.5,1.0)): def __init__(self, scale=(0.5,1.0)):
if _istuple(scale): if _istuple(scale):
assert len(scale) == 2 assert len(scale) == 2
self.scale_range = scale #sorted(scale) self.scale_range = tuple(scale) #sorted(scale)
self.scale = scale[0] self.scale = float(scale[0])
super(Scale, self).__init__(random_state=True) super(Scale, self).__init__(random_state=True)
else: else:
super(Scale, self).__init__(random_state=False) super(Scale, self).__init__(random_state=False)
self.scale = scale self.scale = float(scale)
def _transform(self, x): def _transform(self, x):
# assert x.ndim == 3 # assert x.ndim == 3
h, w = x.shape[:2] h, w = x.shape[:2]
...@@ -61,11 +86,12 @@ class Scale(Transform): ...@@ -61,11 +86,12 @@ class Scale(Transform):
def _set_rand_param(self): def _set_rand_param(self):
self.scale = uniform(*self.scale_range) self.scale = uniform(*self.scale_range)
class DiscreteScale(Scale): class DiscreteScale(Scale):
def __init__(self, bins=(0.5, 0.75), keep_prob=0.5): def __init__(self, bins=(0.5, 0.75), keep_prob=0.5):
super(DiscreteScale, self).__init__(scale=(min(bins), 1.0)) super(DiscreteScale, self).__init__(scale=(min(bins), 1.0))
self.bins = bins self.bins = tuple(bins)
self.keep_prob = keep_prob self.keep_prob = float(keep_prob)
def _set_rand_param(self): def _set_rand_param(self):
self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins) self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins)
...@@ -117,6 +143,10 @@ class VerticalFlip(Flip): ...@@ -117,6 +143,10 @@ class VerticalFlip(Flip):
super(VerticalFlip, self).__init__(direction=flip) super(VerticalFlip, self).__init__(direction=flip)
class Rotate(Flip):
_directions = ('90', '180', '270', 'no')
class Crop(Transform): class Crop(Transform):
_inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r') _inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r')
def __init__(self, crop_size=None, bounds=None): def __init__(self, crop_size=None, bounds=None):
...@@ -148,8 +178,10 @@ class Crop(Transform): ...@@ -148,8 +178,10 @@ class Crop(Transform):
elif self.bounds == 'r': elif self.bounds == 'r':
return x[:,w//2:] return x[:,w//2:]
elif len(self.bounds) == 2: elif len(self.bounds) == 2:
assert self.crop_size < (h, w) assert self.crop_size <= (h, w)
ch, cw = self.crop_size ch, cw = self.crop_size
if (ch,cw) == (h,w):
return x
cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1]) cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
return x[cy:cy+ch, cx:cx+cw] return x[cy:cy+ch, cx:cx+cw]
else: else:
...@@ -188,6 +220,59 @@ class MSCrop(Crop): ...@@ -188,6 +220,59 @@ class MSCrop(Crop):
self.bounds = (left, top, left+cw, top+ch) self.bounds = (left, top, left+cw, top+ch)
class Shift(Transform):
def __init__(self, x_shift=(-0.0625, 0.0625), y_shift=(-0.0625, 0.0625), circular=True):
super(Shift, self).__init__(random_state=_istuple(x_shift) or _istuple(y_shift))
if _istuple(x_shift):
self.xshift_range = tuple(x_shift)
self.xshift = float(x_shift[0])
else:
self.xshift = float(x_shift)
self.xshift_range = (self.xshift, self.xshift)
if _istuple(y_shift):
self.yshift_range = tuple(y_shift)
self.yshift = float(y_shift[0])
else:
self.yshift = float(y_shift)
self.yshift_range = (self.yshift, self.yshift)
self.circular = circular
def _transform(self, im):
h, w = im.shape[:2]
xsh = -int(self.xshift*w)
ysh = -int(self.yshift*h)
if self.circular:
# Shift along the x-axis
im_shifted = np.concatenate((im[:, xsh:], im[:, :xsh]), axis=1)
# Shift along the y-axis
im_shifted = np.concatenate((im_shifted[ysh:], im_shifted[:ysh]), axis=0)
else:
zeros = np.zeros(im.shape)
im1, im2 = (zeros, im) if xsh < 0 else (im, zeros)
im_shifted = np.concatenate((im1[:, xsh:], im2[:, :xsh]), axis=1)
im1, im2 = (zeros, im_shifted) if ysh < 0 else (im_shifted, zeros)
im_shifted = np.concatenate((im1[ysh:], im2[:ysh]), axis=0)
return im_shifted
def _set_rand_param(self):
self.xshift = uniform(*self.xshift_range)
self.yshift = uniform(*self.yshift_range)
class XShift(Shift):
def __init__(self, x_shift=(-0.0625, 0.0625), circular=True):
super(XShift, self).__init__(x_shift, 0.0, circular)
class YShift(Shift):
def __init__(self, y_shift=(-0.0625, 0.0625), circular=True):
super(YShift, self).__init__(0.0, y_shift, circular)
# Color jittering and transformation # Color jittering and transformation
# The followings partially refer to https://github.com/albu/albumentations/ # The followings partially refer to https://github.com/albu/albumentations/
class _ValueTransform(Transform): class _ValueTransform(Transform):
...@@ -201,8 +286,12 @@ class _ValueTransform(Transform): ...@@ -201,8 +286,12 @@ class _ValueTransform(Transform):
def wrapper(obj, x): def wrapper(obj, x):
# # Make a copy # # Make a copy
# x = x.copy() # x = x.copy()
x = tf(obj, np.clip(x, *obj.limit)) dtype = x.dtype
return np.clip(x, *obj.limit) # The calculations are done with floating type in case of overflow
# This is a stupid yet simple way
x = tf(obj, np.clip(x.astype(np.float32), *obj.limit))
# Convert back to the original type
return np.clip(x, *obj.limit).astype(dtype)
return wrapper return wrapper
...@@ -222,7 +311,7 @@ class ColorJitter(_ValueTransform): ...@@ -222,7 +311,7 @@ class ColorJitter(_ValueTransform):
else: else:
if _istuple(shift): if _istuple(shift):
if len(shift) != _nc: if len(shift) != _nc:
raise ValueError("specify the shift value (or range) for every channel") raise ValueError("please specify the shift value (or range) for every channel.")
rs = all(_istuple(s) for s in shift) rs = all(_istuple(s) for s in shift)
self.shift = self.range = shift self.shift = self.range = shift
else: else:
...@@ -233,23 +322,20 @@ class ColorJitter(_ValueTransform): ...@@ -233,23 +322,20 @@ class ColorJitter(_ValueTransform):
self.random_state = rs self.random_state = rs
def _(x): def _(x):
return x, () return x
self.convert_to = _ self.convert_to = _
self.convert_back = _ self.convert_back = _
@_ValueTransform.keep_range @_ValueTransform.keep_range
def _transform(self, x): def _transform(self, x):
# CAUTION! x = self.convert_to(x)
# Type conversion here
x, params = self.convert_to(x)
for i, c in enumerate(self._channel): for i, c in enumerate(self._channel):
x[...,c] += self.shift[i] x[...,c] = self._clip(x[...,c]+float(self.shift[i]))
x[...,c] = self._clip(x[...,c]) x = self.convert_back(x)
x, _ = self.convert_back(x, *params)
return x return x
def _clip(self, x): def _clip(self, x):
return np.clip(x, *self.limit) return x
def _set_rand_param(self): def _set_rand_param(self):
if len(self._channel) == 1: if len(self._channel) == 1:
...@@ -262,20 +348,22 @@ class HSVShift(ColorJitter): ...@@ -262,20 +348,22 @@ class HSVShift(ColorJitter):
def __init__(self, shift, limit): def __init__(self, shift, limit):
super().__init__(shift, limit) super().__init__(shift, limit)
def _convert_to(x): def _convert_to(x):
type_x = x.dtype
x = x.astype(np.float32) x = x.astype(np.float32)
# Normalize to [0,1] # Normalize to [0,1]
x -= self.limit[0] x -= self.limit[0]
x /= self.limit_range x /= self.limit_range
x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV) x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV)
return x, (type_x,) return x
def _convert_back(x, type_x): def _convert_back(x):
x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB) x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB)
return x.astype(type_x) * self.limit_range + self.limit[0], () return x * self.limit_range + self.limit[0]
# Pack conversion methods # Pack conversion methods
self.convert_to = _convert_to self.convert_to = _convert_to
self.convert_back = _convert_back self.convert_back = _convert_back
def _clip(self, x):
raise NotImplementedError
class HueShift(HSVShift): class HueShift(HSVShift):
_channel = (0,) _channel = (0,)
...@@ -332,7 +420,7 @@ class PCAJitter(_ValueTransform): ...@@ -332,7 +420,7 @@ class PCAJitter(_ValueTransform):
old_shape = x.shape old_shape = x.shape
x = np.reshape(x, (-1,3), order='F') # For RGB x = np.reshape(x, (-1,3), order='F') # For RGB
x_mean = np.mean(x, 0) x_mean = np.mean(x, 0)
x -= x_mean x = x - x_mean
cov_x = np.cov(x, rowvar=False) cov_x = np.cov(x, rowvar=False)
eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x)) eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x))
# The eigen vectors are already unit "length" # The eigen vectors are already unit "length"
...@@ -354,9 +442,9 @@ class ContraBrightScale(_ValueTransform): ...@@ -354,9 +442,9 @@ class ContraBrightScale(_ValueTransform):
@_ValueTransform.keep_range @_ValueTransform.keep_range
def _transform(self, x): def _transform(self, x):
if self.alpha != 1: if not math.isclose(self.alpha, 1.0):
x *= self.alpha x *= self.alpha
if self.beta != 0: if not math.isclose(self.beta, 0.0):
x += self.beta*np.mean(x) x += self.beta*np.mean(x)
return x return x
...@@ -387,7 +475,7 @@ class _AddNoise(_ValueTransform): ...@@ -387,7 +475,7 @@ class _AddNoise(_ValueTransform):
def __call__(self, *args): def __call__(self, *args):
shape = args[0].shape shape = args[0].shape
if any(im.shape != shape for im in args): if any(im.shape != shape for im in args):
raise ValueError("the input images should be of same size") raise ValueError("the input images should be of same size.")
self._im_shape = shape self._im_shape = shape
return super().__call__(*args) return super().__call__(*args)
...@@ -399,16 +487,3 @@ class AddGaussNoise(_AddNoise): ...@@ -399,16 +487,3 @@ class AddGaussNoise(_AddNoise):
self.sigma = sigma self.sigma = sigma
def _set_rand_param(self): def _set_rand_param(self):
self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu
\ No newline at end of file
def __test():
a = np.arange(12).reshape((2,2,3)).astype(np.float64)
tf = Compose(BrightnessScale(), AddGaussNoise(), HueShift())
print(a[...,0])
c = tf(a)
print(c[...,0])
print(a[...,0])
if __name__ == '__main__':
__test()
...@@ -131,7 +131,7 @@ def main(): ...@@ -131,7 +131,7 @@ def main():
args = parse_args() args = parse_args()
gpc, logger = set_gpc_and_logger(args) gpc, logger = set_gpc_and_logger(args)
if exists(args.exp_config): if args.exp_config:
# Make a copy of the config file # Make a copy of the config file
cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False) cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False)
shutil.copy(args.exp_config, cfg_path) shutil.copy(args.exp_config, cfg_path)
...@@ -147,16 +147,11 @@ def main(): ...@@ -147,16 +147,11 @@ def main():
try: try:
trainer = CDTrainer(args.model, args.dataset, args.optimizer, args) trainer = CDTrainer(args.model, args.dataset, args.optimizer, args)
if args.cmd == 'train': trainer.run()
trainer.train()
elif args.cmd == 'val':
trainer.validate()
else:
pass
except BaseException as e: except BaseException as e:
import traceback import traceback
# Catch ALL kinds of exceptions # Catch ALL kinds of exceptions
logger.error(traceback.format_exc()) logger.fatal(traceback.format_exc())
exit(1) exit(1)
if __name__ == '__main__': if __name__ == '__main__':
......
import logging import logging
import os import os
import sys
from time import localtime from time import localtime
from collections import OrderedDict from collections import OrderedDict
from weakref import proxy from weakref import proxy
...@@ -17,8 +18,13 @@ class Logger: ...@@ -17,8 +18,13 @@ class Logger:
Logger._count += 1 Logger._count += 1
self._logger.setLevel(logging.DEBUG) self._logger.setLevel(logging.DEBUG)
self._err_handler = logging.StreamHandler(stream=sys.stderr)
self._err_handler.setLevel(logging.ERROR)
self._err_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT))
self._logger.addHandler(self._err_handler)
if scrn: if scrn:
self._scrn_handler = logging.StreamHandler() self._scrn_handler = logging.StreamHandler(stream=sys.stdout)
self._scrn_handler.setLevel(logging.INFO) self._scrn_handler.setLevel(logging.INFO)
self._scrn_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT)) self._scrn_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT))
self._logger.addHandler(self._scrn_handler) self._logger.addHandler(self._scrn_handler)
...@@ -50,9 +56,12 @@ class Logger: ...@@ -50,9 +56,12 @@ class Logger:
def error(self, *args, **kwargs): def error(self, *args, **kwargs):
return self._logger.error(*args, **kwargs) return self._logger.error(*args, **kwargs)
def fatal(self, *args, **kwargs):
return self._logger.critical(*args, **kwargs)
@staticmethod @staticmethod
def make_desc(counter, total, *triples): def make_desc(counter, total, *triples, opt_str=''):
desc = "[{}/{}]".format(counter, total) desc = "[{}/{}] {}".format(counter, total, opt_str)
# The three elements of each triple are # The three elements of each triple are
# (name to display, AverageMeter object, formatting string) # (name to display, AverageMeter object, formatting string)
for name, obj, fmt in triples: for name, obj, fmt in triples:
...@@ -258,6 +267,7 @@ class _Tree: ...@@ -258,6 +267,7 @@ class _Tree:
def add_node(self, path, val=None): def add_node(self, path, val=None):
if not path.strip(): if not path.strip():
raise ValueError("the path is null") raise ValueError("the path is null")
path = path.strip('/')
if val is None: if val is None:
val = self._def_val val = self._def_val
names = self.parse_path(path) names = self.parse_path(path)
...@@ -281,6 +291,8 @@ class OutPathGetter: ...@@ -281,6 +291,8 @@ class OutPathGetter:
def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs): def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs):
super().__init__() super().__init__()
self._root = root.rstrip('/') # Work robustly for multiple ending '/'s self._root = root.rstrip('/') # Work robustly for multiple ending '/'s
if len(self._root) == 0 and len(root) > 0:
self._root = '/' # In case of the system root dir
self._suffix = suffix self._suffix = suffix
self._keys = dict(log=log, out=out, weight=weight, **subs) self._keys = dict(log=log, out=out, weight=weight, **subs)
self._dir_tree = _Tree( self._dir_tree = _Tree(
......
#!/bin/bash #!/bin/bash
# Activate conda environment # # Activate conda environment
source activate $ME # source activate $ME
# Change directory # Change directory
cd src cd src
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment