From 878f99ca655b6f0a68ac2be17a31c87524f52190 Mon Sep 17 00:00:00 2001
From: Bobholamovic <bob1998425@hotmail.com>
Date: Wed, 16 Dec 2020 16:29:06 +0800
Subject: [PATCH] Fix eval epoch

---
 src/core/trainer.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/src/core/trainer.py b/src/core/trainer.py
index ae87413..c1a2611 100644
--- a/src/core/trainer.py
+++ b/src/core/trainer.py
@@ -159,11 +159,15 @@ class Trainer(metaclass=ABCMeta):
                 return False
             else:
                 self.logger.warn("=> {} params are to be loaded.".format(num_to_update))
-        elif not self.ctx['anew'] or not self.is_training:
+            ckp_epoch = -1
+        else:
             ckp_epoch = checkpoint.get('epoch', -1)
-            self.start_epoch = ckp_epoch+1
             self._init_acc_epoch = checkpoint.get('max_acc', (0.0, ckp_epoch))
-            if self.ctx['load_optim'] and self.is_training:
+            if not self.is_training:
+                self.start_epoch = ckp_epoch
+            elif not self.ctx['anew']:
+                self.start_epoch = ckp_epoch+1
+            if self.ctx['load_optim']:
                 # XXX: Note that weight decay might be modified here.
                 self.optimizer.load_state_dict(checkpoint['optimizer'])
                 self.logger.warn("Weight decay might have been modified.")
@@ -171,11 +175,11 @@ class Trainer(metaclass=ABCMeta):
         state_dict.update(update_dict)
         self.model.load_state_dict(state_dict)
 
-        if self.start_epoch == 0:
+        if ckp_epoch == -1:
             self.logger.show("=> Loaded checkpoint '{}'".format(self.checkpoint))
         else:
             self.logger.show("=> Loaded checkpoint '{}' (epoch {}, max_acc {:.4f} at epoch {}).".format(
-                self.checkpoint, self.start_epoch-1, *self._init_acc_epoch
+                self.checkpoint, ckp_epoch, *self._init_acc_epoch
                 ))
         return True
         
-- 
GitLab