diff --git a/src/utils/utils.py b/src/utils/utils.py index 8e0a7c77880c4f6993f44861ee7b1bdf5249ecc7..6bdadb11b67d6030c3da6338ed023c1f6ec89f3b 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,5 +1,5 @@ import math -import weakref +from collections import OrderedDict import torch import numpy as np @@ -21,61 +21,92 @@ def mod_crop(blob, N): return blob[..., :nh, :nw] +class FeatureContainer: + r"""A simple wrapper for OrderedDict.""" + def __init__(self): + self._dict = OrderedDict() + + def __setitem__(self, key, val): + if key not in self._dict: + self._dict[key] = list() + self._dict[key].append(val) + + def __getitem__(self, key): + return self._dict[key] + + def __repr__(self): + return self._dict.__repr__() + + def items(self): + return self._dict.items() + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + class HookHelper: def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'): - self.model = weakref.proxy(model) + # XXX: A HookHelper object should only be used as a context manager and should not + # persist in memory since it may keep references to some very large objects. + self.model = model self.fetch_dict = fetch_dict - # Subclass the built-in list to make it weak referenceable - class _list(list): - pass - for entry in self.fetch_dict.values(): - # entry is expected to be a string or a non-nested tuple - if isinstance(entry, tuple): - for key in entry: - out_dict[key] = _list() - else: - out_dict[entry] = _list() - self.out_dict = weakref.WeakValueDictionary(out_dict) + self.out_dict = out_dict self._handles = [] - if hook_type not in ('forward_in', 'forward_out', 'backward_out'): + if hook_type not in ('forward_in', 'forward_out', 'backward'): raise NotImplementedError("Hook type is not implemented.") + self.hook_type = hook_type - def _proto_hook(x, entry): - # x should be a tensor or a tuple + def __enter__(self): + def _proto_forward_hook(x, entry): + # x should be a tensor or a tuple; + # entry is expected to be a string or a non-nested tuple. if isinstance(entry, tuple): for key, f in zip(entry, x): - self.out_dict[key].append(f.detach().clone()) + self.out_dict[key] = f.data.clone() else: - self.out_dict[entry].append(x.detach().clone()) - - def _forward_in_hook(m, x, y, entry): - # x is a tuple - return _proto_hook(x[0] if len(x)==1 else x, entry) - - def _forward_out_hook(m, x, y, entry): - # y is a tensor or a tuple - return _proto_hook(y, entry) - - def _backward_out_hook(m, grad_in, grad_out, entry): - # grad_out is a tuple - return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry) + self.out_dict[entry] = x.data.clone() - self._hook_func, self._reg_func_name = { - 'forward_in': (_forward_in_hook, 'register_forward_hook'), - 'forward_out': (_forward_out_hook, 'register_forward_hook'), - 'backward_out': (_backward_out_hook, 'register_backward_hook'), - }[hook_type] - - def __enter__(self): - for name, module in self.model.named_modules(): - if name in self.fetch_dict: - entry = self.fetch_dict[name] - self._handles.append( - getattr(module, self._reg_func_name)( - lambda *args, entry=entry: self._hook_func(*args, entry=entry) + if self.hook_type == 'forward_in': + # NOTE: Register forward hooks for MODULEs. + for name, module in self.model.named_modules(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + module.register_forward_hook( + lambda m, x, y, entry=entry: + # x is a tuple + _proto_forward_hook(x[0] if len(x)==1 else x, entry) + ) + ) + elif self.hook_type == 'forward_out': + # NOTE: Register forward hooks for MODULEs. + for name, module in self.model.named_modules(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + module.register_forward_hook( + lambda m, x, y, entry=entry: + # y is a tensor or a tuple + _proto_forward_hook(y, entry) + ) + ) + elif self.hook_type == 'backward': + # NOTE: Register backward hooks for TENSORs. + for name, param in self.model.named_parameters(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + param.register_hook( + lambda grad, entry=entry: + _proto_forward_hook(grad, entry) + ) ) - ) + else: + raise NotImplementedError def __exit__(self, exc_type, exc_val, ext_tb): for handle in self._handles: