octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (45) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +5 -5
  7. octopi/entry_points/create_slurm_submission.py +1 -1
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +5 -5
  11. octopi/entry_points/run_localize.py +8 -9
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +11 -19
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/models/common.py +1 -1
  19. octopi/processing/create_targets_from_picks.py +3 -4
  20. octopi/processing/evaluate.py +24 -11
  21. octopi/processing/importers.py +4 -4
  22. octopi/pytorch/hyper_search.py +2 -3
  23. octopi/pytorch/model_search_submitter.py +4 -4
  24. octopi/pytorch/segmentation.py +141 -190
  25. octopi/pytorch/segmentation_multigpu.py +162 -0
  26. octopi/pytorch/trainer.py +2 -2
  27. octopi/utils/__init__.py +0 -0
  28. octopi/utils/config.py +57 -0
  29. octopi/utils/io.py +128 -0
  30. octopi/{utils.py → utils/parsers.py} +10 -84
  31. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  32. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  33. octopi/workflows.py +236 -0
  34. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
  35. octopi-1.2.0.dist-info/RECORD +62 -0
  36. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  37. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  38. {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  39. octopi/io.py +0 -457
  40. octopi/processing/my_metrics.py +0 -26
  41. octopi/processing/writers.py +0 -102
  42. octopi-1.1.dist-info/RECORD +0 -59
  43. octopi-1.1.dist-info/entry_points.txt +0 -4
  44. /octopi/{losses.py → utils/losses.py} +0 -0
  45. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ __version__ = "1.2.0"
@@ -1,7 +1,7 @@
1
1
  from typing import List, Tuple, Callable, Optional, Dict, Any
2
2
  from monai.transforms import Compose
3
3
  from monai.data import CacheDataset
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  from tqdm import tqdm
6
6
  import os, sys
7
7
 
@@ -1,7 +1,7 @@
1
1
  from octopi.datasets import dataset, augment, cached_datset
2
2
  from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
3
3
  from typing import List, Optional
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  import torch, os, random, gc
6
6
  import multiprocess as mp
7
7
 
octopi/datasets/io.py ADDED
@@ -0,0 +1,200 @@
1
+ """
2
+ Data loading, processing, and dataset operations for the datasets module.
3
+ """
4
+
5
+ from monai.data import DataLoader, CacheDataset, Dataset
6
+ from monai.transforms import (
7
+ Compose,
8
+ NormalizeIntensityd,
9
+ EnsureChannelFirstd,
10
+ )
11
+ from sklearn.model_selection import train_test_split
12
+ from collections import defaultdict
13
+ from copick_utils.io import readers
14
+ import copick, torch, os, random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+
19
+ def load_training_data(root,
20
+ runIDs: List[str],
21
+ voxel_spacing: float,
22
+ tomo_algorithm: str,
23
+ segmenation_name: str,
24
+ segmentation_session_id: str = None,
25
+ segmentation_user_id: str = None,
26
+ progress_update: bool = True):
27
+ """
28
+ Load training data from CoPick runs.
29
+ """
30
+ data_dicts = []
31
+ # Use tqdm for progress tracking only if progress_update is True
32
+ iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
33
+ for runID in iterable:
34
+ run = root.get_run(str(runID))
35
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
36
+ segmentation = readers.segmentation(run,
37
+ voxel_spacing,
38
+ segmenation_name,
39
+ segmentation_session_id,
40
+ segmentation_user_id)
41
+ data_dicts.append({"image": tomogram, "label": segmentation})
42
+
43
+ return data_dicts
44
+
45
+
46
+ def load_predict_data(root,
47
+ runIDs: List[str],
48
+ voxel_spacing: float,
49
+ tomo_algorithm: str):
50
+ """
51
+ Load prediction data from CoPick runs.
52
+ """
53
+ data_dicts = []
54
+ for runID in tqdm(runIDs):
55
+ run = root.get_run(str(runID))
56
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
57
+ data_dicts.append({"image": tomogram})
58
+
59
+ return data_dicts
60
+
61
+
62
+ def create_predict_dataloader(
63
+ root,
64
+ voxel_spacing: float,
65
+ tomo_algorithm: str,
66
+ runIDs: str = None,
67
+ ):
68
+ """
69
+ Create a dataloader for prediction data.
70
+ """
71
+ # define pre transforms
72
+ pre_transforms = Compose(
73
+ [ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
74
+ NormalizeIntensityd(keys=["image"]),
75
+ ])
76
+
77
+ # Split trainRunIDs, validateRunIDs, testRunIDs
78
+ if runIDs is None:
79
+ runIDs = [run.name for run in root.runs]
80
+ test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
81
+
82
+ bs = min( len(test_files), 4)
83
+ test_ds = CacheDataset(data=test_files, transform=pre_transforms)
84
+ test_loader = DataLoader(test_ds,
85
+ batch_size=bs,
86
+ shuffle=False,
87
+ num_workers=4,
88
+ pin_memory=torch.cuda.is_available())
89
+ return test_loader, test_ds
90
+
91
+
92
+ def adjust_to_multiple(value, multiple = 16):
93
+ """
94
+ Adjust a value to be a multiple of the specified number.
95
+ """
96
+ return int((value // multiple) * multiple)
97
+
98
+
99
+ def get_input_dimensions(dataset, crop_size: int):
100
+ """
101
+ Get input dimensions for the dataset.
102
+ """
103
+ nx = dataset[0]['image'].shape[1]
104
+ if crop_size > nx:
105
+ first_dim = adjust_to_multiple(nx/2)
106
+ return first_dim, crop_size, crop_size
107
+ else:
108
+ return crop_size, crop_size, crop_size
109
+
110
+
111
+ def get_num_classes(copick_config_path: str):
112
+ """
113
+ Get the number of classes from a CoPick configuration.
114
+ """
115
+ root = copick.from_file(copick_config_path)
116
+ return len(root.pickable_objects) + 1
117
+
118
+
119
+ def split_multiclass_dataset(runIDs,
120
+ train_ratio: float = 0.7,
121
+ val_ratio: float = 0.15,
122
+ test_ratio: float = 0.15,
123
+ return_test_dataset: bool = True,
124
+ random_state: int = 42):
125
+ """
126
+ Splits a given dataset into three subsets: training, validation, and testing. If the dataset
127
+ has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
128
+ list, it is split without categorization.
129
+
130
+ Parameters:
131
+ - runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
132
+ - train_ratio: Proportion of the dataset for training.
133
+ - val_ratio: Proportion of the dataset for validation.
134
+ - test_ratio: Proportion of the dataset for testing.
135
+ - return_test_dataset: Whether to return the test dataset.
136
+ - random_state: Random state for reproducibility.
137
+
138
+ Returns:
139
+ - trainRunIDs: Training subset.
140
+ - valRunIDs: Validation subset.
141
+ - testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
142
+ """
143
+
144
+ # Ensure the ratios add up to 1
145
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
146
+
147
+ # Check if the dataset has categories
148
+ if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
149
+ # Group by category
150
+ grouped = defaultdict(list)
151
+ for item in runIDs:
152
+ grouped[item[0]].append(item)
153
+
154
+ # Split each category
155
+ trainRunIDs, valRunIDs, testRunIDs = [], [], []
156
+ for category, items in grouped.items():
157
+ # Shuffle for randomness
158
+ random.shuffle(items)
159
+ # Split into train and remaining
160
+ train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
161
+ trainRunIDs.extend(train_items)
162
+
163
+ if return_test_dataset:
164
+ # Split remaining into validation and test
165
+ val_items, test_items = train_test_split(
166
+ remaining,
167
+ test_size=(test_ratio / (val_ratio + test_ratio)),
168
+ random_state=random_state,
169
+ )
170
+ valRunIDs.extend(val_items)
171
+ testRunIDs.extend(test_items)
172
+ else:
173
+ valRunIDs.extend(remaining)
174
+ testRunIDs = []
175
+ else:
176
+ # If no categories, split as a 1D list
177
+ trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
178
+ if return_test_dataset:
179
+ valRunIDs, testRunIDs = train_test_split(
180
+ remaining,
181
+ test_size=(test_ratio / (val_ratio + test_ratio)),
182
+ random_state=random_state,
183
+ )
184
+ else:
185
+ valRunIDs = remaining
186
+ testRunIDs = []
187
+
188
+ return trainRunIDs, valRunIDs, testRunIDs
189
+
190
+
191
+ def load_copick_config(path: str):
192
+ """
193
+ Load a CoPick configuration from file.
194
+ """
195
+ if os.path.isfile(path):
196
+ root = copick.from_file(path)
197
+ else:
198
+ raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
199
+
200
+ return root
@@ -1,7 +1,7 @@
1
1
  from octopi.datasets import dataset, augment, cached_datset
2
2
  from octopi.datasets.generators import TrainLoaderManager
3
3
  from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  import multiprocess as mp
6
6
  from typing import List
7
7
  from tqdm import tqdm
@@ -1,4 +1,4 @@
1
- from octopi import utils
1
+ from octopi.utils import parsers
2
2
  import argparse
3
3
 
4
4
  def add_model_parameters(parser, octopi = False):
@@ -8,8 +8,8 @@ def add_model_parameters(parser, octopi = False):
8
8
 
9
9
  # Add U-Net model parameters
10
10
  parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
11
- parser.add_argument("--channels", type=utils.parse_int_list, required=False, default='32,64,96,96', help="List of channel sizes")
12
- parser.add_argument("--strides", type=utils.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
11
+ parser.add_argument("--channels", type=parsers.parse_int_list, required=False, default='32,64,96,96', help="List of channel sizes")
12
+ parser.add_argument("--strides", type=parsers.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
13
13
  parser.add_argument("--res-units", type=int, required=False, default=1, help="Number of residual units in the UNet")
14
14
  parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
15
15
 
@@ -52,11 +52,11 @@ def add_inference_parameters(parser):
52
52
 
53
53
  parser.add_argument("--tomo-alg", required=False, default = 'wbp',
54
54
  help="Tomogram algorithm used for produces segmentation prediction masks.")
55
- parser.add_argument("--seg-info", type=utils.parse_target, required=False,
55
+ parser.add_argument("--seg-info", type=parsers.parse_target, required=False,
56
56
  default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
57
57
  parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
58
58
  help="Batch size for tomogram processing.")
59
- parser.add_argument("--run-ids", type=utils.parse_list, default=None, required=False,
59
+ parser.add_argument("--run-ids", type=parsers.parse_list, default=None, required=False,
60
60
  help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
61
61
 
62
62
  def add_localize_parameters(parser):
@@ -1,5 +1,5 @@
1
1
  from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
2
- from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
2
+ from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
3
3
  from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
4
4
  from octopi.entry_points import common
5
5
  from octopi import utils
@@ -1,8 +1,8 @@
1
1
  import octopi.processing.create_targets_from_picks as create_targets
2
2
  from typing import List, Tuple, Union
3
+ from octopi.utils import io, parsers
3
4
  from collections import defaultdict
4
5
  import argparse, copick, yaml, os
5
- from octopi import utils, io
6
6
  from tqdm import tqdm
7
7
  import numpy as np
8
8
 
@@ -160,16 +160,16 @@ def parse_args():
160
160
 
161
161
  input_group = parser.add_argument_group("Input Arguments")
162
162
  input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
163
- input_group.add_argument("--target", type=utils.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
163
+ input_group.add_argument("--target", type=parsers.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
164
164
  input_group.add_argument("--picks-session-id", type=str, default=None, help="Session ID for the picks.")
165
165
  input_group.add_argument("--picks-user-id", type=str, default=None, help="User ID associated with the picks.")
166
- input_group.add_argument("--seg-target", type=utils.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
167
- input_group.add_argument("--run-ids", type=utils.parse_list, default=None, help="List of run IDs.")
166
+ input_group.add_argument("--seg-target", type=parsers.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
167
+ input_group.add_argument("--run-ids", type=parsers.parse_list, default=None, help="List of run IDs.")
168
168
 
169
169
  # Parameters
170
170
  parameters_group = parser.add_argument_group("Parameters")
171
171
  parameters_group.add_argument("--tomo-alg", type=str, default="wbp", help="Tomogram reconstruction algorithm.")
172
- parameters_group.add_argument("--radius-scale", type=float, default=0.8, help="Scale factor for object radius.")
172
+ parameters_group.add_argument("--radius-scale", type=float, default=0.7, help="Scale factor for object radius.")
173
173
  parameters_group.add_argument("--voxel-size", type=float, default=10, help="Voxel size for tomogram reconstruction.")
174
174
 
175
175
  output_group = parser.add_argument_group("Output Arguments")
@@ -275,7 +275,7 @@ def save_parameters(args, output_path: str):
275
275
  existing_data[input_key] = new_entry[input_key]
276
276
 
277
277
  # Save back to the YAML file
278
- utils.save_parameters_yaml(existing_data, output_path)
278
+ io.save_parameters_yaml(existing_data, output_path)
279
279
 
280
280
  if __name__ == "__main__":
281
281
  cli()
@@ -1,5 +1,5 @@
1
1
  import octopi.processing.evaluate as evaluate
2
- import octopi.utils as utils
2
+ from octopi.utils import parsers
3
3
  from typing import List
4
4
  import argparse
5
5
 
@@ -31,6 +31,7 @@ def cli():
31
31
  """
32
32
  CLI entry point for running evaluation.
33
33
  """
34
+
34
35
  parser = argparse.ArgumentParser(
35
36
  description='Run evaluation on pick and place predictions.',
36
37
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -43,8 +44,8 @@ def cli():
43
44
  parser.add_argument('--predict-session-id', type=str, required=False, default= None, help='Session ID for prediction data')
44
45
  parser.add_argument('--save-path', type=str, required=False, default= None, help='Path to save evaluation results')
45
46
  parser.add_argument('--distance-threshold-scale', type=float, required=False, default = 0.8, help='Compute Distance Threshold Based on Particle Radius')
46
- parser.add_argument('--object-names', type=utils.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
47
- parser.add_argument('--run-ids', type=utils.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
47
+ parser.add_argument('--object-names', type=parsers.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
48
+ parser.add_argument('--run-ids', type=parsers.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
48
49
 
49
50
  args = parser.parse_args()
50
51
 
@@ -1,5 +1,5 @@
1
1
  from octopi.extract import membranebound_extract as extract
2
- from octopi import utils, io
2
+ from octopi.utils import parsers
3
3
  import argparse, json, pprint, copick, json
4
4
  from typing import List, Tuple, Optional
5
5
  import multiprocess as mp
@@ -58,12 +58,12 @@ def cli():
58
58
  parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
59
59
  parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Voxel size.')
60
60
  parser.add_argument('--distance-threshold', type=float, required=False, default=10, help='Distance threshold.')
61
- parser.add_argument('--picks-info', type=utils.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
62
- parser.add_argument('--membrane-info', type=utils.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
63
- parser.add_argument('--organelle-info', type=utils.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
61
+ parser.add_argument('--picks-info', type=parsers.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
62
+ parser.add_argument('--membrane-info', type=parsers.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
63
+ parser.add_argument('--organelle-info', type=parsers.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
64
64
  parser.add_argument('--save-user-id', type=str, required=False, default=None, help='User ID to save the new picks.')
65
65
  parser.add_argument('--save-session-id', type=str, required=True, help='Session ID to save the new picks.')
66
- parser.add_argument('--runIDs', type=utils.parse_list, required=False, help='List of run IDs to process.')
66
+ parser.add_argument('--runIDs', type=parsers.parse_list, required=False, help='List of run IDs to process.')
67
67
  parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use.')
68
68
 
69
69
  args = parser.parse_args()
@@ -1,11 +1,10 @@
1
1
  from octopi.entry_points import common
2
+ from octopi.utils import parsers, io
2
3
  from octopi.extract import localize
3
- from octopi import utils
4
4
  import copick, argparse, pprint
5
5
  from typing import List, Tuple
6
6
  import multiprocess as mp
7
7
  from tqdm import tqdm
8
- import os
9
8
 
10
9
  def pick_particles(
11
10
  copick_config_path: str,
@@ -53,13 +52,13 @@ def pick_particles(
53
52
 
54
53
  # Nprocesses shouldnt exceed computation resource or number of available runs
55
54
  n_run_ids = len(run_ids)
56
- n_procs = min(mp.mp.cpu_count(), n_procs, n_run_ids)
55
+ n_procs = min(mp.cpu_count(), n_procs, n_run_ids)
57
56
 
58
57
  # Run Localization - Main Parallelization Loop
59
58
  print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
60
59
  with mp.Pool(processes=n_procs) as pool:
61
60
  with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
62
- worker_func = lambda run_id: localize.processs_localization(
61
+ worker_func = lambda run_id: localize.process_localization(
63
62
  root.get_run(run_id),
64
63
  objects,
65
64
  seg_info,
@@ -85,20 +84,20 @@ def localize_parser(parser_description, add_slurm: bool = False):
85
84
  input_group = parser.add_argument_group("Input Arguments")
86
85
  input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
87
86
  input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
88
- input_group.add_argument('--seg-info', type=utils.parse_target, required=True, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
87
+ input_group.add_argument('--seg-info', type=parsers.parse_target, required=False, default='predict,octopi,1', help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
89
88
  input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
90
- input_group.add_argument("--runIDs", type=utils.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
89
+ input_group.add_argument("--runIDs", type=parsers.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
91
90
 
92
91
  localize_group = parser.add_argument_group("Localize Arguments")
93
92
  localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
94
93
  localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
95
94
  localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
96
- localize_group.add_argument("--pick-objects", type=utils.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
95
+ localize_group.add_argument("--pick-objects", type=parsers.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
97
96
  localize_group.add_argument("--n-procs", type=int, default=8, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
98
97
 
99
98
  output_group = parser.add_argument_group("Output Arguments")
100
99
  output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
101
- output_group.add_argument("--pick-user-id", type=str, default='monai', required=False, help="User ID for the particle picks.")
100
+ output_group.add_argument("--pick-user-id", type=str, default='octopi', required=False, help="User ID for the particle picks.")
102
101
 
103
102
  if add_slurm:
104
103
  slurm_group = parser.add_argument_group("SLURM Arguments")
@@ -165,7 +164,7 @@ def save_parameters(args: argparse.Namespace,
165
164
  pprint.pprint(params); print()
166
165
 
167
166
  # Save to YAML file
168
- utils.save_parameters_yaml(params, output_path)
167
+ io.save_parameters_yaml(params, output_path)
169
168
 
170
169
  if __name__ == "__main__":
171
170
  cli()
@@ -1,7 +1,7 @@
1
1
  from octopi.pytorch.model_search_submitter import ModelSearchSubmit
2
2
  from octopi.entry_points import common
3
+ from octopi.utils import parsers, io
3
4
  import argparse, os, pprint
4
- from octopi import utils
5
5
 
6
6
  def optuna_parser(parser_description, add_slurm: bool = False):
7
7
  """
@@ -20,22 +20,22 @@ def optuna_parser(parser_description, add_slurm: bool = False):
20
20
  # Input Arguments
21
21
  input_group = parser.add_argument_group("Input Arguments")
22
22
  common.add_config(input_group, single_config=False)
23
- input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
23
+ input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
24
24
  help="Target information, e.g., 'name' or 'name,user_id,session_id'")
25
25
  input_group.add_argument("--tomo-alg", default='wbp',
26
26
  help="Tomogram algorithm used for training")
27
27
  input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
28
28
  help="Name of the MLflow experiment (default: 'model-search').")
29
- input_group.add_argument("--trainRunIDs", type=utils.parse_list, default=None, required=False,
29
+ input_group.add_argument("--trainRunIDs", type=parsers.parse_list, default=None, required=False,
30
30
  help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
31
- input_group.add_argument("--validateRunIDs", type=utils.parse_list, default=None, required=False,
31
+ input_group.add_argument("--validateRunIDs", type=parsers.parse_list, default=None, required=False,
32
32
  help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
33
33
  input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
34
34
  "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
35
35
 
36
36
  model_group = parser.add_argument_group("Model Arguments")
37
37
  model_group.add_argument("--model-type", type=str, default='Unet', required=False,
38
- choices=['Unet', 'AttentionUnet'],
38
+ choices=['Unet', 'AttentionUnet', 'MedNeXt', 'SegResNet'],
39
39
  help="Model type to use for training")
40
40
  model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
41
41
 
@@ -61,7 +61,7 @@ def cli():
61
61
  args = optuna_parser(description)
62
62
 
63
63
  # Parse the CoPick configuration paths
64
- if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
64
+ if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
65
65
  else: copick_configs = args.config[0]
66
66
 
67
67
  # Create the model exploration directory
@@ -133,7 +133,7 @@ def save_parameters(args: argparse.Namespace,
133
133
  pprint.pprint(params); print()
134
134
 
135
135
  # Save to YAML file
136
- utils.save_parameters_yaml(params, output_path)
136
+ io.save_parameters_yaml(params, output_path)
137
137
 
138
138
  if __name__ == "__main__":
139
139
  cli()
@@ -1,8 +1,8 @@
1
+ import torch, argparse, json, pprint, yaml, os
1
2
  from octopi.pytorch import segmentation
2
3
  from octopi.entry_points import common
3
- import torch, argparse, json, pprint, yaml, os
4
- from octopi import utils
5
4
  from typing import List, Tuple
5
+ from octopi.utils import io
6
6
 
7
7
  def inference(
8
8
  copick_config_path: str,
@@ -136,7 +136,7 @@ def save_parameters(args: argparse.Namespace,
136
136
  output_path: str):
137
137
 
138
138
  # Load the model config
139
- model_config = utils.load_yaml(args.model_config)
139
+ model_config = io.load_yaml(args.model_config)
140
140
 
141
141
  # Create parameters dictionary
142
142
  params = {
@@ -160,7 +160,7 @@ def save_parameters(args: argparse.Namespace,
160
160
  pprint.pprint(params); print()
161
161
 
162
162
  # Save to YAML file
163
- utils.save_parameters_yaml(params, output_path)
163
+ io.save_parameters_yaml(params, output_path)
164
164
 
165
165
  if __name__ == "__main__":
166
166
  cli()
@@ -2,12 +2,11 @@ from octopi.datasets import generators, multi_config_generator
2
2
  from monai.losses import DiceLoss, FocalLoss, TverskyLoss
3
3
  from octopi.models import common as builder
4
4
  from monai.metrics import ConfusionMatrixMetric
5
+ from octopi.utils import parsers, io
5
6
  from octopi.entry_points import common
6
7
  from octopi.pytorch import trainer
7
- from octopi import io, utils
8
8
  import torch, os, argparse
9
9
  from typing import List, Optional, Tuple
10
- import pprint
11
10
 
12
11
  def train_model(
13
12
  copick_config_path: str,
@@ -56,7 +55,7 @@ def train_model(
56
55
 
57
56
 
58
57
  # Get the data splits
59
- ratios = utils.parse_data_split(data_split)
58
+ ratios = parsers.parse_data_split(data_split)
60
59
  data_generator.get_data_splits(
61
60
  trainRunIDs = trainRunIDs,
62
61
  validateRunIDs = validateRunIDs,
@@ -114,11 +113,11 @@ def train_model_parser(parser_description, add_slurm: bool = False):
114
113
  # Input Arguments
115
114
  input_group = parser.add_argument_group("Input Arguments")
116
115
  common.add_config(input_group, single_config=False)
117
- input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
116
+ input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
118
117
  help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
119
118
  input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
120
- input_group.add_argument("--trainRunIDs", type=utils.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
121
- input_group.add_argument("--validateRunIDs", type=utils.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
119
+ input_group.add_argument("--trainRunIDs", type=parsers.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
120
+ input_group.add_argument("--validateRunIDs", type=parsers.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
122
121
  input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
123
122
  "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
124
123
 
@@ -153,11 +152,11 @@ def cli():
153
152
  args = train_model_parser(parser_description)
154
153
 
155
154
  # Parse the CoPick configuration paths
156
- if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
155
+ if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
157
156
  else: copick_configs = args.config[0]
158
157
 
159
158
  if args.model_config:
160
- model_config = utils.load_yaml(args.model_config)
159
+ model_config = io.load_yaml(args.model_config)
161
160
  else:
162
161
  model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
163
162
 
@@ -3,15 +3,15 @@ from scipy.cluster.hierarchy import fcluster, linkage
3
3
  from skimage.segmentation import watershed
4
4
  from typing import List, Optional, Tuple
5
5
  from skimage.measure import regionprops
6
+ from copick_utils.io import readers
6
7
  from scipy.spatial import distance
7
8
  from dataclasses import dataclass
8
- from octopi import io
9
9
  import scipy.ndimage as ndi
10
10
  from tqdm import tqdm
11
11
  import numpy as np
12
12
  import gc
13
13
 
14
- def processs_localization(run,
14
+ def process_localization(run,
15
15
  objects,
16
16
  seg_info: Tuple[str, str, str],
17
17
  method: str = 'com',
@@ -27,12 +27,12 @@ def processs_localization(run,
27
27
  raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
28
28
 
29
29
  # Get Segmentation
30
- seg = io.get_segmentation_array(run,
31
- voxel_size,
32
- seg_info[0],
33
- user_id=seg_info[1],
34
- session_id=seg_info[2],
35
- raise_error=False)
30
+ seg = readers.segmentation(
31
+ run, voxel_size,
32
+ seg_info[0],
33
+ user_id=seg_info[1],
34
+ session_id=seg_info[2],
35
+ raise_error=False)
36
36
 
37
37
  # Preprocess Segmentation
38
38
  # seg = preprocess_segmentation(seg, voxel_size, objects)
@@ -99,8 +99,8 @@ def extract_particle_centroids_via_watershed(
99
99
  max_particle_size (int): Maximum size threshold for particles.
100
100
  """
101
101
 
102
- if maxima_filter_size is None or maxima_filter_size < 0:
103
- AssertionError('Enter a Non-Zero Filter Size!')
102
+ if maxima_filter_size is None or maxima_filter_size <= 0:
103
+ raise ValueError('Enter a Non-Zero Filter Size!')
104
104
 
105
105
  # Calculate minimum and maximum particle volumes based on the given radii
106
106
  min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
@@ -117,12 +117,8 @@ def extract_particle_centroids_via_watershed(
117
117
  # Structuring element for erosion and dilation
118
118
  struct_elem = ball(1)
119
119
  eroded = binary_erosion(binary_mask, struct_elem)
120
- del binary_mask
121
- gc.collect()
122
120
 
123
121
  dilated = binary_dilation(eroded, struct_elem)
124
- del eroded
125
- gc.collect()
126
122
 
127
123
  # Distance transform and local maxima detection
128
124
  distance = ndi.distance_transform_edt(dilated)
@@ -131,12 +127,11 @@ def extract_particle_centroids_via_watershed(
131
127
  # Watershed segmentation
132
128
  markers, _ = ndi.label(local_max)
133
129
  del local_max
134
- markers = markers.astype(np.uint8)
135
130
  gc.collect()
136
131
 
137
132
  watershed_labels = watershed(-distance, markers, mask=dilated)
133
+ distance, markers, dilated = None, None, None
138
134
  del distance, markers, dilated
139
- watershed_labels = watershed_labels.astype(np.uint8)
140
135
  gc.collect()
141
136
 
142
137
  # Extract region properties and filter based on particle size
@@ -147,9 +142,6 @@ def extract_particle_centroids_via_watershed(
147
142
  # Option 1: Use all centroids
148
143
  all_centroids.append(region.centroid)
149
144
 
150
- del watershed_labels
151
- gc.collect()
152
-
153
145
  return all_centroids
154
146
 
155
147
  def extract_particle_centroids_via_com(