Skip to content
Snippets Groups Projects
Commit 1cecdb36 authored by pjtka's avatar pjtka
Browse files

Not error checked yet

parent 28731d97
No related branches found
No related tags found
No related merge requests found
Showing
with 1390 additions and 0 deletions
input_size: [224,224]
random_resize: True
same_size: False
mean: [0.0,0.0,0.0]
std: [1.0,1.0,1.0]
full_rot: 180
scale: (0.8, 1.2)
shear: 10
cutout: 16
\ No newline at end of file
TripletMargin:
triplets_per_anchor: all
margin: 0.09610074859813894
sampler:
MPerClassSampler:
m: 4
Contrastive:
pos_margin: 0.26523381895861114
neg_margin: 0.5409405918690342
sampler:
MPerClassSampler:
m: 4
best_sub_experiment_name: cub_triplet37
best_parameters:
loss_funcs/metric_loss/TripletMarginLoss/margin: 0.09610074859813894
best_values:
mean_average_precision_at_r:
mean: 0.42442235728792294
SEM: 0.018151013308347688
best_sub_experiment_name: cars_contrastive46
best_parameters:
loss_funcs/metric_loss/ContrastiveLoss/pos_margin: 0.26523381895861114
loss_funcs/metric_loss/ContrastiveLoss/neg_margin: 0.5409405918690342
best_values:
mean_average_precision_at_r:
mean: 0.2208228997968752
SEM: 0.030657043753582813
LeastSquares:
reduction: mean
L1Loss:
reduction: mean
KendallsTau:
SomeParameter: 0
ADAM:
lr: 0.001
betas: (0.9, 0.999)
eps: 1e-08
weight_decay: 0
SGD:
lr: 0.01
momentum: 0
dampening: 0
weight_decay: 0
TRAIN:
ENABLE: True
DATASET: AISC
BATCH_SIZE: 32
EVAL_PERIOD: 2
CHECKPOINT_PERIOD: 2
AUTO_RESUME: True
DATA:
PATH_TO_DATA: r"C:\Users\ptrkm\PycharmProjects\BachelorDeeplearning\Embeddings\New_Embeddings"
PATH_TO_LABEL: r"C:\Users\ptrkm\PycharmProjects\BachelorDeeplearning\Embeddings\New_Embeddings"
PATH_TO_DIFFICULTIES: r"C:\Users\ptrkm\PycharmProjects\BachelorDeeplearning\Embeddings\New_Embeddings"
PATH_TO_SPLIT: r"C:\Users\ptrkm\PycharmProjects\BachelorDeeplearning\Embeddings\New_Embeddings"
NETWORK:
PATH_TO_SAVED: None
BACKBONE:
NAME: 'efficientnet-b5'
OUTPUT_DIM: 128
ALREADY_TRAINED: False
FREEZE_BATCHNORM: True
HEAD:
STRUCTURE: [128, 64, 16, 1]
ACTIVATION: sigmoid
BATCH_NORM_STRUCTURE: [False, False, False, False]
TRAINING:
BACKBONE:
MAX_EPOCH: 100
LOSS: contrastive
EARLY_STOP_PATIENCE: 3
HEAD:
MAX_EPOCH: 20
LOSS: least_squares
EARLY_STOP_PATIENCE: 2
COMBINED:
MAX_EPOCH: 10
ALPHA: 0.5
EARLY_STOP_PATIENCE: 1
SOLVER:
BASE_LR: 0.1
MOMENTUM: 0.9
WEIGHT_DECAY: 1e-4
WARMUP_START_LR: 0.01
OPTIMIZING_METHOD: ADAM
AUGMENTATION:
NAME: ngessert
CONFIG: standard_augmenter.yaml
TEST:
ENABLE: True
BATCH_SIZE: 64
DATA_LOADER:
NUM_WORKERS: 8
PIN_MEMORY: True
NUM_GPUS: 1 # Not set up to handle more currently
NUM_SHARDS: 1
RNG_SEED: 0
OUTPUT_DIR: r"C:\Users\ptrkm\PycharmProjects\BachelorDeeplearning\Embeddings\New_Embeddings"
\ No newline at end of file
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, utils
import math
from PIL import Image
from numba import jit
import color_constancy as cc
import pickle
from argparse import Namespace
model_params = {}
model_params['input_size'] = [224, 224, 3]
model_params['random_resize'] = True
model_params['same_size'] = False
model_params['mean'] = np.array([0.0,0.0,0.0])
model_params['std'] = np.array([1.0,1.0,1.0])
model_params['full_rot'] = 180
model_params['scale'] = (0.8,1.2)
model_params['shear'] = 10
model_params['cutout'] = 16
class DataAugmentISIC_AISC:
def __init__(self, model_params):
"""
To initialize all transformations in the correct order, subsequently applied in method "apply"
:param model_params: (Dict) of chosen hyperparameters for the data augmentation.
random_resize, same_size and input_size are the only parameters, with no default values
"""
assert model_params.get('random_resize', False) + model_params.get('same_size', False) == 1
self.random_resize = model_params.get('random_resize', False)
self.same_size = model_params.get('same_size', False)
self.input_size = model_params.get('input_size')
all_transforms = []
if self.same_size:
all_transforms.append(transforms.RandomCrop(self.input_size, padding_mode='reflect', pad_if_needed=True))
elif self.random_resize:
all_transforms.append(transforms.RandomResizedCrop(self.input_size[0], scale=(0.08, 1.0)))
all_transforms.append(cc.general_color_constancy(gaussian_differentiation=0, minkowski_norm=6, sigma=0))
all_transforms.append(transforms.RandomHorizontalFlip())
all_transforms.append(transforms.RandomVerticalFlip())
all_transforms.append(transforms.RandomChoice([transforms.RandomAffine(model_params.get('full_rot',180),
scale=model_params.get('scale', (0.8,1.2)),
shear=model_params.get('shear', 10),
interpolation=Image.NEAREST),
transforms.RandomAffine(model_params.get('full_rot',180),
scale=model_params.get('scale',(0.8,1.2)),
shear=model_params.get('shear', 10),
interpolation=Image.BICUBIC),
transforms.RandomAffine(model_params.get('full_rot',180),
scale=model_params.get('scale',(0.8,1.2)),
shear=model_params.get('shear', 10),
interpolation=Image.BILINEAR)]))
all_transforms.append(transforms.ColorJitter(brightness=32. /255., saturation=0.5))
all_transforms.append(RandomCutOut(n_holes=1, length=model_params.get('cutout',16), prob = 0.5))
all_transforms.append(transforms.ToTensor())
all_transforms.append(transforms.Normalize(np.float32(model_params.get('mean', np.array([0.0,0.0,0.0]))),
np.float32(model_params.get('std', np.array([1.0,1.0,1.0])))))
self.composed_train = transforms.Compose(all_transforms)
self.composed_eval = transforms.Compose([
cc.general_color_constancy(gaussian_differentiation=0, minkowski_norm=6, sigma = 0),
transforms.Resize(self.input_size),
transforms.ToTensor(),
transforms.Normalize(np.float32(model_params.get('mean', np.array([0.0, 0.0, 0.0]))),
np.float32(model_params.get('std', np.array([1.0, 1.0, 1.0]))))
])
def __call__(self, image, mode):
"""
Applies the composite of all transforms as seen in __init__
:param image: Image of type PIL.Image
:return: A torch.Tensor of the input image on which all augmentations have been applied
"""
if mode == 'train':
return self.composed_train(image)
else:
return self.composed_eval(image)
class RandomCutOut(object):
"""
Randomly mask out zero or more patches from an image
"""
def __init__(self, n_holes = 1, length = 16, prob = 0.5):
self.prob = prob
self.cutout = Cutout_v0(n_holes, length)
def __call__(self, img):
if np.random.uniform() < self.prob:
return self.cutout(img)
else:
return img
class Cutout_v0(object):
"""Randomly mask out one or more patches from an image.
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
img = np.array(img)
#print(img.shape)
h = img.shape[0]
w = img.shape[1]
mask = np.ones((h, w), np.uint8)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
#mask = torch.from_numpy(mask)
#mask = mask.expand_as(img)
img = img * np.expand_dims(mask,axis=2)
img = Image.fromarray(img)
return img
DATA_AUGMENTERS = {'ngessert': DataAugmentISIC_AISC}
def get_data_augmenter(augment_params):
return DATA_AUGMENTERS[augment_params.NAME](augment_params.vals)
if __name__ == "__main__":
model_params = {}
model_params['input_size'] = [224, 224]
model_params['random_resize'] = True
model_params['same_size'] = False
model_params['mean'] = np.array([0.0, 0.0, 0.0])
model_params['std'] = np.array([1.0, 1.0, 1.0])
model_params['full_rot'] = 180
model_params['scale'] = (0.8, 1.2)
model_params['shear'] = 10
model_params['cutout'] = 16
model_params['name'] = 'ngessert'
data_aug = DataAugmentISIC_AISC(model_params)
test = Image.open(r'C:\Users\ptrkm\Downloads\3-non-polariset.jpeg')
trans = transforms.ToPILImage()
test_new = data_aug(test, 'train')
breakpoint()
\ No newline at end of file
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import pickle
import pandas as pd
from Embeddings.New_Embeddings.data_augmentations import augmentations as aug
class AISC(Dataset):
def __init__(self, dataset_params):
self.path_to_data = dataset_params.PATH_TO_DATA
self.path_to_labels = dataset_params.PATH_TO_LABELS
self.path_to_difficulties = dataset_params.PATH_TO_DIFFICULTIES
self.path_to_split = dataset_params.PATH_TO_SPLIT
self.difficulties = None
self.name_to_file_label_difficulty = self.read_data_labels_and_difficulty()
self.name_to_file_label_difficulty, self.loading_order = self.split_dataset()
self.mode = 'train'
self.data_augmenter = aug.get_data_augmenter(dataset_params.data_augmentation)
def __len__(self):
return len(self.name_to_file_label_difficulty[self.mode])
def read_data_labels_and_difficulty(self):
self.difficulties = self.read_difficulties()
file_names_to_file = self.read_data()
label_names, labels = self.read_labels()
if not all(name in file_names_to_file for name in label_names):
raise ValueError("Not all names in the labels file are present in the image path")
return self.ensure_order(file_names_to_file, label_names, labels)
def ensure_order(self, file_names_to_file, label_names, labels):
"""
Function to ensure that the file order corresponds to the label order
:param file_names_to_file: (dict) image_name to full path to image
:param label_names: (list) of file names, not full path
:param labels: (np.ndarray) of size (N, C) where C is the number of classes, one-hot encoded
:return: (dict) with keys equal to label_names
"""
name_to_file_label_difficulty = dict()
for idx, name in enumerate(label_names):
name_to_file_label_difficulty[name] = {
'path': file_names_to_file[name],
'label': labels[idx],
'difficulty': self.difficulties[name],
'has_difficulty': self.difficulties[name] == -1
}
return name_to_file_label_difficulty
def read_data(self):
if not all(os.path.isdir(path) for path in self.path_to_data):
raise ValueError("The path to data attribute is not a directory on this device")
file_name_to_file = {}
for p in self.path_to_data:
for file in os.listdir(p):
if file not in file_name_to_file:
file_name_to_file[file] = os.path.join(p, file)
return file_name_to_file
def read_labels(self):
"""
Function to read labels assuming it is saved as csv
:return:
"""
if not os.path.isfile(self.path_to_labels):
raise ValueError("Path to labels is not a path to file on this device")
labels = pd.read_csv(self.path_to_labels)
label_names = list(labels['names'])
labels = labels.drop('names', axis=1).values()
return label_names, labels
def read_difficulties(self):
"""
Function to read difficulty estimates for images
:return: (dict) with image names as keys (not full path) and difficulty as value
"""
if not os.path.isfile(self.path_to_difficulties):
raise ValueError("Chosen path to difficulties is not a file on this device")
difficulties = pickle.load(open(self.path_to_difficulties, 'rb'))
return difficulties
def split_dataset(self):
"""
Function to split the dataset into the number of splits, specified in dataset_params.path_to_split
:return: (dict) with names, labels and difficulties for the splits
"""
split = pickle.load(open(self.path_to_split, 'rb'))
temp = dict()
loading_order = dict()
for mode, names in split.items():
temp[mode] = {
name: self.name_to_file_label_difficulty[name]
for name in names
}
loading_order[mode] = names
return temp, loading_order
def __getitem__(self, item):
"""
:param item: (int) conforming to the index of names
:return: (tuple) of (torch.Tensor, torch.Tensor, torch.Tensor) of image, label and difficulty
"""
file, label, difficulty, has_diff = self.name_to_file_label_difficulty[
self.loading_order[self.mode][item]
]
image = Image.open(file)
image = self.data_augmenter(image, self.mode)
label = torch.tensor(label)
difficulty = torch.tensor(difficulty)
if self.mode == 'train':
return image, label, difficulty, has_diff
else:
return image, label, difficulty, file, has_diff
import os
import pickle
import numpy as np
import pandas as pd
class BaseAISC:
def __init__(self, dataset_params):
self.path_to_data = dataset_params.path_to_data
self.path_to_labels = dataset_params.path_to_label
self.path_to_difficulties = dataset_params.path_to_difficulties
self.path_to_split = dataset_params.path_to_split
self.name_to_file_label_difficulty = self.read_data_labels_and_difficulty()
self.name_to_file_label_difficulty = self.split_dataset()
self.mode = 'train'
def __len__(self):
return len(self.name_to_file_label_difficulty[self.mode])
def read_data_labels_and_difficulty(self):
file_names_to_file = self.read_data()
label_names, labels = self.read_labels()
if not all(name in file_names_to_file for name in label_names):
raise ValueError("Not all names in the labels file are present in the image path")
return self.ensure_order(file_names_to_file, label_names, labels)
def ensure_order(self, file_names_to_file, label_names, labels):
"""
Function to ensure that the file order corresponds to the label order
:param file_names_to_file: (dict) image_name to full path to image
:param label_names: (list) of file names, not full path
:param labels: (np.ndarray) of size (N, C) where C is the number of classes, one-hot encoded
:return: (dict) with keys equal to label_names
"""
name_to_file_label_difficulty = dict()
for idx, name in enumerate(label_names):
name_to_file_label_difficulty[name] = {
'path': file_names_to_file[name],
'label': labels[idx],
'difficulty': self.difficulties[name]
}
return name_to_file_label_difficulty
def read_data(self):
if not os.path.isdir(self.path_to_data):
raise ValueError("The path to data attribute is not a directory on this device")
file_names_to_file = {
file: os.path.join(self.path_to_data, file) for file in os.listdir(self.path_to_data)
}
return file_names_to_file
def read_labels(self):
"""
Function to read labels assuming it is saved as csv
:return:
"""
if not os.path.isfile(self.path_to_labels):
raise ValueError("Path to labels is not a path to file on this device")
labels = pd.read_csv(self.path_to_labels)
label_names = list(labels['names'])
labels = labels.drop('names', axis=1).values()
return label_names, labels
def read_difficulties(self):
"""
Function to read difficulty estimates for images
:return: (dict) with image names as keys (not full path) and difficulty as value
"""
if not os.path.isfile(self.path_to_difficulties):
raise ValueError("Chosen path to difficulties is not a file on this device")
difficulties_all = pickle.load(open(self.path_to_difficulties, 'rb'))
difficulties = dict()
for lesion_uid, val in difficulties_all.items():
if len(val['image']) > 1:
for idx, name in enumerate(val['image']):
difficulties[name] = val['diff'][idx]
return difficulties
def split_dataset(self):
"""
Function to split the dataset into the number of splits, specified in dataset_params.path_to_split
:return: (dict) with names, labels and difficulties for the splits
"""
split = pickle.load(open(self.path_to_split, 'rb'))
temp = dict()
for mode, names in split.items():
temp[mode] = {
name: self.name_to_file_label_difficulty[name]
for name in names
}
return temp
\ No newline at end of file
embeddings_and_difficulty/dataloaders/histogram_over_channels.png

21.9 KiB

embeddings_and_difficulty/dataloaders/histogram_over_variance.png

5.05 KiB

%% Cell type:code id: tags:
``` python
import numpy as np
import pandas as pd
import pickle
import os
```
%% Cell type:code id: tags:
``` python
data = pd.read_csv(r'C:\Users\ptrkm\data_aisc\training-assessments.csv', sep = ";")
data = data.dropna(subset = ['correctDiagnosisName'])
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
correct_diagnosis = data['correctDiagnosisName']
image_name = data['dermoscopicImageName']
training_id = data['trainingCaseId']
user_id = data['userId']
correct_assesments = data['assessedCorrectly']
print(len(image_name.unique()) - len(training_id.unique()))
```
%% Output
0
%% Cell type:code id: tags:
``` python
print(correct_diagnosis.value_counts())
diagnosis_aisc_isic = {
'Melanoma': 'MEL',
'Nevus': 'NV',
'Seb. keratosis/ Lentigo solaris': 'BKL',
'Actinic keratosis': 'AK',
'Dermatofibroma': 'DF',
'Basal cell carcinoma': 'BCC',
'Hemangioma': 'VASC',
'Squamous cell carcinoma': 'SCC',
'Lentigo': 'BKL',
'Lentigo solaris': 'BKL',
'Vascular/Hemorrhage': 'VASC',
'Vascular lesion': 'VASC',
"Bowen's disease": 'SCC',
'Seborrheic keratosis': 'BKL'
}
isic_label_names =['MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC']
isic_idxs = dict(zip(isic_label_names, range(len(isic_label_names))))
diags = []
for diag in correct_diagnosis:
if diag in diagnosis_aisc_isic:
diags.append(isic_idxs[diagnosis_aisc_isic[diag]])
else:
diags.append(None)
```
%% Output
Nevus 37088
Melanoma 35165
Seb. keratosis/ Lentigo solaris 32615
Dermatofibroma 16307
Basal cell carcinoma 16142
Hemangioma 15508
Squamous cell carcinoma 15387
Lentigo 815
Vascular/Hemorrhage 725
Actinic keratosis 336
Other 102
Bowen's disease 100
Vascular lesion 16
Seborrheic keratosis 11
Lentigo solaris 5
Name: correctDiagnosisName, dtype: int64
%% Cell type:code id: tags:
``` python
from sklearn.metrics import accuracy_score
def calculate_difficulty(answers, labels):
return accuracy_score(labels, answers)
def run_through_all(id,ans):
test = pd.DataFrame(columns = ['id', 'ans'])
test['id'] = id
test['ans'] = ans
print(len(test))
difficulty = {}
for i in test['id']:
if i not in difficulty:
ans = test[test['id'] == i]['ans'].values
lab = np.ones((len(ans,)))
if len(ans) > 5:
difficulty[i] = calculate_difficulty(ans, lab)
return difficulty
difficulty = run_through_all(image_name, correct_assesments)
```
%% Output
170322
%% Cell type:code id: tags:
``` python
for name in data['dermoscopicImageName'].unique():
if name not in difficulty:
difficulty[name] = -1
```
%% Cell type:code id: tags:
``` python
specific_diagnosis_to_diagnosis = {}
for spc, diag in zip(data['correctSpecificDiagnosisName'], data['correctDiagnosisName']):
if spc not in specific_diagnosis_to_diagnosis:
specific_diagnosis_to_diagnosis[spc] = diag
all_imgs = pd.read_csv(r'C:\Users\ptrkm\data_aisc\additional-dermoscopic-images.csv', sep = ";")
cdiags = []
not_there = []
for spc, img_name in zip(all_imgs['correctSpecificDiagnosisName'], all_imgs['dermoscopicImageName']):
if spc not in specific_diagnosis_to_diagnosis:
not_there.append(img_name)
cdiags.append(-1)
else:
cdiags.append(specific_diagnosis_to_diagnosis[spc])
print(len(not_there))
all_imgs['correctDiagnosisName'] = cdiags
all_imgs = all_imgs[all_imgs['correctDiagnosisName'] != -1]
len(all_imgs['dermoscopicImageName'].unique())
```
%% Output
352
38507
%% Cell type:code id: tags:
``` python
for name in all_imgs['dermoscopicImageName']:
if name not in difficulty:
difficulty[name] = -1
```
%% Cell type:code id: tags:
``` python
with open(r'C:\Users\ptrkm\data_aisc\difficulties.pkl', 'wb') as handle:
pickle.dump(difficulty, handle, protocol=pickle.HIGHEST_PROTOCOL)
```
%% Cell type:code id: tags:
``` python
all_images_ = list(all_imgs['dermoscopicImageName']) + list(data['dermoscopicImageName'])
all_labels = list(all_imgs['correctDiagnosisName']) + list(data['correctDiagnosisName'])
diags = []
labels = []
for name, lab in zip(all_images_, all_labels):
if lab != 'Other':
diags.append(isic_idxs[diagnosis_aisc_isic[lab]])
labels.append(name)
print(len(labels))
print(len(diags))
labels_csv = pd.DataFrame()
labels_csv['names'] = labels
labels_csv['labels'] = diags
labels_csv.to_csv(r'C:\Users\ptrkm\data_aisc\labels.csv', index = None)
```
%% Output
208416
208416
%% Cell type:code id: tags:
``` python
```
(1506, 1506, 3)
(1506, 1506, 3)
(1506, 1506, 3)
(1492, 1492, 3)
(1492, 1492, 3)
(4032, 3024, 3)
(4032, 3024, 3)
(4032, 3024, 3)
(4032, 3024, 3)
Prediction for variance method [False, False, False, False, False, True, False, True, False]
prediction for point method [True, True, True, False, False, True, False, True, False]
['image (1).png', 'image (2).png', 'image (3).png', 'image (4).png', 'image (5).png', 'IMG_2062.jpeg', 'IMG_2063.jpeg', 'IMG_2057.jpeg', 'IMG_2059.jpeg']
\ No newline at end of file
This diff is collapsed.
def create_loss(losses, args):
alpha = args.ALPHA
loss_backbone = losses[0]
loss_head = losses[1]
def loss(embeddings, est_difficulties,labels, difficulties, score_keeper):
if score_keeper.is_training == 'backbone':
return loss_backbone(embeddings, labels)
if score_keeper.is_training == 'head':
return loss_head(est_difficulties, difficulties)
if score_keeper.is_training == 'combined':
return alpha * loss_backbone(embeddings, labels) + (1-alpha)*loss_head(est_difficulties, difficulties)
return loss
import numpy as np
import torch
from pytorch_metric_learning import losses
"""
All losses are added from pytorch_metric_learning - losses
https://kevinmusgrave.github.io/pytorch-metric-learning/losses/
all functions should be in the form get_loss(args): return(losses.loss(args))
args should point at yaml file in configs folder, if new loss is added, then there should also be added a yaml file
with the same name e.g. contrastive.yaml this should correspond to the string put in "configs/general.yaml" under loss
when added it should also be added to the dictionary in the bottom named all_losses
"""
def get_contrastive(args):
return losses.ContrastiveLoss(args.pos_margin, args.neg_margin, **args.kwargs)
def get_triplet_margin(args):
return losses.TripletMarginLoss(margin=args.margin,
swap = args.swap,
smooth_loss=args.smooth_loss,
triplets_per_anchor=args.triplets_per_anchor,
**args.kwargs)
all_losses = {
'contrastive': get_contrastive,
'triplet_marging': get_triplet_margin
}
def get_loss(loss, loss_args):
return all_losses[loss](loss_args)
import numpy as np
import torch
from pytorch_metric_learning import losses
import torch.nn as nn
def get_least_squares(args):
return nn.MSELoss(reduction=args.reductions)
def get_l1(args):
return nn.L1Loss(reduction=args.reductions)
class KendallsTau(nn.modules.loss._Loss):
def __init__(self, args):
self.args = args
def forward(self, difficulty, values):
sgn_difficulty = torch.zeros()
tau = 2/(len(difficulty) * (len(difficulty)-1)) * torch.sum(torch)
all_losses = {
'least_squares': get_least_squares,
'L1': get_l1,
}
def get_loss(loss, loss_args):
return all_losses[loss](loss_args)
\ No newline at end of file
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from scipy.spatial.distance import cdist
from scipy.stats import kendalltau
def calculate_embedding_accuracy_knn(embeddings, labels):
embs = embeddings.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
knn_preds = KNeighborsClassifier(n_neighbors=10).fit(embs, labels).predict(embs)
return accuracy_score(labels, knn_preds)
def calculate_embedding_accuracy_means(embeddings, labels):
embs = embeddings.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
means = np.array([embs[labels == i] for i in sorted(np.unique(labels))])
dist_mat = cdist(embs, means, metric='euclidian')
return accuracy_score(labels, np.argmin(dist_mat, -1))
def calculate_difficulty_accuracy_kendall_tau(estimated_difficulty, difficulty):
estimated_difficulty = estimated_difficulty.detach().cpu().numpy()
difficulty = difficulty.detach().cpu().numpy()
corr, _ = kendalltau(estimated_difficulty, difficulty)
return corr
def calculate_difficulty_accuracy_mean_squared_error(estimated_difficulty, difficulty):
estimated_difficulty = estimated_difficulty.detach().cpu().numpy()
difficulty = difficulty.detach().cpu().numpy()
return np.mean(np.linalg.norm(estimated_difficulty - difficulty))
accuracy_methods_embeddings = {
'knn': calculate_embedding_accuracy_knn,
'means': calculate_embedding_accuracy_means,
}
accuracy_methods_difficulties = {
'kendall_tau': calculate_difficulty_accuracy_kendall_tau,
'MSE': calculate_difficulty_accuracy_mean_squared_error
}
def get_accuracy_methods(args):
if args.EVAL_METRICS.BACKBONE in accuracy_methods_embeddings:
backbone_func = accuracy_methods_embeddings[args.EVAL_METRICS.BACKBONE]
else:
raise NotImplementedError("Accuracy calculation method for backbone is not implemented yet")
if args.EVAL_METRICS.HEAD in accuracy_methods_difficulties:
head_func = accuracy_methods_difficulties[args.EVAL_METRICS.HEAD]
else:
raise NotImplementedError("Accuracy calculation method for head is not implemented yet")
return backbone_func, head_func
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment