From 878f99ca655b6f0a68ac2be17a31c87524f52190 Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Wed, 16 Dec 2020 16:29:06 +0800 Subject: [PATCH] Fix eval epoch --- src/core/trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/core/trainer.py b/src/core/trainer.py index ae87413..c1a2611 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -159,11 +159,15 @@ class Trainer(metaclass=ABCMeta): return False else: self.logger.warn("=> {} params are to be loaded.".format(num_to_update)) - elif not self.ctx['anew'] or not self.is_training: + ckp_epoch = -1 + else: ckp_epoch = checkpoint.get('epoch', -1) - self.start_epoch = ckp_epoch+1 self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch)) - if self.ctx['load_optim'] and self.is_training: + if not self.is_training: + self.start_epoch = ckp_epoch + elif not self.ctx['anew']: + self.start_epoch = ckp_epoch+1 + if self.ctx['load_optim']: # XXX: Note that weight decay might be modified here. self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.warn("Weight decay might have been modified.") @@ -171,11 +175,11 @@ class Trainer(metaclass=ABCMeta): state_dict.update(update_dict) self.model.load_state_dict(state_dict) - if self.start_epoch == 0: + if ckp_epoch == -1: self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint)) else: self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format( - self.checkpoint, self.start_epoch-1, *self._init_acc_epoch + self.checkpoint, ckp_epoch, *self._init_acc_epoch )) return True -- GitLab