Skip to content
Snippets Groups Projects
Commit 21b491a7 authored by ZeroAct's avatar ZeroAct
Browse files

fix train.py optimizer and readme image

parent 5eeabbba
Branches
No related tags found
No related merge requests found
doc/epoch1.PNG

124 KiB | W: | H:

doc/epoch1.PNG

139 KiB | W: | H:

doc/epoch1.PNG
doc/epoch1.PNG
doc/epoch1.PNG
doc/epoch1.PNG
  • 2-up
  • Swipe
  • Onion skin
doc/epoch10.PNG

151 KiB | W: | H:

doc/epoch10.PNG

138 KiB | W: | H:

doc/epoch10.PNG
doc/epoch10.PNG
doc/epoch10.PNG
doc/epoch10.PNG
  • 2-up
  • Swipe
  • Onion skin
doc/epoch30.PNG

117 KiB | W: | H:

doc/epoch30.PNG

134 KiB | W: | H:

doc/epoch30.PNG
doc/epoch30.PNG
doc/epoch30.PNG
doc/epoch30.PNG
  • 2-up
  • Swipe
  • Onion skin
doc/epoch5.PNG

147 KiB | W: | H:

doc/epoch5.PNG

94 KiB | W: | H:

doc/epoch5.PNG
doc/epoch5.PNG
doc/epoch5.PNG
doc/epoch5.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -105,7 +105,6 @@ if __name__ == "__main__": ...@@ -105,7 +105,6 @@ if __name__ == "__main__":
# Train Setting # Train Setting
model_optim = Adam(model.parameters(), 0.0001, (0.5, 0.9)) model_optim = Adam(model.parameters(), 0.0001, (0.5, 0.9))
discrim_optim = Adam(model.discrim.parameters(), 0.0004)
### Train ### Train
for epoch in range(initial_epoch, epochs): for epoch in range(initial_epoch, epochs):
...@@ -138,10 +137,9 @@ if __name__ == "__main__": ...@@ -138,10 +137,9 @@ if __name__ == "__main__":
Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \
torch.mean(F.relu(1+model.discrim(Mm, Ite_))) torch.mean(F.relu(1+model.discrim(Mm, Ite_)))
discrim_optim.zero_grad() model_optim.zero_grad()
Ldsn.backward() Ldsn.backward()
discrim_optim.step() model_optim.step()
ltsd = Ltsd.detach().cpu().item() ltsd = Ltsd.detach().cpu().item()
ltrg = Ltrg.detach().cpu().item() ltrg = Ltrg.detach().cpu().item()
...@@ -180,26 +178,30 @@ if __name__ == "__main__": ...@@ -180,26 +178,30 @@ if __name__ == "__main__":
pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}") pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}")
if len(result_images) < args.show_num: if len(result_images) < args.show_num:
result_images.append([I.cpu(), Itegt.cpu(), Ite_.cpu(), Msgt.cpu(), Ms_.cpu()]) result_images.append([I.cpu(), Itegt.cpu(), Ite.cpu(), Ite_.cpu(), Msgt.cpu(), Ms.cpu(), Ms_.cpu()])
else: else:
break break
val_loss = sum(val_loss) / len(val_loss) val_loss = sum(val_loss) / len(val_loss)
## visualize ## visualize
fig, axs = plt.subplots(args.show_num, 1, figsize=(5, 2*args.show_num)) fig, axs = plt.subplots(args.show_num, 1, figsize=(10, 2*args.show_num))
fig.suptitle("Image, Gt, Gen, Stroke Gt, Stroke") fig.suptitle("I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]")
for i, (I, Itegt, Ite_, Msgt, Ms_) in enumerate(result_images): for i, (I, Itegt, Ite, Ite_, Msgt, Ms, Ms_) in enumerate(result_images):
I = postprocess_image(tensor_to_mat(I))[0] I = postprocess_image(tensor_to_mat(I))[0]
Itegt = postprocess_image(tensor_to_mat(Itegt))[0] Itegt = postprocess_image(tensor_to_mat(Itegt))[0]
Ite = postprocess_image(tensor_to_mat(Ite))[0]
Ite_ = postprocess_image(tensor_to_mat(Ite_))[0] Ite_ = postprocess_image(tensor_to_mat(Ite_))[0]
Msgt = postprocess_image(tensor_to_mat(Msgt))[0] Msgt = postprocess_image(tensor_to_mat(Msgt))[0]
Ms = postprocess_image(tensor_to_mat(Ms))[0]
Ms_ = postprocess_image(tensor_to_mat(Ms_))[0] Ms_ = postprocess_image(tensor_to_mat(Ms_))[0]
Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR) Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR)
Ms = cv2.cvtColor(Ms, cv2.COLOR_GRAY2BGR)
Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR) Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR)
axs[i].imshow(np.hstack([I, Itegt, Ite_, Msgt, Ms_])) axs[i].imshow(np.hstack([I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]))
axs[i].set_xticks([]) axs[i].set_xticks([])
axs[i].set_yticks([]) axs[i].set_yticks([])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment