octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,56 @@
1
+ from monai.networks.nets import AttentionUnet
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myAttentionUnet:
6
+ def __init__(self):
7
+
8
+ # Placeholder for the model and config
9
+ self.model = None
10
+ self.config = None
11
+
12
+ def build_model( self, config: dict ):
13
+ """Creates the AttentionUnet model based on provided parameters."""
14
+
15
+ self.model = AttentionUnet(
16
+ spatial_dims=3,
17
+ in_channels=1,
18
+ out_channels=config['num_classes'],
19
+ channels=config['channels'],
20
+ strides=config['strides'],
21
+ dropout=config['dropout']
22
+ )
23
+
24
+ return self.model
25
+
26
+ def bayesian_search(self, trial, num_classes: int):
27
+ """Defines the Bayesian optimization search space and builds the model with suggested parameters."""
28
+
29
+ # Define the search space
30
+ num_layers = trial.suggest_int("num_layers", 3, 5)
31
+ hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
32
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
33
+
34
+ # Create channel sizes and strides
35
+ downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
36
+ hidden_channels = [downsampling_channels[-1]] * hidden_layers
37
+ channels = downsampling_channels + hidden_channels
38
+ strides = [2] * (num_layers - 1) + [1] * hidden_layers
39
+ dropout = trial.suggest_float("dropout", 0.0, 0.2)
40
+
41
+ config = {
42
+ 'architecture': 'AttentionUnet',
43
+ 'num_classes': num_classes,
44
+ 'channels': channels,
45
+ 'strides': strides,
46
+ 'dropout': dropout
47
+ }
48
+
49
+ return self.build_model(config)
50
+
51
+ def get_model_parameters(self):
52
+ """Retrieve stored model parameters."""
53
+ if self.model is None:
54
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
55
+
56
+ return self.config
@@ -0,0 +1,111 @@
1
+ from monai.networks.nets import MedNeXt
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myMedNeXt:
6
+ def __init__(self ):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model(
12
+ self,
13
+ init_filters=32,
14
+ encoder_expansion_ratio=2,
15
+ decoder_expansion_ratio=2,
16
+ bottleneck_expansion_ratio=2,
17
+ kernel_size=7,
18
+ deep_supervision=False,
19
+ use_residual_connection=False,
20
+ norm_type="group",
21
+ global_resp_norm=False,
22
+ blocks_down=(2, 2, 2, 2),
23
+ blocks_bottleneck=2,
24
+ blocks_up=(2, 2, 2, 2),
25
+ ):
26
+ """
27
+ Create the MedNeXt model with the specified hyperparameters.
28
+ Note: For cryoET with small objects, a shallower network might help preserve details.
29
+ """
30
+ self.model = MedNeXt(
31
+ spatial_dims=3,
32
+ in_channels=1,
33
+ out_channels=self.num_classes,
34
+ init_filters=init_filters,
35
+ encoder_expansion_ratio=encoder_expansion_ratio,
36
+ decoder_expansion_ratio=decoder_expansion_ratio,
37
+ bottleneck_expansion_ratio=bottleneck_expansion_ratio,
38
+ kernel_size=kernel_size,
39
+ deep_supervision=deep_supervision,
40
+ use_residual_connection=use_residual_connection,
41
+ norm_type=norm_type,
42
+ global_resp_norm=global_resp_norm,
43
+ blocks_down=blocks_down,
44
+ blocks_bottleneck=blocks_bottleneck,
45
+ blocks_up=blocks_up,
46
+ ).to(self.device)
47
+
48
+ def bayesian_search(self, trial):
49
+ """
50
+ Defines the Bayesian optimization search space and builds the model with suggested parameters.
51
+ The search space has been adapted for cryoET applications:
52
+ - Small kernel sizes (3 or 5) to capture fine details.
53
+ - Choice of a shallower vs. deeper architecture to balance resolution and feature extraction.
54
+ - Robust normalization options for low signal-to-noise data.
55
+ """
56
+ # Core hyperparameters
57
+ init_filters = trial.suggest_categorical("init_filters", [16, 32])
58
+ encoder_expansion_ratio = trial.suggest_int("encoder_expansion_ratio", 1, 3)
59
+ decoder_expansion_ratio = trial.suggest_int("decoder_expansion_ratio", 1, 3)
60
+ bottleneck_expansion_ratio = trial.suggest_int("bottleneck_expansion_ratio", 1, 4)
61
+ kernel_size = trial.suggest_categorical("kernel_size", [3, 5])
62
+ deep_supervision = trial.suggest_categorical("deep_supervision", [True, False])
63
+ norm_type = trial.suggest_categorical("norm_type", ["group", "instance"])
64
+ # For extremely low SNR, you might opt to disable global response normalization
65
+ global_resp_norm = trial.suggest_categorical("global_resp_norm", [False])
66
+
67
+ # Architecture: shallow vs. deep.
68
+ # For small objects, a shallower network (fewer downsampling stages) may preserve spatial detail.
69
+ architecture = trial.suggest_categorical("architecture", ["shallow", "deep"])
70
+ if architecture == "shallow":
71
+ blocks_down = (2, 2, 2) # 3 downsampling stages
72
+ blocks_bottleneck = 2
73
+ blocks_up = (2, 2, 2)
74
+ else:
75
+ blocks_down = (2, 2, 2, 2) # 4 downsampling stages
76
+ blocks_bottleneck = 2
77
+ blocks_up = (2, 2, 2, 2)
78
+
79
+ self.build_model(
80
+ init_filters=init_filters,
81
+ encoder_expansion_ratio=encoder_expansion_ratio,
82
+ decoder_expansion_ratio=decoder_expansion_ratio,
83
+ bottleneck_expansion_ratio=bottleneck_expansion_ratio,
84
+ kernel_size=kernel_size,
85
+ deep_supervision=deep_supervision,
86
+ use_residual_connection=True,
87
+ norm_type=norm_type,
88
+ global_resp_norm=False,
89
+ blocks_down=blocks_down,
90
+ blocks_bottleneck=blocks_bottleneck,
91
+ blocks_up=blocks_up,
92
+ )
93
+
94
+ def get_model_parameters(self):
95
+ """Retrieve stored model parameters."""
96
+ if self.model is None:
97
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
98
+
99
+ return {
100
+ 'architecture': 'MedNeXt',
101
+ 'num_classes': self.num_classes,
102
+ 'init_filters': self.model.init_filters,
103
+ 'encoder_expansion_ratio': self.model.encoder_expansion_ratio,
104
+ 'decoder_expansion_ratio': self.model.decoder_expansion_ratio,
105
+ 'bottleneck_expansion_ratio': self.model.bottleneck_expansion_ratio,
106
+ 'kernel_size': self.model.kernel_size,
107
+ 'deep_supervision': self.model.do_ds,
108
+ 'use_residual_connection': self.model.use_residual_connection,
109
+ 'norm_type': self.model.norm_type,
110
+ 'global_resp_norm': self.model.global_resp_norm,
111
+ }
@@ -0,0 +1,36 @@
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class myModelTemplate:
5
+ def __init__(self):
6
+ """
7
+ Initialize the model template.
8
+ """
9
+ # Placeholder for the model and config
10
+ self.model = None
11
+ self.config = None
12
+
13
+ def build_model(self, config: dict):
14
+ """
15
+ Build the model based on provided parameters in a config dictionary.
16
+ """
17
+ pass
18
+
19
+ def bayesian_search(self, trial, num_classes: int):
20
+ """
21
+ Define the hyperparameter search space for Bayesian optimization and build the model.
22
+
23
+ The search space below is just an example and can be customized.
24
+
25
+ Args:
26
+ trial (optuna.trial.Trial): Optuna trial object.
27
+ num_classes (int): Number of classes in the dataset.
28
+ """
29
+ pass
30
+
31
+ def get_model_parameters(self):
32
+ """Retrieve stored model parameters."""
33
+ if self.model is None:
34
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
35
+
36
+ return self.config
@@ -0,0 +1,92 @@
1
+ from monai.networks.nets import SegResNetDS
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class mySegResNet:
6
+ def __init__(self):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model(
12
+ self,
13
+ config: dict
14
+ ):
15
+ """
16
+ Creates the SegResNetDS model based on provided parameters.
17
+
18
+ Args:
19
+ init_filters (int): Number of output channels for the initial convolution.
20
+ blocks_down (tuple): Tuple defining the number of blocks at each downsampling stage.
21
+ dsdepth (int): Depth for deep supervision (number of output scales).
22
+ act (str): Activation type.
23
+ norm (str): Normalization type.
24
+ blocks_up (tuple or None): Number of upsample blocks (if None, uses default behavior).
25
+ upsample_mode (str): Upsampling mode, e.g. 'deconv' or 'trilinear'.
26
+ resolution (optional): If provided, adjusts non-isotropic kernels to isotropic spacing.
27
+ preprocess (callable or None): Optional preprocessing function for the input.
28
+
29
+ Returns:
30
+ torch.nn.Module: The instantiated SegResNetDS model on the specified device.
31
+ """
32
+ self.model = SegResNetDS(
33
+ spatial_dims=3,
34
+ init_filters=config['init_filters'],
35
+ in_channels=1,
36
+ out_channels=config['num_classes'],
37
+ act=config['act'],
38
+ norm=config['norm'],
39
+ blocks_down=config['blocks_down'],
40
+ blocks_up=config['blocks_up'],
41
+ dsdepth=config['dsdepth'],
42
+ preprocess=config['preprocess'],
43
+ upsample_mode=config['upsample_mode'],
44
+ resolution=config['resolution']
45
+ )
46
+ return self.model.to(self.device)
47
+
48
+ def bayesian_search(self, trial):
49
+ """
50
+ Defines the Bayesian optimization search space and builds the model with suggested parameters.
51
+
52
+ Args:
53
+ trial (optuna.trial.Trial): An Optuna trial object.
54
+
55
+ Returns:
56
+ torch.nn.Module: The model built with hyperparameters suggested by the trial.
57
+ """
58
+ # Define search space parameters
59
+ init_filters = trial.suggest_categorical("init_filters", [16, 32, 64])
60
+ dsdepth = trial.suggest_int("dsdepth", 1, 3)
61
+ blocks_down = trial.suggest_categorical("blocks_down", [(1, 2, 2, 4), (1, 2, 2, 2), (1, 1, 2, 2)])
62
+ act = trial.suggest_categorical("act", ['relu', 'leaky_relu', "LeakyReLU", "PReLU", "GELU", "ELU"])
63
+ norm = trial.suggest_categorical("norm", ['batch', 'instance'])
64
+ upsample_mode = trial.suggest_categorical("upsample_mode", ['deconv', 'trilinear'])
65
+
66
+ self.config = {
67
+ 'init_filters': init_filters,
68
+ 'blocks_down': blocks_down,
69
+ 'dsdepth': dsdepth,
70
+ 'act': act,
71
+ 'norm': norm,
72
+ 'blocks_up': None, # using default upsampling blocks
73
+ 'upsample_mode': upsample_mode,
74
+ 'resolution': None,
75
+ 'preprocess': None
76
+ }
77
+ return self.build_model(self.config)
78
+
79
+ def get_model_parameters(self):
80
+ """
81
+ Retrieve stored model parameters.
82
+
83
+ Returns:
84
+ dict: A dictionary of key model parameters.
85
+
86
+ Raises:
87
+ ValueError: If the model has not been built yet.
88
+ """
89
+ if self.model is None:
90
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
91
+
92
+ return self.config
octopi/models/Unet.py ADDED
@@ -0,0 +1,59 @@
1
+ from monai.networks.nets import UNet
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ class myUNet:
6
+ def __init__(self):
7
+ # Placeholder for the model and config
8
+ self.model = None
9
+ self.config = None
10
+
11
+ def build_model( self, config: dict ):
12
+ """Creates the Unet model based on provided parameters."""
13
+
14
+ self.config = config
15
+ self.model = UNet(
16
+ spatial_dims=3,
17
+ in_channels=1,
18
+ out_channels=config['num_classes'],
19
+ channels=config['channels'],
20
+ strides=config['strides'],
21
+ num_res_units=config['num_res_units'],
22
+ dropout=config['dropout']
23
+ )
24
+ return self.model
25
+
26
+ def bayesian_search(self, trial, num_classes: int):
27
+ """Defines the Bayesian optimization search space and builds the model with suggested parameters."""
28
+
29
+ # Define the search space
30
+ num_layers = trial.suggest_int("num_layers", 3, 5)
31
+ hidden_layers = trial.suggest_int("hidden_layers", 1, 3)
32
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32])
33
+ num_res_units = trial.suggest_int("num_res_units", 1, 3)
34
+ dropout = trial.suggest_float("dropout", 0.0, 0.3)
35
+
36
+ # Create channel sizes and strides
37
+ downsampling_channels = [base_channel * (2 ** i) for i in range(num_layers)]
38
+ hidden_channels = [downsampling_channels[-1]] * hidden_layers
39
+ channels = downsampling_channels + hidden_channels
40
+ strides = [2] * (num_layers - 1) + [1] * hidden_layers
41
+
42
+ # Create config dictionary
43
+ self.config = {
44
+ 'architecture': 'Unet',
45
+ 'num_classes': num_classes,
46
+ 'channels': channels,
47
+ 'strides': strides,
48
+ 'num_res_units': num_res_units,
49
+ 'dropout': dropout
50
+ }
51
+
52
+ return self.build_model(self.config)
53
+
54
+ def get_model_parameters(self):
55
+ """Retrieve stored model parameters."""
56
+ if self.model is None:
57
+ raise ValueError("Model has not been initialized yet. Call build_model() or bayesian_search() first.")
58
+
59
+ return self.config
@@ -0,0 +1,47 @@
1
+ from monai.networks.nets import BasicUNetPlusPlus
2
+
3
+ class myUNetPlusPlus:
4
+ def __init__(self, num_classes, device):
5
+ self.device = device
6
+ self.num_classes = num_classes
7
+
8
+
9
+ def build_model(
10
+ self,
11
+ features=(16, 16, 32, 64, 128, 16),
12
+ dropout=0.2,
13
+ upsample='deconv',
14
+ activation='relu'
15
+ ):
16
+
17
+ model = BasicUNetPlusPlus(
18
+ spatial_dims=3,
19
+ in_channels=1,
20
+ out_channels=n_classes,
21
+ deep_supervision=True,
22
+ features=features, # Halve the features
23
+ dropout=dropout,
24
+ upsample=upsample,
25
+ act=activation
26
+ )
27
+
28
+ return model.to(self.device)
29
+
30
+ def bayesian_search(self, trial):
31
+
32
+ act = trial.suggest_categorical("activation", ["LeakyReLU", "PReLU", "GELU", "ELU"])
33
+ dropout_prob = trial.suggest_float("dropout", 0.0, 0.5)
34
+ upsample = trial.suggest_categorical("upsample", ["deconv", "pixelshuffle", "nontrainable"])
35
+ model_parameters = {"activation": act,'dropout': dropout_prob, 'upsample': upsample}
36
+
37
+ model = self.build_model(features, dropout_prob, upsample, act)
38
+ return model
39
+
40
+ def get_model_parameters(self):
41
+ return {
42
+ 'model_name': 'UnetPlusPlus',
43
+ 'features': self.features,
44
+ 'dropout': self.dropout,
45
+ 'upsample': self.upsample,
46
+ 'activation': self.activation
47
+ }
File without changes
@@ -0,0 +1,62 @@
1
+ from monai.losses import FocalLoss, TverskyLoss
2
+ from octopi import losses
3
+ from octopi.models import (
4
+ Unet, AttentionUnet, MedNeXt, SegResNet
5
+ )
6
+
7
+ def get_model(architecture):
8
+
9
+ # Initialize model based on architecture
10
+ if architecture == "Unet":
11
+ model = Unet.myUNet()
12
+ elif architecture == "AttentionUnet":
13
+ model = AttentionUnet.myAttentionUnet()
14
+ elif architecture == "MedNeXt":
15
+ model = MedNeXt.myMedNeXt()
16
+ elif architecture == "SegResNet":
17
+ model = SegResNet.mySegResNet()
18
+ else:
19
+ raise ValueError(f"Model type {architecture} not supported!\nPlease use one of the following: Unet, AttentionUnet, MedNeXt, SegResNet")
20
+
21
+ return model
22
+
23
+ def get_loss_function(trial, loss_name = None):
24
+
25
+ # Loss function selection
26
+ if loss_name is None:
27
+ loss_name = trial.suggest_categorical(
28
+ "loss_function",
29
+ ["FocalLoss", "WeightedFocalTverskyLoss", 'FocalTverskyLoss'])
30
+
31
+ if loss_name == "FocalLoss":
32
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
33
+ loss_function = FocalLoss(include_background=True, to_onehot_y=True, use_softmax=True, gamma=gamma)
34
+
35
+ elif loss_name == "TverskyLoss":
36
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
37
+ beta = 1.0 - alpha
38
+ loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
39
+
40
+ elif loss_name == 'WeightedFocalTverskyLoss':
41
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
42
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
43
+ beta = 1.0 - alpha
44
+ weight_tversky = round(trial.suggest_float("weight_tversky", 0.1, 0.9), 3)
45
+ weight_focal = 1.0 - weight_tversky
46
+ loss_function = losses.WeightedFocalTverskyLoss(
47
+ gamma=gamma, alpha=alpha, beta=beta,
48
+ weight_tversky=weight_tversky, weight_focal=weight_focal
49
+ )
50
+
51
+ elif loss_name == 'FocalTverskyLoss':
52
+ gamma = round(trial.suggest_float("gamma", 0.1, 2), 3)
53
+ alpha = round(trial.suggest_float("alpha", 0.1, 0.5), 3)
54
+ beta = 1.0 - alpha
55
+ loss_function = losses.FocalTverskyLoss(gamma=gamma, alpha=alpha, beta=beta)
56
+
57
+ return loss_function
58
+
59
+
60
+ #### TODO : Models to try Adding?
61
+ # 1. Swin UNETR
62
+ # 2. Swin-Conv-UNet
File without changes
@@ -0,0 +1,106 @@
1
+ from octopi.processing.segmentation_from_picks import from_picks
2
+ import octopi.processing.writers as write
3
+ from octopi import io
4
+ from typing import List
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ def generate_targets(
9
+ root,
10
+ train_targets: dict,
11
+ voxel_size: float = 10,
12
+ tomo_algorithm: str = 'wbp',
13
+ radius_scale: float = 0.8,
14
+ target_segmentation_name: str = 'targets',
15
+ target_user_name: str = 'monai',
16
+ target_session_id: str = '1',
17
+ run_ids: List[str] = None,
18
+ ):
19
+ """
20
+ Generate segmentation targets from picks in CoPick configuration.
21
+
22
+ Args:
23
+ copick_config_path (str): Path to CoPick configuration file.
24
+ picks_user_id (str): User ID associated with picks.
25
+ picks_session_id (str): Session ID associated with picks.
26
+ target_segmentation_name (str): Name for the target segmentation.
27
+ target_user_name (str): User name associated with target segmentation.
28
+ target_session_id (str): Session ID for the target segmentation.
29
+ voxel_size (float): Voxel size for tomogram reconstruction.
30
+ tomo_algorithm (str): Tomogram reconstruction algorithm.
31
+ radius_scale (float): Scale factor for target object radius.
32
+ """
33
+
34
+ # Default session ID to 1 if not provided
35
+ if target_session_id is None:
36
+ target_session_id = '1'
37
+
38
+ print('Creating Targets for the following objects:', ', '.join(train_targets.keys()))
39
+
40
+ # Get Target Names
41
+ target_names = list(train_targets.keys())
42
+
43
+ # If runIDs are not provided, load all runs
44
+ if run_ids is None:
45
+ run_ids = [run.name for run in root.runs]
46
+
47
+ # Iterate Over All Runs
48
+ for runID in tqdm(run_ids):
49
+
50
+ # Get Run
51
+ numPicks = 0
52
+ run = root.get_run(runID)
53
+
54
+ # Get Tomogram
55
+ tomo = io.get_tomogram_array(run, voxel_size, tomo_algorithm)
56
+
57
+ # Initialize Target Volume
58
+ target = np.zeros(tomo.shape, dtype=np.uint8)
59
+
60
+ # Generate Targets
61
+ # Applicable segmentations
62
+ query_seg = []
63
+ for target_name in target_names:
64
+ if not train_targets[target_name]["is_particle_target"]:
65
+ query_seg += run.get_segmentations(
66
+ name=target_name,
67
+ user_id=train_targets[target_name]["user_id"],
68
+ session_id=train_targets[target_name]["session_id"],
69
+ voxel_size=voxel_size
70
+ )
71
+
72
+ # Add Segmentations to Target
73
+ for seg in query_seg:
74
+ classLabel = root.get_object(seg.name).label
75
+ segvol = seg.numpy()
76
+ # Set all non-zero values to the class label
77
+ segvol[segvol > 0] = classLabel
78
+ target[:] = segvol
79
+
80
+ # Applicable picks
81
+ query = []
82
+ for target_name in target_names:
83
+ if train_targets[target_name]["is_particle_target"]:
84
+ query += run.get_picks(
85
+ object_name=target_name,
86
+ user_id=train_targets[target_name]["user_id"],
87
+ session_id=train_targets[target_name]["session_id"],
88
+ )
89
+
90
+ # Add Picks to Target
91
+ for pick in query:
92
+ numPicks += len(pick.points)
93
+ target = from_picks(pick,
94
+ target,
95
+ train_targets[pick.pickable_object_name]['radius'] * radius_scale,
96
+ train_targets[pick.pickable_object_name]['label'],
97
+ voxel_size
98
+ )
99
+
100
+ # Write Segmentation for non-empty targets
101
+ if target.max() > 0 and numPicks > 0:
102
+ tqdm.write(f'Annotating {numPicks} picks in {runID}...')
103
+ write.segmentation(run, target, target_user_name,
104
+ name = target_segmentation_name, session_id= target_session_id,
105
+ voxel_size = voxel_size)
106
+ print('Creation of targets complete!')