octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/main.py ADDED
@@ -0,0 +1,33 @@
1
+ import rich_click as click
2
+ from octopi import cli_context
3
+ from octopi.entry_points import groups
4
+ from octopi.processing.downloader import cli as download
5
+ from octopi.processing.importers import cli as import_tomograms
6
+ from octopi.entry_points.run_train import cli as train_model
7
+ from octopi.entry_points.run_optuna import cli as model_explore
8
+ from octopi.entry_points.run_create_targets import cli as create_targets
9
+ from octopi.entry_points.run_segment import cli as inference
10
+ from octopi.entry_points.run_localize import cli as localize
11
+ from octopi.entry_points.run_evaluate import cli as evaluate
12
+ from octopi.entry_points.run_extract_mb_picks import cli as mb_extract
13
+
14
+ @click.group(context_settings=cli_context)
15
+ def routines():
16
+ """Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data"""
17
+ pass
18
+
19
+ routines.add_command(download)
20
+ routines.add_command(import_tomograms)
21
+ routines.add_command(train_model)
22
+ routines.add_command(create_targets)
23
+ routines.add_command(inference)
24
+ routines.add_command(localize)
25
+ routines.add_command(model_explore)
26
+ routines.add_command(evaluate)
27
+ routines.add_command(mb_extract)
28
+
29
+ @click.group(context_settings=cli_context)
30
+ def slurm_routines():
31
+ """Slurm-Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data"""
32
+ pass
33
+
@@ -0,0 +1,56 @@
1
+ from monai.networks.nets import AttentionUnet
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myAttentionUnet:
6
+ def __init__(self):
7
+
8
+ # Placeholder for the model and config
9
+ self.model = None
10
+ self.config = None
11
+
12
+ def build_model( self, config: dict ):
13
+ """Creates the AttentionUnet model based on provided parameters."""
14
+
15
+ self.model = AttentionUnet(
16
+ spatial_dims=3,
17
+ in_channels=1,
18
+ out_channels=config['num_classes'],
19
+ channels=config['channels'],
20
+ strides=config['strides'],
21
+ dropout=config['dropout']
22
+ )
23
+
24
+ return self.model
25
+
26
+ def bayesian_search(self, trial, num_classes: int):
27
+ """Defines the Bayesian optimization search space and builds the model with suggested parameters."""
28
+
29
+ # Define the search space
30
+ num_layers = trial.suggest_int("num_layers", 3, 5)
31
+ hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
32
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
33
+
34
+ # Create channel sizes and strides
35
+ downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
36
+ hidden_channels = [downsampling_channels[-1]] * hidden_layers
37
+ channels = downsampling_channels + hidden_channels
38
+ strides = [2] * (num_layers - 1) + [1] * hidden_layers
39
+ dropout = trial.suggest_float("dropout", 0.0, 0.2)
40
+
41
+ config = {
42
+ 'architecture': 'AttentionUnet',
43
+ 'num_classes': num_classes,
44
+ 'channels': channels,
45
+ 'strides': strides,
46
+ 'dropout': dropout
47
+ }
48
+
49
+ return self.build_model(config)
50
+
51
+ def get_model_parameters(self):
52
+ """Retrieve stored model parameters."""
53
+ if self.model is None:
54
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
55
+
56
+ return self.config
@@ -0,0 +1,111 @@
1
+ from monai.networks.nets import MedNeXt
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myMedNeXt:
6
+ def __init__(self ):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model(
12
+ self,
13
+ init_filters=32,
14
+ encoder_expansion_ratio=2,
15
+ decoder_expansion_ratio=2,
16
+ bottleneck_expansion_ratio=2,
17
+ kernel_size=7,
18
+ deep_supervision=False,
19
+ use_residual_connection=False,
20
+ norm_type="group",
21
+ global_resp_norm=False,
22
+ blocks_down=(2, 2, 2, 2),
23
+ blocks_bottleneck=2,
24
+ blocks_up=(2, 2, 2, 2),
25
+ ):
26
+ """
27
+ Create the MedNeXt model with the specified hyperparameters.
28
+ Note: For cryoET with small objects, a shallower network might help preserve details.
29
+ """
30
+ self.model = MedNeXt(
31
+ spatial_dims=3,
32
+ in_channels=1,
33
+ out_channels=self.num_classes,
34
+ init_filters=init_filters,
35
+ encoder_expansion_ratio=encoder_expansion_ratio,
36
+ decoder_expansion_ratio=decoder_expansion_ratio,
37
+ bottleneck_expansion_ratio=bottleneck_expansion_ratio,
38
+ kernel_size=kernel_size,
39
+ deep_supervision=deep_supervision,
40
+ use_residual_connection=use_residual_connection,
41
+ norm_type=norm_type,
42
+ global_resp_norm=global_resp_norm,
43
+ blocks_down=blocks_down,
44
+ blocks_bottleneck=blocks_bottleneck,
45
+ blocks_up=blocks_up,
46
+ ).to(self.device)
47
+
48
+ def bayesian_search(self, trial):
49
+ """
50
+ Defines the Bayesian optimization search space and builds the model with suggested parameters.
51
+ The search space has been adapted for cryoET applications:
52
+ - Small kernel sizes (3 or 5) to capture fine details.
53
+ - Choice of a shallower vs. deeper architecture to balance resolution and feature extraction.
54
+ - Robust normalization options for low signal-to-noise data.
55
+ """
56
+ # Core hyperparameters
57
+ init_filters = trial.suggest_categorical("init_filters", [16, 32])
58
+ encoder_expansion_ratio = trial.suggest_int("encoder_expansion_ratio", 1, 3)
59
+ decoder_expansion_ratio = trial.suggest_int("decoder_expansion_ratio", 1, 3)
60
+ bottleneck_expansion_ratio = trial.suggest_int("bottleneck_expansion_ratio", 1, 4)
61
+ kernel_size = trial.suggest_categorical("kernel_size", [3, 5])
62
+ deep_supervision = trial.suggest_categorical("deep_supervision", [True, False])
63
+ norm_type = trial.suggest_categorical("norm_type", ["group", "instance"])
64
+ # For extremely low SNR, you might opt to disable global response normalization
65
+ global_resp_norm = trial.suggest_categorical("global_resp_norm", [False])
66
+
67
+ # Architecture: shallow vs. deep.
68
+ # For small objects, a shallower network (fewer downsampling stages) may preserve spatial detail.
69
+ architecture = trial.suggest_categorical("architecture", ["shallow", "deep"])
70
+ if architecture == "shallow":
71
+ blocks_down = (2, 2, 2) # 3 downsampling stages
72
+ blocks_bottleneck = 2
73
+ blocks_up = (2, 2, 2)
74
+ else:
75
+ blocks_down = (2, 2, 2, 2) # 4 downsampling stages
76
+ blocks_bottleneck = 2
77
+ blocks_up = (2, 2, 2, 2)
78
+
79
+ self.build_model(
80
+ init_filters=init_filters,
81
+ encoder_expansion_ratio=encoder_expansion_ratio,
82
+ decoder_expansion_ratio=decoder_expansion_ratio,
83
+ bottleneck_expansion_ratio=bottleneck_expansion_ratio,
84
+ kernel_size=kernel_size,
85
+ deep_supervision=deep_supervision,
86
+ use_residual_connection=True,
87
+ norm_type=norm_type,
88
+ global_resp_norm=False,
89
+ blocks_down=blocks_down,
90
+ blocks_bottleneck=blocks_bottleneck,
91
+ blocks_up=blocks_up,
92
+ )
93
+
94
+ def get_model_parameters(self):
95
+ """Retrieve stored model parameters."""
96
+ if self.model is None:
97
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
98
+
99
+ return {
100
+ 'architecture': 'MedNeXt',
101
+ 'num_classes': self.num_classes,
102
+ 'init_filters': self.model.init_filters,
103
+ 'encoder_expansion_ratio': self.model.encoder_expansion_ratio,
104
+ 'decoder_expansion_ratio': self.model.decoder_expansion_ratio,
105
+ 'bottleneck_expansion_ratio': self.model.bottleneck_expansion_ratio,
106
+ 'kernel_size': self.model.kernel_size,
107
+ 'deep_supervision': self.model.do_ds,
108
+ 'use_residual_connection': self.model.use_residual_connection,
109
+ 'norm_type': self.model.norm_type,
110
+ 'global_resp_norm': self.model.global_resp_norm,
111
+ }
@@ -0,0 +1,36 @@
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class myModelTemplate:
5
+ def __init__(self):
6
+ """
7
+ Initialize the model template.
8
+ """
9
+ # Placeholder for the model and config
10
+ self.model = None
11
+ self.config = None
12
+
13
+ def build_model(self, config: dict):
14
+ """
15
+ Build the model based on provided parameters in a config dictionary.
16
+ """
17
+ pass
18
+
19
+ def bayesian_search(self, trial, num_classes: int):
20
+ """
21
+ Define the hyperparameter search space for Bayesian optimization and build the model.
22
+
23
+ The search space below is just an example and can be customized.
24
+
25
+ Args:
26
+ trial (optuna.trial.Trial): Optuna trial object.
27
+ num_classes (int): Number of classes in the dataset.
28
+ """
29
+ pass
30
+
31
+ def get_model_parameters(self):
32
+ """Retrieve stored model parameters."""
33
+ if self.model is None:
34
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
35
+
36
+ return self.config
@@ -0,0 +1,92 @@
1
+ from monai.networks.nets import SegResNetDS
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class mySegResNet:
6
+ def __init__(self):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model(
12
+ self,
13
+ config: dict
14
+ ):
15
+ """
16
+ Creates the SegResNetDS model based on provided parameters.
17
+
18
+ Args:
19
+ init_filters (int): Number of output channels for the initial convolution.
20
+ blocks_down (tuple): Tuple defining the number of blocks at each downsampling stage.
21
+ dsdepth (int): Depth for deep supervision (number of output scales).
22
+ act (str): Activation type.
23
+ norm (str): Normalization type.
24
+ blocks_up (tuple or None): Number of upsample blocks (if None, uses default behavior).
25
+ upsample_mode (str): Upsampling mode, e.g. 'deconv' or 'trilinear'.
26
+ resolution (optional): If provided, adjusts non-isotropic kernels to isotropic spacing.
27
+ preprocess (callable or None): Optional preprocessing function for the input.
28
+
29
+ Returns:
30
+ torch.nn.Module: The instantiated SegResNetDS model on the specified device.
31
+ """
32
+ self.model = SegResNetDS(
33
+ spatial_dims=3,
34
+ init_filters=config['init_filters'],
35
+ in_channels=1,
36
+ out_channels=config['num_classes'],
37
+ act=config['act'],
38
+ norm=config['norm'],
39
+ blocks_down=config['blocks_down'],
40
+ blocks_up=config['blocks_up'],
41
+ dsdepth=config['dsdepth'],
42
+ preprocess=config['preprocess'],
43
+ upsample_mode=config['upsample_mode'],
44
+ resolution=config['resolution']
45
+ )
46
+ return self.model.to(self.device)
47
+
48
+ def bayesian_search(self, trial):
49
+ """
50
+ Defines the Bayesian optimization search space and builds the model with suggested parameters.
51
+
52
+ Args:
53
+ trial (optuna.trial.Trial): An Optuna trial object.
54
+
55
+ Returns:
56
+ torch.nn.Module: The model built with hyperparameters suggested by the trial.
57
+ """
58
+ # Define search space parameters
59
+ init_filters = trial.suggest_categorical("init_filters", [16, 32, 64])
60
+ dsdepth = trial.suggest_int("dsdepth", 1, 3)
61
+ blocks_down = trial.suggest_categorical("blocks_down", [(1, 2, 2, 4), (1, 2, 2, 2), (1, 1, 2, 2)])
62
+ act = trial.suggest_categorical("act", ['relu', 'leaky_relu', "LeakyReLU", "PReLU", "GELU", "ELU"])
63
+ norm = trial.suggest_categorical("norm", ['batch', 'instance'])
64
+ upsample_mode = trial.suggest_categorical("upsample_mode", ['deconv', 'trilinear'])
65
+
66
+ self.config = {
67
+ 'init_filters': init_filters,
68
+ 'blocks_down': blocks_down,
69
+ 'dsdepth': dsdepth,
70
+ 'act': act,
71
+ 'norm': norm,
72
+ 'blocks_up': None, # using default upsampling blocks
73
+ 'upsample_mode': upsample_mode,
74
+ 'resolution': None,
75
+ 'preprocess': None
76
+ }
77
+ return self.build_model(self.config)
78
+
79
+ def get_model_parameters(self):
80
+ """
81
+ Retrieve stored model parameters.
82
+
83
+ Returns:
84
+ dict: A dictionary of key model parameters.
85
+
86
+ Raises:
87
+ ValueError: If the model has not been built yet.
88
+ """
89
+ if self.model is None:
90
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
91
+
92
+ return self.config
octopi/models/Unet.py ADDED
@@ -0,0 +1,59 @@
1
+ from monai.networks.nets import UNet
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myUNet:
6
+ def __init__(self):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model( self, config: dict ):
12
+ """Creates the Unet model based on provided parameters."""
13
+
14
+ self.config = config
15
+ self.model = UNet(
16
+ spatial_dims=3,
17
+ in_channels=1,
18
+ out_channels=config['num_classes'],
19
+ channels=config['channels'],
20
+ strides=config['strides'],
21
+ num_res_units=config['num_res_units'],
22
+ dropout=config['dropout']
23
+ )
24
+ return self.model
25
+
26
+ def bayesian_search(self, trial, num_classes: int):
27
+ """Defines the Bayesian optimization search space and builds the model with suggested parameters."""
28
+
29
+ # Define the search space
30
+ num_layers = trial.suggest_int("num_layers", 3, 5)
31
+ hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
32
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
33
+ num_res_units = trial.suggest_int("num_res_units", 1, 3)
34
+ dropout = trial.suggest_float("dropout", 0.0, 0.3)
35
+
36
+ # Create channel sizes and strides
37
+ downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
38
+ hidden_channels = [downsampling_channels[-1]] * hidden_layers
39
+ channels = downsampling_channels + hidden_channels
40
+ strides = [2] * (num_layers - 1) + [1] * hidden_layers
41
+
42
+ # Create config dictionary
43
+ self.config = {
44
+ 'architecture': 'Unet',
45
+ 'num_classes': num_classes,
46
+ 'channels': channels,
47
+ 'strides': strides,
48
+ 'num_res_units': num_res_units,
49
+ 'dropout': dropout
50
+ }
51
+
52
+ return self.build_model(self.config)
53
+
54
+ def get_model_parameters(self):
55
+ """Retrieve stored model parameters."""
56
+ if self.model is None:
57
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
58
+
59
+ return self.config
@@ -0,0 +1,47 @@
1
+ from monai.networks.nets import BasicUNetPlusPlus
2
+
3
+ class myUNetPlusPlus:
4
+ def __init__(self, num_classes, device):
5
+ self.device = device
6
+ self.num_classes = num_classes
7
+
8
+
9
+ def build_model(
10
+ self,
11
+ features=(16, 16, 32, 64, 128, 16),
12
+ dropout=0.2,
13
+ upsample='deconv',
14
+ activation='relu'
15
+ ):
16
+
17
+ model = BasicUNetPlusPlus(
18
+ spatial_dims=3,
19
+ in_channels=1,
20
+ out_channels=n_classes,
21
+ deep_supervision=True,
22
+ features=features, # Halve the features
23
+ dropout=dropout,
24
+ upsample=upsample,
25
+ act=activation
26
+ )
27
+
28
+ return model.to(self.device)
29
+
30
+ def bayesian_search(self, trial):
31
+
32
+ act = trial.suggest_categorical("activation", ["LeakyReLU", "PReLU", "GELU", "ELU"])
33
+ dropout_prob = trial.suggest_float("dropout", 0.0, 0.5)
34
+ upsample = trial.suggest_categorical("upsample", ["deconv", "pixelshuffle", "nontrainable"])
35
+ model_parameters = {"activation": act,'dropout': dropout_prob, 'upsample': upsample}
36
+
37
+ model = self.build_model(features, dropout_prob, upsample, act)
38
+ return model
39
+
40
+ def get_model_parameters(self):
41
+ return {
42
+ 'model_name': 'UnetPlusPlus',
43
+ 'features': self.features,
44
+ 'dropout': self.dropout,
45
+ 'upsample': self.upsample,
46
+ 'activation': self.activation
47
+ }
File without changes
@@ -0,0 +1,72 @@
1
+ from monai.losses import FocalLoss, TverskyLoss
2
+ from octopi.utils import losses
3
+ from octopi.models import (
4
+ Unet, AttentionUnet, MedNeXt, SegResNet
5
+ )
6
+
7
+ def get_model(architecture):
8
+
9
+ # Initialize model based on architecture
10
+ if architecture == "Unet":
11
+ model = Unet.myUNet()
12
+ elif architecture == "AttentionUnet":
13
+ model = AttentionUnet.myAttentionUnet()
14
+ elif architecture == "MedNeXt":
15
+ model = MedNeXt.myMedNeXt()
16
+ elif architecture == "SegResNet":
17
+ model = SegResNet.mySegResNet()
18
+ else:
19
+ raise ValueError(f"Model type {architecture} not supported!\nPlease use one of the following: Unet, AttentionUnet, MedNeXt, SegResNet")
20
+
21
+ return model
22
+
23
+ def get_loss_function(trial, loss_name = None):
24
+
25
+ # Loss function selection
26
+ if loss_name is None:
27
+ loss_name = trial.suggest_categorical(
28
+ "loss_function",
29
+ ["FocalLoss", "WeightedFocalTverskyLoss", 'FocalTverskyLoss'])
30
+
31
+ if loss_name == "FocalLoss":
32
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
33
+ loss_function = FocalLoss(include_background=True, to_onehot_y=True, use_softmax=True, gamma=gamma)
34
+
35
+ elif loss_name == "TverskyLoss":
36
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
37
+ beta = 1.0 - alpha
38
+ loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
39
+
40
+ elif loss_name == 'WeightedFocalTverskyLoss':
41
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
42
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
43
+ beta = 1.0 - alpha
44
+ weight_tversky = round(trial.suggest_float("weight_tversky", 0.1, 0.9), 3)
45
+ weight_focal = 1.0 - weight_tversky
46
+ loss_function = losses.WeightedFocalTverskyLoss(
47
+ gamma=gamma, alpha=alpha, beta=beta,
48
+ weight_tversky=weight_tversky, weight_focal=weight_focal
49
+ )
50
+
51
+ elif loss_name == 'FocalTverskyLoss':
52
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
53
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
54
+ beta = 1.0 - alpha
55
+ loss_function = losses.FocalTverskyLoss(gamma=gamma, alpha=alpha, beta=beta)
56
+
57
+ return loss_function
58
+
59
+ def get_default_unet_params():
60
+
61
+ model_config = {
62
+ 'architecture': 'Unet',
63
+ 'dim_in': 80,
64
+ 'strides': [2, 2, 1],
65
+ 'channels': [48, 64, 80, 80],
66
+ 'dropout': 0.0, 'num_res_units': 1,
67
+ }
68
+ return model_config
69
+
70
+ #### TODO : Models to try Adding?
71
+ # 1. Swin UNETR
72
+ # 2. Swin-Conv-UNet
File without changes