diff --git a/src/data/OSCD.py b/src/data/OSCD.py
index 02bf2b0be4a3992300bbfbe03f2a37aba9bc1278..a471ae7db56992ae1e0b7f1e5268961766c9d1a4 100644
--- a/src/data/OSCD.py
+++ b/src/data/OSCD.py
@@ -1,3 +1,4 @@
+import os
 from glob import glob
 from os.path import join, basename
 from multiprocessing import Manager
@@ -17,13 +18,14 @@ class OSCDDataset(CDDataset):
         root, phase='train', 
         transforms=(None, None, None), 
         repeats=1,
-        cache_labels=True
+        cache_level=1
     ):
         super().__init__(root, phase, transforms, repeats)
-        self.cache_on = cache_labels
-        if self.cache_on:
+        # 0 for no cache, 1 for caching labels only, 2 and higher for caching all
+        self.cache_level = int(cache_level)
+        if self.cache_level > 0:
             self._manager = Manager()
-            self.label_pool = self._manager.dict()
+            self._pool = self._manager.dict()
 
     def _read_file_paths(self):
         image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
@@ -38,62 +40,34 @@ class OSCDDataset(CDDataset):
         else:
             # For validation, use the remaining 3 pairs
             cities = cities[-3:]
-        # t1_list, t2_list = [], []
-        # for city in cities:
-        #     t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
-        #     t1_list.append(t1s) # Populate t1_list
-        #     # Recognize t2 from t1
-        #     prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
-        #     t2_list.append([prefix+t1[-5:] for t1 in t1s])
-        #
+            
         # Use resampled images
         t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
         t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
         label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
-        
-        
-        
-        #准备数据
-        print('preparing %s data ... \n'%self.phase)
-        pb = tqdm(list(range(len(t1_list))))
-        self.t1_imgs = []
-        self.t2_imgs = []
-        for i in pb:
-            self.t1_imgs.append(self.fetch_image(t1_list[i]))
-            self.t2_imgs.append(self.fetch_image(t2_list[i]))
-
 
         return t1_list, t2_list, label_list
-    
-    #重写该方法
-    def __getitem__(self, index):
-        if index >= len(self):
-            raise IndexError
-        index = index % self.len
-        
-        t1 = self.t1_imgs[index]
-        t2 = self.t2_imgs[index]
-        label = self.fetch_label(self.label_list[index])
-        t1, t2, label = self.preprocess(t1, t2, label)
-        if self.phase == 'train':
-            return t1, t2, label
-        else:
-            return self.get_name(index), t1, t2, label
-    
 
     def fetch_image(self, image_paths):
-        return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
+        key = '-'.join(image_paths[0].split(os.sep)[-3:-1])
+        if self.cache_level >= 2:
+            image = self._pool.get(key, None)
+            if image is not None:
+                return image
+        image = np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
+        if self.cache_level >= 2:
+            self._pool[key] = image
+        return image
 
     def fetch_label(self, label_path):
-        if self.cache_on:
-            label = self.label_pool.get(label_path, None)
+        key = basename(label_path)
+        if self.cache_level >= 1:
+            label = self._pool.get(key, None)
             if label is not None:
                 return label
         # In the tif labels, 1 for NC and 2 for C
         # Thus a -1 offset is needed
         label = default_loader(label_path) - 1
-        if self.cache_on:
-            self.label_pool[label_path] = label
+        if self.cache_level >= 1:
+            self._pool[key] = label
         return label
-
-