Skip to content
Snippets Groups Projects
Commit 9f2bafd0 authored by blia's avatar blia
Browse files

create bash script to reproduce figure

parent 9c3f7e85
No related branches found
No related tags found
No related merge requests found
# Default ignored files
/workspace.xml
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="GOOGLE" />
<option name="myDocStringFormat" value="Google" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Region_Active_Learning.iml" filepath="$PROJECT_DIR$/.idea/Region_Active_Learning.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
1. Clone and enter this repo: 1. Clone and enter this repo:
```bash ```bash
git clone https:... git clone https://lab.compute.dtu.dk/s161488/region_active_learning.git
cd Region_Active_Learning cd Region_Active_Learning
``` ```
...@@ -58,12 +58,6 @@ Args: ...@@ -58,12 +58,6 @@ Args:
In order to reproduce the figures in the paper, run In order to reproduce the figures in the paper, run
```bash ```bash
# Download the statistics ./produce_figure.sh
cd Exp_Stat
wget https:
tar -xzvf calibration_score.tar.gz
cd ..
cd eval_calibration
python3 visualize_calibration_score.py --save False
``` ```
...@@ -11,7 +11,14 @@ import os ...@@ -11,7 +11,14 @@ import os
import numpy as np import numpy as np
import scipy.io as sio import scipy.io as sio
data_home = "/home/blia/Region_Active_Learning/DATA/Data/" # NOTE, NEED TO BE MANUALLY DEFINED
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
data_home = 'DATA/Data/'
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
def running_test_for_all_acquisition_steps(path_input, version_use, start_step, pool_or_test): def running_test_for_all_acquisition_steps(path_input, version_use, start_step, pool_or_test):
......
...@@ -13,10 +13,17 @@ from sklearn.utils import shuffle ...@@ -13,10 +13,17 @@ from sklearn.utils import shuffle
import numpy as np import numpy as np
import os import os
training_data_path = "/home/blia/Region_Active_Learning/DATA/Data/glanddata.npy" # NOTE, NEED TO BE MANUALLY DEFINED
test_data_path = "/home/blia/Region_Active_Learning/DATA/Data/glanddata_testb.npy" # NOTE, NEED TO BE MANUALLY DEFINED print("--------------------------------------------------------------")
resnet_dir = "/home/blia/Region_Active_Learning/pretrain_model/" print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
exp_dir = "/scratch/Act_Learn_Desperate_V8/" # NOTE, NEED TO BE MANUALLY DEFINED print("--------------------------------------------------------------")
training_data_path = "DATA/Data/glanddata.npy" # NOTE, NEED TO BE MANUALLY DEFINED
test_data_path = "DATA/Data/glanddata_testb.npy" # NOTE, NEED TO BE MANUALLY DEFINED
resnet_dir = "pretrain_model/"
exp_dir = "Exp_Stat/" # NOTE, NEED TO BE MANUALLY DEFINED
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
def running_train_use_all_data(version_space): def running_train_use_all_data(version_space):
......
...@@ -15,18 +15,26 @@ from sklearn.utils import shuffle ...@@ -15,18 +15,26 @@ from sklearn.utils import shuffle
from select_regions import selection as SPR_Region_Im from select_regions import selection as SPR_Region_Im
import pickle import pickle
training_data_path = "/home/blia/Region_Active_Learning/DATA/Data/glanddata.npy" # NOTE, NEED TO BE MANUALLY DEFINED
test_data_path = "/home/blia/Region_Active_Learning/DATA/Data/glanddata_testb.npy" # NOTE, NEED TO BE MANUALLY DEFINED
resnet_dir = "/home/blia/Region_Active_Learning/pretrain_model/"
exp_dir = "/scratch/Act_Learn_Desperate_V8/" # USER_DEFINE
ckpt_dir_init = "/home/blia/Exp_Data/initial_model/" # USER_DEFINE
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
training_data_path = "DATA/Data/glanddata.npy" # NOTE, NEED TO BE MANUALLY DEFINED
test_data_path = "DATA/Data/glanddata_testb.npy" # NOTE, NEED TO BE MANUALLY DEFINED
resnet_dir = "pretrain_model/"
exp_dir = "Exp_Stat/" # NOTE, NEED TO BE MANUALLY DEFINED
ckpt_dir_init = "Exp_Stat/initial_model/"
print("-------THE PATH FOR THE INITIAL MODEL NEEDS TO BE USER DEFINED", ckpt_dir_init)
print("--------------------------------------------------------------")
print("---------------DEFINE YOUR TRAINING DATA PATH-----------------")
print("--------------------------------------------------------------")
def run_loop_active_learning_region(stage, round_number=[0, 1, 2, 3]):
def run_loop_active_learning_region(stage, round_number=np.array([0, 1, 2, 3])):
"""This function is used to train the active learning framework with region specific annotation. """This function is used to train the active learning framework with region specific annotation.
Args: Args:
stage: int, 0--> random selection, 1--> VarRatio, 2--> entropy, 3--> BALD stage: int, 0--> random selection, 1--> VarRatio, 2--> entropy, 3--> BALD
round_number: [int], repeat experiments in order to get confidence interval round_number: list, [int], repeat experiments in order to get confidence interval
Ops: Ops:
1. this script can only be run given the model that is trained with the initial training data (10)!!! 1. this script can only be run given the model that is trained with the initial training data (10)!!!
2. in each acquisition step, the experiment is repeated # times to avoid bad local optimal 2. in each acquisition step, the experiment is repeated # times to avoid bad local optimal
......
# -*- coding: utf-8 -*-
"""
Created on Wed May 16 14:27:21 2018
@author: s161488
"""
import numpy as np
import tensorflow as tf
def Prepare_Active_Data(path, num_tr, combine=True):
"""
choose_index_tr: worst 16+best 16
or middle 32
this num_tr should be 1/2*total_number_of_training_images_at_inital_step
I have tried it for 32, then I am going to check 16
"""
val_num_im = 96
tot_numeric_index = np.arange(900)
if combine is True:
tr_select_numeric_index = np.concatenate([tot_numeric_index[:num_tr], tot_numeric_index[-num_tr:]], axis=0)
else:
tr_select_numeric_index = tot_numeric_index[340:(340 + num_tr * 2)]
val_select_numeric_index = tot_numeric_index[500:(500 + val_num_im)]
pool_numeric_index = np.delete(tot_numeric_index,
np.concatenate([tr_select_numeric_index, val_select_numeric_index], axis=0))
im_seg_score = np.load('/home/s161488/Exp_Stat/Skin_Lesion/init_segment_score.npy')
sorted_index = np.argsort(im_seg_score)
data_set = np.load(path, encoding='latin1').item()
images = np.array(data_set['image'])
labels = np.array(data_set['label'])
edges = np.array(data_set['edge'])
labels = np.expand_dims(labels, axis=-1)
edges = np.expand_dims(edges, axis=-1)
tr_select_image_index = np.sort(sorted_index[tr_select_numeric_index])
val_select_image_index = np.sort(sorted_index[val_select_numeric_index])
pl_select_image_index = np.sort(sorted_index[pool_numeric_index])
X_image_tr, Y_label_tr, Y_edge_tr = Extract_Diff_Data(images, labels, edges, tr_select_image_index)
X_image_pl, Y_label_pl, Y_edge_pl = Extract_Diff_Data(images, labels, edges, pl_select_image_index)
X_image_val, Y_label_val, Y_edge_val = Extract_Diff_Data(images, labels, edges, val_select_image_index)
print("-------------There are %d training images %d validation images %d pool images" % (
np.shape(X_image_tr)[0], np.shape(X_image_val)[0], np.shape(X_image_pl)[0]))
Data_Train = [X_image_tr, Y_label_tr, Y_edge_tr]
Data_Pool = [X_image_pl, Y_label_pl, Y_edge_pl]
Data_Val = [X_image_val, Y_label_val, Y_edge_val]
return Data_Train, Data_Pool, Data_Val
def Generate_Batch(X_image_tr, Y_label_tr, Y_edge_tr, Y_binary_mask_tr, batch_index, batch_size):
"""The data augmentation include: rotation, random_brightness, crop_and_pad
Args:
X_image_tr, Y_label_tr, and Y_edge_tr are already shuffled
Return:
[X_image_aug, Y_label_aug, Y_edge_aug]
"""
X_image_batch = X_image_tr[batch_index:(batch_size + batch_index), :, :, :]
Y_label_batch = Y_label_tr[batch_index:(batch_size + batch_index), :, :, :]
Y_edge_batch = Y_edge_tr[batch_index:(batch_size + batch_index), :, :, :]
Y_binary_mask_batch = Y_binary_mask_tr[batch_index:(batch_size + batch_index), :, :, :]
batch_index = batch_index + batch_size
return X_image_batch, Y_label_batch, Y_edge_batch, Y_binary_mask_batch, batch_index
def Extract_Diff_Data(image, labels, edge, choose_index):
X_image = []
Y_label = []
Y_edge = []
for i in choose_index:
X_image.append(image[i])
Y_label.append(labels[i])
Y_edge.append(edge[i])
return np.array(X_image), np.array(Y_label), np.array(Y_edge)
def Aug_Train_Data(image, label, edge, binary_mask, batch_size, aug, IMAGE_SHAPE):
"""This function is used for performing data augmentation.
image: placeholder. shape: [Batch_Size, im_h, im_w, 3], tf.float32
label: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
edge: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
binary_mask: placeholder. shape: [Batch_Size, im_h, im_w, 1], tf.int64
Outputs:
image: [Batch_Size, targ_h, targ_w, 3]
label: [Batch_Size, targ_h, targ_w, 1]
edge: [Batch_Size, targ_h, targ_w, 1]
binary_mask: [Batch_Size, targ_h, targ_w, 1]
"""
image = tf.cast(image, tf.int64)
BigMatrix = tf.concat([image, label, edge, binary_mask], axis=3)
target_height = IMAGE_SHAPE[0].astype('int32')
target_width = IMAGE_SHAPE[1].astype('int32')
if aug is True:
BigMatrix_crop = tf.random_crop(BigMatrix, size=[batch_size, target_height, target_width, 6])
k = tf.random_uniform(shape=[batch_size], minval=0, maxval=6.5, dtype=tf.float32)
BigMatrix_rot = tf.contrib.image.rotate(BigMatrix_crop, angles=k)
image_aug = tf.cast(BigMatrix_rot[:, :, :, 0:3], tf.float32)
label_aug = BigMatrix_rot[:, :, :, 3]
edge_aug = BigMatrix_rot[:, :, :, 4]
binary_mask_aug = BigMatrix_rot[:, :, :, 5]
else:
BigMatrix_rot = tf.image.resize_image_with_crop_or_pad(BigMatrix, target_height, target_width)
image_aug = tf.cast(tf.cast(BigMatrix_rot[:, :, :, 0:3], tf.uint8), tf.float32)
label_aug = tf.cast(BigMatrix_rot[:, :, :, 3], tf.int64)
edge_aug = tf.cast(BigMatrix_rot[:, :, :, 4], tf.int64)
binary_mask_aug = tf.cast(BigMatrix_rot[:, :, :, 5], tf.int64)
return image_aug, tf.expand_dims(label_aug, -1), tf.expand_dims(edge_aug, -1), tf.expand_dims(binary_mask_aug, -1)
#!/bin/bash
cd Exp_Stat
filename=calibration_score
if [ -d "$filename" ]; then
echo "$filename exists"
echo "-----next step, reproducing the figures...................."
else
echo "$filename does not exist"
echo "Download the file-------------------"
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /t$
echo "-----unziping the datafile"
tar -xzvf calibration_score.tar.gz
mv use_stat calibration_score
rm calibration_score.tar.gz
fi
cd ..
python3 visualize_calibration_score.py --save True --path Exp_Stat/calibration_score/
...@@ -24,15 +24,10 @@ def give_args(): ...@@ -24,15 +24,10 @@ def give_args():
"""This function is used to give the argument""" """This function is used to give the argument"""
parser = argparse.ArgumentParser(description='Reproducing figure') parser = argparse.ArgumentParser(description='Reproducing figure')
parser.add_argument('--save', type=str2bool, default=False, metavar='SAVE') parser.add_argument('--save', type=str2bool, default=False, metavar='SAVE')
parser.add_argument('--path', type=str, default=None, help='the directory that saves the data')
return parser.parse_args() return parser.parse_args()
path = '/home/blia/Region_Active_Learning/Exp_Stat/calibration_score/'
save_fig_path = path+'save_figure/'
if not os.path.exists(save_fig_path):
os.makedirs(save_fig_path)
def ax_global_get(fig): def ax_global_get(fig):
ax_global = fig.add_subplot(111, frameon=False) ax_global = fig.add_subplot(111, frameon=False)
ax_global.spines['top'].set_color('none') ax_global.spines['top'].set_color('none')
...@@ -43,14 +38,14 @@ def ax_global_get(fig): ...@@ -43,14 +38,14 @@ def ax_global_get(fig):
return ax_global return ax_global
def give_score_path(): def give_score_path(path_use):
str_group = ["_B_", "_C_", "_D_"] str_group = ["_B_", "_C_", "_D_"]
region_path = path + 'Act_Learn_Desperate_V6/' region_path = path_use + 'Act_Learn_Desperate_V6/'
region_group = [[] for _ in range(3)] region_group = [[] for _ in range(3)]
for iterr, single_str in enumerate(str_group): for iterr, single_str in enumerate(str_group):
select_folder = [region_path + v for v in os.listdir(region_path) if single_str in v and '.obj' in v] select_folder = [region_path + v for v in os.listdir(region_path) if single_str in v and '.obj' in v]
region_group[iterr] = select_folder region_group[iterr] = select_folder
full_path = [path + 'Act_Learn_Desperate_V7/', path + 'Act_Learn_Desperate_V8/'] full_path = [path_use + 'Act_Learn_Desperate_V7/', path_use + 'Act_Learn_Desperate_V8/']
full_group = [[] for _ in range(3)] full_group = [[] for _ in range(3)]
for iterr, single_str in enumerate(str_group): for iterr, single_str in enumerate(str_group):
folder_select = [v + q for v in full_path for q in os.listdir(v) if single_str in q and '.obj' in q] folder_select = [v + q for v in full_path for q in os.listdir(v) if single_str in q and '.obj' in q]
...@@ -593,14 +588,34 @@ def get_region_uncert(return_stat=False): ...@@ -593,14 +588,34 @@ def get_region_uncert(return_stat=False):
if __name__ == '__main__': if __name__ == '__main__':
args = give_args() args = give_args()
reg_group, ful_group = give_score_path() path = args.path
save_fig_path = path + 'save_figure/'
if not os.path.exists(save_fig_path):
os.makedirs(save_fig_path)
print("--------------------------------")
print("---The data files are saved in the directory", path)
print("---The figures are going to be saved in ", save_fig_path)
reg_group, ful_group = give_score_path(path)
print("----------------------------------")
print("-----creating the first figure----") print("-----creating the first figure----")
print("----------------------------------")
give_first_figure(reg_group, ful_group, args.save) give_first_figure(reg_group, ful_group, args.save)
print("----------------------------------")
print("-----creating figure 4 and figure E1---") print("-----creating figure 4 and figure E1---")
print("----------------------------------")
give_figure_4_and_e1(ful_group, False, args.save) give_figure_4_and_e1(ful_group, False, args.save)
print("----------------------------------")
print("-----creating figure 5----------------") print("-----creating figure 5----------------")
print("----------------------------------")
give_figure_5(reg_group, ful_group, args.save) give_figure_5(reg_group, ful_group, args.save)
print("----------------------------------")
print("-----creating figure e2---------------") print("-----creating figure e2---------------")
print("----------------------------------")
give_figure_e2(reg_group, ful_group, args.save) give_figure_e2(reg_group, ful_group, args.save)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment