diff --git a/src/core/factories.py b/src/core/factories.py index 98963cefdb0df754e418ceda650d1f7305e89b9b..2e3cc1ceb75ccf0397bf0a2bc955348860915fcc 100644 --- a/src/core/factories.py +++ b/src/core/factories.py @@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C): optims = [] for name, model in zip(name_list, models): param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()] + if next(model.parameters(recurse=False), None) is not None: + param_groups.append({'params': model.parameters(recurse=False), 'name': '_direct'}) optims.append(single_optim_factory(name, param_groups, C)) return DuckOptimizer(*optims) else: return single_optim_factory( optim_names, - [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], + [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()] + + ([{'params': models.parameters(recurse=False), 'name': '_direct'}] + if next(models.parameters(recurse=False), None) is not None else []), C )