Skip to content
Snippets Groups Projects
Commit 15422bd7 authored by Bobholamovic's avatar Bobholamovic
Browse files

Merge New Year Commit

parents 7f0846c1 36e94e06
Branches
Tags
No related merge requests found
...@@ -12,6 +12,7 @@ import constants ...@@ -12,6 +12,7 @@ import constants
import utils.metrics as metrics import utils.metrics as metrics
from utils.misc import R from utils.misc import R
class _Desc: class _Desc:
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
...@@ -26,15 +27,7 @@ class _Desc: ...@@ -26,15 +27,7 @@ class _Desc:
def _func_deco(func_name): def _func_deco(func_name):
def _wrapper(self, *args): def _wrapper(self, *args):
# TODO: Add key argument support return tuple(getattr(ins, func_name)(*args) for ins in self)
try:
# Dispatch type 1
ret = tuple(getattr(ins, func_name)(*args) for ins in self)
except Exception:
# Dispatch type 2
if len(args) > 1 or (len(args[0]) != len(self)): raise
ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
return ret
return _wrapper return _wrapper
...@@ -45,6 +38,16 @@ def _generator_deco(func_name): ...@@ -45,6 +38,16 @@ def _generator_deco(func_name):
return _wrapper return _wrapper
def _mark(func):
func.__marked__ = True
return func
def _unmark(func):
func.__marked__ = False
return func
# Duck typing # Duck typing
class Duck(tuple): class Duck(tuple):
__ducktype__ = object __ducktype__ = object
...@@ -60,6 +63,9 @@ class DuckMeta(type): ...@@ -60,6 +63,9 @@ class DuckMeta(type):
for k, v in getmembers(bases[0]): for k, v in getmembers(bases[0]):
if k.startswith('__'): if k.startswith('__'):
continue continue
if k in attrs and hasattr(attrs[k], '__marked__'):
if attrs[k].__marked__:
continue
if isgeneratorfunction(v): if isgeneratorfunction(v):
attrs[k] = _generator_deco(k) attrs[k] = _generator_deco(k)
elif isfunction(v): elif isfunction(v):
...@@ -71,14 +77,48 @@ class DuckMeta(type): ...@@ -71,14 +77,48 @@ class DuckMeta(type):
class DuckModel(nn.Module, metaclass=DuckMeta): class DuckModel(nn.Module, metaclass=DuckMeta):
pass DELIM = ':'
@_mark
def load_state_dict(self, state_dict):
dicts = [dict() for _ in range(len(self))]
for k, v in state_dict.items():
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k)
i = int(i)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
@_mark
def state_dict(self):
dict_ = dict()
for i, ins in enumerate(self):
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
return dict_
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta): class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
DELIM = ':'
@property @property
def param_groups(self): def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self)) return list(chain.from_iterable(ins.param_groups for ins in self))
@_mark
def state_dict(self):
dict_ = dict()
for i, ins in enumerate(self):
dict_.update({self.DELIM.join([str(i), key]):val for key, val in ins.state_dict().items()})
return dict_
@_mark
def load_state_dict(self, state_dict):
dicts = [dict() for _ in range(len(self))]
for k, v in state_dict.items():
i, *k = k.split(self.DELIM)
k = self.DELIM.join(k)
i = int(i)
dicts[i][k] = v
for i in range(len(self)): self[i].load_state_dict(dicts[i])
class DuckCriterion(nn.Module, metaclass=DuckMeta): class DuckCriterion(nn.Module, metaclass=DuckMeta):
pass pass
...@@ -112,7 +152,8 @@ def single_model_factory(model_name, C): ...@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
def single_optim_factory(optim_name, params, C): def single_optim_factory(optim_name, params, C):
name = optim_name.strip().upper() optim_name = optim_name.strip()
name = optim_name.upper()
if name == 'ADAM': if name == 'ADAM':
return torch.optim.Adam( return torch.optim.Adam(
params, params,
...@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C): ...@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
def single_critn_factory(critn_name, C): def single_critn_factory(critn_name, C):
import losses import losses
critn_name = critn_name.strip()
try: try:
criterion, params = { criterion, params = {
'L1': (nn.L1Loss, ()), 'L1': (nn.L1Loss, ()),
...@@ -145,6 +187,23 @@ def single_critn_factory(critn_name, C): ...@@ -145,6 +187,23 @@ def single_critn_factory(critn_name, C):
raise NotImplementedError("{} is not a supported criterion type".format(critn_name)) raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def _get_basic_configs(ds_name, C):
if ds_name == 'OSCD':
return dict(
root = constants.IMDB_OSCD
)
elif ds_name.startswith('AC'):
return dict(
root = constants.IMDB_AirChange
)
elif ds_name.startswith('Lebedev'):
return dict(
root = constants.IMDB_LEBEDEV
)
else:
return dict()
def single_train_ds_factory(ds_name, C): def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip from data.augmentation import Compose, Crop, Flip
ds_name = ds_name.strip() ds_name = ds_name.strip()
...@@ -155,22 +214,14 @@ def single_train_ds_factory(ds_name, C): ...@@ -155,22 +214,14 @@ def single_train_ds_factory(ds_name, C):
transforms=(Compose(Crop(C.crop_size), Flip()), None, None), transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats repeats=C.repeats
) )
if ds_name == 'OSCD':
configs.update( # Update some common configurations
dict( configs.update(_get_basic_configs(ds_name, C))
root = constants.IMDB_OSCD
) # Set phase-specific ones
) if ds_name == 'Lebedev':
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AIRCHANGE
)
)
elif ds_name == 'Lebedev':
configs.update( configs.update(
dict( dict(
root = constants.IMDB_LEBEDEV,
subsets = ('real',) subsets = ('real',)
) )
) )
...@@ -197,22 +248,14 @@ def single_val_ds_factory(ds_name, C): ...@@ -197,22 +248,14 @@ def single_val_ds_factory(ds_name, C):
transforms=(None, None, None), transforms=(None, None, None),
repeats=1 repeats=1
) )
if ds_name == 'OSCD':
configs.update( # Update some common configurations
dict( configs.update(_get_basic_configs(ds_name, C))
root = constants.IMDB_OSCD
) # Set phase-specific ones
) if ds_name == 'Lebedev':
elif ds_name.startswith('AC'):
configs.update(
dict(
root = constants.IMDB_AIRCHANGE
)
)
elif ds_name == 'Lebedev':
configs.update( configs.update(
dict( dict(
root = constants.IMDB_LEBEDEV,
subsets = ('real',) subsets = ('real',)
) )
) )
...@@ -243,12 +286,24 @@ def model_factory(model_names, C): ...@@ -243,12 +286,24 @@ def model_factory(model_names, C):
return single_model_factory(model_names, C) return single_model_factory(model_names, C)
def optim_factory(optim_names, params, C): def optim_factory(optim_names, models, C):
name_list = _parse_input_names(optim_names) name_list = _parse_input_names(optim_names)
if len(name_list) > 1: num_models = len(models) if isinstance(models, DuckModel) else 1
return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list)) if len(name_list) != num_models:
raise ValueError("the number of optimizers does not match the number of models")
if num_models > 1:
optims = []
for name, model in zip(name_list, models):
param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
optims.append(single_optim_factory(name, param_groups, C))
return DuckOptimizer(*optims)
else: else:
return single_optim_factory(optim_names, params, C) return single_optim_factory(
optim_names,
[{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()],
C
)
def critn_factory(critn_names, C): def critn_factory(critn_names, C):
......
...@@ -33,8 +33,8 @@ class Trainer: ...@@ -33,8 +33,8 @@ class Trainer:
self.lr = float(context.lr) self.lr = float(context.lr)
self.save = context.save_on or context.out_dir self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir self.out_dir = context.out_dir
self.trace_freq = context.trace_freq self.trace_freq = int(context.trace_freq)
self.device = context.device self.device = torch.device(context.device)
self.suffix_off = context.suffix_off self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()): for k, v in sorted(self.ctx.items()):
...@@ -44,7 +44,7 @@ class Trainer: ...@@ -44,7 +44,7 @@ class Trainer:
self.model.to(self.device) self.model.to(self.device)
self.criterion = critn_factory(criterion, context) self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device) self.criterion.to(self.device)
self.optimizer = optim_factory(optimizer, self.model.parameters(), context) self.optimizer = optim_factory(optimizer, self.model, context)
self.metrics = metric_factory(context.metrics, context) self.metrics = metric_factory(context.metrics, context)
self.train_loader = data_factory(dataset, 'train', context) self.train_loader = data_factory(dataset, 'train', context)
...@@ -74,6 +74,10 @@ class Trainer: ...@@ -74,6 +74,10 @@ class Trainer:
# Train for one epoch # Train for one epoch
self.train_epoch() self.train_epoch()
# Clear the history of metric objects
for m in self.metrics:
m.reset()
# Evaluate the model on validation set # Evaluate the model on validation set
self.logger.show_nl("Validate") self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save) acc = self.validate_epoch(epoch=epoch, store=self.save)
...@@ -255,7 +259,7 @@ class CDTrainer(Trainer): ...@@ -255,7 +259,7 @@ class CDTrainer(Trainer):
losses.update(loss.item(), n=self.batch_size) losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays # Convert to numpy arrays
CM = to_array(torch.argmax(prob, 1)).astype('uint8') CM = to_array(torch.argmax(prob[0], 0)).astype('uint8')
label = to_array(label[0]).astype('uint8') label = to_array(label[0]).astype('uint8')
for m in self.metrics: for m in self.metrics:
m.update(CM, label) m.update(CM, label)
...@@ -272,6 +276,6 @@ class CDTrainer(Trainer): ...@@ -272,6 +276,6 @@ class CDTrainer(Trainer):
self.logger.dump(desc) self.logger.dump(desc)
if store: if store:
self.save_image(name[0], (CM*255).squeeze(-1), epoch) self.save_image(name[0], CM*255, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc) return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
from functools import partial
import numpy as np
from sklearn import metrics from sklearn import metrics
class AverageMeter: class AverageMeter:
def __init__(self, callback=None): def __init__(self, callback=None):
super().__init__() super().__init__()
self.callback = callback if callback is not None:
self.compute = callback
self.reset() self.reset()
def compute(self, *args): def compute(self, *args):
if self.callback is not None: if len(args) == 1:
return self.callback(*args)
elif len(args) == 1:
return args[0] return args[0]
else: else:
raise NotImplementedError raise NotImplementedError
def reset(self): def reset(self):
self.val = 0.0 self.val = 0
self.avg = 0.0 self.avg = 0
self.sum = 0.0 self.sum = 0
self.count = 0 self.count = 0
def update(self, *args, n=1): def update(self, *args, n=1):
...@@ -27,36 +29,75 @@ class AverageMeter: ...@@ -27,36 +29,75 @@ class AverageMeter:
self.count += n self.count += n
self.avg = self.sum / self.count self.avg = self.sum / self.count
def __repr__(self):
return 'val: {} avg: {} cnt: {}'.format(self.val, self.avg, self.count)
# These metrics only for numpy arrays
class Metric(AverageMeter): class Metric(AverageMeter):
__name__ = 'Metric' __name__ = 'Metric'
def __init__(self, callback, **configs): def __init__(self, n_classes=2, mode='accum', reduction='binary'):
super().__init__(callback) super().__init__(None)
self.configs = configs self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)))
assert mode in ('accum', 'separ')
self.mode = mode
assert reduction in ('mean', 'none', 'binary')
if reduction == 'binary' and n_classes != 2:
raise ValueError("binary reduction only works in 2-class cases")
self.reduction = reduction
def _compute(self, cm):
raise NotImplementedError
def compute(self, cm):
if self.reduction == 'none':
# Do not reduce size
return self._compute(cm)
elif self.reduction == 'mean':
# Micro averaging
return self._compute(cm).mean()
else:
# The pos_class be 1
return self._compute(cm)[1]
def update(self, pred, true, n=1):
# Note that this is no thread-safe
self._cm.update(true.ravel(), pred.ravel())
if self.mode == 'accum':
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
else:
raise NotImplementedError
super().update(cm, n=n)
def compute(self, pred, true): def __repr__(self):
return self.callback(true.ravel(), pred.ravel(), **self.configs) return self.__name__+' '+super().__repr__()
class Precision(Metric): class Precision(Metric):
__name__ = 'Prec.' __name__ = 'Prec.'
def __init__(self, **configs): def _compute(self, cm):
super().__init__(metrics.precision_score, **configs) return np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
class Recall(Metric): class Recall(Metric):
__name__ = 'Recall' __name__ = 'Recall'
def __init__(self, **configs): def _compute(self, cm):
super().__init__(metrics.recall_score, **configs) return np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
class Accuracy(Metric): class Accuracy(Metric):
__name__ = 'OA' __name__ = 'OA'
def __init__(self, **configs): def __init__(self, n_classes=2, mode='accum'):
super().__init__(metrics.accuracy_score, **configs) super().__init__(n_classes=n_classes, mode=mode, reduction='none')
def _compute(self, cm):
return np.nan_to_num(np.diag(cm).sum()/cm.sum())
class F1Score(Metric): class F1Score(Metric):
__name__ = 'F1' __name__ = 'F1'
def __init__(self, **configs): def _compute(self, cm):
super().__init__(metrics.f1_score, **configs) prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
\ No newline at end of file recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
return np.nan_to_num(2*(prec*recall) / (prec+recall))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment