octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,222 @@
1
+ from octopi.entry_points import common
2
+ from octopi.extract import localize
3
+ from octopi import utils
4
+ import copick, argparse, pprint
5
+ from typing import List, Tuple
6
+ import multiprocess as mp
7
+ from tqdm import tqdm
8
+
9
+ def pick_particles(
10
+ copick_config_path: str,
11
+ method: str,
12
+ seg_info: Tuple[str, str, str],
13
+ voxel_size: float,
14
+ pick_session_id: str,
15
+ pick_user_id: str,
16
+ radius_min_scale: float,
17
+ radius_max_scale: float,
18
+ filter_size: float,
19
+ pick_objects: List[str],
20
+ runIDs: List[str],
21
+ n_procs: int,
22
+ ):
23
+
24
+ # Load the Copick Project
25
+ root = copick.from_file(copick_config_path)
26
+
27
+ # Get objects that can be Picked
28
+ objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
29
+
30
+ # Verify each object has the required attributes
31
+ for obj in objects:
32
+ if len(obj) < 3 or not isinstance(obj[2], (float, int)):
33
+ raise ValueError(f"Invalid object format: {obj}. Expected a tuple with (name, label, radius).")
34
+
35
+ # Filter elements
36
+ if pick_objects is not None:
37
+ objects = [obj for obj in objects if obj[0] in pick_objects]
38
+
39
+ print(f'Running Localization on the Following Objects: ')
40
+ print(', '.join([f'{obj[0]} (Label: {obj[1]})' for obj in objects]) + '\n')
41
+
42
+ # Either Specify Input RunIDs or Run on All RunIDs
43
+ if runIDs: print('Running Localization on the Following RunIDs: ' + ', '.join(runIDs) + '\n')
44
+ run_ids = runIDs if runIDs else [run.name for run in root.runs]
45
+ n_run_ids = len(run_ids)
46
+
47
+ # Determine the number of processes to use
48
+ if n_procs is None:
49
+ n_procs = min(int(mp.cpu_count()//4), n_run_ids)
50
+ print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
51
+
52
+ # Initialize tqdm progress bar
53
+ with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
54
+ for _iz in range(0, n_run_ids, n_procs):
55
+
56
+ start_idx = _iz
57
+ end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
58
+ print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
59
+
60
+ processes = []
61
+ for _in in range(n_procs):
62
+ _iz_this = _iz + _in
63
+ if _iz_this >= n_run_ids:
64
+ break
65
+ run_id = run_ids[_iz_this]
66
+ run = root.get_run(run_id)
67
+ p = mp.Process(
68
+ target=localize.processs_localization,
69
+ args=(run,
70
+ objects,
71
+ seg_info,
72
+ method,
73
+ voxel_size,
74
+ filter_size,
75
+ radius_min_scale,
76
+ radius_max_scale,
77
+ pick_session_id,
78
+ pick_user_id),
79
+ )
80
+ processes.append(p)
81
+
82
+ for p in processes:
83
+ p.start()
84
+
85
+ for p in processes:
86
+ p.join()
87
+
88
+ for p in processes:
89
+ p.close()
90
+
91
+ # Update tqdm progress bar
92
+ pbar.update(len(processes))
93
+
94
+ print('Localization Complete!')
95
+
96
+ def localize_parser(parser_description, add_slurm: bool = False):
97
+ parser = argparse.ArgumentParser(
98
+ description=parser_description,
99
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
100
+ )
101
+ input_group = parser.add_argument_group("Input Arguments")
102
+ input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
103
+ input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
104
+ input_group.add_argument('--seg-info', type=utils.parse_target, required=True, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
105
+ input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
106
+ input_group.add_argument("--runIDs", type=utils.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
107
+
108
+ localize_group = parser.add_argument_group("Localize Arguments")
109
+ localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
110
+ localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
111
+ localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
112
+ localize_group.add_argument("--pick-objects", type=utils.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
113
+ localize_group.add_argument("--n-procs", type=int, default=None, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
114
+
115
+ output_group = parser.add_argument_group("Output Arguments")
116
+ output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
117
+ output_group.add_argument("--pick-user-id", type=str, default='monai', required=False, help="User ID for the particle picks.")
118
+
119
+ if add_slurm:
120
+ slurm_group = parser.add_argument_group("SLURM Arguments")
121
+ common.add_slurm_parameters(slurm_group, 'localize', gpus = 0)
122
+
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+ # Entry point with argparse
127
+ def cli():
128
+
129
+ parser_description = "Localized particles in tomograms using multiprocessing."
130
+ args = localize_parser(parser_description)
131
+
132
+ # Save JSON with Parameters
133
+ output_yaml = f'localize_{args.pick_user_id}_{args.pick_session_id}.yaml'
134
+ save_parameters(args, output_yaml)
135
+
136
+ # Set multiprocessing start method
137
+ mp.set_start_method("spawn")
138
+
139
+ pick_particles(
140
+ copick_config_path=args.config,
141
+ method=args.method,
142
+ seg_info=args.seg_info,
143
+ voxel_size=args.voxel_size,
144
+ pick_session_id=args.pick_session_id,
145
+ pick_user_id=args.pick_user_id,
146
+ radius_min_scale=args.radius_min_scale,
147
+ radius_max_scale=args.radius_max_scale,
148
+ filter_size=args.filter_size,
149
+ runIDs=args.runIDs,
150
+ pick_objects=args.pick_objects,
151
+ n_procs=args.n_procs,
152
+ )
153
+
154
+ def save_parameters(args: argparse.Namespace,
155
+ output_path: str):
156
+
157
+ # Organize parameters into categories
158
+ params = {
159
+ "input": {
160
+ "config": args.config,
161
+ "seg_name": args.seg_info[0],
162
+ "seg_user_id": args.seg_info[1],
163
+ "seg_session_id": args.seg_info[2],
164
+ "voxel_size": args.voxel_size
165
+ },
166
+ "output": {
167
+ "pick_session_id": args.pick_session_id,
168
+ "pick_user_id": args.pick_user_id
169
+ },
170
+ "parameters": {
171
+ "method": args.method,
172
+ "radius_min_scale": args.radius_min_scale,
173
+ "radius_max_scale": args.radius_max_scale,
174
+ "filter_size": args.filter_size,
175
+ "runIDs": args.runIDs
176
+ }
177
+ }
178
+
179
+ # Print the parameters
180
+ print(f"\nParameters for Localization:")
181
+ pprint.pprint(params); print()
182
+
183
+ # Save to YAML file
184
+ utils.save_parameters_yaml(params, output_path)
185
+
186
+ if __name__ == "__main__":
187
+ cli()
188
+
189
+ # def time_pick_particles():
190
+ # import json, time
191
+
192
+ # # Set multiprocessing start method
193
+ # mp.set_start_method("spawn")
194
+
195
+ # copick_config_path = "/mnt/simulations/ml_challenge/ml_config.json" # Replace with your actual path
196
+ # n_procs_list = [1, 4, 8, 16, 32] # Adjust based on your needs
197
+ # n_procs_list = [32, 16, 8, 4, 1]
198
+ # timing_results = {}
199
+
200
+ # session_id = 1
201
+ # for n_procs in n_procs_list:
202
+ # print(f"Testing with {n_procs} processes...")
203
+ # start_time = time.time()
204
+ # pick_particles(
205
+ # copick_config_path=copick_config_path,
206
+ # pick_session_id=str(session_id),
207
+ # n_procs=n_procs
208
+ # )
209
+ # elapsed_time = time.time() - start_time
210
+ # timing_results[n_procs] = elapsed_time
211
+ # print(f"Elapsed time with {n_procs} processes: {elapsed_time:.2f} seconds")
212
+
213
+ # session_id +=1
214
+
215
+ # # Save timing results to a JSON file
216
+ # with open("timing_results.json", "w") as f:
217
+ # json.dump(timing_results, f, indent=4)
218
+
219
+ # print("Timing results saved to 'timing_results.json'")
220
+
221
+ # if __name__ == "__main__":
222
+ # time_pick_particles()
@@ -0,0 +1,139 @@
1
+ from octopi.pytorch.model_search_submitter import ModelSearchSubmit
2
+ from octopi.entry_points import common
3
+ import argparse, os, pprint
4
+ from octopi import utils
5
+
6
+ def optuna_parser(parser_description, add_slurm: bool = False):
7
+ """
8
+ Create an argument parser for model architecture search using Optuna.
9
+
10
+ Args:
11
+ parser_description (str): Description of the parser
12
+ add_slurm (bool): Whether to add SLURM-specific arguments
13
+ """
14
+
15
+ parser = argparse.ArgumentParser(
16
+ description=parser_description,
17
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
18
+ )
19
+
20
+ # Input Arguments
21
+ input_group = parser.add_argument_group("Input Arguments")
22
+ common.add_config(input_group, single_config=False)
23
+ input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
24
+ help="Target information, e.g., 'name' or 'name,user_id,session_id'")
25
+ input_group.add_argument("--tomo-alg", default='wbp',
26
+ help="Tomogram algorithm used for training")
27
+ input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
28
+ help="Name of the MLflow experiment (default: 'model-search').")
29
+ input_group.add_argument("--trainRunIDs", type=utils.parse_list, default=None, required=False,
30
+ help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
31
+ input_group.add_argument("--validateRunIDs", type=utils.parse_list, default=None, required=False,
32
+ help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
33
+ input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
34
+ "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
35
+
36
+ model_group = parser.add_argument_group("Model Arguments")
37
+ model_group.add_argument("--model-type", type=str, default='Unet', required=False,
38
+ choices=['Unet', 'AttentionUnet'],
39
+ help="Model type to use for training")
40
+ model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
41
+
42
+ train_group = parser.add_argument_group("Training Arguments")
43
+ common.add_train_parameters(train_group, octopi = True)
44
+ train_group.add_argument("--random-seed", type=int, default=42, required=False,
45
+ help="Random seed for reproducibility (default: 42).")
46
+
47
+ if add_slurm:
48
+ slurm_group = parser.add_argument_group("SLURM Arguments")
49
+ common.add_slurm_parameters(slurm_group, 'optuna')
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+ # Entry point with argparse
55
+ def cli():
56
+ """
57
+ CLI entry point for running optuna model archetecture search.
58
+ """
59
+
60
+ description="Perform model architecture search with Optuna and MLflow integration."
61
+ args = optuna_parser(description)
62
+
63
+ # Parse the CoPick configuration paths
64
+ if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
65
+ else: copick_configs = args.config[0]
66
+
67
+ # Create the model exploration directory
68
+ os.makedirs(f'explore_results_{args.model_type}', exist_ok=True)
69
+
70
+ # Save JSON with Parameters
71
+ save_parameters(args, f'explore_results_{args.model_type}/octopi.yaml')
72
+
73
+ # Call the function with parsed arguments
74
+ search = ModelSearchSubmit(
75
+ copick_config=copick_configs,
76
+ target_name=args.target_info[0],
77
+ target_user_id=args.target_info[1],
78
+ target_session_id=args.target_info[2],
79
+ tomo_algorithm=args.tomo_alg,
80
+ voxel_size=args.voxel_size,
81
+ Nclass=args.Nclass,
82
+ model_type=args.model_type,
83
+ mlflow_experiment_name=args.mlflow_experiment_name,
84
+ random_seed=args.random_seed,
85
+ num_epochs=args.num_epochs,
86
+ num_trials=args.num_trials,
87
+ trainRunIDs=args.trainRunIDs,
88
+ validateRunIDs=args.validateRunIDs,
89
+ tomo_batch_size=args.tomo_batch_size,
90
+ best_metric=args.best_metric,
91
+ val_interval=args.val_interval,
92
+ data_split=args.data_split
93
+ )
94
+
95
+ # Run the model search
96
+ search.run_model_search()
97
+
98
+ def save_parameters(args: argparse.Namespace,
99
+ output_path: str):
100
+ """
101
+ Save the Optuna search parameters to a JSON file.
102
+ Args:
103
+ args: Parsed arguments from argparse.
104
+ output_path: Path to save the JSON file.
105
+ """
106
+ # Organize parameters into categories
107
+ params = {
108
+ "input": {
109
+ "copick_config": args.config,
110
+ "target_info": args.target_info,
111
+ "tomo_algorithm": args.tomo_alg,
112
+ "voxel_size": args.voxel_size,
113
+ "Nclass": args.Nclass,
114
+ },
115
+ "optimization": {
116
+ "model_type": args.model_type,
117
+ "mlflow_experiment_name": args.mlflow_experiment_name,
118
+ "random_seed": args.random_seed,
119
+ "num_trials": args.num_trials,
120
+ "best_metric": args.best_metric
121
+ },
122
+ "training": {
123
+ "num_epochs": args.num_epochs,
124
+ "tomo_batch_size": args.tomo_batch_size,
125
+ "trainRunIDs": args.trainRunIDs,
126
+ "validateRunIDs": args.validateRunIDs,
127
+ "data_split": args.data_split
128
+ }
129
+ }
130
+
131
+ # Print the parameters
132
+ print(f"\nParameters for Model Architecture Search:")
133
+ pprint.pprint(params); print()
134
+
135
+ # Save to YAML file
136
+ utils.save_parameters_yaml(params, output_path)
137
+
138
+ if __name__ == "__main__":
139
+ cli()
@@ -0,0 +1,166 @@
1
+ from octopi.pytorch import segmentation
2
+ from octopi.entry_points import common
3
+ import torch, argparse, json, pprint, yaml, os
4
+ from octopi import utils
5
+ from typing import List, Tuple
6
+
7
+ def inference(
8
+ copick_config_path: str,
9
+ model_weights: str,
10
+ model_config: str,
11
+ seg_info: Tuple[str,str,str],
12
+ voxel_size: float,
13
+ tomo_algorithm: str,
14
+ tomo_batch_size: int,
15
+ run_ids: List[str],
16
+ ):
17
+ """
18
+ Perform segmentation inference using a model on provided tomograms.
19
+
20
+ Args:
21
+ copick_config_path (str): Path to CoPick configuration file.
22
+ run_ids (List[str]): List of tomogram run IDs for inference.
23
+ model_weights (str): Path to the trained model weights file.
24
+ channels (List[int]): List of channel sizes for each layer.
25
+ strides (List[int]): List of strides for the layers.
26
+ res_units (int): Number of residual units for the model.
27
+ voxel_size (float): Voxel size for tomogram reconstruction.
28
+ tomo_algorithm (str): Tomogram reconstruction algorithm to use.
29
+ segmentation_name (str): Name for the segmentation output.
30
+ segmentation_user_id (str): User ID associated with the segmentation.
31
+ segmentation_session_id (str): Session ID for this segmentation run.
32
+ """
33
+
34
+ gpu_count = torch.cuda.device_count()
35
+ print(f"Number of GPUs available: {gpu_count}")
36
+
37
+ if gpu_count > 1:
38
+ print("Using Multi-GPU Predictor.")
39
+ predict = segmentation.MultiGPUPredictor(
40
+ copick_config_path,
41
+ model_config,
42
+ model_weights
43
+ )
44
+
45
+ # Run Multi-GPU inference
46
+ predict.multi_gpu_inference(
47
+ runIDs=run_ids,
48
+ tomo_algorithm=tomo_algorithm,
49
+ voxel_spacing=voxel_size,
50
+ segmentation_name=seg_info[0],
51
+ segmentation_user_id=seg_info[1],
52
+ segmentation_session_id=seg_info[2],
53
+ save=True
54
+ )
55
+
56
+ else:
57
+ print("Using Single-GPU Predictor.")
58
+ predict = segmentation.Predictor(
59
+ copick_config_path,
60
+ model_config,
61
+ model_weights,
62
+ )
63
+
64
+ # Run batch prediction
65
+ predict.batch_predict(
66
+ runIDs=run_ids,
67
+ num_tomos_per_batch=tomo_batch_size,
68
+ tomo_algorithm=tomo_algorithm,
69
+ voxel_spacing=voxel_size,
70
+ segmentation_name=seg_info[0],
71
+ segmentation_user_id=seg_info[1],
72
+ segmentation_session_id=seg_info[2]
73
+ )
74
+
75
+ print("Inference completed successfully.")
76
+
77
+ def inference_parser(parser_description, add_slurm: bool = False):
78
+ """
79
+ Parse the arguments for the inference
80
+ """
81
+ parser = argparse.ArgumentParser(
82
+ description=parser_description,
83
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
84
+ )
85
+ input_group = parser.add_argument_group("Input Arguments")
86
+ common.add_config(input_group, single_config=True)
87
+
88
+ model_group = parser.add_argument_group("Model Arguments")
89
+ common.inference_model_parameters(model_group)
90
+
91
+ inference_group = parser.add_argument_group("Inference Arguments")
92
+ common.add_inference_parameters(inference_group)
93
+
94
+ if add_slurm:
95
+ slurm_group = parser.add_argument_group("SLURM Arguments")
96
+ common.add_slurm_parameters(slurm_group, 'segment_predict', gpus = 2)
97
+
98
+ args = parser.parse_args()
99
+ return args
100
+
101
+ # Entry point with argparse
102
+ def cli():
103
+ """
104
+ CLI entry point for running inference.
105
+ """
106
+
107
+ # Parse the arguments
108
+ parser_description = "Run segmentation predictions with a specified model and configuration on CryoET Tomograms."
109
+ args = inference_parser(parser_description)
110
+
111
+ # Set default values if not provided
112
+ args.seg_info = list(args.seg_info) # Convert tuple to list
113
+ if args.seg_info[1] is None:
114
+ args.seg_info[1] = "octopi"
115
+
116
+ if args.seg_info[2] is None:
117
+ args.seg_info[2] = "1"
118
+
119
+ # Save JSON with Parameters
120
+ output_json = f'segment-predict_{args.seg_info[1]}_{args.seg_info[2]}_{args.seg_info[0]}.yaml'
121
+ save_parameters(args, output_json)
122
+
123
+ # Call the inference function with parsed arguments
124
+ inference(
125
+ copick_config_path=args.config,
126
+ model_weights=args.model_weights,
127
+ model_config=args.model_config,
128
+ seg_info=args.seg_info,
129
+ voxel_size=args.voxel_size,
130
+ tomo_algorithm=args.tomo_alg,
131
+ tomo_batch_size=args.tomo_batch_size,
132
+ run_ids=args.run_ids,
133
+ )
134
+
135
+ def save_parameters(args: argparse.Namespace,
136
+ output_path: str):
137
+
138
+ # Load the model config
139
+ model_config = utils.load_yaml(args.model_config)
140
+
141
+ # Create parameters dictionary
142
+ params = {
143
+ "inputs": {
144
+ "config": args.config,
145
+ "model_config": args.model_config,
146
+ "model_weights": args.model_weights,
147
+ "tomo_algorithm": args.tomo_alg,
148
+ "voxel_size": args.voxel_size
149
+ },
150
+ "outputs": {
151
+ "segmentation_name": args.seg_info[0],
152
+ "segmentation_user_id": args.seg_info[1],
153
+ "segmentation_session_id": args.seg_info[2]
154
+ },
155
+ 'model': model_config['model']
156
+ }
157
+
158
+ # Print the parameters
159
+ print(f"\nParameters for Inference (Segment Prediction):")
160
+ pprint.pprint(params); print()
161
+
162
+ # Save to YAML file
163
+ utils.save_parameters_yaml(params, output_path)
164
+
165
+ if __name__ == "__main__":
166
+ cli()