Skip to content
Snippets Groups Projects
Commit 9126a3ce authored by Bobholamovic's avatar Bobholamovic
Browse files

Add option to save optim state

parent 39d0b776
Branches
No related tags found
1 merge request!2Update outdated code
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 8 ...@@ -22,6 +22,7 @@ batch_size: 8
num_epochs: 15 num_epochs: 15
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -22,6 +22,7 @@ batch_size: 32 ...@@ -22,6 +22,7 @@ batch_size: 32
num_epochs: 10 num_epochs: 10
resume: '' resume: ''
load_optim: True load_optim: True
save_optim: True
anew: False anew: False
trace_freq: 1 trace_freq: 1
device: cuda device: cuda
......
...@@ -106,7 +106,11 @@ class Trainer: ...@@ -106,7 +106,11 @@ class Trainer:
acc, epoch, max_acc, best_epoch)) acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch # The checkpoint saves next epoch
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), (max_acc, best_epoch), epoch+1, is_best) self._save_checkpoint(
self.model.state_dict(),
self.optimizer.state_dict() if self.ctx['save_optim'] else {},
(max_acc, best_epoch), epoch+1, is_best
)
def evaluate(self): def evaluate(self):
if self.checkpoint: if self.checkpoint:
...@@ -164,11 +168,8 @@ class Trainer: ...@@ -164,11 +168,8 @@ class Trainer:
else: else:
self._init_max_acc_and_epoch = max_acc_and_epoch self._init_max_acc_and_epoch = max_acc_and_epoch
if self.ctx['load_optim'] and self.is_training: if self.ctx['load_optim'] and self.is_training:
try:
# Note that weight decay might be modified here # Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
except KeyError:
self.logger.warning("Warning: failed to load optimizer parameters.")
state_dict.update(update_dict) state_dict.update(update_dict)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
......
...@@ -62,6 +62,7 @@ def parse_args(): ...@@ -62,6 +62,7 @@ def parse_args():
group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE', group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE',
help='number of epochs to train (default: %(default)s)') help='number of epochs to train (default: %(default)s)')
group_train.add_argument('--load-optim', action='store_true') group_train.add_argument('--load-optim', action='store_true')
group_train.add_argument('--save-optim', action='store_true')
group_train.add_argument('--resume', default='', type=str, metavar='PATH', group_train.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint') help='path to latest checkpoint')
group_train.add_argument('--anew', action='store_true', group_train.add_argument('--anew', action='store_true',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment