diff --git a/.gitignore b/.gitignore index b6e47617de110dea7ca47e087ff1347cc2646eda..b84a66512ee7bc0202f91e575a56c666b10b7e73 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +# etc +.png +.jpg +.pth \ No newline at end of file diff --git a/create_dataset.py b/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..10054a21b84e373f7a38cf489bd7bab8baa7a3db --- /dev/null +++ b/create_dataset.py @@ -0,0 +1,111 @@ +import os +import cv2 +import glob +import random +import progressbar + +import numpy as np + +import matplotlib.pyplot as plt + +rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) +rand_pos = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1)) + +target_size = 256 +imgs_per_back = 30 + +backs = glob.glob('./dataset/backs/*.png') +fonts = glob.glob('./dataset/font_mask/*.png') + +os.makedirs('./dataset/train/I', exist_ok=True) +os.makedirs('./dataset/train/Itegt', exist_ok=True) +os.makedirs('./dataset/train/Mm', exist_ok=True) +os.makedirs('./dataset/train/Msgt', exist_ok=True) + +os.makedirs('./dataset/val/I', exist_ok=True) +os.makedirs('./dataset/val/Itegt', exist_ok=True) +os.makedirs('./dataset/val/Mm', exist_ok=True) +os.makedirs('./dataset/val/Msgt', exist_ok=True) + +t_idx = len(os.listdir('./dataset/train/I')) +v_idx = len(os.listdir('./dataset/val/I')) + +bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back) +bar.start() +for back in backs: + back_img = cv2.imread(back) + bh, bw, _ = back_img.shape + if bh < target_size or bw < target_size: + back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC) + bh, bw, _ = back_img.shape + + for bi in range(imgs_per_back): + sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size) + + Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy() + I = Itegt.copy() + Mm = np.zeros_like(I) + Msgt = np.zeros_like(I) + + hist = [] + for font in random.sample(fonts, random.randint(2, 4)): + font_img = cv2.imread(font) + mask_img = np.ones_like(font_img, dtype=np.uint8)*255 + + height, width, _ = font_img.shape + + angle = random.randint(-30, +30) + fs = random.randint(90, 120) + ratio = fs / height - 0.2 + + matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio) + font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC) + mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC) + + h, w, _ = font_rot.shape + + font_in_I = np.zeros_like(I) + mask_in_I = np.zeros_like(I) + + allow = 0 + while True: + sx, sy = rand_pos(0, target_size-w) + + done = True + for sx_, sy_ in hist: + if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow: + done = False + break + allow += 5 + + if done: + hist.append([sx, sy]) + break + + font_in_I[sy:sy+h, sx:sx+w, :] = font_rot + mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot + + font_in_I[font_in_I > 30] = 255 + mask_in_I[mask_in_I > 30] = 255 + + I = cv2.bitwise_and(I, 255-font_in_I) + I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8)) + + Mm = cv2.bitwise_or(Mm, mask_in_I) + Msgt = cv2.bitwise_or(Msgt, font_in_I) + + if bi < imgs_per_back*0.8: + cv2.imwrite(f'dataset/train/I/{t_idx}.png', I) + cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt) + cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm) + cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt) + t_idx += 1 + else: + cv2.imwrite(f'dataset/val/I/{v_idx}.png', I) + cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt) + cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm) + cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt) + v_idx += 1 + + bar.update(t_idx + v_idx) +bar.finish() \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f7fe6926cf4e5cf46f35707484b50fced30ca2 --- /dev/null +++ b/dataset.py @@ -0,0 +1,75 @@ +import os, cv2 +import numpy as np + +import torch +from torch.utils.data import Dataset + +def mat_to_tensor(mat): + mat = mat.transpose((2, 0, 1)) + tensor = torch.Tensor(mat) + return tensor + +def tensor_to_mat(tensor): + mat = tensor.detach().cpu().numpy() + mat = mat.transpose((0, 2, 3, 1)) + return mat + +def preprocess_image(img, target_shape: tuple): + img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32) + img = img / 255. + if len(img.shape) == 2: + img = img.reshape(*img.shape, 1) + + return img + +def postprocess_image(img): + # img = img * 255 + img = (img - img.min()) / (img.max() - img.min()) * 255 + return img.astype(np.uint8) + +class CustomDataset(Dataset): + def __init__(self, + data_dir, + set_name="train", + target_size=(256, 256)): + + super().__init__() + + self.root_dir = os.path.join(data_dir, set_name) + self.target_size = target_size + + self.I_dir = os.path.join(self.root_dir, "I") + self.Itegt_dir = os.path.join(self.root_dir, "Itegt") + self.Mm_dir = os.path.join(self.root_dir, "Mm") + self.Msgt_dir = os.path.join(self.root_dir, "Msgt") + + self.datas = os.listdir(self.I_dir) + + def __len__(self): + return len(self.datas) + + def __getitem__(self, idx): + img_name = self.datas[idx] + + I = cv2.imread(os.path.join(self.I_dir, img_name)) + Itegt = cv2.imread(os.path.join(self.Itegt_dir, img_name)) + Mm = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE) + Msgt = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE) + + I = mat_to_tensor(preprocess_image(I, self.target_size)) + Itegt = mat_to_tensor(preprocess_image(Itegt, self.target_size)) + Mm = mat_to_tensor(preprocess_image(Mm, self.target_size)) + Msgt = mat_to_tensor(preprocess_image(Msgt, self.target_size)) + + return I, Itegt, Mm, Msgt + + +if __name__ == "__main__": + ds = CustomDataset('dataset', 'train') + + I, Itegt, Mm, Ms = ds.__getitem__(0) + print(f"Dataset length : {len(ds)}") + print(f"I shape : {I.shape}") + print(f"Itegt shape : {Itegt.shape}") + print(f"Mm shape : {Mm.shape}") + print(f"Ms shape : {Ms.shape}") diff --git a/losses.py b/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..bec12c644c7d550488942b97a29acc1e0bcb5573 --- /dev/null +++ b/losses.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def TSDLoss(Mgt, Ms, Ms_, r=10): + return torch.mean(torch.abs(Ms-Mgt) + r * torch.abs(Ms_-Mgt)) + +def TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_, rm=5, rs=5, rr=10): + + Mw = torch.ones_like(Mm) + rm * Mm + rs * Ms + Mw_ = torch.ones_like(Mm) + rm * Mm + rs * Ms_ + + Ltrg = torch.mean(torch.abs(torch.mul(Ite, Mw) - torch.mul(Itegt, Mw)) + \ + rr * torch.abs(torch.mul(Ite_, Mw_) - torch.mul(Itegt, Mw_))) + + return Ltrg diff --git a/modules.py b/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3e00fa3ebfe2b0a6b1a7042d3aed14ff46f4ea --- /dev/null +++ b/modules.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# dis_conv +# (https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L84) +class _dis_conv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2): + super().__init__() + + self._conv = nn.Sequential( + nn.utils.spectral_norm( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + ), + nn.LeakyReLU(inplace=True) + ) + + # weight initialization + def weight_init(m): + if isinstance(m, nn.Conv2d): + # nn.utils.spectral_norm(m.weight) + nn.init.zeros_(m.bias) + + self.apply(weight_init) + + def forward(self, x): + return self._conv(x) + +# weights are fixed to one, bias to zero +class _one_conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2): + super().__init__() + + self._conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + ) + + # weight initialization + def weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + m.weight.requires_grad = False + m.bias.requires_grad = False + + self.apply(weight_init) + + def forward(self, x): + return self._conv(x) + +class _double_conv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, mid_channels=None): + super().__init__() + + if not mid_channels: + mid_channels = out_channels + + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + + nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + # weight initialization + def weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu')) + nn.init.zeros_(m.bias) + + self.apply(weight_init) + + def forward(self, x): + return self.double_conv(x) + + +class _down_conv2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size): + + super().__init__() + + self.seq_model = nn.Sequential( + nn.MaxPool2d(2), + _double_conv2d(in_channels, out_channels) + ) + + + def forward(self, x): + return self.seq_model(x) + + +class _up_conv2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size): + + super().__init__() + + self.conv_t = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2) + self.conv = _double_conv2d(in_channels, out_channels) + + # x1 : input, x2 : matching down_conv2d output + def forward(self, x1, x2): + x1 = self.conv_t(x1) + + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class _final_conv2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size): + + super().__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, 1, 1) + + def forward(self, x): + return self.conv(x) \ No newline at end of file diff --git a/network.py b/network.py new file mode 100644 index 0000000000000000000000000000000000000000..8437ca36b63d7ae08ea71ac1d2b88bc384519156 --- /dev/null +++ b/network.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules import \ + _double_conv2d, _down_conv2d, _up_conv2d, _final_conv2d, _dis_conv, _one_conv + +from losses import TSDLoss, TRGLoss + +# Text Stroke Detection (GD in paper) +class TSDNet(nn.Module): + def __init__(self): + super(TSDNet, self).__init__() + + self.inc = _double_conv2d(4, 16, 3) + self.down1 = _down_conv2d(16, 32, 3) + self.down2 = _down_conv2d(32, 64, 3) + self.down3 = _down_conv2d(64, 128, 3) + + self.up1 = _up_conv2d(128, 64, 3) + self.up2 = _up_conv2d(64, 32, 3) + self.up3 = _up_conv2d(32, 16, 3) + + self.outc = _final_conv2d(16, 1, 3) + + def forward(self, Igt, M): + x = torch.cat([Igt, M], dim=1) + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + + x = self.up1(x4, x3) + x = self.up2(x, x2) + x = self.up3(x, x1) + + M = self.outc(x) + return M + +# Text Removal Generation (GR, GR' in paper) +class TRGNet(nn.Module): + def __init__(self): + super(TRGNet, self).__init__() + + self.inc = _double_conv2d(5, 16, 5, 2) + self.down1 = _down_conv2d(16, 32, 3) + self.down2 = _down_conv2d(32, 64, 3) + self.down3 = _down_conv2d(64, 128, 3) + + self.mid_layer = _double_conv2d(128, 128, 3) + + self.up1 = _up_conv2d(128, 64, 3) + self.up2 = _up_conv2d(64, 32, 3) + self.up3 = _up_conv2d(32, 16, 3) + + self.outc = _final_conv2d(16, 3, 3) + + def forward(self, Igt, M, Ms): + x = torch.cat([Igt, M, Ms], dim=1) + x1 = self.inc(x) + + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + + x4 = torch.add(self.mid_layer(x4), x4) + + x = self.up1(x4, x3) + x = self.up2(x, x2) + x = self.up3(x, x1) + + M = self.outc(x) + return M + +# Text Stroke Detection _ (G'D in paper) +class TSDNet_(nn.Module): + def __init__(self): + super(TSDNet_, self).__init__() + + self.inc = _double_conv2d(5, 16, 3) + self.down1 = _down_conv2d(16, 32, 3) + self.down2 = _down_conv2d(32, 64, 3) + self.down3 = _down_conv2d(64, 128, 3) + + self.up1 = _up_conv2d(128, 64, 3) + self.up2 = _up_conv2d(64, 32, 3) + self.up3 = _up_conv2d(32, 16, 3) + + self.outc = _final_conv2d(16, 1, 3) + + def forward(self, Ite, M, Ms): + x = torch.cat([Ite, M, Ms], dim=1) + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + + x = self.up1(x4, x3) + x = self.up2(x, x2) + x = self.up3(x, x1) + + M = self.outc(x) + return M + +# weighted patch based discriminator (D, Dm in paper) +# build_sn_patch_gan_discriminator +# (https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_model.py) +class Discriminator(nn.Module): + def __init__(self): + + super(Discriminator, self).__init__() + + self.Dm = nn.Sequential( + _one_conv(1, 1, 5, 2, 2), + nn.Sigmoid(), + _one_conv(1, 1, 5, 2, 2), + nn.Sigmoid(), + _one_conv(1, 1, 5, 2, 2), + nn.Sigmoid(), + _one_conv(1, 1, 5, 2, 2), + nn.Sigmoid(), + _one_conv(1, 1, 5, 2, 2), + nn.Sigmoid() + ) + + self.D = nn.Sequential( + _dis_conv(3, 64, 5, 2, 2), + _dis_conv(64, 128, 5, 2, 2), + _dis_conv(128, 256, 5, 2, 2), + _dis_conv(256, 256, 5, 2, 2), + _dis_conv(256, 256, 5, 2, 2) + ) + + self.pool = nn.AvgPool2d(8) + self.linear = nn.Linear(256, 1) + # self.sigmoid = nn.Sigmoid() + + def forward(self, Mm, Ite_): + mi = self.Dm(Mm) + di = self.D(Ite_) + + y = torch.mul(mi, di) + # y = self.pool(y) + # y = self.linear(y.view(-1, 256)) + return y + + +class STRNet(nn.Module): + def __init__(self): + + super(STRNet, self).__init__() + + self.tsdnet = TSDNet() + self.trgnet = TRGNet() + self.tsdnet_ = TSDNet_() + self.trgnet_ = TRGNet() + + self.discrim = Discriminator() + + def forward(self, I, Mm): + Ms = self.tsdnet(I, Mm) + Ite = self.trgnet(I, Mm, Ms) + Ms_ = self.tsdnet_(Ite, Mm, Ms) + Ite_ = self.trgnet_(Ite, Mm, Ms) + + return Ms, Ite, Ms_, Ite_ + + +if __name__ == "__main__": + from torch.optim import Adam + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + # device='cpu' + print(device) + + # I : input image + # Itegt: input image + # M : Text Region mask + # Ms : Text Stroke mask (from tsdnet) + # + + I = torch.randn((2, 3, 256, 256)).to(device) + print(f"I shape\n : {I.shape}") + + Itegt = torch.randn((2, 3, 256, 256)).to(device) + print(f"Itegt shape\n : {Itegt.shape}") + + Mm = torch.randn((2, 1, 256, 256)).to(device) + print(f"Mm shape\n : {Mm.shape}") + + Msgt = torch.randn((2, 1, 256, 256)).to(device) + print(f"Mgt shape\n : {Msgt.shape}") + + One = torch.ones((2, 1)).to(device) + Zero = torch.zeros((2, 1)).to(device) + + model = STRNet().to(device) + + model_optim = Adam(model.parameters(), 0.0001) + discrim_optim = Adam(model.discrim.parameters(), 0.0001) + bce_loss = nn.BCEWithLogitsLoss() + + Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) + + Ltsd = TSDLoss(Msgt, Ms, Ms_) + Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) + # Lgsn = -bce_loss(model.discrim(Mm, Ite_), One) + Lgsn = -torch.mean(model.discrim(Mm, Ite_)) + + total_loss = Ltsd + Ltrg + Lgsn + + model_optim.zero_grad() + total_loss.backward() + model_optim.step() + + Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) + # Ldsn = F.relu(1-bce_loss(model.discrim(Mm, Itegt), One)) + \ + # F.relu(1+bce_loss(model.discrim(Mm, Ite_), Zero)) + Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ + torch.mean(F.relu(1+model.discrim(Mm, Ite_))) + + discrim_optim.zero_grad() + Ldsn.backward() + discrim_optim.step() + + + # Igt = torch.randn((2, 3, 256, 256)).to(device) + # print(f"Igt shape\n : {Igt.shape}") + + # M = torch.randn((2, 1, 256, 256)).to(device) + # print(f"M shape\n : {M.shape}") + + # Mgt = torch.randn((2, 1, 256, 256)).to(device) + # print(f"Mgt shape\n : {Mgt.shape}") + + # # models + # tsdnet = TSDNet().to(device) + # trgnet = TRGNet().to(device) + # tsdnet_ = TSDNet_().to(device) + # trgnet_ = TRGNet().to(device) + + # discriminator = Discriminator().to(device) + + # # optim + # from torch.optim import Adam + # total_optim = Adam(list(tsdnet.parameters()) + list(trgnet.parameters()) + + # list(tsdnet_.parameters()) + list(trgnet_.parameters()), 0.0001) + # total_optim.zero_grad() + + # discr_optim = Adam(discriminator.parameters()) + # discr_optim.zero_grad() + + # # inference + # Ms = tsdnet(Igt, M) + # # print(f"tsdnet output Ms shape\n : {Ms.shape}") + # Ite = trgnet(Igt, M, Ms) + # # print(f"trgnet output Ite shape\n : {Ite.shape}") + # Ms_ = tsdnet_(Ite, M, Ms) + # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}") + # Ite_ = trgnet_(Ite, M, Ms_) + # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}") + + # # calculate loss + # Lgsn = -discriminator.forward_with_loss(M, Ite) + + # from losses import TSDLoss, TRGLoss + # Ltsd = TSDLoss(Mgt, Ms, Ms_) + # Ltrg = TRGLoss(M, Ms, Ms_, Igt, Ite, Ite_) + + # total_loss = Ltsd + Ltrg + Lgsn + + # # train total model + # total_loss.backward() + # total_optim.step() + # print(total_loss.detach().cpu().item()) + + # # train discriminator + # Ms = tsdnet(Igt, M) + # # print(f"tsdnet output Ms shape\n : {Ms.shape}") + # Ite = trgnet(Igt, M, Ms) + # # print(f"trgnet output Ite shape\n : {Ite.shape}") + # Ms_ = tsdnet_(Ite, M, Ms) + # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}") + # Ite_ = trgnet_(Ite, M, Ms_) + # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}") + + # Ldsn = discriminator.forward_with_loss(M, Igt) + discriminator.forward_with_loss(M, Ite_) + # discr_loss = Ldsn + + # discr_loss.backward() + # discr_optim.step() + # print(discr_loss.detach().cpu().item()) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8792be9e5f2cc85e5d6fc64254edf16596242b96 --- /dev/null +++ b/train.py @@ -0,0 +1,214 @@ +import sys +import os, argparse, time, tqdm, random, cv2 +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import matplotlib.pyplot as plt +import numpy as np + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.optim import Adam +from torch import nn + +from dataset import CustomDataset, postprocess_image, tensor_to_mat +from network import STRNet +from losses import TSDLoss, TRGLoss + +random_seed = 123 + +torch.manual_seed(random_seed) +torch.cuda.manual_seed(random_seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +np.random.seed(random_seed) +random.seed(random_seed) + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_path", default='dataset', help="data root path") + + parser.add_argument("-e", "--num_epochs", default=100, type=int, help="num epochs") + parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size > 1") + + parser.add_argument("-n", "--num_workers", default=8, type=int, help="num_workers for DataLoader") + parser.add_argument("-sn", "--show_num", default=4, type=int, help="show result images during training num") + + args = parser.parse_args() + + return args + +def load_weights_from_directory(model, weight_path) -> int: + if weight_path.endswith('.pth'): + wp = weight_path + else: + wps = sorted(os.listdir(weight_path), key=lambda x: int(x.split('_')[0])) + if wps: + wp = wps[-1] + else: + return 0 + + print(f"Loading weights from {wp}...") + model.load_state_dict(torch.load(os.path.join(weight_path, wp))) + return int(wp.split('_')[0]) + +if __name__ == "__main__": + + args = get_args() + + ### Path + model_path = "results" + weight_path = os.path.join(model_path, "weights") + show_path = os.path.join(model_path, "show") + + os.makedirs(model_path, exist_ok=True) + os.makedirs(weight_path, exist_ok=True) + os.makedirs(show_path, exist_ok=True) + + + ### Hyperparameters + epochs = args.num_epochs + batch_size = args.batch_size + if batch_size <= 1: + raise "Batch size should bigger than 1 for batch normalization" + + num_workers = args.num_workers + show_num = args.show_num + + ### DataLoader + dataloader_params = {'batch_size': batch_size, + 'shuffle': True, + 'drop_last': True, + 'num_workers': num_workers} + + train_data = CustomDataset(args.data_path, set_name="train") + train_gen = DataLoader(train_data, **dataloader_params) + + dataloader_params = {'batch_size': 1, + 'shuffle': True, + 'drop_last': False, + 'num_workers': num_workers} + val_data = CustomDataset(args.data_path, set_name="val") + val_gen = DataLoader(val_data, **dataloader_params) + + steps_per_epoch = len(train_gen) + + ### Model + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using {device}...") + + model = STRNet().to(device) + + # load best weight + initial_epoch = load_weights_from_directory(model, weight_path) + 1 + print(f"Training start from epoch {initial_epoch}") + + # Train Setting + model_optim = Adam(model.parameters(), 0.0001) + discrim_optim = Adam(model.discrim.parameters(), 0.0004) + + ### Train + for epoch in range(initial_epoch, epochs): + ## training + train_loss = [] + train_discrim_loss = [] + + model.train() + pgbar = tqdm.tqdm(train_gen, total=len(train_gen)) + pgbar.set_description(f"Epoch {epoch}/{epochs}") + for I, Itegt, Mm, Msgt in pgbar: + + I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device) + + # train model + Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) + + Ltsd = TSDLoss(Msgt, Ms, Ms_) + Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) + Lgsn = -torch.mean(model.discrim(Mm, Ite_)) + + total_loss = Ltsd + Ltrg + Lgsn + + model_optim.zero_grad() + total_loss.backward() + model_optim.step() + + # train discriminator + Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) + Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ + torch.mean(F.relu(1+model.discrim(Mm, Ite_))) + + discrim_optim.zero_grad() + Ldsn.backward() + discrim_optim.step() + + + ltsd = Ltsd.detach().cpu().item() + ltrg = Ltrg.detach().cpu().item() + lgsn = Lgsn.detach().cpu().item() + train_loss.append(total_loss.detach().cpu().item()) + train_discrim_loss.append(Ldsn.detach().cpu().item()) + + pgbar.set_postfix_str(f"total loss : {train_loss[-1]:.6f} ltsd : {ltsd:.6f} ltrg : {ltrg:.6f} lgsn : {lgsn:.6f} d_loss : {train_discrim_loss[-1]:.6f}") + + train_loss = sum(train_loss)/len(train_loss) + + ## validation + val_loss = [] + + # will saved in show directory + result_images = [] + + model.eval() + pgbar = tqdm.tqdm(val_gen, total=len(val_gen)) + pgbar.set_description("Validating...") + for I, Itegt, Mm, Msgt in pgbar: + + I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device) + + # train model + Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) + + Ltsd = TSDLoss(Msgt, Ms, Ms_) + Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) + Lgsn = -torch.mean(model.discrim(Mm, Ite_)) + + total_loss = Ltsd + Ltrg + Lgsn + + val_loss.append(total_loss.detach().cpu().item()) + + pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}") + + if len(result_images) < args.show_num: + result_images.append([I.cpu(), Itegt.cpu(), Ite_.cpu(), Msgt.cpu(), Ms_.cpu()]) + else: + break + + val_loss = sum(val_loss) / len(val_loss) + + ## visualize + fig, axs = plt.subplots(args.show_num, 1, figsize=(5, 2*args.show_num)) + fig.suptitle("Image, Gt, Gen, Stroke Gt, Stroke") + for i, (I, Itegt, Ite_, Msgt, Ms_) in enumerate(result_images): + I = postprocess_image(tensor_to_mat(I))[0] + Itegt = postprocess_image(tensor_to_mat(Itegt))[0] + Ite_ = postprocess_image(tensor_to_mat(Ite_))[0] + Msgt = postprocess_image(tensor_to_mat(Msgt))[0] + Ms_ = postprocess_image(tensor_to_mat(Ms_))[0] + + Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR) + Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR) + + axs[i].imshow(np.hstack([I, Itegt, Ite_, Msgt, Ms_])) + axs[i].set_xticks([]) + axs[i].set_yticks([]) + + fig.savefig(os.path.join(model_path, "show", f"epoch_{epoch}.png")) + plt.close() + + print(f"train_loss : {train_loss}, val_loss : {val_loss}") + print() + time.sleep(0.2) + + torch.save(model.state_dict(), os.path.join(weight_path, f"{epoch}_train_{train_loss}_val_{val_loss}.pth")) + \ No newline at end of file