Skip to content
Snippets Groups Projects
Commit 3a95fb80 authored by Bobholamovic's avatar Bobholamovic
Browse files

Refactor framework

parent f68ed034
Branches
No related tags found
1 merge request!2Update outdated code
......@@ -53,16 +53,16 @@ class DatasetBase(data.Dataset, metaclass=ABCMeta):
raise FileNotFoundError
# phase stands for the working mode,
# 'train' for training and 'eval' for validating or testing.
if phase not in ('train', 'eval'):
raise ValueError("Invalid phase")
# if phase not in ('train', 'eval'):
# raise ValueError("Invalid phase")
# subset is the sub-dataset to use.
# For some datasets there are three subsets,
# while for others there are only train and test(val).
if subset not in ('train', 'val', 'test'):
raise ValueError("Invalid subset")
# if subset not in ('train', 'val', 'test'):
# raise ValueError("Invalid subset")
self.phase = phase
self.transforms = transforms
self.repeats = int(repeats)
self.repeats = repeats
# Use 'train' subset during training.
self.subset = 'train' if self.phase == 'train' else subset
......
......@@ -31,8 +31,8 @@ def _isseq(x): return isinstance(x, (tuple, list))
class Transform:
def __init__(self, rand_state=False, prob_apply=1.0):
self._rand_state = bool(rand_state)
self.prob_apply = float(prob_apply)
self._rand_state = rand_state
self.prob_apply = prob_apply
def _transform(self, x, params):
raise NotImplementedError
......@@ -100,7 +100,7 @@ class Scale(Transform):
raise ValueError
self.scale = tuple(scale)
else:
self.scale = float(scale)
self.scale = scale
def _transform(self, x, params):
if self._rand_state:
......@@ -138,10 +138,7 @@ class FlipRotate(Transform):
_DIRECTIONS = ('ud', 'lr', '90', '180', '270')
def __init__(self, direction=None, prob_apply=1.0):
super(FlipRotate, self).__init__(rand_state=(direction is None), prob_apply=prob_apply)
if direction is not None:
if direction not in self._DIRECTIONS:
raise ValueError("Invalid direction")
self.direction = direction
self.direction = direction
def _transform(self, x, params):
if self._rand_state:
......@@ -212,9 +209,6 @@ class Crop(Transform):
if _no_bounds:
if crop_size is None:
raise TypeError("crop_size should be specified if bounds is set to None.")
else:
if not((_isseq(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._INNER_BOUNDS)):
raise ValueError("Invalid bounds")
self.bounds = bounds
self.crop_size = crop_size if _isseq(crop_size) else (crop_size, crop_size)
......@@ -287,12 +281,12 @@ class Shift(Transform):
if _isseq(xshift):
self.xshift = tuple(xshift)
else:
self.xshift = float(xshift)
self.xshift = xshift
if _isseq(yshift):
self.yshift = tuple(yshift)
else:
self.yshift = float(yshift)
self.yshift = yshift
self.circular = circular
......@@ -368,12 +362,12 @@ class ContrastBrightScale(_ValueTransform):
if _isseq(alpha):
self.alpha = tuple(alpha)
else:
self.alpha = float(alpha)
self.alpha = alpha
if _isseq(beta):
self.beta = tuple(beta)
else:
self.beta = float(beta)
self.beta = beta
@_ValueTransform.keep_range
def _transform(self, x, params):
......@@ -406,8 +400,8 @@ class BrightnessScale(ContrastBrightScale):
class AddGaussNoise(_ValueTransform):
def __init__(self, mu=0.0, sigma=0.1, prob_apply=1.0, limit=(0, 255)):
super().__init__(True, prob_apply, limit)
self.mu = float(mu)
self.sigma = float(sigma)
self.mu = mu
self.sigma = sigma
@_ValueTransform.keep_range
def _transform(self, x, params):
......
......@@ -58,7 +58,7 @@ class CDTrainer(Trainer):
self.out_dir = self.ctx['out_dir']
self.save = (self.ctx['save_on'] or self.out_dir) and not self.debug
self.val_iters = float(self.ctx['val_iters'])
self.val_iters = self.ctx['val_iters']
def init_learning_rate(self):
# Set learning rate adjustment strategy
......
......@@ -9,7 +9,7 @@ class AverageMeter:
super().__init__()
if callback is not None:
self.calculate = callback
self.calc_avg = bool(calc_avg)
self.calc_avg = calc_avg
self.reset()
def calculate(self, *args):
......@@ -43,10 +43,6 @@ class AverageMeter:
class Metric(AverageMeter):
__name__ = 'Metric'
def __init__(self, n_classes=2, mode='separ', reduction='binary'):
if mode not in ('accum', 'separ'):
raise ValueError("Invalid working mode")
if reduction not in ('mean', 'none', 'binary'):
raise ValueError("Invalid reduction type")
self._cm = AverageMeter(partial(metrics.confusion_matrix, labels=np.arange(n_classes)), False)
self.mode = mode
if reduction == 'binary' and n_classes != 2:
......@@ -63,6 +59,8 @@ class Metric(AverageMeter):
cm = self._cm.sum
elif self.mode == 'separ':
cm = self._cm.val
else:
raise ValueError("Invalid working mode")
if self.reduction == 'none':
# Do not reduce size
......@@ -73,6 +71,8 @@ class Metric(AverageMeter):
elif self.reduction == 'binary':
# The pos_class be 1
return self._calculate_metric(cm)[1]
else:
raise ValueError("Invalid reduction type")
def reset(self):
super().reset()
......
......@@ -55,9 +55,6 @@ class HookHelper:
self.fetch_dict = fetch_dict
self.out_dict = out_dict
self._handles = []
if hook_type not in ('forward_in', 'forward_out', 'backward'):
raise NotImplementedError("Hook type is not implemented.")
self.hook_type = hook_type
def __enter__(self):
......@@ -106,7 +103,7 @@ class HookHelper:
)
)
else:
raise NotImplementedError
raise NotImplementedError("Hook type is not implemented.")
def __exit__(self, exc_type, exc_val, ext_tb):
for handle in self._handles:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment