Skip to content
Snippets Groups Projects
Commit 39d0b776 authored by Bobholamovic's avatar Bobholamovic
Browse files

Fix direct param not updated by optimizer

parent 31e99e9f
No related branches found
No related tags found
1 merge request!2Update outdated code
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
...@@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C): ...@@ -274,12 +274,16 @@ def optim_factory(optim_names, models, C):
optims = [] optims = []
for name, model in zip(name_list, models): for name, model in zip(name_list, models):
param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()] param_groups = [{'params': module.parameters(), 'name': module_name} for module_name, module in model.named_children()]
if next(model.parameters(recurse=False), None) is not None:
param_groups.append({'params': model.parameters(recurse=False), 'name': '_direct'})
optims.append(single_optim_factory(name, param_groups, C)) optims.append(single_optim_factory(name, param_groups, C))
return DuckOptimizer(*optims) return DuckOptimizer(*optims)
else: else:
return single_optim_factory( return single_optim_factory(
optim_names, optim_names,
[{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()], [{'params': module.parameters(), 'name': module_name} for module_name, module in models.named_children()] +
([{'params': models.parameters(recurse=False), 'name': '_direct'}]
if next(models.parameters(recurse=False), None) is not None else []),
C C
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment