Skip to content
Snippets Groups Projects

Progress UNet July Oskar

Merged
ofhkrrequested to merge
Progress_UNet_July into main
All threads resolved!
6 files
+ 228
67
Compare changes
  • Side-by-side
  • Inline

Files

%% Cell type:code id:be66055b-8ee9-46be-ad9d-f15edf2654a4 tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id:0c61dd11-5a2b-44ff-b0e5-989360bbb677 tags:
``` python
from os.path import join
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import qim3d
%matplotlib inline
```
%% Cell type:code id:cd6bb832-1297-462f-8d35-1738a9c37ffd tags:
``` python
# Define function for getting dataset path from string
def get_dataset_path(name: str, datasets):
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
return datasets_path[dataset_idx]
```
%% Cell type:code id:b8f54ff7-b3f6-4c32-87a4-b500dd877f17 tags:
``` python
my_aug = qim3d.qim3d.utils.Augmentation(resize = 256)
aug_train = my_aug.augment('heavy')
aug_val_test = my_aug.augment(None)
```
%% Cell type:code id:9a3b9c3c-4bbb-4a19-9685-f68c437e8bee tags:
``` python
datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d',
'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']
dataset = datasets[3]
dataset = datasets[-1]
# should not use gaudez2022: 3d image
# reichardt2021: multiclass segmentation
train_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset,datasets),transform=aug_train)
val_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset,datasets),transform=aug_val_test)
test_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset,datasets),split='test',transform=aug_val_test)
```
%% Output
images are all the same size!
Not implemented yet: process images that are not powers of 2.
images are all the same size!
Not implemented yet: process images that are not powers of 2.
images are all the same size!
Not implemented yet: process images that are not powers of 2.
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
(256, 256)
%% Cell type:code id:4ab857f1-4595-4a21-b07b-9bdef0117c78 tags:
``` python
VAL_FRACTION = 0.3
split_idx = int(np.floor(VAL_FRACTION * len(train_set)))
indices = torch.randperm(len(train_set))
train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
```
%% Cell type:code id:d94c5521-f934-4fae-b3aa-ed22546ced88 tags:
``` python
# Define batch sizes
TRAIN_BATCH_SIZE = 3
VAL_BATCH_SIZE = 3
TEST_BATCH_SIZE = 3
# Define dataloaders
train_loader = DataLoader(dataset=train_set, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset=val_set, batch_size=VAL_BATCH_SIZE, num_workers=8, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=TEST_BATCH_SIZE, num_workers=8, pin_memory=True)
```
%% Cell type:code id:ce64ae65-01fb-45a9-bdcb-a3806de8469e tags:
``` python
# choosing the device to be run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# defining model
qim_unet = qim3d.qim3d.utils.qim_UNet(size = 'medium')
qim_unet = qim3d.qim3d.models.UNet(size = 'medium')
model = qim_unet()
model.to(device)
# model hyperparameters
qim_hyper = qim3d.qim3d.utils.qim_hyperparameters(model, n_epochs=50, learning_rate = 1e-3)
qim_hyper = qim3d.qim3d.models.Hyperparameters(model, n_epochs=50, learning_rate = 1e-3)
hyper_dict = qim_hyper()
# training model:
train_loss, val_loss = qim3d.qim3d.utils.train_model(model, hyper_dict, train_loader, val_loader)
```
%% Output
Epoch 0, train loss: 0.3357, val loss: 0.2101
Epoch 5, train loss: 0.0333, val loss: 0.0487
Epoch 10, train loss: 0.0307, val loss: 0.0244
Epoch 15, train loss: 0.0301, val loss: 0.0315
Epoch 20, train loss: 0.0171, val loss: 0.0170
Epoch 25, train loss: 0.0138, val loss: 0.0139
Epoch 30, train loss: 0.0113, val loss: 0.0096
Epoch 35, train loss: 0.0119, val loss: 0.0119
Epoch 40, train loss: 0.0077, val loss: 0.0082
Epoch 45, train loss: 0.0087, val loss: 0.0083
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[8], line 14
11 hyper_dict = qim_hyper()
13 # training model:
---> 14 train_loss, val_loss = qim3d.qim3d.utils.train_model(model, hyper_dict, train_loader, val_loader)
File ~/qim3d/qim3d/utils/models.py:69, in train_model(model, qim_hyperparameters, train_loader, val_loader, eval_every, print_every)
66 targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
68 optimizer.zero_grad()
---> 69 outputs = model(inputs)
70 loss = criterion(outputs, targets)
72 # Backpropagation
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/lib/python3.10/site-packages/monai/networks/nets/unet.py:303, in UNet.forward(self, x)
302 def forward(self, x: torch.Tensor) -> torch.Tensor:
--> 303 x = self.model(x)
304 return x
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input)
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias)
455 if self.padding_mode != 'zeros':
456 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
457 weight, bias, self.stride,
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
KeyboardInterrupt:
%% Cell type:code id:27fb326d-2771-470d-8288-a91d71e38fce tags:
``` python
plt.figure(figsize=(16, 6))
qim3d.qim3d.viz.plot_metrics(train_loss, label = 'Train')
qim3d.qim3d.viz.plot_metrics(val_loss,color = 'orange', label = 'Valid.')
plt.show()
```
%% Output
%% Cell type:code id:f8684cb0-5673-4409-8d22-f00b7d099ca4 tags:
``` python
in_targ_preds_test = qim3d.utils.inference(test_set,model)
qim3d.qim3d.viz.grid_pred(in_targ_preds_test,alpha=1)
```
%% Output
Loading