From 7d951e6d23034891687f3072e2d8ca934a82fb93 Mon Sep 17 00:00:00 2001
From: Larry Zheng <46273741+Larry-Zheng@users.noreply.github.com>
Date: Mon, 20 Jul 2020 17:46:17 +0800
Subject: [PATCH] Update OSCD.py

---
 src/data/OSCD.py | 31 ++++++++++++++++++++++++++++++-
 1 file changed, 30 insertions(+), 1 deletion(-)

diff --git a/src/data/OSCD.py b/src/data/OSCD.py
index e6ba2a7..02bf2b0 100644
--- a/src/data/OSCD.py
+++ b/src/data/OSCD.py
@@ -50,7 +50,36 @@ class OSCDDataset(CDDataset):
         t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
         t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
         label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
+        
+        
+        
+        #准备数据
+        print('preparing %s data ... \n'%self.phase)
+        pb = tqdm(list(range(len(t1_list))))
+        self.t1_imgs = []
+        self.t2_imgs = []
+        for i in pb:
+            self.t1_imgs.append(self.fetch_image(t1_list[i]))
+            self.t2_imgs.append(self.fetch_image(t2_list[i]))
+
+
         return t1_list, t2_list, label_list
+    
+    #重写该方法
+    def __getitem__(self, index):
+        if index >= len(self):
+            raise IndexError
+        index = index % self.len
+        
+        t1 = self.t1_imgs[index]
+        t2 = self.t2_imgs[index]
+        label = self.fetch_label(self.label_list[index])
+        t1, t2, label = self.preprocess(t1, t2, label)
+        if self.phase == 'train':
+            return t1, t2, label
+        else:
+            return self.get_name(index), t1, t2, label
+    
 
     def fetch_image(self, image_paths):
         return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
@@ -67,4 +96,4 @@ class OSCDDataset(CDDataset):
             self.label_pool[label_path] = label
         return label
 
-        
\ No newline at end of file
+        
-- 
GitLab