octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/main.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import rich_click as click
|
|
2
|
+
from octopi import cli_context
|
|
3
|
+
from octopi.entry_points import groups
|
|
4
|
+
from octopi.processing.downloader import cli as download
|
|
5
|
+
from octopi.processing.importers import cli as import_tomograms
|
|
6
|
+
from octopi.entry_points.run_train import cli as train_model
|
|
7
|
+
from octopi.entry_points.run_optuna import cli as model_explore
|
|
8
|
+
from octopi.entry_points.run_create_targets import cli as create_targets
|
|
9
|
+
from octopi.entry_points.run_segment import cli as inference
|
|
10
|
+
from octopi.entry_points.run_localize import cli as localize
|
|
11
|
+
from octopi.entry_points.run_evaluate import cli as evaluate
|
|
12
|
+
from octopi.entry_points.run_extract_mb_picks import cli as mb_extract
|
|
13
|
+
|
|
14
|
+
@click.group(context_settings=cli_context)
|
|
15
|
+
def routines():
|
|
16
|
+
"""Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
routines.add_command(download)
|
|
20
|
+
routines.add_command(import_tomograms)
|
|
21
|
+
routines.add_command(train_model)
|
|
22
|
+
routines.add_command(create_targets)
|
|
23
|
+
routines.add_command(inference)
|
|
24
|
+
routines.add_command(localize)
|
|
25
|
+
routines.add_command(model_explore)
|
|
26
|
+
routines.add_command(evaluate)
|
|
27
|
+
routines.add_command(mb_extract)
|
|
28
|
+
|
|
29
|
+
@click.group(context_settings=cli_context)
|
|
30
|
+
def slurm_routines():
|
|
31
|
+
"""Slurm-Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from monai.networks.nets import AttentionUnet
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myAttentionUnet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
|
|
8
|
+
# Placeholder for the model and config
|
|
9
|
+
self.model = None
|
|
10
|
+
self.config = None
|
|
11
|
+
|
|
12
|
+
def build_model( self, config: dict ):
|
|
13
|
+
"""Creates the AttentionUnet model based on provided parameters."""
|
|
14
|
+
|
|
15
|
+
self.model = AttentionUnet(
|
|
16
|
+
spatial_dims=3,
|
|
17
|
+
in_channels=1,
|
|
18
|
+
out_channels=config['num_classes'],
|
|
19
|
+
channels=config['channels'],
|
|
20
|
+
strides=config['strides'],
|
|
21
|
+
dropout=config['dropout']
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
return self.model
|
|
25
|
+
|
|
26
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
27
|
+
"""Defines the Bayesian optimization search space and builds the model with suggested parameters."""
|
|
28
|
+
|
|
29
|
+
# Define the search space
|
|
30
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
31
|
+
hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
|
|
32
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
|
|
33
|
+
|
|
34
|
+
# Create channel sizes and strides
|
|
35
|
+
downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
36
|
+
hidden_channels = [downsampling_channels[-1]] * hidden_layers
|
|
37
|
+
channels = downsampling_channels + hidden_channels
|
|
38
|
+
strides = [2] * (num_layers - 1) + [1] * hidden_layers
|
|
39
|
+
dropout = trial.suggest_float("dropout", 0.0, 0.2)
|
|
40
|
+
|
|
41
|
+
config = {
|
|
42
|
+
'architecture': 'AttentionUnet',
|
|
43
|
+
'num_classes': num_classes,
|
|
44
|
+
'channels': channels,
|
|
45
|
+
'strides': strides,
|
|
46
|
+
'dropout': dropout
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return self.build_model(config)
|
|
50
|
+
|
|
51
|
+
def get_model_parameters(self):
|
|
52
|
+
"""Retrieve stored model parameters."""
|
|
53
|
+
if self.model is None:
|
|
54
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
55
|
+
|
|
56
|
+
return self.config
|
octopi/models/MedNeXt.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from monai.networks.nets import MedNeXt
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myMedNeXt:
|
|
6
|
+
def __init__(self ):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model(
|
|
12
|
+
self,
|
|
13
|
+
init_filters=32,
|
|
14
|
+
encoder_expansion_ratio=2,
|
|
15
|
+
decoder_expansion_ratio=2,
|
|
16
|
+
bottleneck_expansion_ratio=2,
|
|
17
|
+
kernel_size=7,
|
|
18
|
+
deep_supervision=False,
|
|
19
|
+
use_residual_connection=False,
|
|
20
|
+
norm_type="group",
|
|
21
|
+
global_resp_norm=False,
|
|
22
|
+
blocks_down=(2, 2, 2, 2),
|
|
23
|
+
blocks_bottleneck=2,
|
|
24
|
+
blocks_up=(2, 2, 2, 2),
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Create the MedNeXt model with the specified hyperparameters.
|
|
28
|
+
Note: For cryoET with small objects, a shallower network might help preserve details.
|
|
29
|
+
"""
|
|
30
|
+
self.model = MedNeXt(
|
|
31
|
+
spatial_dims=3,
|
|
32
|
+
in_channels=1,
|
|
33
|
+
out_channels=self.num_classes,
|
|
34
|
+
init_filters=init_filters,
|
|
35
|
+
encoder_expansion_ratio=encoder_expansion_ratio,
|
|
36
|
+
decoder_expansion_ratio=decoder_expansion_ratio,
|
|
37
|
+
bottleneck_expansion_ratio=bottleneck_expansion_ratio,
|
|
38
|
+
kernel_size=kernel_size,
|
|
39
|
+
deep_supervision=deep_supervision,
|
|
40
|
+
use_residual_connection=use_residual_connection,
|
|
41
|
+
norm_type=norm_type,
|
|
42
|
+
global_resp_norm=global_resp_norm,
|
|
43
|
+
blocks_down=blocks_down,
|
|
44
|
+
blocks_bottleneck=blocks_bottleneck,
|
|
45
|
+
blocks_up=blocks_up,
|
|
46
|
+
).to(self.device)
|
|
47
|
+
|
|
48
|
+
def bayesian_search(self, trial):
|
|
49
|
+
"""
|
|
50
|
+
Defines the Bayesian optimization search space and builds the model with suggested parameters.
|
|
51
|
+
The search space has been adapted for cryoET applications:
|
|
52
|
+
- Small kernel sizes (3 or 5) to capture fine details.
|
|
53
|
+
- Choice of a shallower vs. deeper architecture to balance resolution and feature extraction.
|
|
54
|
+
- Robust normalization options for low signal-to-noise data.
|
|
55
|
+
"""
|
|
56
|
+
# Core hyperparameters
|
|
57
|
+
init_filters = trial.suggest_categorical("init_filters", [16, 32])
|
|
58
|
+
encoder_expansion_ratio = trial.suggest_int("encoder_expansion_ratio", 1, 3)
|
|
59
|
+
decoder_expansion_ratio = trial.suggest_int("decoder_expansion_ratio", 1, 3)
|
|
60
|
+
bottleneck_expansion_ratio = trial.suggest_int("bottleneck_expansion_ratio", 1, 4)
|
|
61
|
+
kernel_size = trial.suggest_categorical("kernel_size", [3, 5])
|
|
62
|
+
deep_supervision = trial.suggest_categorical("deep_supervision", [True, False])
|
|
63
|
+
norm_type = trial.suggest_categorical("norm_type", ["group", "instance"])
|
|
64
|
+
# For extremely low SNR, you might opt to disable global response normalization
|
|
65
|
+
global_resp_norm = trial.suggest_categorical("global_resp_norm", [False])
|
|
66
|
+
|
|
67
|
+
# Architecture: shallow vs. deep.
|
|
68
|
+
# For small objects, a shallower network (fewer downsampling stages) may preserve spatial detail.
|
|
69
|
+
architecture = trial.suggest_categorical("architecture", ["shallow", "deep"])
|
|
70
|
+
if architecture == "shallow":
|
|
71
|
+
blocks_down = (2, 2, 2) # 3 downsampling stages
|
|
72
|
+
blocks_bottleneck = 2
|
|
73
|
+
blocks_up = (2, 2, 2)
|
|
74
|
+
else:
|
|
75
|
+
blocks_down = (2, 2, 2, 2) # 4 downsampling stages
|
|
76
|
+
blocks_bottleneck = 2
|
|
77
|
+
blocks_up = (2, 2, 2, 2)
|
|
78
|
+
|
|
79
|
+
self.build_model(
|
|
80
|
+
init_filters=init_filters,
|
|
81
|
+
encoder_expansion_ratio=encoder_expansion_ratio,
|
|
82
|
+
decoder_expansion_ratio=decoder_expansion_ratio,
|
|
83
|
+
bottleneck_expansion_ratio=bottleneck_expansion_ratio,
|
|
84
|
+
kernel_size=kernel_size,
|
|
85
|
+
deep_supervision=deep_supervision,
|
|
86
|
+
use_residual_connection=True,
|
|
87
|
+
norm_type=norm_type,
|
|
88
|
+
global_resp_norm=False,
|
|
89
|
+
blocks_down=blocks_down,
|
|
90
|
+
blocks_bottleneck=blocks_bottleneck,
|
|
91
|
+
blocks_up=blocks_up,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def get_model_parameters(self):
|
|
95
|
+
"""Retrieve stored model parameters."""
|
|
96
|
+
if self.model is None:
|
|
97
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
'architecture': 'MedNeXt',
|
|
101
|
+
'num_classes': self.num_classes,
|
|
102
|
+
'init_filters': self.model.init_filters,
|
|
103
|
+
'encoder_expansion_ratio': self.model.encoder_expansion_ratio,
|
|
104
|
+
'decoder_expansion_ratio': self.model.decoder_expansion_ratio,
|
|
105
|
+
'bottleneck_expansion_ratio': self.model.bottleneck_expansion_ratio,
|
|
106
|
+
'kernel_size': self.model.kernel_size,
|
|
107
|
+
'deep_supervision': self.model.do_ds,
|
|
108
|
+
'use_residual_connection': self.model.use_residual_connection,
|
|
109
|
+
'norm_type': self.model.norm_type,
|
|
110
|
+
'global_resp_norm': self.model.global_resp_norm,
|
|
111
|
+
}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class myModelTemplate:
|
|
5
|
+
def __init__(self):
|
|
6
|
+
"""
|
|
7
|
+
Initialize the model template.
|
|
8
|
+
"""
|
|
9
|
+
# Placeholder for the model and config
|
|
10
|
+
self.model = None
|
|
11
|
+
self.config = None
|
|
12
|
+
|
|
13
|
+
def build_model(self, config: dict):
|
|
14
|
+
"""
|
|
15
|
+
Build the model based on provided parameters in a config dictionary.
|
|
16
|
+
"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
20
|
+
"""
|
|
21
|
+
Define the hyperparameter search space for Bayesian optimization and build the model.
|
|
22
|
+
|
|
23
|
+
The search space below is just an example and can be customized.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
trial (optuna.trial.Trial): Optuna trial object.
|
|
27
|
+
num_classes (int): Number of classes in the dataset.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def get_model_parameters(self):
|
|
32
|
+
"""Retrieve stored model parameters."""
|
|
33
|
+
if self.model is None:
|
|
34
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
35
|
+
|
|
36
|
+
return self.config
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from monai.networks.nets import SegResNetDS
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class mySegResNet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model(
|
|
12
|
+
self,
|
|
13
|
+
config: dict
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Creates the SegResNetDS model based on provided parameters.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
init_filters (int): Number of output channels for the initial convolution.
|
|
20
|
+
blocks_down (tuple): Tuple defining the number of blocks at each downsampling stage.
|
|
21
|
+
dsdepth (int): Depth for deep supervision (number of output scales).
|
|
22
|
+
act (str): Activation type.
|
|
23
|
+
norm (str): Normalization type.
|
|
24
|
+
blocks_up (tuple or None): Number of upsample blocks (if None, uses default behavior).
|
|
25
|
+
upsample_mode (str): Upsampling mode, e.g. 'deconv' or 'trilinear'.
|
|
26
|
+
resolution (optional): If provided, adjusts non-isotropic kernels to isotropic spacing.
|
|
27
|
+
preprocess (callable or None): Optional preprocessing function for the input.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
torch.nn.Module: The instantiated SegResNetDS model on the specified device.
|
|
31
|
+
"""
|
|
32
|
+
self.model = SegResNetDS(
|
|
33
|
+
spatial_dims=3,
|
|
34
|
+
init_filters=config['init_filters'],
|
|
35
|
+
in_channels=1,
|
|
36
|
+
out_channels=config['num_classes'],
|
|
37
|
+
act=config['act'],
|
|
38
|
+
norm=config['norm'],
|
|
39
|
+
blocks_down=config['blocks_down'],
|
|
40
|
+
blocks_up=config['blocks_up'],
|
|
41
|
+
dsdepth=config['dsdepth'],
|
|
42
|
+
preprocess=config['preprocess'],
|
|
43
|
+
upsample_mode=config['upsample_mode'],
|
|
44
|
+
resolution=config['resolution']
|
|
45
|
+
)
|
|
46
|
+
return self.model.to(self.device)
|
|
47
|
+
|
|
48
|
+
def bayesian_search(self, trial):
|
|
49
|
+
"""
|
|
50
|
+
Defines the Bayesian optimization search space and builds the model with suggested parameters.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
trial (optuna.trial.Trial): An Optuna trial object.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
torch.nn.Module: The model built with hyperparameters suggested by the trial.
|
|
57
|
+
"""
|
|
58
|
+
# Define search space parameters
|
|
59
|
+
init_filters = trial.suggest_categorical("init_filters", [16, 32, 64])
|
|
60
|
+
dsdepth = trial.suggest_int("dsdepth", 1, 3)
|
|
61
|
+
blocks_down = trial.suggest_categorical("blocks_down", [(1, 2, 2, 4), (1, 2, 2, 2), (1, 1, 2, 2)])
|
|
62
|
+
act = trial.suggest_categorical("act", ['relu', 'leaky_relu', "LeakyReLU", "PReLU", "GELU", "ELU"])
|
|
63
|
+
norm = trial.suggest_categorical("norm", ['batch', 'instance'])
|
|
64
|
+
upsample_mode = trial.suggest_categorical("upsample_mode", ['deconv', 'trilinear'])
|
|
65
|
+
|
|
66
|
+
self.config = {
|
|
67
|
+
'init_filters': init_filters,
|
|
68
|
+
'blocks_down': blocks_down,
|
|
69
|
+
'dsdepth': dsdepth,
|
|
70
|
+
'act': act,
|
|
71
|
+
'norm': norm,
|
|
72
|
+
'blocks_up': None, # using default upsampling blocks
|
|
73
|
+
'upsample_mode': upsample_mode,
|
|
74
|
+
'resolution': None,
|
|
75
|
+
'preprocess': None
|
|
76
|
+
}
|
|
77
|
+
return self.build_model(self.config)
|
|
78
|
+
|
|
79
|
+
def get_model_parameters(self):
|
|
80
|
+
"""
|
|
81
|
+
Retrieve stored model parameters.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
dict: A dictionary of key model parameters.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If the model has not been built yet.
|
|
88
|
+
"""
|
|
89
|
+
if self.model is None:
|
|
90
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
91
|
+
|
|
92
|
+
return self.config
|
octopi/models/Unet.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from monai.networks.nets import UNet
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myUNet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model( self, config: dict ):
|
|
12
|
+
"""Creates the Unet model based on provided parameters."""
|
|
13
|
+
|
|
14
|
+
self.config = config
|
|
15
|
+
self.model = UNet(
|
|
16
|
+
spatial_dims=3,
|
|
17
|
+
in_channels=1,
|
|
18
|
+
out_channels=config['num_classes'],
|
|
19
|
+
channels=config['channels'],
|
|
20
|
+
strides=config['strides'],
|
|
21
|
+
num_res_units=config['num_res_units'],
|
|
22
|
+
dropout=config['dropout']
|
|
23
|
+
)
|
|
24
|
+
return self.model
|
|
25
|
+
|
|
26
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
27
|
+
"""Defines the Bayesian optimization search space and builds the model with suggested parameters."""
|
|
28
|
+
|
|
29
|
+
# Define the search space
|
|
30
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
31
|
+
hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
|
|
32
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
|
|
33
|
+
num_res_units = trial.suggest_int("num_res_units", 1, 3)
|
|
34
|
+
dropout = trial.suggest_float("dropout", 0.0, 0.3)
|
|
35
|
+
|
|
36
|
+
# Create channel sizes and strides
|
|
37
|
+
downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
38
|
+
hidden_channels = [downsampling_channels[-1]] * hidden_layers
|
|
39
|
+
channels = downsampling_channels + hidden_channels
|
|
40
|
+
strides = [2] * (num_layers - 1) + [1] * hidden_layers
|
|
41
|
+
|
|
42
|
+
# Create config dictionary
|
|
43
|
+
self.config = {
|
|
44
|
+
'architecture': 'Unet',
|
|
45
|
+
'num_classes': num_classes,
|
|
46
|
+
'channels': channels,
|
|
47
|
+
'strides': strides,
|
|
48
|
+
'num_res_units': num_res_units,
|
|
49
|
+
'dropout': dropout
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
return self.build_model(self.config)
|
|
53
|
+
|
|
54
|
+
def get_model_parameters(self):
|
|
55
|
+
"""Retrieve stored model parameters."""
|
|
56
|
+
if self.model is None:
|
|
57
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
58
|
+
|
|
59
|
+
return self.config
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from monai.networks.nets import BasicUNetPlusPlus
|
|
2
|
+
|
|
3
|
+
class myUNetPlusPlus:
|
|
4
|
+
def __init__(self, num_classes, device):
|
|
5
|
+
self.device = device
|
|
6
|
+
self.num_classes = num_classes
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def build_model(
|
|
10
|
+
self,
|
|
11
|
+
features=(16, 16, 32, 64, 128, 16),
|
|
12
|
+
dropout=0.2,
|
|
13
|
+
upsample='deconv',
|
|
14
|
+
activation='relu'
|
|
15
|
+
):
|
|
16
|
+
|
|
17
|
+
model = BasicUNetPlusPlus(
|
|
18
|
+
spatial_dims=3,
|
|
19
|
+
in_channels=1,
|
|
20
|
+
out_channels=n_classes,
|
|
21
|
+
deep_supervision=True,
|
|
22
|
+
features=features, # Halve the features
|
|
23
|
+
dropout=dropout,
|
|
24
|
+
upsample=upsample,
|
|
25
|
+
act=activation
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return model.to(self.device)
|
|
29
|
+
|
|
30
|
+
def bayesian_search(self, trial):
|
|
31
|
+
|
|
32
|
+
act = trial.suggest_categorical("activation", ["LeakyReLU", "PReLU", "GELU", "ELU"])
|
|
33
|
+
dropout_prob = trial.suggest_float("dropout", 0.0, 0.5)
|
|
34
|
+
upsample = trial.suggest_categorical("upsample", ["deconv", "pixelshuffle", "nontrainable"])
|
|
35
|
+
model_parameters = {"activation": act,'dropout': dropout_prob, 'upsample': upsample}
|
|
36
|
+
|
|
37
|
+
model = self.build_model(features, dropout_prob, upsample, act)
|
|
38
|
+
return model
|
|
39
|
+
|
|
40
|
+
def get_model_parameters(self):
|
|
41
|
+
return {
|
|
42
|
+
'model_name': 'UnetPlusPlus',
|
|
43
|
+
'features': self.features,
|
|
44
|
+
'dropout': self.dropout,
|
|
45
|
+
'upsample': self.upsample,
|
|
46
|
+
'activation': self.activation
|
|
47
|
+
}
|
|
File without changes
|
octopi/models/common.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from monai.losses import FocalLoss, TverskyLoss
|
|
2
|
+
from octopi.utils import losses
|
|
3
|
+
from octopi.models import (
|
|
4
|
+
Unet, AttentionUnet, MedNeXt, SegResNet
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
def get_model(architecture):
|
|
8
|
+
|
|
9
|
+
# Initialize model based on architecture
|
|
10
|
+
if architecture == "Unet":
|
|
11
|
+
model = Unet.myUNet()
|
|
12
|
+
elif architecture == "AttentionUnet":
|
|
13
|
+
model = AttentionUnet.myAttentionUnet()
|
|
14
|
+
elif architecture == "MedNeXt":
|
|
15
|
+
model = MedNeXt.myMedNeXt()
|
|
16
|
+
elif architecture == "SegResNet":
|
|
17
|
+
model = SegResNet.mySegResNet()
|
|
18
|
+
else:
|
|
19
|
+
raise ValueError(f"Model type {architecture} not supported!\nPlease use one of the following: Unet, AttentionUnet, MedNeXt, SegResNet")
|
|
20
|
+
|
|
21
|
+
return model
|
|
22
|
+
|
|
23
|
+
def get_loss_function(trial, loss_name = None):
|
|
24
|
+
|
|
25
|
+
# Loss function selection
|
|
26
|
+
if loss_name is None:
|
|
27
|
+
loss_name = trial.suggest_categorical(
|
|
28
|
+
"loss_function",
|
|
29
|
+
["FocalLoss", "WeightedFocalTverskyLoss", 'FocalTverskyLoss'])
|
|
30
|
+
|
|
31
|
+
if loss_name == "FocalLoss":
|
|
32
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
33
|
+
loss_function = FocalLoss(include_background=True, to_onehot_y=True, use_softmax=True, gamma=gamma)
|
|
34
|
+
|
|
35
|
+
elif loss_name == "TverskyLoss":
|
|
36
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
37
|
+
beta = 1.0 - alpha
|
|
38
|
+
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
|
|
39
|
+
|
|
40
|
+
elif loss_name == 'WeightedFocalTverskyLoss':
|
|
41
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
42
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
43
|
+
beta = 1.0 - alpha
|
|
44
|
+
weight_tversky = round(trial.suggest_float("weight_tversky", 0.1, 0.9), 3)
|
|
45
|
+
weight_focal = 1.0 - weight_tversky
|
|
46
|
+
loss_function = losses.WeightedFocalTverskyLoss(
|
|
47
|
+
gamma=gamma, alpha=alpha, beta=beta,
|
|
48
|
+
weight_tversky=weight_tversky, weight_focal=weight_focal
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
elif loss_name == 'FocalTverskyLoss':
|
|
52
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
53
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
54
|
+
beta = 1.0 - alpha
|
|
55
|
+
loss_function = losses.FocalTverskyLoss(gamma=gamma, alpha=alpha, beta=beta)
|
|
56
|
+
|
|
57
|
+
return loss_function
|
|
58
|
+
|
|
59
|
+
def get_default_unet_params():
|
|
60
|
+
|
|
61
|
+
model_config = {
|
|
62
|
+
'architecture': 'Unet',
|
|
63
|
+
'dim_in': 80,
|
|
64
|
+
'strides': [2, 2, 1],
|
|
65
|
+
'channels': [48, 64, 80, 80],
|
|
66
|
+
'dropout': 0.0, 'num_res_units': 1,
|
|
67
|
+
}
|
|
68
|
+
return model_config
|
|
69
|
+
|
|
70
|
+
#### TODO : Models to try Adding?
|
|
71
|
+
# 1. Swin UNETR
|
|
72
|
+
# 2. Swin-Conv-UNet
|
|
File without changes
|