Skip to content
Snippets Groups Projects

Update outdated code

Open manli requested to merge github/fork/Bobholamovic/master into master
1 file
+ 75
44
Compare changes
  • Side-by-side
  • Inline
+ 75
44
import math
import math
import weakref
from collections import OrderedDict
import torch
import torch
import numpy as np
import numpy as np
@@ -21,61 +21,92 @@ def mod_crop(blob, N):
@@ -21,61 +21,92 @@ def mod_crop(blob, N):
return blob[..., :nh, :nw]
return blob[..., :nh, :nw]
 
class FeatureContainer:
 
r"""A simple wrapper for OrderedDict."""
 
def __init__(self):
 
self._dict = OrderedDict()
 
 
def __setitem__(self, key, val):
 
if key not in self._dict:
 
self._dict[key] = list()
 
self._dict[key].append(val)
 
 
def __getitem__(self, key):
 
return self._dict[key]
 
 
def __repr__(self):
 
return self._dict.__repr__()
 
 
def items(self):
 
return self._dict.items()
 
 
def keys(self):
 
return self._dict.keys()
 
 
def values(self):
 
return self._dict.values()
 
 
class HookHelper:
class HookHelper:
def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'):
self.model = weakref.proxy(model)
# XXX: A HookHelper object should only be used as a context manager and should not
 
# persist in memory since it may keep references to some very large objects.
 
self.model = model
self.fetch_dict = fetch_dict
self.fetch_dict = fetch_dict
# Subclass the built-in list to make it weak referenceable
self.out_dict = out_dict
class _list(list):
pass
for entry in self.fetch_dict.values():
# entry is expected to be a string or a non-nested tuple
if isinstance(entry, tuple):
for key in entry:
out_dict[key] = _list()
else:
out_dict[entry] = _list()
self.out_dict = weakref.WeakValueDictionary(out_dict)
self._handles = []
self._handles = []
if hook_type not in ('forward_in', 'forward_out', 'backward_out'):
if hook_type not in ('forward_in', 'forward_out', 'backward'):
raise NotImplementedError("Hook type is not implemented.")
raise NotImplementedError("Hook type is not implemented.")
 
self.hook_type = hook_type
def _proto_hook(x, entry):
def __enter__(self):
# x should be a tensor or a tuple
def _proto_forward_hook(x, entry):
 
# x should be a tensor or a tuple;
 
# entry is expected to be a string or a non-nested tuple.
if isinstance(entry, tuple):
if isinstance(entry, tuple):
for key, f in zip(entry, x):
for key, f in zip(entry, x):
self.out_dict[key].append(f.detach().clone())
self.out_dict[key] = f.data.clone()
else:
else:
self.out_dict[entry].append(x.detach().clone())
self.out_dict[entry] = x.data.clone()
def _forward_in_hook(m, x, y, entry):
# x is a tuple
return _proto_hook(x[0] if len(x)==1 else x, entry)
def _forward_out_hook(m, x, y, entry):
# y is a tensor or a tuple
return _proto_hook(y, entry)
def _backward_out_hook(m, grad_in, grad_out, entry):
# grad_out is a tuple
return _proto_hook(grad_out[0] if len(grad_out)==1 else grad_out, entry)
self._hook_func, self._reg_func_name = {
if self.hook_type == 'forward_in':
'forward_in': (_forward_in_hook, 'register_forward_hook'),
# NOTE: Register forward hooks for MODULEs.
'forward_out': (_forward_out_hook, 'register_forward_hook'),
for name, module in self.model.named_modules():
'backward_out': (_backward_out_hook, 'register_backward_hook'),
if name in self.fetch_dict:
}[hook_type]
entry = self.fetch_dict[name]
self._handles.append(
def __enter__(self):
module.register_forward_hook(
for name, module in self.model.named_modules():
lambda m, x, y, entry=entry:
if name in self.fetch_dict:
# x is a tuple
entry = self.fetch_dict[name]
_proto_forward_hook(x[0] if len(x)==1 else x, entry)
self._handles.append(
)
getattr(module, self._reg_func_name)(
)
lambda *args, entry=entry: self._hook_func(*args, entry=entry)
elif self.hook_type == 'forward_out':
 
# NOTE: Register forward hooks for MODULEs.
 
for name, module in self.model.named_modules():
 
if name in self.fetch_dict:
 
entry = self.fetch_dict[name]
 
self._handles.append(
 
module.register_forward_hook(
 
lambda m, x, y, entry=entry:
 
# y is a tensor or a tuple
 
_proto_forward_hook(y, entry)
 
)
 
)
 
elif self.hook_type == 'backward':
 
# NOTE: Register backward hooks for TENSORs.
 
for name, param in self.model.named_parameters():
 
if name in self.fetch_dict:
 
entry = self.fetch_dict[name]
 
self._handles.append(
 
param.register_hook(
 
lambda grad, entry=entry:
 
_proto_forward_hook(grad, entry)
 
)
)
)
)
else:
 
raise NotImplementedError
def __exit__(self, exc_type, exc_val, ext_tb):
def __exit__(self, exc_type, exc_val, ext_tb):
for handle in self._handles:
for handle in self._handles:
Loading