diff --git a/doc/epoch1.PNG b/doc/epoch1.PNG index 64fc635fcdee68f2a8b896aa8ddf2a9aa10c60ef..6c82ef86976f451ccab44b58a218d1bdaf72761c 100644 Binary files a/doc/epoch1.PNG and b/doc/epoch1.PNG differ diff --git a/doc/epoch10.PNG b/doc/epoch10.PNG index 06e7337ba3a1b88e576853256154d3aab0b7eba1..ee2fe98a506b97745b3fa918491db5a8a8512cba 100644 Binary files a/doc/epoch10.PNG and b/doc/epoch10.PNG differ diff --git a/doc/epoch30.PNG b/doc/epoch30.PNG index f83838178c004bd0a19046bdaf2c2657d585e6ae..3e874a5ca1e349647af9961d2f56b71aed6d8ff7 100644 Binary files a/doc/epoch30.PNG and b/doc/epoch30.PNG differ diff --git a/doc/epoch5.PNG b/doc/epoch5.PNG index 8e39f21181cd8a7d8e11c82fca2c0e6494e29f14..add5b234ccaa98e79fe16a8a255b976e759ceb88 100644 Binary files a/doc/epoch5.PNG and b/doc/epoch5.PNG differ diff --git a/train.py b/train.py index 5294b28309f5a5d56304aa502e8cbf93148da7bd..f50d18b4006ae69c55c734c164db053d958d21a3 100644 --- a/train.py +++ b/train.py @@ -105,7 +105,6 @@ if __name__ == "__main__": # Train Setting model_optim = Adam(model.parameters(), 0.0001, (0.5, 0.9)) - discrim_optim = Adam(model.discrim.parameters(), 0.0004) ### Train for epoch in range(initial_epoch, epochs): @@ -138,10 +137,9 @@ if __name__ == "__main__": Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ torch.mean(F.relu(1+model.discrim(Mm, Ite_))) - discrim_optim.zero_grad() + model_optim.zero_grad() Ldsn.backward() - discrim_optim.step() - + model_optim.step() ltsd = Ltsd.detach().cpu().item() ltrg = Ltrg.detach().cpu().item() @@ -180,26 +178,30 @@ if __name__ == "__main__": pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}") if len(result_images) < args.show_num: - result_images.append([I.cpu(), Itegt.cpu(), Ite_.cpu(), Msgt.cpu(), Ms_.cpu()]) + result_images.append([I.cpu(), Itegt.cpu(), Ite.cpu(), Ite_.cpu(), Msgt.cpu(), Ms.cpu(), Ms_.cpu()]) else: break val_loss = sum(val_loss) / len(val_loss) ## visualize - fig, axs = plt.subplots(args.show_num, 1, figsize=(5, 2*args.show_num)) - fig.suptitle("Image, Gt, Gen, Stroke Gt, Stroke") - for i, (I, Itegt, Ite_, Msgt, Ms_) in enumerate(result_images): + fig, axs = plt.subplots(args.show_num, 1, figsize=(10, 2*args.show_num)) + fig.suptitle("I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]") + for i, (I, Itegt, Ite, Ite_, Msgt, Ms, Ms_) in enumerate(result_images): + I = postprocess_image(tensor_to_mat(I))[0] Itegt = postprocess_image(tensor_to_mat(Itegt))[0] + Ite = postprocess_image(tensor_to_mat(Ite))[0] Ite_ = postprocess_image(tensor_to_mat(Ite_))[0] Msgt = postprocess_image(tensor_to_mat(Msgt))[0] + Ms = postprocess_image(tensor_to_mat(Ms))[0] Ms_ = postprocess_image(tensor_to_mat(Ms_))[0] Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR) + Ms = cv2.cvtColor(Ms, cv2.COLOR_GRAY2BGR) Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR) - axs[i].imshow(np.hstack([I, Itegt, Ite_, Msgt, Ms_])) + axs[i].imshow(np.hstack([I, Itegt, Ite, Ite_, Msgt, Ms, Ms_])) axs[i].set_xticks([]) axs[i].set_yticks([])