Skip to content
Snippets Groups Projects
OSCD.py 2.65 KiB
Newer Older
  • Learn to ignore specific revisions
  • import os
    
    Bobholamovic's avatar
    Bobholamovic committed
    from glob import glob
    from os.path import join, basename
    from multiprocessing import Manager
    
    import numpy as np
    
    from . import CDDataset
    from .common import default_loader
    
    class OSCDDataset(CDDataset):
        __BAND_NAMES = (
            'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 
            'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
        )
        def __init__(
            self, 
            root, phase='train', 
            transforms=(None, None, None), 
            repeats=1,
    
            cache_level=1
    
    Bobholamovic's avatar
    Bobholamovic committed
        ):
            super().__init__(root, phase, transforms, repeats)
    
            # 0 for no cache, 1 for caching labels only, 2 and higher for caching all
            self.cache_level = int(cache_level)
            if self.cache_level > 0:
    
    Bobholamovic's avatar
    Bobholamovic committed
                self._manager = Manager()
    
                self._pool = self._manager.dict()
    
    Bobholamovic's avatar
    Bobholamovic committed
    
        def _read_file_paths(self):
            image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
            label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
            txt_file = join(image_dir, 'train.txt')
            # Read cities
            with open(txt_file, 'r') as f:
                cities = [city.strip() for city in f.read().strip().split(',')]
            if self.phase == 'train':
    
    Bobholamovic's avatar
    Bobholamovic committed
                # For training, use the first 11 pairs
    
    Bobholamovic's avatar
    Bobholamovic committed
                cities = cities[:-3]
            else:
                # For validation, use the remaining 3 pairs
                cities = cities[-3:]
    
    Bobholamovic's avatar
    Bobholamovic committed
            # Use resampled images
            t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
            label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
    
    Larry Zheng's avatar
    Larry Zheng committed
    
    
    Bobholamovic's avatar
    Bobholamovic committed
            return t1_list, t2_list, label_list
    
        def fetch_image(self, image_paths):
    
            key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
            if self.cache_level >= 2:
                image = self._pool.get(key, None)
                if image is not None:
                    return image
            image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
            if self.cache_level >= 2:
                self._pool[key] = image
            return image
    
    Bobholamovic's avatar
    Bobholamovic committed
    
        def fetch_label(self, label_path):
    
            key = basename(label_path)
            if self.cache_level >= 1:
                label = self._pool.get(key, None)
    
    Bobholamovic's avatar
    Bobholamovic committed
                if label is not None:
                    return label
            # In the tif labels, 1 for NC and 2 for C
            # Thus a -1 offset is needed
            label = default_loader(label_path) - 1
    
            if self.cache_level >= 1:
                self._pool[key] = label
    
    Bobholamovic's avatar
    Bobholamovic committed
            return label