Skip to content
Snippets Groups Projects
Commit a36e43ac authored by Bobholamovic's avatar Bobholamovic
Browse files

Refactor utils/metrics.py

parent eb48f74c
No related branches found
No related tags found
1 merge request!2Update outdated code
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
......@@ -53,11 +53,13 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta):
raise FileNotFoundError
# phase stands for the working mode,
# 'train' for training and 'eval' for validating or testing.
assert phase in ('train', 'eval')
if phase not in ('train', 'eval'):
raise ValueError("Invalid phase")
# subset is the sub-dataset to use.
# For some datasets there are three subsets,
# while for others there are only train and test(val).
assert subset in ('train', 'val', 'test')
if subset not in ('train', 'val', 'test'):
raise ValueError("Invalid subset")
self.phase = phase
self.transforms = transforms
self.repeats = int(repeats)
......
......@@ -57,7 +57,8 @@ class Duck(Sequence, ABC):
class DuckMeta(ABCMeta):
def __new__(cls, name, bases, attrs):
assert len(bases) == 1 # Multiple inheritance is not yet supported.
if len(bases) > 1:
raise NotImplementedError("Multiple inheritance is not yet supported.")
members = dict(getmembers(bases[0])) # Trade space for time
for k in attrs['__ava__']:
......
......@@ -225,7 +225,8 @@ class _Tree:
r"""
This is different from a travasal in that this search allows early stop.
"""
assert mode in ('name', 'path', 'val')
if mode not in ('name', 'path', 'val'):
raise NotImplementedError("Invalid mode")
if mode == 'path':
nodes = self.parse_path(tar)
root = self.root
......
......@@ -61,7 +61,8 @@ class Transform:
class Compose:
def __init__(self, *tfs):
assert len(tfs) > 0
if len(tfs) == 0:
raise ValueError("The transformation sequence should contain at least one element.")
self.tfs = tfs
def __call__(self, *x):
......@@ -80,7 +81,8 @@ class Compose:
class Choose:
def __init__(self, *tfs):
assert len(tfs) > 1
if len(tfs) < 2:
raise ValueError("The transformation sequence should contain at least two elements.")
self.tfs = tfs
def __call__(self, *x):
......@@ -94,7 +96,8 @@ class Scale(Transform):
def __init__(self, scale=(0.5, 1.0), prob_apply=1.0):
super(Scale, self).__init__(rand_state=_isseq(scale), prob_apply=prob_apply)
if _isseq(scale):
assert len(scale) == 2
if len(scale) != 2:
raise ValueError
self.scale = tuple(scale)
else:
self.scale = float(scale)
......@@ -136,7 +139,8 @@ class FlipRotate(Transform):
def __init__(self, direction=None, prob_apply=1.0):
super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply)
if direction is not None:
assert direction in self._DIRECTIONS
if direction not in self._DIRECTIONS:
raise ValueError("Invalid direction")
self.direction = direction
def _transform(self, x, params):
......@@ -206,7 +210,8 @@ class Crop(Transform):
_no_bounds = (bounds is None)
super(Crop, self).__init__(rand_state=_no_bounds, prob_apply=prob_apply)
if _no_bounds:
assert crop_size is not None
if crop_size is None:
raise TypeError("crop_size should be specified if bounds is set to None.")
else:
if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)):
raise ValueError("Invalid bounds")
......@@ -237,7 +242,8 @@ class Crop(Transform):
left, top, right, lower = bounds
return x[top:lower, left:right]
else:
assert self.crop_size <= (h, w)
if self.crop_size > (h, w):
raise ValueError("Image size is smaller than cropping size.")
ch, cw = self.crop_size
if (ch,cw) == (h,w):
return x
......@@ -262,7 +268,8 @@ class CenterCrop(Transform):
ch, cw = self.crop_size
assert ch<=h and cw<=w
if ch>h or cw>w:
raise ValueError("Image size is smaller than cropping size.")
offset_up = (h-ch)//2
offset_left = (w-cw)//2
......
......@@ -8,11 +8,11 @@ class AverageMeter:
def __init__(self, callback=None, calc_avg=True):
super().__init__()
if callback is not None:
self.compute = callback
self.calc_avg = calc_avg
self.calculate = callback
self.calc_avg = bool(calc_avg)
self.reset()
def compute(self, *args):
def calculate(self, *args):
if len(args) == 1:
return args[0]
else:
......@@ -25,58 +25,59 @@ class AverageMeter:
if self.calc_avg:
self.avg = 0
for attr in filter(lambda a: not a.startswith('__'), dir(self)):
obj = getattr(self, attr)
if isinstance(obj, AverageMeter):
AverageMeter.reset(obj)
def update(self, *args, n=1):
self.val = self.compute(*args)
self.val = self.calculate(*args)
self.sum += self.val * n
self.count += n
if self.calc_avg:
self.avg = self.sum / self.count
def __repr__(self):
if self.calc_avg:
return "val: {} avg: {} cnt: {}".format(self.val, self.avg, self.count)
else:
return "val: {} cnt: {}".format(self.val, self.count)
# These metrics only for numpy arrays
class Metric(AverageMeter):
__name__ = 'Metric'
def __init__(self, n_classes=2, mode='separ', reduction='binary'):
assert mode in ('accum', 'separ')
assert reduction in ('mean', 'none', 'binary')
super().__init__(None, mode!='accum')
if mode not in ('accum', 'separ'):
raise ValueError("Invalid working mode")
if reduction not in ('mean', 'none', 'binary'):
raise ValueError("Invalid reduction type")
self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False)
self.mode = mode
if reduction == 'binary' and n_classes != 2:
raise ValueError("Binary reduction only works in 2-class cases.")
self.reduction = reduction
super().__init__(None, mode!='accum')
def _compute(self, cm):
def _calculate_metric(self, cm):
raise NotImplementedError
def compute(self, cm):
def calculate(self, pred, true, n=1):
self._cm.update(true.ravel(), pred.ravel())
if self.mode == 'accum':
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
if self.reduction == 'none':
# Do not reduce size
return self._compute(cm)
return self._calculate_metric(cm)
elif self.reduction == 'mean':
# Micro averaging
return self._compute(cm).mean()
else:
return self._calculate_metric(cm).mean()
elif self.reduction == 'binary':
# The pos_class be 1
return self._compute(cm)[1]
return self._calculate_metric(cm)[1]
def update(self, pred, true, n=1):
self._cm.update(true.ravel(), pred.ravel())
if self.mode == 'accum':
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
else:
raise NotImplementedError
super().update(cm, n=n)
def reset(self):
super().reset()
# Reset the confusion matrix
self._cm.reset()
def __repr__(self):
return self.__name__+" "+super().__repr__()
......@@ -84,13 +85,13 @@ class Metric(AverageMeter):
class Precision(Metric):
__name__ = 'Prec.'
def _compute(self, cm):
def _calculate_metric(self, cm):
return np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
class Recall(Metric):
__name__ = 'Recall'
def _compute(self, cm):
def _calculate_metric(self, cm):
return np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
......@@ -98,13 +99,14 @@ class Accuracy(Metric):
__name__ = 'OA'
def __init__(self, n_classes=2, mode='separ'):
super().__init__(n_classes=n_classes, mode=mode, reduction='none')
def _compute(self, cm):
def _calculate_metric(self, cm):
return np.nan_to_num(np.diag(cm).sum()/cm.sum())
class F1Score(Metric):
__name__ = 'F1'
def _compute(self, cm):
def _calculate_metric(self, cm):
prec = np.nan_to_num(np.diag(cm)/cm.sum(axis=0))
recall = np.nan_to_num(np.diag(cm)/cm.sum(axis=1))
return np.nan_to_num(2*(prec*recall) / (prec+recall))
\ No newline at end of file
......@@ -38,7 +38,8 @@ class HookHelper:
self.out_dict = weakref.WeakValueDictionary(out_dict)
self._handles = []
assert hook_type in ('forward_in', 'forward_out', 'backward_out')
if hook_type not in ('forward_in', 'forward_out', 'backward_out'):
raise NotImplementedError("Hook type is not implemented.")
def _proto_hook(x, entry):
# x should be a tensor or a tuple
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment