From d3d59c95ba92a40bee1bbd8a24d3214bebc7764b Mon Sep 17 00:00:00 2001 From: Bobholamovic <bob1998425@hotmail.com> Date: Wed, 29 Jul 2020 18:48:24 +0800 Subject: [PATCH] Add multi-level cache for OSCD --- src/data/OSCD.py | 68 +++++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/src/data/OSCD.py b/src/data/OSCD.py index 02bf2b0..a471ae7 100644 --- a/src/data/OSCD.py +++ b/src/data/OSCD.py @@ -1,3 +1,4 @@ +import os from glob import glob from os.path import join, basename from multiprocessing import Manager @@ -17,13 +18,14 @@ class OSCDDataset(CDDataset): root, phase='train', transforms=(None, None, None), repeats=1, - cache_labels=True + cache_level=1 ): super().__init__(root, phase, transforms, repeats) - self.cache_on = cache_labels - if self.cache_on: + # 0 for no cache, 1 for caching labels only, 2 and higher for caching all + self.cache_level = int(cache_level) + if self.cache_level > 0: self._manager = Manager() - self.label_pool = self._manager.dict() + self._pool = self._manager.dict() def _read_file_paths(self): image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images') @@ -38,62 +40,34 @@ class OSCDDataset(CDDataset): else: # For validation, use the remaining 3 pairs cities = cities[-3:] - # t1_list, t2_list = [], [] - # for city in cities: - # t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif')) - # t1_list.append(t1s) # Populate t1_list - # # Recognize t2 from t1 - # prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5] - # t2_list.append([prefix+t1[-5:] for t1 in t1s]) - # + # Use resampled images t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities] label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities] - - - - #准备数据 - print('preparing %s data ... \n'%self.phase) - pb = tqdm(list(range(len(t1_list)))) - self.t1_imgs = [] - self.t2_imgs = [] - for i in pb: - self.t1_imgs.append(self.fetch_image(t1_list[i])) - self.t2_imgs.append(self.fetch_image(t2_list[i])) - return t1_list, t2_list, label_list - - #重写该方法 - def __getitem__(self, index): - if index >= len(self): - raise IndexError - index = index % self.len - - t1 = self.t1_imgs[index] - t2 = self.t2_imgs[index] - label = self.fetch_label(self.label_list[index]) - t1, t2, label = self.preprocess(t1, t2, label) - if self.phase == 'train': - return t1, t2, label - else: - return self.get_name(index), t1, t2, label - def fetch_image(self, image_paths): - return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32) + key = '-'.join(image_paths[0].split(os.sep)[-3:-1]) + if self.cache_level >= 2: + image = self._pool.get(key, None) + if image is not None: + return image + image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32) + if self.cache_level >= 2: + self._pool[key] = image + return image def fetch_label(self, label_path): - if self.cache_on: - label = self.label_pool.get(label_path, None) + key = basename(label_path) + if self.cache_level >= 1: + label = self._pool.get(key, None) if label is not None: return label # In the tif labels, 1 for NC and 2 for C # Thus a -1 offset is needed label = default_loader(label_path) - 1 - if self.cache_on: - self.label_pool[label_path] = label + if self.cache_level >= 1: + self._pool[key] = label return label - - -- GitLab