From 39d0b77640251452e48b6719857e869c4f92121e Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Sat, 2 May 2020 10:34:07 +0800 Subject: [PATCH] Fix direct param not updated by optimizer --- src/core/factories.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/core/factories.py b/src/core/factories.py index 98963ce..2e3cc1c 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C): optims = [] for name, model in zip(name_list, models): param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()] + if next(model.parameters(recurse=False), None) is not None: + param_groups.append({'params': model.parameters(recurse=False), 'name': '_direct'}) optims.append(single_optim_factory(name, param_groups, C)) return DuckOptimizer(*optims) else: return single_optim_factory( optim_names, - [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], + [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()] + + ([{'params': models.parameters(recurse=False), 'name': '_direct'}] + if next(models.parameters(recurse=False), None) is not None else []), C ) -- GitLab