diff --git a/.gitignore b/.gitignore
index b6e47617de110dea7ca47e087ff1347cc2646eda..b84a66512ee7bc0202f91e575a56c666b10b7e73 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,3 +127,8 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# etc
+.png
+.jpg
+.pth
\ No newline at end of file
diff --git a/create_dataset.py b/create_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..10054a21b84e373f7a38cf489bd7bab8baa7a3db
--- /dev/null
+++ b/create_dataset.py
@@ -0,0 +1,111 @@
+import os
+import cv2
+import glob
+import random
+import progressbar
+
+import numpy as np
+
+import matplotlib.pyplot as plt
+
+rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+rand_pos   = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1))
+
+target_size = 256
+imgs_per_back = 30
+
+backs = glob.glob('./dataset/backs/*.png')
+fonts = glob.glob('./dataset/font_mask/*.png')
+
+os.makedirs('./dataset/train/I', exist_ok=True)
+os.makedirs('./dataset/train/Itegt', exist_ok=True)
+os.makedirs('./dataset/train/Mm', exist_ok=True)
+os.makedirs('./dataset/train/Msgt', exist_ok=True)
+
+os.makedirs('./dataset/val/I', exist_ok=True)
+os.makedirs('./dataset/val/Itegt', exist_ok=True)
+os.makedirs('./dataset/val/Mm', exist_ok=True)
+os.makedirs('./dataset/val/Msgt', exist_ok=True)
+
+t_idx = len(os.listdir('./dataset/train/I'))
+v_idx = len(os.listdir('./dataset/val/I'))
+
+bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back)
+bar.start()
+for back in backs:
+    back_img = cv2.imread(back)
+    bh, bw, _ = back_img.shape
+    if bh < target_size or bw < target_size:
+        back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC)
+        bh, bw, _ = back_img.shape
+
+    for bi in range(imgs_per_back):
+        sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size)
+        
+        Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy()
+        I     = Itegt.copy()
+        Mm    = np.zeros_like(I)
+        Msgt  = np.zeros_like(I)
+        
+        hist = []
+        for font in random.sample(fonts, random.randint(2, 4)):
+            font_img = cv2.imread(font)
+            mask_img = np.ones_like(font_img, dtype=np.uint8)*255
+            
+            height, width, _ = font_img.shape
+            
+            angle = random.randint(-30, +30)
+            fs = random.randint(90, 120)
+            ratio = fs / height - 0.2
+            
+            matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio)
+            font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC)
+            mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC)
+            
+            h, w, _ = font_rot.shape
+            
+            font_in_I = np.zeros_like(I)
+            mask_in_I = np.zeros_like(I)
+            
+            allow = 0
+            while True:
+                sx, sy = rand_pos(0, target_size-w)
+                
+                done = True
+                for sx_, sy_ in hist:
+                    if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow:
+                        done = False
+                        break
+                allow += 5
+                
+                if done:
+                    hist.append([sx, sy])
+                    break
+            
+            font_in_I[sy:sy+h, sx:sx+w, :] = font_rot
+            mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot
+            
+            font_in_I[font_in_I > 30] = 255
+            mask_in_I[mask_in_I > 30] = 255
+            
+            I = cv2.bitwise_and(I, 255-font_in_I)
+            I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8))
+            
+            Mm = cv2.bitwise_or(Mm, mask_in_I)
+            Msgt = cv2.bitwise_or(Msgt, font_in_I)
+        
+        if bi < imgs_per_back*0.8:
+            cv2.imwrite(f'dataset/train/I/{t_idx}.png', I)
+            cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt)
+            cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm)
+            cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt)
+            t_idx += 1
+        else:
+            cv2.imwrite(f'dataset/val/I/{v_idx}.png', I)
+            cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt)
+            cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm)
+            cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt)
+            v_idx += 1
+            
+        bar.update(t_idx + v_idx)
+bar.finish()
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6f7fe6926cf4e5cf46f35707484b50fced30ca2
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,75 @@
+import os, cv2
+import numpy as np
+
+import torch
+from torch.utils.data import Dataset
+
+def mat_to_tensor(mat):
+    mat = mat.transpose((2, 0, 1))
+    tensor = torch.Tensor(mat)
+    return tensor
+
+def tensor_to_mat(tensor):
+    mat = tensor.detach().cpu().numpy()
+    mat = mat.transpose((0, 2, 3, 1))
+    return mat
+
+def preprocess_image(img, target_shape: tuple):
+    img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
+    img = img / 255.
+    if len(img.shape) == 2:
+        img = img.reshape(*img.shape, 1)
+    
+    return img
+
+def postprocess_image(img):
+    # img = img * 255
+    img = (img - img.min()) / (img.max() - img.min()) * 255
+    return img.astype(np.uint8)
+
+class CustomDataset(Dataset):
+    def __init__(self,
+                 data_dir,
+                 set_name="train",
+                 target_size=(256, 256)):
+        
+        super().__init__()
+        
+        self.root_dir = os.path.join(data_dir, set_name)
+        self.target_size = target_size
+        
+        self.I_dir = os.path.join(self.root_dir, "I")
+        self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
+        self.Mm_dir = os.path.join(self.root_dir, "Mm")
+        self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
+        
+        self.datas = os.listdir(self.I_dir)
+        
+    def __len__(self):
+        return len(self.datas)
+    
+    def __getitem__(self, idx):
+        img_name = self.datas[idx]
+        
+        I      = cv2.imread(os.path.join(self.I_dir, img_name))
+        Itegt  = cv2.imread(os.path.join(self.Itegt_dir, img_name))
+        Mm     = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
+        Msgt   = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
+        
+        I      = mat_to_tensor(preprocess_image(I,     self.target_size))
+        Itegt  = mat_to_tensor(preprocess_image(Itegt, self.target_size))
+        Mm     = mat_to_tensor(preprocess_image(Mm,    self.target_size))
+        Msgt   = mat_to_tensor(preprocess_image(Msgt,  self.target_size))
+        
+        return I, Itegt, Mm, Msgt
+        
+
+if __name__ == "__main__":
+    ds = CustomDataset('dataset', 'train')
+    
+    I, Itegt, Mm, Ms = ds.__getitem__(0)
+    print(f"Dataset length : {len(ds)}")
+    print(f"I shape : {I.shape}")
+    print(f"Itegt shape : {Itegt.shape}")
+    print(f"Mm shape : {Mm.shape}")
+    print(f"Ms shape : {Ms.shape}")
diff --git a/losses.py b/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec12c644c7d550488942b97a29acc1e0bcb5573
--- /dev/null
+++ b/losses.py
@@ -0,0 +1,16 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def TSDLoss(Mgt, Ms, Ms_, r=10):
+    return torch.mean(torch.abs(Ms-Mgt) + r * torch.abs(Ms_-Mgt))
+
+def TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_, rm=5, rs=5, rr=10):
+    
+    Mw  = torch.ones_like(Mm) + rm * Mm + rs * Ms
+    Mw_ = torch.ones_like(Mm) + rm * Mm + rs * Ms_
+    
+    Ltrg = torch.mean(torch.abs(torch.mul(Ite, Mw) - torch.mul(Itegt, Mw)) + \
+                     rr * torch.abs(torch.mul(Ite_, Mw_) - torch.mul(Itegt, Mw_)))
+    
+    return Ltrg
diff --git a/modules.py b/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3e00fa3ebfe2b0a6b1a7042d3aed14ff46f4ea
--- /dev/null
+++ b/modules.py
@@ -0,0 +1,139 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# dis_conv 
+# (https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L84)
+class _dis_conv(nn.Module):
+    
+    def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
+        super().__init__()
+        
+        self._conv = nn.Sequential(
+            nn.utils.spectral_norm(
+                nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+            ),
+            nn.LeakyReLU(inplace=True)
+        )
+        
+        # weight initialization
+        def weight_init(m):
+            if isinstance(m, nn.Conv2d):
+                # nn.utils.spectral_norm(m.weight)
+                nn.init.zeros_(m.bias)
+        
+        self.apply(weight_init)
+
+    def forward(self, x):
+        return self._conv(x)
+
+# weights are fixed to one, bias to zero
+class _one_conv(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
+        super().__init__()
+        
+        self._conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+        )
+        
+        # weight initialization
+        def weight_init(m):
+            if isinstance(m, nn.Conv2d):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+                m.weight.requires_grad = False
+                m.bias.requires_grad = False
+        
+        self.apply(weight_init)
+
+    def forward(self, x):
+        return self._conv(x)
+
+class _double_conv2d(nn.Module):
+    
+    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, mid_channels=None):
+        super().__init__()
+        
+        if not mid_channels:
+            mid_channels = out_channels
+            
+        self.double_conv = nn.Sequential(
+            nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding),
+            nn.BatchNorm2d(mid_channels),
+            nn.ReLU(inplace=True),
+            
+            nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+        
+        # weight initialization
+        def weight_init(m):
+            if isinstance(m, nn.Conv2d):
+                nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
+                nn.init.zeros_(m.bias)
+        
+        self.apply(weight_init)
+
+    def forward(self, x):
+        return self.double_conv(x)
+
+
+class _down_conv2d(nn.Module):
+    
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size):
+        
+        super().__init__()
+        
+        self.seq_model = nn.Sequential(
+                nn.MaxPool2d(2),
+                _double_conv2d(in_channels, out_channels)
+            )
+        
+        
+    def forward(self, x):
+        return self.seq_model(x)
+
+
+class _up_conv2d(nn.Module):
+    
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size):
+        
+        super().__init__()
+        
+        self.conv_t = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
+        self.conv   = _double_conv2d(in_channels, out_channels)
+        
+    # x1 : input, x2 : matching down_conv2d output
+    def forward(self, x1, x2):
+        x1 = self.conv_t(x1)
+        
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+
+        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+                        diffY // 2, diffY - diffY // 2])
+        
+        x = torch.cat([x2, x1], dim=1)
+        return self.conv(x)
+
+
+class _final_conv2d(nn.Module):
+    
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size):
+        
+        super().__init__()
+        
+        self.conv = nn.Conv2d(in_channels, out_channels, 1, 1)
+        
+    def forward(self, x):
+        return self.conv(x)
\ No newline at end of file
diff --git a/network.py b/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..8437ca36b63d7ae08ea71ac1d2b88bc384519156
--- /dev/null
+++ b/network.py
@@ -0,0 +1,292 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules import \
+    _double_conv2d, _down_conv2d, _up_conv2d, _final_conv2d, _dis_conv, _one_conv
+
+from losses import TSDLoss, TRGLoss
+
+# Text Stroke Detection (GD in paper)
+class TSDNet(nn.Module):
+    def __init__(self):
+        super(TSDNet, self).__init__()
+
+        self.inc = _double_conv2d(4, 16, 3)
+        self.down1 = _down_conv2d(16, 32, 3)
+        self.down2 = _down_conv2d(32, 64, 3)
+        self.down3 = _down_conv2d(64, 128, 3)
+        
+        self.up1 = _up_conv2d(128, 64, 3)
+        self.up2 = _up_conv2d(64, 32, 3)
+        self.up3 = _up_conv2d(32, 16, 3)
+        
+        self.outc = _final_conv2d(16, 1, 3)
+
+    def forward(self, Igt, M):
+        x = torch.cat([Igt, M], dim=1)
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        
+        x = self.up1(x4, x3)
+        x = self.up2(x, x2)
+        x = self.up3(x, x1)
+        
+        M = self.outc(x)
+        return M
+
+# Text Removal Generation (GR, GR' in paper)
+class TRGNet(nn.Module):
+    def __init__(self):
+        super(TRGNet, self).__init__()
+
+        self.inc = _double_conv2d(5, 16, 5, 2)
+        self.down1 = _down_conv2d(16, 32, 3)
+        self.down2 = _down_conv2d(32, 64, 3)
+        self.down3 = _down_conv2d(64, 128, 3)
+        
+        self.mid_layer = _double_conv2d(128, 128, 3)
+        
+        self.up1 = _up_conv2d(128, 64, 3)
+        self.up2 = _up_conv2d(64, 32, 3)
+        self.up3 = _up_conv2d(32, 16, 3)
+        
+        self.outc = _final_conv2d(16, 3, 3)
+
+    def forward(self, Igt, M, Ms):
+        x = torch.cat([Igt, M, Ms], dim=1)
+        x1 = self.inc(x)
+        
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        
+        x4 = torch.add(self.mid_layer(x4), x4)
+        
+        x = self.up1(x4, x3)
+        x = self.up2(x, x2)
+        x = self.up3(x, x1)
+        
+        M = self.outc(x)
+        return M
+
+# Text Stroke Detection _ (G'D in paper)
+class TSDNet_(nn.Module):
+    def __init__(self):
+        super(TSDNet_, self).__init__()
+
+        self.inc = _double_conv2d(5, 16, 3)
+        self.down1 = _down_conv2d(16, 32, 3)
+        self.down2 = _down_conv2d(32, 64, 3)
+        self.down3 = _down_conv2d(64, 128, 3)
+        
+        self.up1 = _up_conv2d(128, 64, 3)
+        self.up2 = _up_conv2d(64, 32, 3)
+        self.up3 = _up_conv2d(32, 16, 3)
+        
+        self.outc = _final_conv2d(16, 1, 3)
+
+    def forward(self, Ite, M, Ms):
+        x = torch.cat([Ite, M, Ms], dim=1)
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        
+        x = self.up1(x4, x3)
+        x = self.up2(x, x2)
+        x = self.up3(x, x1)
+        
+        M = self.outc(x)
+        return M
+
+# weighted patch based discriminator (D, Dm in paper)
+# build_sn_patch_gan_discriminator 
+# (https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_model.py)
+class Discriminator(nn.Module):
+    def __init__(self):
+        
+        super(Discriminator, self).__init__()
+        
+        self.Dm = nn.Sequential(
+                _one_conv(1, 1, 5, 2, 2),
+                nn.Sigmoid(),
+                _one_conv(1, 1, 5, 2, 2),
+                nn.Sigmoid(),
+                _one_conv(1, 1, 5, 2, 2),
+                nn.Sigmoid(),
+                _one_conv(1, 1, 5, 2, 2),
+                nn.Sigmoid(),
+                _one_conv(1, 1, 5, 2, 2),
+                nn.Sigmoid()
+            )
+        
+        self.D = nn.Sequential(
+                _dis_conv(3, 64, 5, 2, 2),
+                _dis_conv(64, 128, 5, 2, 2),
+                _dis_conv(128, 256, 5, 2, 2),
+                _dis_conv(256, 256, 5, 2, 2),
+                _dis_conv(256, 256, 5, 2, 2)
+            )
+        
+        self.pool = nn.AvgPool2d(8)
+        self.linear = nn.Linear(256, 1)
+        # self.sigmoid = nn.Sigmoid()
+        
+    def forward(self, Mm, Ite_):
+        mi = self.Dm(Mm)
+        di = self.D(Ite_)
+        
+        y = torch.mul(mi, di)
+        # y = self.pool(y)
+        # y = self.linear(y.view(-1, 256))
+        return y
+
+
+class STRNet(nn.Module):
+    def __init__(self):
+        
+        super(STRNet, self).__init__()
+        
+        self.tsdnet  = TSDNet()
+        self.trgnet  = TRGNet()
+        self.tsdnet_ = TSDNet_()
+        self.trgnet_ = TRGNet()
+        
+        self.discrim = Discriminator()
+    
+    def forward(self, I, Mm):
+        Ms = self.tsdnet(I, Mm)
+        Ite = self.trgnet(I, Mm, Ms)
+        Ms_ = self.tsdnet_(Ite, Mm, Ms)
+        Ite_ = self.trgnet_(Ite, Mm, Ms)
+        
+        return Ms, Ite, Ms_, Ite_
+    
+
+if __name__ == "__main__":
+    from torch.optim import Adam
+    
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    # device='cpu'
+    print(device)
+    
+    # I    : input image
+    # Itegt: input image
+    # M    : Text Region mask
+    # Ms   : Text Stroke mask (from tsdnet)
+    # 
+    
+    I = torch.randn((2, 3, 256, 256)).to(device)
+    print(f"I shape\n : {I.shape}")
+    
+    Itegt = torch.randn((2, 3, 256, 256)).to(device)
+    print(f"Itegt shape\n : {Itegt.shape}")
+    
+    Mm = torch.randn((2, 1, 256, 256)).to(device)
+    print(f"Mm shape\n : {Mm.shape}")
+    
+    Msgt = torch.randn((2, 1, 256, 256)).to(device)
+    print(f"Mgt shape\n : {Msgt.shape}")
+    
+    One = torch.ones((2, 1)).to(device)
+    Zero = torch.zeros((2, 1)).to(device)
+    
+    model = STRNet().to(device)
+    
+    model_optim = Adam(model.parameters(), 0.0001)
+    discrim_optim = Adam(model.discrim.parameters(), 0.0001)
+    bce_loss = nn.BCEWithLogitsLoss()
+    
+    Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
+    
+    Ltsd = TSDLoss(Msgt, Ms, Ms_)
+    Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
+    # Lgsn = -bce_loss(model.discrim(Mm, Ite_), One)
+    Lgsn = -torch.mean(model.discrim(Mm, Ite_))
+    
+    total_loss = Ltsd + Ltrg + Lgsn
+    
+    model_optim.zero_grad()
+    total_loss.backward()
+    model_optim.step()
+    
+    Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
+    # Ldsn = F.relu(1-bce_loss(model.discrim(Mm, Itegt), One)) + \
+    #               F.relu(1+bce_loss(model.discrim(Mm, Ite_), Zero))
+    Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \
+                  torch.mean(F.relu(1+model.discrim(Mm, Ite_)))
+                  
+    discrim_optim.zero_grad()
+    Ldsn.backward()
+    discrim_optim.step()
+    
+    
+    # Igt = torch.randn((2, 3, 256, 256)).to(device)
+    # print(f"Igt shape\n : {Igt.shape}")
+    
+    # M = torch.randn((2, 1, 256, 256)).to(device)
+    # print(f"M shape\n : {M.shape}")
+    
+    # Mgt = torch.randn((2, 1, 256, 256)).to(device)
+    # print(f"Mgt shape\n : {Mgt.shape}")
+    
+    # # models
+    # tsdnet = TSDNet().to(device)
+    # trgnet = TRGNet().to(device)
+    # tsdnet_ = TSDNet_().to(device)
+    # trgnet_ = TRGNet().to(device)
+    
+    # discriminator = Discriminator().to(device)
+    
+    # # optim
+    # from torch.optim import Adam
+    # total_optim = Adam(list(tsdnet.parameters()) + list(trgnet.parameters()) +
+    #                     list(tsdnet_.parameters()) + list(trgnet_.parameters()), 0.0001)
+    # total_optim.zero_grad()
+    
+    # discr_optim = Adam(discriminator.parameters())
+    # discr_optim.zero_grad()
+    
+    # # inference
+    # Ms = tsdnet(Igt, M)
+    # # print(f"tsdnet output Ms shape\n : {Ms.shape}")
+    # Ite = trgnet(Igt, M, Ms)
+    # # print(f"trgnet output Ite shape\n : {Ite.shape}")
+    # Ms_ = tsdnet_(Ite, M, Ms)
+    # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}")
+    # Ite_ = trgnet_(Ite, M, Ms_)
+    # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}")
+    
+    # # calculate loss
+    # Lgsn = -discriminator.forward_with_loss(M, Ite)
+    
+    # from losses import TSDLoss, TRGLoss
+    # Ltsd = TSDLoss(Mgt, Ms, Ms_)
+    # Ltrg = TRGLoss(M, Ms, Ms_, Igt, Ite, Ite_)
+    
+    # total_loss = Ltsd + Ltrg + Lgsn
+    
+    # # train total model 
+    # total_loss.backward()
+    # total_optim.step()
+    # print(total_loss.detach().cpu().item())
+    
+    # # train discriminator
+    # Ms = tsdnet(Igt, M)
+    # # print(f"tsdnet output Ms shape\n : {Ms.shape}")
+    # Ite = trgnet(Igt, M, Ms)
+    # # print(f"trgnet output Ite shape\n : {Ite.shape}")
+    # Ms_ = tsdnet_(Ite, M, Ms)
+    # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}")
+    # Ite_ = trgnet_(Ite, M, Ms_)
+    # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}")
+    
+    # Ldsn = discriminator.forward_with_loss(M, Igt) + discriminator.forward_with_loss(M, Ite_)
+    # discr_loss = Ldsn
+    
+    # discr_loss.backward()
+    # discr_optim.step()
+    # print(discr_loss.detach().cpu().item())
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8792be9e5f2cc85e5d6fc64254edf16596242b96
--- /dev/null
+++ b/train.py
@@ -0,0 +1,214 @@
+import sys
+import os, argparse, time, tqdm, random, cv2
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.optim import Adam
+from torch import nn
+
+from dataset import CustomDataset, postprocess_image, tensor_to_mat
+from network import STRNet
+from losses import TSDLoss, TRGLoss
+
+random_seed = 123
+
+torch.manual_seed(random_seed)
+torch.cuda.manual_seed(random_seed)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+np.random.seed(random_seed)
+random.seed(random_seed)
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-d", "--data_path", default='dataset', help="data root path")
+    
+    parser.add_argument("-e", "--num_epochs", default=100, type=int, help="num epochs")
+    parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size > 1")
+    
+    parser.add_argument("-n", "--num_workers", default=8, type=int, help="num_workers for DataLoader")
+    parser.add_argument("-sn", "--show_num", default=4, type=int, help="show result images during training num")
+    
+    args = parser.parse_args()
+    
+    return args
+
+def load_weights_from_directory(model, weight_path) -> int:
+    if weight_path.endswith('.pth'):
+        wp = weight_path
+    else:
+        wps = sorted(os.listdir(weight_path), key=lambda x: int(x.split('_')[0]))
+        if wps:
+            wp = wps[-1]
+        else:
+            return 0
+    
+    print(f"Loading weights from {wp}...")
+    model.load_state_dict(torch.load(os.path.join(weight_path, wp)))
+    return int(wp.split('_')[0])
+
+if __name__ == "__main__":
+    
+    args = get_args()
+    
+    ### Path
+    model_path     = "results"
+    weight_path    = os.path.join(model_path, "weights")
+    show_path      = os.path.join(model_path, "show")
+    
+    os.makedirs(model_path, exist_ok=True)
+    os.makedirs(weight_path, exist_ok=True)
+    os.makedirs(show_path, exist_ok=True)
+    
+    
+    ### Hyperparameters
+    epochs         = args.num_epochs
+    batch_size     = args.batch_size
+    if batch_size <= 1:
+        raise "Batch size should bigger than 1 for batch normalization"
+    
+    num_workers    = args.num_workers
+    show_num       = args.show_num
+    
+    ### DataLoader
+    dataloader_params = {'batch_size': batch_size,
+                         'shuffle': True,
+                         'drop_last': True,
+                         'num_workers': num_workers}
+    
+    train_data = CustomDataset(args.data_path, set_name="train")
+    train_gen = DataLoader(train_data, **dataloader_params)
+    
+    dataloader_params = {'batch_size': 1,
+                         'shuffle': True,
+                         'drop_last': False,
+                         'num_workers': num_workers}
+    val_data = CustomDataset(args.data_path, set_name="val")
+    val_gen = DataLoader(val_data, **dataloader_params)
+    
+    steps_per_epoch = len(train_gen)
+    
+    ### Model
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    print(f"Using {device}...")
+    
+    model = STRNet().to(device)
+    
+    # load best weight
+    initial_epoch = load_weights_from_directory(model, weight_path) + 1
+    print(f"Training start from epoch {initial_epoch}")
+    
+    # Train Setting
+    model_optim = Adam(model.parameters(), 0.0001)
+    discrim_optim = Adam(model.discrim.parameters(), 0.0004)
+    
+    ### Train
+    for epoch in range(initial_epoch, epochs):
+        ## training
+        train_loss = []
+        train_discrim_loss = []
+        
+        model.train()
+        pgbar = tqdm.tqdm(train_gen, total=len(train_gen))
+        pgbar.set_description(f"Epoch {epoch}/{epochs}")
+        for I, Itegt, Mm, Msgt in pgbar:
+            
+            I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device)
+            
+            # train model
+            Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
+            
+            Ltsd = TSDLoss(Msgt, Ms, Ms_)
+            Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
+            Lgsn = -torch.mean(model.discrim(Mm, Ite_))
+            
+            total_loss = Ltsd + Ltrg + Lgsn
+            
+            model_optim.zero_grad()
+            total_loss.backward()
+            model_optim.step()
+            
+            # train discriminator
+            Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
+            Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \
+                torch.mean(F.relu(1+model.discrim(Mm, Ite_)))
+                          
+            discrim_optim.zero_grad()
+            Ldsn.backward()
+            discrim_optim.step()
+            
+            
+            ltsd = Ltsd.detach().cpu().item()
+            ltrg = Ltrg.detach().cpu().item()
+            lgsn = Lgsn.detach().cpu().item()
+            train_loss.append(total_loss.detach().cpu().item())
+            train_discrim_loss.append(Ldsn.detach().cpu().item())
+            
+            pgbar.set_postfix_str(f"total loss : {train_loss[-1]:.6f} ltsd : {ltsd:.6f} ltrg : {ltrg:.6f} lgsn : {lgsn:.6f} d_loss : {train_discrim_loss[-1]:.6f}")
+        
+        train_loss = sum(train_loss)/len(train_loss)
+        
+        ## validation
+        val_loss = []
+        
+        # will saved in show directory
+        result_images = []
+        
+        model.eval()
+        pgbar = tqdm.tqdm(val_gen, total=len(val_gen))
+        pgbar.set_description("Validating...")
+        for I, Itegt, Mm, Msgt in pgbar:
+            
+            I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device)
+            
+            # train model
+            Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
+            
+            Ltsd = TSDLoss(Msgt, Ms, Ms_)
+            Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
+            Lgsn = -torch.mean(model.discrim(Mm, Ite_))
+            
+            total_loss = Ltsd + Ltrg + Lgsn
+            
+            val_loss.append(total_loss.detach().cpu().item())
+            
+            pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}")
+            
+            if len(result_images) < args.show_num:
+                result_images.append([I.cpu(), Itegt.cpu(), Ite_.cpu(), Msgt.cpu(), Ms_.cpu()])
+            else:
+                break
+        
+        val_loss = sum(val_loss) / len(val_loss)
+        
+        ## visualize
+        fig, axs = plt.subplots(args.show_num, 1, figsize=(5, 2*args.show_num))
+        fig.suptitle("Image, Gt, Gen, Stroke Gt, Stroke")
+        for i, (I, Itegt, Ite_, Msgt, Ms_) in enumerate(result_images):
+            I = postprocess_image(tensor_to_mat(I))[0]
+            Itegt = postprocess_image(tensor_to_mat(Itegt))[0]
+            Ite_ = postprocess_image(tensor_to_mat(Ite_))[0]
+            Msgt = postprocess_image(tensor_to_mat(Msgt))[0]
+            Ms_ = postprocess_image(tensor_to_mat(Ms_))[0]
+            
+            Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR)
+            Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR)
+            
+            axs[i].imshow(np.hstack([I, Itegt, Ite_, Msgt, Ms_]))
+            axs[i].set_xticks([])
+            axs[i].set_yticks([])
+        
+        fig.savefig(os.path.join(model_path, "show", f"epoch_{epoch}.png"))
+        plt.close()
+        
+        print(f"train_loss : {train_loss}, val_loss : {val_loss}")
+        print()
+        time.sleep(0.2)
+        
+        torch.save(model.state_dict(), os.path.join(weight_path, f"{epoch}_train_{train_loss}_val_{val_loss}.pth"))
+        
\ No newline at end of file