octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (48) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +9 -9
  7. octopi/entry_points/create_slurm_submission.py +16 -8
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +22 -45
  11. octopi/entry_points/run_localize.py +37 -54
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +19 -12
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/main.py +1 -1
  19. octopi/models/common.py +1 -1
  20. octopi/processing/create_targets_from_picks.py +11 -5
  21. octopi/processing/downsample.py +6 -10
  22. octopi/processing/evaluate.py +24 -11
  23. octopi/processing/importers.py +4 -4
  24. octopi/pytorch/hyper_search.py +2 -3
  25. octopi/pytorch/model_search_submitter.py +15 -15
  26. octopi/pytorch/segmentation.py +147 -192
  27. octopi/pytorch/segmentation_multigpu.py +162 -0
  28. octopi/pytorch/trainer.py +9 -3
  29. octopi/utils/__init__.py +0 -0
  30. octopi/utils/config.py +57 -0
  31. octopi/utils/io.py +128 -0
  32. octopi/{utils.py → utils/parsers.py} +10 -84
  33. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  34. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  35. octopi/workflows.py +236 -0
  36. octopi-1.2.0.dist-info/METADATA +120 -0
  37. octopi-1.2.0.dist-info/RECORD +62 -0
  38. {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  39. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  40. {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  41. octopi/io.py +0 -457
  42. octopi/processing/my_metrics.py +0 -26
  43. octopi/processing/writers.py +0 -102
  44. octopi-1.0.dist-info/METADATA +0 -209
  45. octopi-1.0.dist-info/RECORD +0 -59
  46. octopi-1.0.dist-info/entry_points.txt +0 -4
  47. /octopi/{losses.py → utils/losses.py} +0 -0
  48. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ __version__ = "1.2.0"
@@ -1,7 +1,7 @@
1
1
  from typing import List, Tuple, Callable, Optional, Dict, Any
2
2
  from monai.transforms import Compose
3
3
  from monai.data import CacheDataset
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  from tqdm import tqdm
6
6
  import os, sys
7
7
 
@@ -1,7 +1,7 @@
1
1
  from octopi.datasets import dataset, augment, cached_datset
2
2
  from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
3
3
  from typing import List, Optional
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  import torch, os, random, gc
6
6
  import multiprocess as mp
7
7
 
octopi/datasets/io.py ADDED
@@ -0,0 +1,200 @@
1
+ """
2
+ Data loading, processing, and dataset operations for the datasets module.
3
+ """
4
+
5
+ from monai.data import DataLoader, CacheDataset, Dataset
6
+ from monai.transforms import (
7
+ Compose,
8
+ NormalizeIntensityd,
9
+ EnsureChannelFirstd,
10
+ )
11
+ from sklearn.model_selection import train_test_split
12
+ from collections import defaultdict
13
+ from copick_utils.io import readers
14
+ import copick, torch, os, random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+
19
+ def load_training_data(root,
20
+ runIDs: List[str],
21
+ voxel_spacing: float,
22
+ tomo_algorithm: str,
23
+ segmenation_name: str,
24
+ segmentation_session_id: str = None,
25
+ segmentation_user_id: str = None,
26
+ progress_update: bool = True):
27
+ """
28
+ Load training data from CoPick runs.
29
+ """
30
+ data_dicts = []
31
+ # Use tqdm for progress tracking only if progress_update is True
32
+ iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
33
+ for runID in iterable:
34
+ run = root.get_run(str(runID))
35
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
36
+ segmentation = readers.segmentation(run,
37
+ voxel_spacing,
38
+ segmenation_name,
39
+ segmentation_session_id,
40
+ segmentation_user_id)
41
+ data_dicts.append({"image": tomogram, "label": segmentation})
42
+
43
+ return data_dicts
44
+
45
+
46
+ def load_predict_data(root,
47
+ runIDs: List[str],
48
+ voxel_spacing: float,
49
+ tomo_algorithm: str):
50
+ """
51
+ Load prediction data from CoPick runs.
52
+ """
53
+ data_dicts = []
54
+ for runID in tqdm(runIDs):
55
+ run = root.get_run(str(runID))
56
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
57
+ data_dicts.append({"image": tomogram})
58
+
59
+ return data_dicts
60
+
61
+
62
+ def create_predict_dataloader(
63
+ root,
64
+ voxel_spacing: float,
65
+ tomo_algorithm: str,
66
+ runIDs: str = None,
67
+ ):
68
+ """
69
+ Create a dataloader for prediction data.
70
+ """
71
+ # define pre transforms
72
+ pre_transforms = Compose(
73
+ [ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
74
+ NormalizeIntensityd(keys=["image"]),
75
+ ])
76
+
77
+ # Split trainRunIDs, validateRunIDs, testRunIDs
78
+ if runIDs is None:
79
+ runIDs = [run.name for run in root.runs]
80
+ test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
81
+
82
+ bs = min( len(test_files), 4)
83
+ test_ds = CacheDataset(data=test_files, transform=pre_transforms)
84
+ test_loader = DataLoader(test_ds,
85
+ batch_size=bs,
86
+ shuffle=False,
87
+ num_workers=4,
88
+ pin_memory=torch.cuda.is_available())
89
+ return test_loader, test_ds
90
+
91
+
92
+ def adjust_to_multiple(value, multiple = 16):
93
+ """
94
+ Adjust a value to be a multiple of the specified number.
95
+ """
96
+ return int((value // multiple) * multiple)
97
+
98
+
99
+ def get_input_dimensions(dataset, crop_size: int):
100
+ """
101
+ Get input dimensions for the dataset.
102
+ """
103
+ nx = dataset[0]['image'].shape[1]
104
+ if crop_size > nx:
105
+ first_dim = adjust_to_multiple(nx/2)
106
+ return first_dim, crop_size, crop_size
107
+ else:
108
+ return crop_size, crop_size, crop_size
109
+
110
+
111
+ def get_num_classes(copick_config_path: str):
112
+ """
113
+ Get the number of classes from a CoPick configuration.
114
+ """
115
+ root = copick.from_file(copick_config_path)
116
+ return len(root.pickable_objects) + 1
117
+
118
+
119
+ def split_multiclass_dataset(runIDs,
120
+ train_ratio: float = 0.7,
121
+ val_ratio: float = 0.15,
122
+ test_ratio: float = 0.15,
123
+ return_test_dataset: bool = True,
124
+ random_state: int = 42):
125
+ """
126
+ Splits a given dataset into three subsets: training, validation, and testing. If the dataset
127
+ has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
128
+ list, it is split without categorization.
129
+
130
+ Parameters:
131
+ - runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
132
+ - train_ratio: Proportion of the dataset for training.
133
+ - val_ratio: Proportion of the dataset for validation.
134
+ - test_ratio: Proportion of the dataset for testing.
135
+ - return_test_dataset: Whether to return the test dataset.
136
+ - random_state: Random state for reproducibility.
137
+
138
+ Returns:
139
+ - trainRunIDs: Training subset.
140
+ - valRunIDs: Validation subset.
141
+ - testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
142
+ """
143
+
144
+ # Ensure the ratios add up to 1
145
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
146
+
147
+ # Check if the dataset has categories
148
+ if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
149
+ # Group by category
150
+ grouped = defaultdict(list)
151
+ for item in runIDs:
152
+ grouped[item[0]].append(item)
153
+
154
+ # Split each category
155
+ trainRunIDs, valRunIDs, testRunIDs = [], [], []
156
+ for category, items in grouped.items():
157
+ # Shuffle for randomness
158
+ random.shuffle(items)
159
+ # Split into train and remaining
160
+ train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
161
+ trainRunIDs.extend(train_items)
162
+
163
+ if return_test_dataset:
164
+ # Split remaining into validation and test
165
+ val_items, test_items = train_test_split(
166
+ remaining,
167
+ test_size=(test_ratio / (val_ratio + test_ratio)),
168
+ random_state=random_state,
169
+ )
170
+ valRunIDs.extend(val_items)
171
+ testRunIDs.extend(test_items)
172
+ else:
173
+ valRunIDs.extend(remaining)
174
+ testRunIDs = []
175
+ else:
176
+ # If no categories, split as a 1D list
177
+ trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
178
+ if return_test_dataset:
179
+ valRunIDs, testRunIDs = train_test_split(
180
+ remaining,
181
+ test_size=(test_ratio / (val_ratio + test_ratio)),
182
+ random_state=random_state,
183
+ )
184
+ else:
185
+ valRunIDs = remaining
186
+ testRunIDs = []
187
+
188
+ return trainRunIDs, valRunIDs, testRunIDs
189
+
190
+
191
+ def load_copick_config(path: str):
192
+ """
193
+ Load a CoPick configuration from file.
194
+ """
195
+ if os.path.isfile(path):
196
+ root = copick.from_file(path)
197
+ else:
198
+ raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
199
+
200
+ return root
@@ -1,7 +1,7 @@
1
1
  from octopi.datasets import dataset, augment, cached_datset
2
2
  from octopi.datasets.generators import TrainLoaderManager
3
3
  from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
4
- from octopi import io
4
+ from octopi.datasets import io
5
5
  import multiprocess as mp
6
6
  from typing import List
7
7
  from tqdm import tqdm
@@ -1,4 +1,4 @@
1
- from octopi import utils
1
+ from octopi.utils import parsers
2
2
  import argparse
3
3
 
4
4
  def add_model_parameters(parser, octopi = False):
@@ -8,9 +8,9 @@ def add_model_parameters(parser, octopi = False):
8
8
 
9
9
  # Add U-Net model parameters
10
10
  parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
11
- parser.add_argument("--channels", type=utils.parse_int_list, required=False, default='32,64,128,128', help="List of channel sizes")
12
- parser.add_argument("--strides", type=utils.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
13
- parser.add_argument("--res-units", type=int, required=False, default=2, help="Number of residual units in the UNet")
11
+ parser.add_argument("--channels", type=parsers.parse_int_list, required=False, default='32,64,96,96', help="List of channel sizes")
12
+ parser.add_argument("--strides", type=parsers.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
13
+ parser.add_argument("--res-units", type=int, required=False, default=1, help="Number of residual units in the UNet")
14
14
  parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
15
15
 
16
16
  def inference_model_parameters(parser):
@@ -24,7 +24,7 @@ def add_train_parameters(parser, octopi = False):
24
24
  """
25
25
  Add training parameters to the parser.
26
26
  """
27
- parser.add_argument("--num-epochs", type=int, required=False, default=100, help="Number of training epochs")
27
+ parser.add_argument("--num-epochs", type=int, required=False, default=1000, help="Number of training epochs")
28
28
  parser.add_argument("--val-interval", type=int, required=False, default=10, help="Interval for validation metric calculations")
29
29
  parser.add_argument("--tomo-batch-size", type=int, required=False, default=15, help="Number of tomograms to load per epoch for training")
30
30
  parser.add_argument("--best-metric", type=str, default='avg_f1', required=False, help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")
@@ -32,8 +32,8 @@ def add_train_parameters(parser, octopi = False):
32
32
  if not octopi:
33
33
  parser.add_argument("--num-tomo-crops", type=int, required=False, default=16, help="Number of tomogram crops to use per patch")
34
34
  parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning rate for the optimizer")
35
- parser.add_argument("--tversky-alpha", type=float, required=False, default=0.5, help="Alpha parameter for the Tversky loss")
36
- parser.add_argument("--model-save-path", required=True, help="Path to model save directory")
35
+ parser.add_argument("--tversky-alpha", type=float, required=False, default=0.3, help="Alpha parameter for the Tversky loss")
36
+ parser.add_argument("--model-save-path", required=False, default='results', help="Path to model save directory")
37
37
  else:
38
38
  parser.add_argument("--num-trials", type=int, default=10, required=False, help="Number of trials for architecture search (default: 10).")
39
39
 
@@ -52,11 +52,11 @@ def add_inference_parameters(parser):
52
52
 
53
53
  parser.add_argument("--tomo-alg", required=False, default = 'wbp',
54
54
  help="Tomogram algorithm used for produces segmentation prediction masks.")
55
- parser.add_argument("--seg-info", type=utils.parse_target, required=False,
55
+ parser.add_argument("--seg-info", type=parsers.parse_target, required=False,
56
56
  default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
57
57
  parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
58
58
  help="Batch size for tomogram processing.")
59
- parser.add_argument("--run-ids", type=utils.parse_list, default=None, required=False,
59
+ parser.add_argument("--run-ids", type=parsers.parse_list, default=None, required=False,
60
60
  help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
61
61
 
62
62
  def add_localize_parameters(parser):
@@ -1,5 +1,5 @@
1
1
  from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
2
- from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
2
+ from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
3
3
  from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
4
4
  from octopi.entry_points import common
5
5
  from octopi import utils
@@ -16,19 +16,27 @@ def create_train_script(args):
16
16
 
17
17
  command = f"""
18
18
  octopi train \\
19
+ {strconfigs} \\
19
20
  --model-save-path {args.model_save_path} \\
20
- --target-info {args.target_info} \\
21
- --voxel-size {args.voxel_size} --tomo-algorithm {args.tomo_algorithm} --Nclass {args.Nclass} \\
22
- --best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
21
+ --target-info {','.join(args.target_info)} \\
22
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
23
23
  --tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
24
- {strconfigs}
25
- """
24
+ --best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
25
+ """
26
26
 
27
27
  # If a model config is provided, use it to build the model
28
28
  if args.model_config is not None:
29
29
  command += f" --model-config {args.model_config}"
30
30
  else:
31
- command += f" --tversky-alpha {args.tversky_alpha} --channels {args.channels} --strides {args.strides} --dim-in {args.dim_in} --res-units {args.res_units}"
31
+ channels = ",".join(map(str, args.channels))
32
+ strides = ",".join(map(str, args.strides))
33
+ command += (
34
+ f" --tversky-alpha {args.tversky_alpha}"
35
+ f" --channels {channels}"
36
+ f" --strides {strides}"
37
+ f" --dim-in {args.dim_in}"
38
+ f" --res-units {args.res_units}"
39
+ )
32
40
 
33
41
  # If Model Weights are provided, use them to initialize the model
34
42
  if args.model_weights is not None and args.model_config is not None:
@@ -240,4 +248,4 @@ def download_dataportal_slurm():
240
248
  """
241
249
  parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
242
250
  args = cli_dataportal_parser(parser_description, add_slurm=True)
243
- create_download_dataportal_script(args)
251
+ create_download_dataportal_script(args)
@@ -1,8 +1,8 @@
1
1
  import octopi.processing.create_targets_from_picks as create_targets
2
2
  from typing import List, Tuple, Union
3
+ from octopi.utils import io, parsers
3
4
  from collections import defaultdict
4
5
  import argparse, copick, yaml, os
5
- from octopi import utils, io
6
6
  from tqdm import tqdm
7
7
  import numpy as np
8
8
 
@@ -160,16 +160,16 @@ def parse_args():
160
160
 
161
161
  input_group = parser.add_argument_group("Input Arguments")
162
162
  input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
163
- input_group.add_argument("--target", type=utils.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
163
+ input_group.add_argument("--target", type=parsers.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
164
164
  input_group.add_argument("--picks-session-id", type=str, default=None, help="Session ID for the picks.")
165
165
  input_group.add_argument("--picks-user-id", type=str, default=None, help="User ID associated with the picks.")
166
- input_group.add_argument("--seg-target", type=utils.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
167
- input_group.add_argument("--run-ids", type=utils.parse_list, default=None, help="List of run IDs.")
166
+ input_group.add_argument("--seg-target", type=parsers.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
167
+ input_group.add_argument("--run-ids", type=parsers.parse_list, default=None, help="List of run IDs.")
168
168
 
169
169
  # Parameters
170
170
  parameters_group = parser.add_argument_group("Parameters")
171
171
  parameters_group.add_argument("--tomo-alg", type=str, default="wbp", help="Tomogram reconstruction algorithm.")
172
- parameters_group.add_argument("--radius-scale", type=float, default=0.8, help="Scale factor for object radius.")
172
+ parameters_group.add_argument("--radius-scale", type=float, default=0.7, help="Scale factor for object radius.")
173
173
  parameters_group.add_argument("--voxel-size", type=float, default=10, help="Voxel size for tomogram reconstruction.")
174
174
 
175
175
  output_group = parser.add_argument_group("Output Arguments")
@@ -275,7 +275,7 @@ def save_parameters(args, output_path: str):
275
275
  existing_data[input_key] = new_entry[input_key]
276
276
 
277
277
  # Save back to the YAML file
278
- utils.save_parameters_yaml(existing_data, output_path)
278
+ io.save_parameters_yaml(existing_data, output_path)
279
279
 
280
280
  if __name__ == "__main__":
281
281
  cli()
@@ -1,5 +1,5 @@
1
1
  import octopi.processing.evaluate as evaluate
2
- import octopi.utils as utils
2
+ from octopi.utils import parsers
3
3
  from typing import List
4
4
  import argparse
5
5
 
@@ -31,6 +31,7 @@ def cli():
31
31
  """
32
32
  CLI entry point for running evaluation.
33
33
  """
34
+
34
35
  parser = argparse.ArgumentParser(
35
36
  description='Run evaluation on pick and place predictions.',
36
37
  formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -43,8 +44,8 @@ def cli():
43
44
  parser.add_argument('--predict-session-id', type=str, required=False, default= None, help='Session ID for prediction data')
44
45
  parser.add_argument('--save-path', type=str, required=False, default= None, help='Path to save evaluation results')
45
46
  parser.add_argument('--distance-threshold-scale', type=float, required=False, default = 0.8, help='Compute Distance Threshold Based on Particle Radius')
46
- parser.add_argument('--object-names', type=utils.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
47
- parser.add_argument('--run-ids', type=utils.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
47
+ parser.add_argument('--object-names', type=parsers.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
48
+ parser.add_argument('--run-ids', type=parsers.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
48
49
 
49
50
  args = parser.parse_args()
50
51
 
@@ -1,5 +1,5 @@
1
1
  from octopi.extract import membranebound_extract as extract
2
- from octopi import utils, io
2
+ from octopi.utils import parsers
3
3
  import argparse, json, pprint, copick, json
4
4
  from typing import List, Tuple, Optional
5
5
  import multiprocess as mp
@@ -30,46 +30,23 @@ def extract_membrane_bound_picks(
30
30
  if n_procs is None:
31
31
  n_procs = min(mp.cpu_count(), n_run_ids)
32
32
  print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
33
-
34
- # Initialize tqdm progress bar
35
- with tqdm(total=n_run_ids, desc="Membrane-Protein Isolation", unit="run") as pbar:
36
- for _iz in range(0, n_run_ids, n_procs):
37
-
38
- start_idx = _iz
39
- end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
40
- print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
41
-
42
- processes = []
43
- for _in in range(n_procs):
44
- _iz_this = _iz + _in
45
- if _iz_this >= n_run_ids:
46
- break
47
- run_id = run_ids[_iz_this]
48
- run = root.get_run(run_id)
49
- p = mp.Process(
50
- target=extract.process_membrane_bound_extract,
51
- args=(run,
52
- voxel_size,
53
- picks_info,
54
- membrane_info,
55
- organelle_info,
56
- save_user_id,
57
- save_session_id,
58
- distance_threshold),
59
- )
60
- processes.append(p)
61
-
62
- for p in processes:
63
- p.start()
64
-
65
- for p in processes:
66
- p.join()
67
-
68
- for p in processes:
69
- p.close()
70
-
71
- # Update tqdm progress bar
72
- pbar.update(len(processes))
33
+
34
+ # Run Membrane-Protein Isolation - Main Parallelization Loop
35
+ with mp.Pool(processes=n_procs) as pool:
36
+ with tqdm(total=n_run_ids, desc="Membrane-Protein Isolation", unit="run") as pbar:
37
+ worker_func = lambda run_id: extract.process_membrane_bound_extract(
38
+ root.get_run(run_id),
39
+ voxel_size,
40
+ picks_info,
41
+ membrane_info,
42
+ organelle_info,
43
+ save_user_id,
44
+ save_session_id,
45
+ distance_threshold
46
+ )
47
+
48
+ for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
49
+ pbar.update(1)
73
50
 
74
51
  print('Extraction of Membrane-Bound Proteins Complete!')
75
52
 
@@ -81,12 +58,12 @@ def cli():
81
58
  parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
82
59
  parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Voxel size.')
83
60
  parser.add_argument('--distance-threshold', type=float, required=False, default=10, help='Distance threshold.')
84
- parser.add_argument('--picks-info', type=utils.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
85
- parser.add_argument('--membrane-info', type=utils.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
86
- parser.add_argument('--organelle-info', type=utils.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
61
+ parser.add_argument('--picks-info', type=parsers.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
62
+ parser.add_argument('--membrane-info', type=parsers.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
63
+ parser.add_argument('--organelle-info', type=parsers.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
87
64
  parser.add_argument('--save-user-id', type=str, required=False, default=None, help='User ID to save the new picks.')
88
65
  parser.add_argument('--save-session-id', type=str, required=True, help='Session ID to save the new picks.')
89
- parser.add_argument('--runIDs', type=utils.parse_list, required=False, help='List of run IDs to process.')
66
+ parser.add_argument('--runIDs', type=parsers.parse_list, required=False, help='List of run IDs to process.')
90
67
  parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use.')
91
68
 
92
69
  args = parser.parse_args()
@@ -1,6 +1,6 @@
1
1
  from octopi.entry_points import common
2
+ from octopi.utils import parsers, io
2
3
  from octopi.extract import localize
3
- from octopi import utils
4
4
  import copick, argparse, pprint
5
5
  from typing import List, Tuple
6
6
  import multiprocess as mp
@@ -40,56 +40,39 @@ def pick_particles(
40
40
  print(', '.join([f'{obj[0]} (Label: {obj[1]})' for obj in objects]) + '\n')
41
41
 
42
42
  # Either Specify Input RunIDs or Run on All RunIDs
43
- if runIDs: print('Running Localization on the Following RunIDs: ' + ', '.join(runIDs) + '\n')
44
- run_ids = runIDs if runIDs else [run.name for run in root.runs]
43
+ if runIDs:
44
+ print('Running Localization on the Following RunIDs: ' + ', '.join(runIDs) + '\n')
45
+ run_ids = runIDs
46
+ else:
47
+ run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is not None]
48
+ skipped_run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is None]
49
+
50
+ if skipped_run_ids:
51
+ print(f"Warning: skipping runs with no voxel spacing {voxel_size}: {skipped_run_ids}")
52
+
53
+ # Nprocesses shouldnt exceed computation resource or number of available runs
45
54
  n_run_ids = len(run_ids)
55
+ n_procs = min(mp.cpu_count(), n_procs, n_run_ids)
46
56
 
47
- # Determine the number of processes to use
48
- if n_procs is None:
49
- n_procs = min(int(mp.cpu_count()//4), n_run_ids)
57
+ # Run Localization - Main Parallelization Loop
50
58
  print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
51
-
52
- # Initialize tqdm progress bar
53
- with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
54
- for _iz in range(0, n_run_ids, n_procs):
55
-
56
- start_idx = _iz
57
- end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
58
- print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
59
-
60
- processes = []
61
- for _in in range(n_procs):
62
- _iz_this = _iz + _in
63
- if _iz_this >= n_run_ids:
64
- break
65
- run_id = run_ids[_iz_this]
66
- run = root.get_run(run_id)
67
- p = mp.Process(
68
- target=localize.processs_localization,
69
- args=(run,
70
- objects,
71
- seg_info,
72
- method,
73
- voxel_size,
74
- filter_size,
75
- radius_min_scale,
76
- radius_max_scale,
77
- pick_session_id,
78
- pick_user_id),
79
- )
80
- processes.append(p)
81
-
82
- for p in processes:
83
- p.start()
84
-
85
- for p in processes:
86
- p.join()
87
-
88
- for p in processes:
89
- p.close()
90
-
91
- # Update tqdm progress bar
92
- pbar.update(len(processes))
59
+ with mp.Pool(processes=n_procs) as pool:
60
+ with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
61
+ worker_func = lambda run_id: localize.process_localization(
62
+ root.get_run(run_id),
63
+ objects,
64
+ seg_info,
65
+ method,
66
+ voxel_size,
67
+ filter_size,
68
+ radius_min_scale,
69
+ radius_max_scale,
70
+ pick_session_id,
71
+ pick_user_id
72
+ )
73
+
74
+ for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
75
+ pbar.update(1)
93
76
 
94
77
  print('Localization Complete!')
95
78
 
@@ -101,20 +84,20 @@ def localize_parser(parser_description, add_slurm: bool = False):
101
84
  input_group = parser.add_argument_group("Input Arguments")
102
85
  input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
103
86
  input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
104
- input_group.add_argument('--seg-info', type=utils.parse_target, required=True, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
87
+ input_group.add_argument('--seg-info', type=parsers.parse_target, required=False, default='predict,octopi,1', help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
105
88
  input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
106
- input_group.add_argument("--runIDs", type=utils.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
89
+ input_group.add_argument("--runIDs", type=parsers.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
107
90
 
108
91
  localize_group = parser.add_argument_group("Localize Arguments")
109
92
  localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
110
93
  localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
111
94
  localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
112
- localize_group.add_argument("--pick-objects", type=utils.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
113
- localize_group.add_argument("--n-procs", type=int, default=None, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
95
+ localize_group.add_argument("--pick-objects", type=parsers.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
96
+ localize_group.add_argument("--n-procs", type=int, default=8, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
114
97
 
115
98
  output_group = parser.add_argument_group("Output Arguments")
116
99
  output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
117
- output_group.add_argument("--pick-user-id", type=str, default='monai', required=False, help="User ID for the particle picks.")
100
+ output_group.add_argument("--pick-user-id", type=str, default='octopi', required=False, help="User ID for the particle picks.")
118
101
 
119
102
  if add_slurm:
120
103
  slurm_group = parser.add_argument_group("SLURM Arguments")
@@ -181,7 +164,7 @@ def save_parameters(args: argparse.Namespace,
181
164
  pprint.pprint(params); print()
182
165
 
183
166
  # Save to YAML file
184
- utils.save_parameters_yaml(params, output_path)
167
+ io.save_parameters_yaml(params, output_path)
185
168
 
186
169
  if __name__ == "__main__":
187
170
  cli()