diff --git a/src/core/data.py b/src/core/data.py index b1fcbcde851b31e441a21404260db3e208c647db..13a2ef4b30f0aa9590c0b7701ad56772eb467849 100644 --- a/src/core/data.py +++ b/src/core/data.py @@ -53,11 +53,13 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta): raise FileNotFoundError # phase stands for the working mode, # 'train' for training and 'eval' for validating or testing. - assert phase in ('train', 'eval') + if phase not in ('train', 'eval'): + raise ValueError("Invalid phase") # subset is the sub-dataset to use. # For some datasets there are three subsets, # while for others there are only train and test(val). - assert subset in ('train', 'val', 'test') + if subset not in ('train', 'val', 'test'): + raise ValueError("Invalid subset") self.phase = phase self.transforms = transforms self.repeats = int(repeats) diff --git a/src/core/factories.py b/src/core/factories.py index c39e4268a448f5a590cf49d6cc4604df28452856..a0ff531d013c06e98c27f2d3be86b8646aae0492 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -57,7 +57,8 @@ class Duck(Sequence, ABC): class DuckMeta(ABCMeta): def __new__(cls, name, bases, attrs): - assert len(bases) == 1 # Multiple inheritance is not yet supported. + if len(bases) > 1: + raise NotImplementedError("Multiple inheritance is not yet supported.") members = dict(getmembers(bases[0])) # Trade space for time for k in attrs['__ava__']: diff --git a/src/core/misc.py b/src/core/misc.py index 20122ae9fa123695878477e01b261088c317d312..ae4b4287cbd241749e8c3fc3f1f2981ebff1f74d 100644 --- a/src/core/misc.py +++ b/src/core/misc.py @@ -225,7 +225,8 @@ class _Tree: r""" This is different from a travasal in that this search allows early stop. """ - assert mode in ('name', 'path', 'val') + if mode not in ('name', 'path', 'val'): + raise NotImplementedError("Invalid mode") if mode == 'path': nodes = self.parse_path(tar) root = self.root diff --git a/src/data/augmentations.py b/src/data/augmentations.py index 473372326f94ef3ef1a01ebbf6d9aea32cd9e92f..826182944dea5bc6dc6d4a7cc48d151269177c46 100644 --- a/src/data/augmentations.py +++ b/src/data/augmentations.py @@ -61,7 +61,8 @@ class Transform: class Compose: def __init__(self, *tfs): - assert len(tfs) > 0 + if len(tfs) == 0: + raise ValueError("The transformation sequence should contain at least one element.") self.tfs = tfs def __call__(self, *x): @@ -80,7 +81,8 @@ class Compose: class Choose: def __init__(self, *tfs): - assert len(tfs) > 1 + if len(tfs) < 2: + raise ValueError("The transformation sequence should contain at least two elements.") self.tfs = tfs def __call__(self, *x): @@ -94,7 +96,8 @@ class Scale(Transform): def __init__(self, scale=(0.5, 1.0), prob_apply=1.0): super(Scale, self).__init__(rand_state=_isseq(scale), prob_apply=prob_apply) if _isseq(scale): - assert len(scale) == 2 + if len(scale) != 2: + raise ValueError self.scale = tuple(scale) else: self.scale = float(scale) @@ -136,7 +139,8 @@ class FlipRotate(Transform): def __init__(self, direction=None, prob_apply=1.0): super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply) if direction is not None: - assert direction in self._DIRECTIONS + if direction not in self._DIRECTIONS: + raise ValueError("Invalid direction") self.direction = direction def _transform(self, x, params): @@ -206,7 +210,8 @@ class Crop(Transform): _no_bounds = (bounds is None) super(Crop, self).__init__(rand_state=_no_bounds, prob_apply=prob_apply) if _no_bounds: - assert crop_size is not None + if crop_size is None: + raise TypeError("crop_size should be specified if bounds is set to None.") else: if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)): raise ValueError("Invalid bounds") @@ -237,7 +242,8 @@ class Crop(Transform): left, top, right, lower = bounds return x[top:lower, left:right] else: - assert self.crop_size <= (h, w) + if self.crop_size > (h, w): + raise ValueError("Image size is smaller than cropping size.") ch, cw = self.crop_size if (ch,cw) == (h,w): return x @@ -262,7 +268,8 @@ class CenterCrop(Transform): ch, cw = self.crop_size - assert ch<=h and cw<=w + if ch>h or cw>w: + raise ValueError("Image size is smaller than cropping size.") offset_up = (h-ch)//2 offset_left = (w-cw)//2 diff --git a/src/utils/metrics.py b/src/utils/metrics.py index f4668b254334a30d53ff92f7a2a8e280cbb709fe..291151e62ff723f8f7524bbef94a0288e45fb7b9 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -8,11 +8,11 @@ class AverageMeter: def __init__(self, callback=None, calc_avg=True): super().__init__() if callback is not None: - self.compute = callback - self.calc_avg = calc_avg + self.calculate = callback + self.calc_avg = bool(calc_avg) self.reset() - def compute(self, *args): + def calculate(self, *args): if len(args) == 1: return args[0] else: @@ -25,58 +25,59 @@ class AverageMeter: if self.calc_avg: self.avg = 0 - for attr in filter(lambda a: not a.startswith('__'), dir(self)): - obj = getattr(self, attr) - if isinstance(obj, AverageMeter): - AverageMeter.reset(obj) - def update(self, *args, n=1): - self.val = self.compute(*args) + self.val = self.calculate(*args) self.sum += self.val * n self.count += n if self.calc_avg: self.avg = self.sum / self.count def __repr__(self): - return "val: {} avg: {} cnt: {}".format(self.val, self.avg, self.count) + if self.calc_avg: + return "val: {} avg: {} cnt: {}".format(self.val, self.avg, self.count) + else: + return "val: {} cnt: {}".format(self.val, self.count) # These metrics only for numpy arrays class Metric(AverageMeter): __name__ = 'Metric' def __init__(self, n_classes=2, mode='separ', reduction='binary'): - assert mode in ('accum', 'separ') - assert reduction in ('mean', 'none', 'binary') - super().__init__(None, mode!='accum') + if mode not in ('accum', 'separ'): + raise ValueError("Invalid working mode") + if reduction not in ('mean', 'none', 'binary'): + raise ValueError("Invalid reduction type") self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False) self.mode = mode if reduction == 'binary' and n_classes != 2: raise ValueError("Binary reduction only works in 2-class cases.") self.reduction = reduction + super().__init__(None, mode!='accum') - def _compute(self, cm): + def _calculate_metric(self, cm): raise NotImplementedError - def compute(self, cm): + def calculate(self, pred, true, n=1): + self._cm.update(true.ravel(), pred.ravel()) + if self.mode == 'accum': + cm = self._cm.sum + elif self.mode == 'separ': + cm = self._cm.val + if self.reduction == 'none': # Do not reduce size - return self._compute(cm) + return self._calculate_metric(cm) elif self.reduction == 'mean': # Micro averaging - return self._compute(cm).mean() - else: + return self._calculate_metric(cm).mean() + elif self.reduction == 'binary': # The pos_class be 1 - return self._compute(cm)[1] + return self._calculate_metric(cm)[1] - def update(self, pred, true, n=1): - self._cm.update(true.ravel(), pred.ravel()) - if self.mode == 'accum': - cm = self._cm.sum - elif self.mode == 'separ': - cm = self._cm.val - else: - raise NotImplementedError - super().update(cm, n=n) + def reset(self): + super().reset() + # Reset the confusion matrix + self._cm.reset() def __repr__(self): return self.__name__+" "+super().__repr__() @@ -84,13 +85,13 @@ class Metric(AverageMeter): class Precision(Metric): __name__ = 'Prec.' - def _compute(self, cm): + def _calculate_metric(self, cm): return np.nan_to_num(np.diag(cm)/cm.sum(axis=0)) class Recall(Metric): __name__ = 'Recall' - def _compute(self, cm): + def _calculate_metric(self, cm): return np.nan_to_num(np.diag(cm)/cm.sum(axis=1)) @@ -98,13 +99,14 @@ class Accuracy(Metric): __name__ = 'OA' def __init__(self, n_classes=2, mode='separ'): super().__init__(n_classes=n_classes, mode=mode, reduction='none') - def _compute(self, cm): + + def _calculate_metric(self, cm): return np.nan_to_num(np.diag(cm).sum()/cm.sum()) class F1Score(Metric): __name__ = 'F1' - def _compute(self, cm): + def _calculate_metric(self, cm): prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0)) recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1)) return np.nan_to_num(2*(prec*recall) / (prec+recall)) \ No newline at end of file diff --git a/src/utils/utils.py b/src/utils/utils.py index eb08225cd052702a012c2c663488af88b52691ec..8e0a7c77880c4f6993f44861ee7b1bdf5249ecc7 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -19,7 +19,7 @@ def mod_crop(blob, N): nh = h - h % N nw = w - w % N return blob[..., :nh, :nw] - + class HookHelper: def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'): @@ -38,7 +38,8 @@ class HookHelper: self.out_dict = weakref.WeakValueDictionary(out_dict) self._handles = [] - assert hook_type in ('forward_in', 'forward_out', 'backward_out') + if hook_type not in ('forward_in', 'forward_out', 'backward_out'): + raise NotImplementedError("Hook type is not implemented.") def _proto_hook(x, entry): # x should be a tensor or a tuple