diff --git a/src/data/OSCD.py b/src/data/OSCD.py index e6ba2a77c7b99e375b4d394ed67b41f60008a0b1..02bf2b0be4a3992300bbfbe03f2a37aba9bc1278 100644 --- a/src/data/OSCD.py +++ b/src/data/OSCD.py @@ -50,7 +50,36 @@ class OSCDDataset(CDDataset): t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities] + + + + #准备数据 + print('preparing %s data ... \n'%self.phase) + pb = tqdm(list(range(len(t1_list)))) + self.t1_imgs = [] + self.t2_imgs = [] + for i in pb: + self.t1_imgs.append(self.fetch_image(t1_list[i])) + self.t2_imgs.append(self.fetch_image(t2_list[i])) + + return t1_list, t2_list, label_list + + #重写该方法 + def __getitem__(self, index): + if index >= len(self): + raise IndexError + index = index % self.len + + t1 = self.t1_imgs[index] + t2 = self.t2_imgs[index] + label = self.fetch_label(self.label_list[index]) + t1, t2, label = self.preprocess(t1, t2, label) + if self.phase == 'train': + return t1, t2, label + else: + return self.get_name(index), t1, t2, label + def fetch_image(self, image_paths): return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32) @@ -67,4 +96,4 @@ class OSCDDataset(CDDataset): self.label_pool[label_path] = label return label - \ No newline at end of file +