octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from monai.networks.nets import AttentionUnet
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myAttentionUnet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
|
|
8
|
+
# Placeholder for the model and config
|
|
9
|
+
self.model = None
|
|
10
|
+
self.config = None
|
|
11
|
+
|
|
12
|
+
def build_model( self, config: dict ):
|
|
13
|
+
"""Creates the AttentionUnet model based on provided parameters."""
|
|
14
|
+
|
|
15
|
+
self.model = AttentionUnet(
|
|
16
|
+
spatial_dims=3,
|
|
17
|
+
in_channels=1,
|
|
18
|
+
out_channels=config['num_classes'],
|
|
19
|
+
channels=config['channels'],
|
|
20
|
+
strides=config['strides'],
|
|
21
|
+
dropout=config['dropout']
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
return self.model
|
|
25
|
+
|
|
26
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
27
|
+
"""Defines the Bayesian optimization search space and builds the model with suggested parameters."""
|
|
28
|
+
|
|
29
|
+
# Define the search space
|
|
30
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
31
|
+
hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
|
|
32
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
|
|
33
|
+
|
|
34
|
+
# Create channel sizes and strides
|
|
35
|
+
downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
36
|
+
hidden_channels = [downsampling_channels[-1]] * hidden_layers
|
|
37
|
+
channels = downsampling_channels + hidden_channels
|
|
38
|
+
strides = [2] * (num_layers - 1) + [1] * hidden_layers
|
|
39
|
+
dropout = trial.suggest_float("dropout", 0.0, 0.2)
|
|
40
|
+
|
|
41
|
+
config = {
|
|
42
|
+
'architecture': 'AttentionUnet',
|
|
43
|
+
'num_classes': num_classes,
|
|
44
|
+
'channels': channels,
|
|
45
|
+
'strides': strides,
|
|
46
|
+
'dropout': dropout
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return self.build_model(config)
|
|
50
|
+
|
|
51
|
+
def get_model_parameters(self):
|
|
52
|
+
"""Retrieve stored model parameters."""
|
|
53
|
+
if self.model is None:
|
|
54
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
55
|
+
|
|
56
|
+
return self.config
|
octopi/models/MedNeXt.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from monai.networks.nets import MedNeXt
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myMedNeXt:
|
|
6
|
+
def __init__(self ):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model(
|
|
12
|
+
self,
|
|
13
|
+
init_filters=32,
|
|
14
|
+
encoder_expansion_ratio=2,
|
|
15
|
+
decoder_expansion_ratio=2,
|
|
16
|
+
bottleneck_expansion_ratio=2,
|
|
17
|
+
kernel_size=7,
|
|
18
|
+
deep_supervision=False,
|
|
19
|
+
use_residual_connection=False,
|
|
20
|
+
norm_type="group",
|
|
21
|
+
global_resp_norm=False,
|
|
22
|
+
blocks_down=(2, 2, 2, 2),
|
|
23
|
+
blocks_bottleneck=2,
|
|
24
|
+
blocks_up=(2, 2, 2, 2),
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Create the MedNeXt model with the specified hyperparameters.
|
|
28
|
+
Note: For cryoET with small objects, a shallower network might help preserve details.
|
|
29
|
+
"""
|
|
30
|
+
self.model = MedNeXt(
|
|
31
|
+
spatial_dims=3,
|
|
32
|
+
in_channels=1,
|
|
33
|
+
out_channels=self.num_classes,
|
|
34
|
+
init_filters=init_filters,
|
|
35
|
+
encoder_expansion_ratio=encoder_expansion_ratio,
|
|
36
|
+
decoder_expansion_ratio=decoder_expansion_ratio,
|
|
37
|
+
bottleneck_expansion_ratio=bottleneck_expansion_ratio,
|
|
38
|
+
kernel_size=kernel_size,
|
|
39
|
+
deep_supervision=deep_supervision,
|
|
40
|
+
use_residual_connection=use_residual_connection,
|
|
41
|
+
norm_type=norm_type,
|
|
42
|
+
global_resp_norm=global_resp_norm,
|
|
43
|
+
blocks_down=blocks_down,
|
|
44
|
+
blocks_bottleneck=blocks_bottleneck,
|
|
45
|
+
blocks_up=blocks_up,
|
|
46
|
+
).to(self.device)
|
|
47
|
+
|
|
48
|
+
def bayesian_search(self, trial):
|
|
49
|
+
"""
|
|
50
|
+
Defines the Bayesian optimization search space and builds the model with suggested parameters.
|
|
51
|
+
The search space has been adapted for cryoET applications:
|
|
52
|
+
- Small kernel sizes (3 or 5) to capture fine details.
|
|
53
|
+
- Choice of a shallower vs. deeper architecture to balance resolution and feature extraction.
|
|
54
|
+
- Robust normalization options for low signal-to-noise data.
|
|
55
|
+
"""
|
|
56
|
+
# Core hyperparameters
|
|
57
|
+
init_filters = trial.suggest_categorical("init_filters", [16, 32])
|
|
58
|
+
encoder_expansion_ratio = trial.suggest_int("encoder_expansion_ratio", 1, 3)
|
|
59
|
+
decoder_expansion_ratio = trial.suggest_int("decoder_expansion_ratio", 1, 3)
|
|
60
|
+
bottleneck_expansion_ratio = trial.suggest_int("bottleneck_expansion_ratio", 1, 4)
|
|
61
|
+
kernel_size = trial.suggest_categorical("kernel_size", [3, 5])
|
|
62
|
+
deep_supervision = trial.suggest_categorical("deep_supervision", [True, False])
|
|
63
|
+
norm_type = trial.suggest_categorical("norm_type", ["group", "instance"])
|
|
64
|
+
# For extremely low SNR, you might opt to disable global response normalization
|
|
65
|
+
global_resp_norm = trial.suggest_categorical("global_resp_norm", [False])
|
|
66
|
+
|
|
67
|
+
# Architecture: shallow vs. deep.
|
|
68
|
+
# For small objects, a shallower network (fewer downsampling stages) may preserve spatial detail.
|
|
69
|
+
architecture = trial.suggest_categorical("architecture", ["shallow", "deep"])
|
|
70
|
+
if architecture == "shallow":
|
|
71
|
+
blocks_down = (2, 2, 2) # 3 downsampling stages
|
|
72
|
+
blocks_bottleneck = 2
|
|
73
|
+
blocks_up = (2, 2, 2)
|
|
74
|
+
else:
|
|
75
|
+
blocks_down = (2, 2, 2, 2) # 4 downsampling stages
|
|
76
|
+
blocks_bottleneck = 2
|
|
77
|
+
blocks_up = (2, 2, 2, 2)
|
|
78
|
+
|
|
79
|
+
self.build_model(
|
|
80
|
+
init_filters=init_filters,
|
|
81
|
+
encoder_expansion_ratio=encoder_expansion_ratio,
|
|
82
|
+
decoder_expansion_ratio=decoder_expansion_ratio,
|
|
83
|
+
bottleneck_expansion_ratio=bottleneck_expansion_ratio,
|
|
84
|
+
kernel_size=kernel_size,
|
|
85
|
+
deep_supervision=deep_supervision,
|
|
86
|
+
use_residual_connection=True,
|
|
87
|
+
norm_type=norm_type,
|
|
88
|
+
global_resp_norm=False,
|
|
89
|
+
blocks_down=blocks_down,
|
|
90
|
+
blocks_bottleneck=blocks_bottleneck,
|
|
91
|
+
blocks_up=blocks_up,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def get_model_parameters(self):
|
|
95
|
+
"""Retrieve stored model parameters."""
|
|
96
|
+
if self.model is None:
|
|
97
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
'architecture': 'MedNeXt',
|
|
101
|
+
'num_classes': self.num_classes,
|
|
102
|
+
'init_filters': self.model.init_filters,
|
|
103
|
+
'encoder_expansion_ratio': self.model.encoder_expansion_ratio,
|
|
104
|
+
'decoder_expansion_ratio': self.model.decoder_expansion_ratio,
|
|
105
|
+
'bottleneck_expansion_ratio': self.model.bottleneck_expansion_ratio,
|
|
106
|
+
'kernel_size': self.model.kernel_size,
|
|
107
|
+
'deep_supervision': self.model.do_ds,
|
|
108
|
+
'use_residual_connection': self.model.use_residual_connection,
|
|
109
|
+
'norm_type': self.model.norm_type,
|
|
110
|
+
'global_resp_norm': self.model.global_resp_norm,
|
|
111
|
+
}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class myModelTemplate:
|
|
5
|
+
def __init__(self):
|
|
6
|
+
"""
|
|
7
|
+
Initialize the model template.
|
|
8
|
+
"""
|
|
9
|
+
# Placeholder for the model and config
|
|
10
|
+
self.model = None
|
|
11
|
+
self.config = None
|
|
12
|
+
|
|
13
|
+
def build_model(self, config: dict):
|
|
14
|
+
"""
|
|
15
|
+
Build the model based on provided parameters in a config dictionary.
|
|
16
|
+
"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
20
|
+
"""
|
|
21
|
+
Define the hyperparameter search space for Bayesian optimization and build the model.
|
|
22
|
+
|
|
23
|
+
The search space below is just an example and can be customized.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
trial (optuna.trial.Trial): Optuna trial object.
|
|
27
|
+
num_classes (int): Number of classes in the dataset.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def get_model_parameters(self):
|
|
32
|
+
"""Retrieve stored model parameters."""
|
|
33
|
+
if self.model is None:
|
|
34
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
35
|
+
|
|
36
|
+
return self.config
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from monai.networks.nets import SegResNetDS
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class mySegResNet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model(
|
|
12
|
+
self,
|
|
13
|
+
config: dict
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Creates the SegResNetDS model based on provided parameters.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
init_filters (int): Number of output channels for the initial convolution.
|
|
20
|
+
blocks_down (tuple): Tuple defining the number of blocks at each downsampling stage.
|
|
21
|
+
dsdepth (int): Depth for deep supervision (number of output scales).
|
|
22
|
+
act (str): Activation type.
|
|
23
|
+
norm (str): Normalization type.
|
|
24
|
+
blocks_up (tuple or None): Number of upsample blocks (if None, uses default behavior).
|
|
25
|
+
upsample_mode (str): Upsampling mode, e.g. 'deconv' or 'trilinear'.
|
|
26
|
+
resolution (optional): If provided, adjusts non-isotropic kernels to isotropic spacing.
|
|
27
|
+
preprocess (callable or None): Optional preprocessing function for the input.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
torch.nn.Module: The instantiated SegResNetDS model on the specified device.
|
|
31
|
+
"""
|
|
32
|
+
self.model = SegResNetDS(
|
|
33
|
+
spatial_dims=3,
|
|
34
|
+
init_filters=config['init_filters'],
|
|
35
|
+
in_channels=1,
|
|
36
|
+
out_channels=config['num_classes'],
|
|
37
|
+
act=config['act'],
|
|
38
|
+
norm=config['norm'],
|
|
39
|
+
blocks_down=config['blocks_down'],
|
|
40
|
+
blocks_up=config['blocks_up'],
|
|
41
|
+
dsdepth=config['dsdepth'],
|
|
42
|
+
preprocess=config['preprocess'],
|
|
43
|
+
upsample_mode=config['upsample_mode'],
|
|
44
|
+
resolution=config['resolution']
|
|
45
|
+
)
|
|
46
|
+
return self.model.to(self.device)
|
|
47
|
+
|
|
48
|
+
def bayesian_search(self, trial):
|
|
49
|
+
"""
|
|
50
|
+
Defines the Bayesian optimization search space and builds the model with suggested parameters.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
trial (optuna.trial.Trial): An Optuna trial object.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
torch.nn.Module: The model built with hyperparameters suggested by the trial.
|
|
57
|
+
"""
|
|
58
|
+
# Define search space parameters
|
|
59
|
+
init_filters = trial.suggest_categorical("init_filters", [16, 32, 64])
|
|
60
|
+
dsdepth = trial.suggest_int("dsdepth", 1, 3)
|
|
61
|
+
blocks_down = trial.suggest_categorical("blocks_down", [(1, 2, 2, 4), (1, 2, 2, 2), (1, 1, 2, 2)])
|
|
62
|
+
act = trial.suggest_categorical("act", ['relu', 'leaky_relu', "LeakyReLU", "PReLU", "GELU", "ELU"])
|
|
63
|
+
norm = trial.suggest_categorical("norm", ['batch', 'instance'])
|
|
64
|
+
upsample_mode = trial.suggest_categorical("upsample_mode", ['deconv', 'trilinear'])
|
|
65
|
+
|
|
66
|
+
self.config = {
|
|
67
|
+
'init_filters': init_filters,
|
|
68
|
+
'blocks_down': blocks_down,
|
|
69
|
+
'dsdepth': dsdepth,
|
|
70
|
+
'act': act,
|
|
71
|
+
'norm': norm,
|
|
72
|
+
'blocks_up': None, # using default upsampling blocks
|
|
73
|
+
'upsample_mode': upsample_mode,
|
|
74
|
+
'resolution': None,
|
|
75
|
+
'preprocess': None
|
|
76
|
+
}
|
|
77
|
+
return self.build_model(self.config)
|
|
78
|
+
|
|
79
|
+
def get_model_parameters(self):
|
|
80
|
+
"""
|
|
81
|
+
Retrieve stored model parameters.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
dict: A dictionary of key model parameters.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
ValueError: If the model has not been built yet.
|
|
88
|
+
"""
|
|
89
|
+
if self.model is None:
|
|
90
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
91
|
+
|
|
92
|
+
return self.config
|
octopi/models/Unet.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from monai.networks.nets import UNet
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
class myUNet:
|
|
6
|
+
def __init__(self):
|
|
7
|
+
# Placeholder for the model and config
|
|
8
|
+
self.model = None
|
|
9
|
+
self.config = None
|
|
10
|
+
|
|
11
|
+
def build_model( self, config: dict ):
|
|
12
|
+
"""Creates the Unet model based on provided parameters."""
|
|
13
|
+
|
|
14
|
+
self.config = config
|
|
15
|
+
self.model = UNet(
|
|
16
|
+
spatial_dims=3,
|
|
17
|
+
in_channels=1,
|
|
18
|
+
out_channels=config['num_classes'],
|
|
19
|
+
channels=config['channels'],
|
|
20
|
+
strides=config['strides'],
|
|
21
|
+
num_res_units=config['num_res_units'],
|
|
22
|
+
dropout=config['dropout']
|
|
23
|
+
)
|
|
24
|
+
return self.model
|
|
25
|
+
|
|
26
|
+
def bayesian_search(self, trial, num_classes: int):
|
|
27
|
+
"""Defines the Bayesian optimization search space and builds the model with suggested parameters."""
|
|
28
|
+
|
|
29
|
+
# Define the search space
|
|
30
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
31
|
+
hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
|
|
32
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
|
|
33
|
+
num_res_units = trial.suggest_int("num_res_units", 1, 3)
|
|
34
|
+
dropout = trial.suggest_float("dropout", 0.0, 0.3)
|
|
35
|
+
|
|
36
|
+
# Create channel sizes and strides
|
|
37
|
+
downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
38
|
+
hidden_channels = [downsampling_channels[-1]] * hidden_layers
|
|
39
|
+
channels = downsampling_channels + hidden_channels
|
|
40
|
+
strides = [2] * (num_layers - 1) + [1] * hidden_layers
|
|
41
|
+
|
|
42
|
+
# Create config dictionary
|
|
43
|
+
self.config = {
|
|
44
|
+
'architecture': 'Unet',
|
|
45
|
+
'num_classes': num_classes,
|
|
46
|
+
'channels': channels,
|
|
47
|
+
'strides': strides,
|
|
48
|
+
'num_res_units': num_res_units,
|
|
49
|
+
'dropout': dropout
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
return self.build_model(self.config)
|
|
53
|
+
|
|
54
|
+
def get_model_parameters(self):
|
|
55
|
+
"""Retrieve stored model parameters."""
|
|
56
|
+
if self.model is None:
|
|
57
|
+
raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
|
|
58
|
+
|
|
59
|
+
return self.config
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from monai.networks.nets import BasicUNetPlusPlus
|
|
2
|
+
|
|
3
|
+
class myUNetPlusPlus:
|
|
4
|
+
def __init__(self, num_classes, device):
|
|
5
|
+
self.device = device
|
|
6
|
+
self.num_classes = num_classes
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def build_model(
|
|
10
|
+
self,
|
|
11
|
+
features=(16, 16, 32, 64, 128, 16),
|
|
12
|
+
dropout=0.2,
|
|
13
|
+
upsample='deconv',
|
|
14
|
+
activation='relu'
|
|
15
|
+
):
|
|
16
|
+
|
|
17
|
+
model = BasicUNetPlusPlus(
|
|
18
|
+
spatial_dims=3,
|
|
19
|
+
in_channels=1,
|
|
20
|
+
out_channels=n_classes,
|
|
21
|
+
deep_supervision=True,
|
|
22
|
+
features=features, # Halve the features
|
|
23
|
+
dropout=dropout,
|
|
24
|
+
upsample=upsample,
|
|
25
|
+
act=activation
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return model.to(self.device)
|
|
29
|
+
|
|
30
|
+
def bayesian_search(self, trial):
|
|
31
|
+
|
|
32
|
+
act = trial.suggest_categorical("activation", ["LeakyReLU", "PReLU", "GELU", "ELU"])
|
|
33
|
+
dropout_prob = trial.suggest_float("dropout", 0.0, 0.5)
|
|
34
|
+
upsample = trial.suggest_categorical("upsample", ["deconv", "pixelshuffle", "nontrainable"])
|
|
35
|
+
model_parameters = {"activation": act,'dropout': dropout_prob, 'upsample': upsample}
|
|
36
|
+
|
|
37
|
+
model = self.build_model(features, dropout_prob, upsample, act)
|
|
38
|
+
return model
|
|
39
|
+
|
|
40
|
+
def get_model_parameters(self):
|
|
41
|
+
return {
|
|
42
|
+
'model_name': 'UnetPlusPlus',
|
|
43
|
+
'features': self.features,
|
|
44
|
+
'dropout': self.dropout,
|
|
45
|
+
'upsample': self.upsample,
|
|
46
|
+
'activation': self.activation
|
|
47
|
+
}
|
|
File without changes
|
octopi/models/common.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from monai.losses import FocalLoss, TverskyLoss
|
|
2
|
+
from octopi import losses
|
|
3
|
+
from octopi.models import (
|
|
4
|
+
Unet, AttentionUnet, MedNeXt, SegResNet
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
def get_model(architecture):
|
|
8
|
+
|
|
9
|
+
# Initialize model based on architecture
|
|
10
|
+
if architecture == "Unet":
|
|
11
|
+
model = Unet.myUNet()
|
|
12
|
+
elif architecture == "AttentionUnet":
|
|
13
|
+
model = AttentionUnet.myAttentionUnet()
|
|
14
|
+
elif architecture == "MedNeXt":
|
|
15
|
+
model = MedNeXt.myMedNeXt()
|
|
16
|
+
elif architecture == "SegResNet":
|
|
17
|
+
model = SegResNet.mySegResNet()
|
|
18
|
+
else:
|
|
19
|
+
raise ValueError(f"Model type {architecture} not supported!\nPlease use one of the following: Unet, AttentionUnet, MedNeXt, SegResNet")
|
|
20
|
+
|
|
21
|
+
return model
|
|
22
|
+
|
|
23
|
+
def get_loss_function(trial, loss_name = None):
|
|
24
|
+
|
|
25
|
+
# Loss function selection
|
|
26
|
+
if loss_name is None:
|
|
27
|
+
loss_name = trial.suggest_categorical(
|
|
28
|
+
"loss_function",
|
|
29
|
+
["FocalLoss", "WeightedFocalTverskyLoss", 'FocalTverskyLoss'])
|
|
30
|
+
|
|
31
|
+
if loss_name == "FocalLoss":
|
|
32
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
33
|
+
loss_function = FocalLoss(include_background=True, to_onehot_y=True, use_softmax=True, gamma=gamma)
|
|
34
|
+
|
|
35
|
+
elif loss_name == "TverskyLoss":
|
|
36
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
37
|
+
beta = 1.0 - alpha
|
|
38
|
+
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
|
|
39
|
+
|
|
40
|
+
elif loss_name == 'WeightedFocalTverskyLoss':
|
|
41
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
42
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
43
|
+
beta = 1.0 - alpha
|
|
44
|
+
weight_tversky = round(trial.suggest_float("weight_tversky", 0.1, 0.9), 3)
|
|
45
|
+
weight_focal = 1.0 - weight_tversky
|
|
46
|
+
loss_function = losses.WeightedFocalTverskyLoss(
|
|
47
|
+
gamma=gamma, alpha=alpha, beta=beta,
|
|
48
|
+
weight_tversky=weight_tversky, weight_focal=weight_focal
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
elif loss_name == 'FocalTverskyLoss':
|
|
52
|
+
gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
|
|
53
|
+
alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
|
|
54
|
+
beta = 1.0 - alpha
|
|
55
|
+
loss_function = losses.FocalTverskyLoss(gamma=gamma, alpha=alpha, beta=beta)
|
|
56
|
+
|
|
57
|
+
return loss_function
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
#### TODO : Models to try Adding?
|
|
61
|
+
# 1. Swin UNETR
|
|
62
|
+
# 2. Swin-Conv-UNet
|
|
File without changes
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from octopi.processing.segmentation_from_picks import from_picks
|
|
2
|
+
import octopi.processing.writers as write
|
|
3
|
+
from octopi import io
|
|
4
|
+
from typing import List
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
def generate_targets(
|
|
9
|
+
root,
|
|
10
|
+
train_targets: dict,
|
|
11
|
+
voxel_size: float = 10,
|
|
12
|
+
tomo_algorithm: str = 'wbp',
|
|
13
|
+
radius_scale: float = 0.8,
|
|
14
|
+
target_segmentation_name: str = 'targets',
|
|
15
|
+
target_user_name: str = 'monai',
|
|
16
|
+
target_session_id: str = '1',
|
|
17
|
+
run_ids: List[str] = None,
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Generate segmentation targets from picks in CoPick configuration.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
copick_config_path (str): Path to CoPick configuration file.
|
|
24
|
+
picks_user_id (str): User ID associated with picks.
|
|
25
|
+
picks_session_id (str): Session ID associated with picks.
|
|
26
|
+
target_segmentation_name (str): Name for the target segmentation.
|
|
27
|
+
target_user_name (str): User name associated with target segmentation.
|
|
28
|
+
target_session_id (str): Session ID for the target segmentation.
|
|
29
|
+
voxel_size (float): Voxel size for tomogram reconstruction.
|
|
30
|
+
tomo_algorithm (str): Tomogram reconstruction algorithm.
|
|
31
|
+
radius_scale (float): Scale factor for target object radius.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Default session ID to 1 if not provided
|
|
35
|
+
if target_session_id is None:
|
|
36
|
+
target_session_id = '1'
|
|
37
|
+
|
|
38
|
+
print('Creating Targets for the following objects:', ', '.join(train_targets.keys()))
|
|
39
|
+
|
|
40
|
+
# Get Target Names
|
|
41
|
+
target_names = list(train_targets.keys())
|
|
42
|
+
|
|
43
|
+
# If runIDs are not provided, load all runs
|
|
44
|
+
if run_ids is None:
|
|
45
|
+
run_ids = [run.name for run in root.runs]
|
|
46
|
+
|
|
47
|
+
# Iterate Over All Runs
|
|
48
|
+
for runID in tqdm(run_ids):
|
|
49
|
+
|
|
50
|
+
# Get Run
|
|
51
|
+
numPicks = 0
|
|
52
|
+
run = root.get_run(runID)
|
|
53
|
+
|
|
54
|
+
# Get Tomogram
|
|
55
|
+
tomo = io.get_tomogram_array(run, voxel_size, tomo_algorithm)
|
|
56
|
+
|
|
57
|
+
# Initialize Target Volume
|
|
58
|
+
target = np.zeros(tomo.shape, dtype=np.uint8)
|
|
59
|
+
|
|
60
|
+
# Generate Targets
|
|
61
|
+
# Applicable segmentations
|
|
62
|
+
query_seg = []
|
|
63
|
+
for target_name in target_names:
|
|
64
|
+
if not train_targets[target_name]["is_particle_target"]:
|
|
65
|
+
query_seg += run.get_segmentations(
|
|
66
|
+
name=target_name,
|
|
67
|
+
user_id=train_targets[target_name]["user_id"],
|
|
68
|
+
session_id=train_targets[target_name]["session_id"],
|
|
69
|
+
voxel_size=voxel_size
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Add Segmentations to Target
|
|
73
|
+
for seg in query_seg:
|
|
74
|
+
classLabel = root.get_object(seg.name).label
|
|
75
|
+
segvol = seg.numpy()
|
|
76
|
+
# Set all non-zero values to the class label
|
|
77
|
+
segvol[segvol > 0] = classLabel
|
|
78
|
+
target[:] = segvol
|
|
79
|
+
|
|
80
|
+
# Applicable picks
|
|
81
|
+
query = []
|
|
82
|
+
for target_name in target_names:
|
|
83
|
+
if train_targets[target_name]["is_particle_target"]:
|
|
84
|
+
query += run.get_picks(
|
|
85
|
+
object_name=target_name,
|
|
86
|
+
user_id=train_targets[target_name]["user_id"],
|
|
87
|
+
session_id=train_targets[target_name]["session_id"],
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Add Picks to Target
|
|
91
|
+
for pick in query:
|
|
92
|
+
numPicks += len(pick.points)
|
|
93
|
+
target = from_picks(pick,
|
|
94
|
+
target,
|
|
95
|
+
train_targets[pick.pickable_object_name]['radius'] * radius_scale,
|
|
96
|
+
train_targets[pick.pickable_object_name]['label'],
|
|
97
|
+
voxel_size
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Write Segmentation for non-empty targets
|
|
101
|
+
if target.max() > 0 and numPicks > 0:
|
|
102
|
+
tqdm.write(f'Annotating {numPicks} picks in {runID}...')
|
|
103
|
+
write.segmentation(run, target, target_user_name,
|
|
104
|
+
name = target_segmentation_name, session_id= target_session_id,
|
|
105
|
+
voxel_size = voxel_size)
|
|
106
|
+
print('Creation of targets complete!')
|