diff --git a/src/core/trainer.py b/src/core/trainer.py index ae87413df97ab40c08bec14e4aa8152b5a17a92f..c1a261188d49925dd56a362581e3fe3c86620506 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -159,11 +159,15 @@ class Trainer(metaclass=ABCMeta): return False else: self.logger.warn("=> {} params are to be loaded.".format(num_to_update)) - elif not self.ctx['anew'] or not self.is_training: + ckp_epoch = -1 + else: ckp_epoch = checkpoint.get('epoch', -1) - self.start_epoch = ckp_epoch+1 self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch)) - if self.ctx['load_optim'] and self.is_training: + if not self.is_training: + self.start_epoch = ckp_epoch + elif not self.ctx['anew']: + self.start_epoch = ckp_epoch+1 + if self.ctx['load_optim']: # XXX: Note that weight decay might be modified here. self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.warn("Weight decay might have been modified.") @@ -171,11 +175,11 @@ class Trainer(metaclass=ABCMeta): state_dict.update(update_dict) self.model.load_state_dict(state_dict) - if self.start_epoch == 0: + if ckp_epoch == -1: self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint)) else: self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format( - self.checkpoint, self.start_epoch-1, *self._init_acc_epoch + self.checkpoint, ckp_epoch, *self._init_acc_epoch )) return True