Skip to content
Snippets Groups Projects
Unverified Commit 7d951e6d authored by Larry Zheng's avatar Larry Zheng Committed by GitHub
Browse files

Update OSCD.py

parent 29041640
No related branches found
No related tags found
1 merge request!2Update outdated code
This commit is part of merge request !2. Comments created here will be created in the context of that merge request.
......@@ -50,8 +50,37 @@ class OSCDDataset(CDDataset):
t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
#准备数据
print('preparing %s data ... \n'%self.phase)
pb = tqdm(list(range(len(t1_list))))
self.t1_imgs = []
self.t2_imgs = []
for i in pb:
self.t1_imgs.append(self.fetch_image(t1_list[i]))
self.t2_imgs.append(self.fetch_image(t2_list[i]))
return t1_list, t2_list, label_list
#重写该方法
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
t1 = self.t1_imgs[index]
t2 = self.t2_imgs[index]
label = self.fetch_label(self.label_list[index])
t1, t2, label = self.preprocess(t1, t2, label)
if self.phase == 'train':
return t1, t2, label
else:
return self.get_name(index), t1, t2, label
def fetch_image(self, image_paths):
return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment