octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (45) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +5 -5
  7. octopi/entry_points/create_slurm_submission.py +1 -1
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +5 -5
  11. octopi/entry_points/run_localize.py +8 -9
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +11 -19
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/models/common.py +1 -1
  19. octopi/processing/create_targets_from_picks.py +3 -4
  20. octopi/processing/evaluate.py +24 -11
  21. octopi/processing/importers.py +4 -4
  22. octopi/pytorch/hyper_search.py +2 -3
  23. octopi/pytorch/model_search_submitter.py +4 -4
  24. octopi/pytorch/segmentation.py +141 -190
  25. octopi/pytorch/segmentation_multigpu.py +162 -0
  26. octopi/pytorch/trainer.py +2 -2
  27. octopi/utils/__init__.py +0 -0
  28. octopi/utils/config.py +57 -0
  29. octopi/utils/io.py +128 -0
  30. octopi/{utils.py → utils/parsers.py} +10 -84
  31. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  32. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  33. octopi/workflows.py +236 -0
  34. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
  35. octopi-1.2.0.dist-info/RECORD +62 -0
  36. {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  37. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  38. {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  39. octopi/io.py +0 -457
  40. octopi/processing/my_metrics.py +0 -26
  41. octopi/processing/writers.py +0 -102
  42. octopi-1.1.dist-info/RECORD +0 -59
  43. octopi-1.1.dist-info/entry_points.txt +0 -4
  44. /octopi/{losses.py → utils/losses.py} +0 -0
  45. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/utils/io.py ADDED
@@ -0,0 +1,128 @@
1
+ """
2
+ File I/O utilities for YAML and JSON operations.
3
+ """
4
+
5
+ import os, json, yaml
6
+
7
+
8
+ # Create a custom dumper that uses flow style for lists only.
9
+ class InlineListDumper(yaml.SafeDumper):
10
+ def represent_list(self, data):
11
+ node = super().represent_list(data)
12
+ node.flow_style = True # Use inline style for lists
13
+ return node
14
+
15
+
16
+ def save_parameters_yaml(params: dict, output_path: str):
17
+ """
18
+ Save parameters to a YAML file.
19
+ """
20
+ InlineListDumper.add_representer(list, InlineListDumper.represent_list)
21
+ with open(output_path, 'w') as f:
22
+ yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
23
+
24
+
25
+ def load_yaml(path: str) -> dict:
26
+ """
27
+ Load a YAML file and return the contents as a dictionary.
28
+ """
29
+ if os.path.exists(path):
30
+ with open(path, 'r') as f:
31
+ return yaml.safe_load(f)
32
+ else:
33
+ raise FileNotFoundError(f"File not found: {path}")
34
+
35
+
36
+ def save_results_to_json(results, filename: str):
37
+ """
38
+ Save training results to a JSON file.
39
+ """
40
+ results = prepare_inline_results_json(results)
41
+ with open(os.path.join(filename), "w") as json_file:
42
+ json.dump( results, json_file, indent=4 )
43
+ print(f"Training Results saved to {filename}")
44
+
45
+
46
+ def prepare_inline_results_json(results):
47
+ """
48
+ Prepare results for inline JSON formatting.
49
+ """
50
+ # Traverse the dictionary and format lists of lists as inline JSON
51
+ for key, value in results.items():
52
+ # Check if the value is a list of lists (like [[epoch, value], ...])
53
+ if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
54
+ # Format the list of lists as a single-line JSON string
55
+ results[key] = json.dumps(value)
56
+ return results
57
+
58
+ def get_optimizer_parameters(trainer):
59
+ """
60
+ Extract optimizer parameters from a trainer object.
61
+ """
62
+ optimizer_parameters = {
63
+ 'my_num_samples': trainer.num_samples,
64
+ 'val_interval': trainer.val_interval,
65
+ 'lr': trainer.optimizer.param_groups[0]['lr'],
66
+ 'optimizer': trainer.optimizer.__class__.__name__,
67
+ 'metrics_function': trainer.metrics_function.__class__.__name__,
68
+ 'loss_function': trainer.loss_function.__class__.__name__,
69
+ }
70
+
71
+ # Log Tversky Loss Parameters
72
+ if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
73
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
74
+ elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
75
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
76
+ elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
77
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
78
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
79
+ optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
80
+ elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
81
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
82
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
83
+
84
+ return optimizer_parameters
85
+
86
+
87
+ def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
88
+ """
89
+ Save training parameters to a YAML file.
90
+ """
91
+
92
+ parameters = {
93
+ 'model': model.get_model_parameters(),
94
+ 'optimizer': get_optimizer_parameters(trainer),
95
+ 'dataloader': dataloader.get_dataloader_parameters()
96
+ }
97
+
98
+ save_parameters_yaml(parameters, filename)
99
+ print(f"Training Parameters saved to {filename}")
100
+
101
+ def flatten_params(params, parent_key=''):
102
+ """
103
+ Helper function to flatten and serialize nested parameters.
104
+ """
105
+ flattened = {}
106
+ for key, value in params.items():
107
+ new_key = f"{parent_key}.{key}" if parent_key else key
108
+ if isinstance(value, dict):
109
+ flattened.update(flatten_params(value, new_key))
110
+ elif isinstance(value, list):
111
+ flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
112
+ else:
113
+ flattened[new_key] = value
114
+ return flattened
115
+
116
+
117
+ def prepare_for_inline_json(data):
118
+ """
119
+ Manually join specific lists into strings for inline display.
120
+ """
121
+ for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
122
+ if key in data['dataloader']:
123
+ data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
124
+
125
+ for key in ['channels', 'strides']:
126
+ if key in data['model']:
127
+ data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
128
+ return data
@@ -1,58 +1,12 @@
1
- from monai.networks.nets import UNet, AttentionUnet
1
+ """
2
+ Argument parsing and configuration utilities.
3
+ """
4
+
5
+ import argparse, os, random
6
+ import torch, numpy as np
2
7
  from typing import List, Tuple, Union
3
8
  from dotenv import load_dotenv
4
- import argparse, octopi
5
- import torch, random, os, yaml
6
- from typing import List
7
- import numpy as np
8
-
9
- ##############################################################################################################################
10
-
11
- def mlflow_setup():
12
-
13
- module_root = os.path.dirname(octopi.__file__)
14
- dotenv_path = module_root.replace('src/octopi','') + '.env'
15
- load_dotenv(dotenv_path=dotenv_path)
16
-
17
- # MLflow setup
18
- username = os.getenv('MLFLOW_TRACKING_USERNAME')
19
- password = os.getenv('MLFLOW_TRACKING_PASSWORD')
20
- if not password or not username:
21
- print("Password not found in environment, loading from .env file...")
22
- load_dotenv() # Loads environment variables from a .env file
23
- username = os.getenv('MLFLOW_TRACKING_USERNAME')
24
- password = os.getenv('MLFLOW_TRACKING_PASSWORD')
25
-
26
- # Check again after loading .env file
27
- if not password:
28
- raise ValueError("Password is not set in environment variables or .env file!")
29
- else:
30
- print("Password loaded successfully")
31
- os.environ['MLFLOW_TRACKING_USERNAME'] = username
32
- os.environ['MLFLOW_TRACKING_PASSWORD'] = password
33
-
34
- return os.getenv('MLFLOW_TRACKING_URI')
35
-
36
- ##############################################################################################################################
37
-
38
- def set_seed(seed):
39
- # Set the seed for Python's random module
40
- random.seed(seed)
41
-
42
- # Set the seed for NumPy
43
- np.random.seed(seed)
44
-
45
- # Set the seed for PyTorch (both CPU and GPU)
46
- torch.manual_seed(seed)
47
- if torch.cuda.is_available():
48
- torch.cuda.manual_seed(seed)
49
- torch.cuda.manual_seed_all(seed) # If using multi-GPU
50
-
51
- # Ensure reproducibility of operations by disabling certain optimizations
52
- torch.backends.cudnn.deterministic = True
53
- torch.backends.cudnn.benchmark = False
54
-
55
- ###############################################################################################################################
9
+ import octopi
56
10
 
57
11
  def parse_list(value: str) -> List[str]:
58
12
  """
@@ -62,7 +16,6 @@ def parse_list(value: str) -> List[str]:
62
16
  value = value.strip("[]") # Remove brackets if present
63
17
  return [x.strip() for x in value.split(",")]
64
18
 
65
- ###############################################################################################################################
66
19
 
67
20
  def parse_int_list(value: str) -> List[int]:
68
21
  """
@@ -71,7 +24,6 @@ def parse_int_list(value: str) -> List[int]:
71
24
  """
72
25
  return [int(x) for x in parse_list(value)]
73
26
 
74
- ###############################################################################################################################
75
27
 
76
28
  def string2bool(value: str):
77
29
  """
@@ -86,7 +38,6 @@ def string2bool(value: str):
86
38
  else:
87
39
  raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
88
40
 
89
- ###############################################################################################################################
90
41
 
91
42
  def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
92
43
  """
@@ -130,6 +81,7 @@ def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str,
130
81
  )
131
82
  return targets
132
83
 
84
+
133
85
  def parse_copick_configs(config_entries: List[str]):
134
86
  """
135
87
  Parse a string representing a list of CoPick configuration file paths.
@@ -172,6 +124,7 @@ def parse_copick_configs(config_entries: List[str]):
172
124
 
173
125
  return copick_configs
174
126
 
127
+
175
128
  def parse_data_split(value: str) -> Tuple[float, float, float]:
176
129
  """
177
130
  Parse data split ratios from string input.
@@ -208,31 +161,4 @@ def parse_data_split(value: str) -> Tuple[float, float, float]:
208
161
  if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
209
162
  raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
210
163
 
211
- return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
212
-
213
- ##############################################################################################################################
214
-
215
- # Create a custom dumper that uses flow style for lists only.
216
- class InlineListDumper(yaml.SafeDumper):
217
- def represent_list(self, data):
218
- node = super().represent_list(data)
219
- node.flow_style = True # Use inline style for lists
220
- return node
221
-
222
- def save_parameters_yaml(params: dict, output_path: str):
223
- """
224
- Save parameters to a YAML file.
225
- """
226
- InlineListDumper.add_representer(list, InlineListDumper.represent_list)
227
- with open(output_path, 'w') as f:
228
- yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
229
-
230
- def load_yaml(path: str) -> dict:
231
- """
232
- Load a YAML file and return the contents as a dictionary.
233
- """
234
- if os.path.exists(path):
235
- with open(path, 'r') as f:
236
- return yaml.safe_load(f)
237
- else:
238
- raise FileNotFoundError(f"File not found: {path}")
164
+ return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
@@ -7,9 +7,9 @@ class EarlyStoppingChecker:
7
7
 
8
8
  def __init__(self,
9
9
  max_nan_epochs=15,
10
- plateau_patience=20,
11
- plateau_min_delta=0.001,
12
- stagnation_patience=50,
10
+ plateau_patience=10,
11
+ plateau_min_delta=0.005,
12
+ stagnation_patience=30,
13
13
  convergence_window=5,
14
14
  convergence_threshold=0.005,
15
15
  val_interval=15,
@@ -1,7 +1,7 @@
1
+ from copick_utils.io import readers
1
2
  import matplotlib.colors as mcolors
2
3
  from typing import Optional, List
3
4
  import matplotlib.pyplot as plt
4
- from octopi import io
5
5
  import numpy as np
6
6
 
7
7
  # Define the plotting function
@@ -76,7 +76,7 @@ def show_tomo_points(tomo, run, objects, user_id, vol_slice, session_id = None,
76
76
 
77
77
  for name,_,_ in objects:
78
78
  try:
79
- coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id, session_id=session_id)
79
+ coordinates = readers.coordinates(run, name=name, user_id=user_id, session_id=session_id)
80
80
  close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
81
81
  plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
82
82
  except:
@@ -94,7 +94,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
94
94
 
95
95
  for name,_,_ in objects:
96
96
  try:
97
- coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
97
+ coordinates = readers.coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
98
98
  close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
99
99
  plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
100
100
  except:
@@ -106,7 +106,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
106
106
 
107
107
  for name,_,_ in objects:
108
108
  try:
109
- coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
109
+ coordinates = readers.coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
110
110
  close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
111
111
  plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
112
112
  except:
octopi/workflows.py ADDED
@@ -0,0 +1,236 @@
1
+ from octopi.extract.localize import process_localization
2
+ import octopi.processing.evaluate as octopi_evaluate
3
+ from monai.metrics import ConfusionMatrixMetric
4
+ from octopi.models import common as builder
5
+ from octopi.pytorch import segmentation
6
+ from octopi.datasets import generators
7
+ from octopi.pytorch import trainer
8
+ import multiprocess as mp
9
+ import copick, torch, os
10
+ from octopi.utils import io
11
+ from tqdm import tqdm
12
+
13
+ def train(config, target_info, tomo_algorithm, voxel_size, loss_function,
14
+ model_config = None, model_weights = None, trainRunIDs = None, validateRunIDs = None,
15
+ model_save_path = 'results', best_metric = 'fBeta2', num_epochs = 1000, use_ema = True):
16
+ """
17
+ Train a UNet Model for Segmentation
18
+
19
+ Args:
20
+ config (str): Path to the Copick Config File
21
+ target_info (list): List containing the target user ID, target session ID, and target algorithm
22
+ tomo_algorithm (str): The tomographic algorithm to use for segmentation
23
+ voxel_size (float): The voxel size of the data
24
+ loss_function (str): The loss function to use for training
25
+ model_config (dict): The model configuration
26
+ model_weights (str): The path to the model weights
27
+ trainRunIDs (list): The list of run IDs to use for training
28
+ validateRunIDs (list): The list of run IDs to use for validation
29
+ model_save_path (str): The path to save the model
30
+ best_metric (str): The metric to use for early stopping
31
+ num_epochs (int): The number of epochs to train for
32
+ """
33
+
34
+ # If No Model Configuration is Provided, Use the Default Configuration
35
+ if model_config is None:
36
+ root = copick.from_file(config)
37
+ model_config = {
38
+ 'architecture': 'Unet',
39
+ 'num_classes': root.pickable_objects[-1].label + 1,
40
+ 'dim_in': 80,
41
+ 'strides': [2, 2, 1],
42
+ 'channels': [48, 64, 80, 80],
43
+ 'dropout': 0.0, 'num_res_units': 1,
44
+ }
45
+ print('No Model Configuration Provided, Using Default Configuration')
46
+ print(model_config)
47
+
48
+ data_generator = generators.TrainLoaderManager(
49
+ config,
50
+ target_info[0],
51
+ target_session_id = target_info[2],
52
+ target_user_id = target_info[1],
53
+ tomo_algorithm = tomo_algorithm,
54
+ voxel_size = voxel_size,
55
+ Nclasses = model_config['num_classes'],
56
+ tomo_batch_size = 15 )
57
+
58
+ data_generator.get_data_splits(
59
+ trainRunIDs = trainRunIDs,
60
+ validateRunIDs = validateRunIDs,
61
+ train_ratio = 0.9, val_ratio = 0.1, test_ratio = 0.0,
62
+ create_test_dataset = False)
63
+
64
+ # Get the reload frequency
65
+ data_generator.get_reload_frequency(num_epochs)
66
+
67
+ # Monai Functions
68
+ metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
69
+
70
+ # Build the Model
71
+ model_builder = builder.get_model(model_config['architecture'])
72
+ model = model_builder.build_model(model_config)
73
+
74
+ # Load the Model Weights if Provided
75
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76
+ if model_weights:
77
+ state_dict = torch.load(model_weights, map_location=device, weights_only=True)
78
+ model.load_state_dict(state_dict)
79
+ model.to(device)
80
+
81
+ # Optimizer
82
+ optimizer = torch.optim.Adam(model.parameters(), 1e-3)
83
+
84
+ # Create UNet-Trainer
85
+ model_trainer = trainer.ModelTrainer(
86
+ model, device, loss_function, metrics_function, optimizer,
87
+ use_ema = use_ema
88
+ )
89
+
90
+ results = model_trainer.train(
91
+ data_generator, model_save_path, max_epochs=num_epochs,
92
+ crop_size=model_config['dim_in'], my_num_samples=16,
93
+ val_interval=10, best_metric=best_metric, verbose=True
94
+ )
95
+
96
+ # Save parameters and results
97
+ parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
98
+ io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
99
+
100
+ # TODO: Write Results to Zarr or Another File Format?
101
+ results_save_name = os.path.join(model_save_path, "results.json")
102
+ io.save_results_to_json(results, results_save_name)
103
+
104
+ def segment(config, tomo_algorithm, voxel_size, model_weights, model_config,
105
+ seg_info = ['predict', 'octopi', '1'], use_tta = False, run_ids = None):
106
+ """
107
+ Segment a Dataset using a Trained Model or Ensemble of Models
108
+
109
+ Args:
110
+ config (str): Path to the Copick Config File
111
+ tomo_algorithm (str): The tomographic algorithm to use for segmentation
112
+ voxel_size (float): The voxel size of the data
113
+ model_weights (str, list): The path to the model weights or a list of paths to the model weights
114
+ model_config (str, list): The model configuration or a list of model configurations
115
+ seg_info (list): The segmentation information
116
+ use_tta (bool): Whether to use test time augmentation
117
+ run_ids (list): The list of run IDs to use for segmentation
118
+ """
119
+
120
+ # Initialize the Predictor
121
+ predict = segmentation.Predictor(
122
+ config,
123
+ model_config,
124
+ model_weights,
125
+ apply_tta = use_tta
126
+ )
127
+
128
+ # Run batch prediction
129
+ predict.batch_predict(
130
+ runIDs=run_ids,
131
+ num_tomos_per_batch=15,
132
+ tomo_algorithm=tomo_algorithm,
133
+ voxel_spacing=voxel_size,
134
+ segmentation_name=seg_info[0],
135
+ segmentation_user_id=seg_info[1],
136
+ segmentation_session_id=seg_info[2]
137
+ )
138
+
139
+ def localize(config, voxel_size, seg_info, pick_user_id, pick_session_id, n_procs = 16,
140
+ method = 'watershed', filter_size = 10, radius_min_scale = 0.4, radius_max_scale = 1.0,
141
+ run_ids = None):
142
+ """
143
+ Extract 3D Coordinates from the Segmentation Maps
144
+
145
+ Args:
146
+ config (str): Path to the Copick Config File
147
+ voxel_size (float): The voxel size of the data
148
+ seg_info (list): The segmentation information
149
+ pick_user_id (str): The user ID of the pick
150
+ pick_session_id (str): The session ID of the pick
151
+ n_procs (int): The number of processes to use for parallelization
152
+ method (str): The method to use for localization
153
+ filter_size (int): The filter size to use for localization
154
+ radius_min_scale (float): The minimum radius scale to use for localization
155
+ radius_max_scale (float): The maximum radius scale to use for localization
156
+ run_ids (list): The list of run IDs to use for localization
157
+ """
158
+
159
+ # Load the Copick Config
160
+ root = copick.from_file(config)
161
+
162
+ # Get objects that can be Picked
163
+ objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
164
+
165
+ # Get all RunIDs
166
+ if run_ids is None:
167
+ run_ids = [run.name for run in root.runs]
168
+ n_run_ids = len(run_ids)
169
+
170
+ # Run Localization - Main Parallelization Loop
171
+ print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
172
+ with mp.Pool(processes=n_procs) as pool:
173
+ with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
174
+ worker_func = lambda run_id: process_localization(
175
+ root.get_run(run_id),
176
+ objects,
177
+ seg_info,
178
+ method,
179
+ voxel_size,
180
+ filter_size,
181
+ radius_min_scale,
182
+ radius_max_scale,
183
+ pick_session_id,
184
+ pick_user_id
185
+ )
186
+
187
+ for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
188
+ pbar.update(1)
189
+
190
+ print('Localization Complete!')
191
+
192
+
193
+ def evaluate(config,
194
+ gt_user_id, gt_session_id,
195
+ pred_user_id, pred_session_id,
196
+ run_ids = None, distance_threshold = 0.5, save_path = None):
197
+ """
198
+ Evaluate the Localization on a Dataset
199
+
200
+ Args:
201
+ config (str): Path to the Copick Config File
202
+ gt_user_id (str): The user ID of the ground truth
203
+ gt_session_id (str): The session ID of the ground truth
204
+ pred_user_id (str): The user ID of the predicted coordinates
205
+ pred_session_id (str): The session ID of the predicted coordinates
206
+ run_ids (list): The list of run IDs to use for evaluation
207
+ distance_threshold (float): The distance threshold to use for evaluation
208
+ save_path (str): The path to save the evaluation results
209
+ """
210
+
211
+ print('Running Evaluation on the Following Query:')
212
+ print(f'Distance Threshold: {distance_threshold}')
213
+ print(f'GT User ID: {gt_user_id}, GT Session ID: {gt_session_id}')
214
+ print(f'Pred User ID: {pred_user_id}, Pred Session ID: {pred_session_id}')
215
+ print(f'Run IDs: {run_ids}')
216
+
217
+ # Load the Copick Config
218
+ root = copick.from_file(config)
219
+
220
+ # For Now Lets Assume Object Names are None..
221
+ object_names = None
222
+
223
+ # Run Evaluation
224
+ eval = octopi_evaluate.evaluator(
225
+ config,
226
+ gt_user_id,
227
+ gt_session_id,
228
+ pred_user_id,
229
+ pred_session_id,
230
+ object_names=object_names
231
+ )
232
+
233
+ eval.run(
234
+ distance_threshold_scale=distance_threshold,
235
+ runIDs=run_ids, save_path=save_path
236
+ )
@@ -1,10 +1,13 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: octopi
3
- Version: 1.1
3
+ Version: 1.2.0
4
4
  Summary: Model architecture exploration for cryoET particle picking
5
+ Project-URL: Homepage, https://github.com/chanzuckerberg/octopi
6
+ Project-URL: Documentation, https://chanzuckerberg.github.io/octopi/
7
+ Project-URL: Issues, https://github.com/chanzuckerberg/octopi/issues
8
+ Author: Jonathan Schwartz, Kevin Zhao, Daniel Ji, Utz Ermel
5
9
  License: MIT
6
- Author: Jonathan Schwartz
7
- Requires-Python: >=3.9,<4.0
10
+ License-File: LICENSE
8
11
  Classifier: License :: OSI Approved :: MIT License
9
12
  Classifier: Programming Language :: Python :: 3
10
13
  Classifier: Programming Language :: Python :: 3.9
@@ -12,28 +15,41 @@ Classifier: Programming Language :: Python :: 3.10
12
15
  Classifier: Programming Language :: Python :: 3.11
13
16
  Classifier: Programming Language :: Python :: 3.12
14
17
  Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
20
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
21
+ Requires-Python: >=3.9
15
22
  Requires-Dist: copick
23
+ Requires-Dist: copick-utils
16
24
  Requires-Dist: ipywidgets
17
25
  Requires-Dist: kaleido
18
26
  Requires-Dist: matplotlib
19
- Requires-Dist: mlflow (==2.17.0)
20
- Requires-Dist: monai-weekly (==1.5.dev2448)
27
+ Requires-Dist: mlflow
28
+ Requires-Dist: monai
21
29
  Requires-Dist: mrcfile
22
30
  Requires-Dist: multiprocess
23
31
  Requires-Dist: nibabel
24
- Requires-Dist: optuna (==4.0.0)
32
+ Requires-Dist: optuna
25
33
  Requires-Dist: optuna-integration[botorch,pytorch-lightning]
26
34
  Requires-Dist: pandas
27
- Requires-Dist: plotly
28
35
  Requires-Dist: python-dotenv
29
- Requires-Dist: pytorch-lightning (==2.4.0)
30
- Requires-Dist: requests (>=2.25.1,<3.0.0)
31
- Requires-Dist: seaborn
36
+ Requires-Dist: requests
32
37
  Requires-Dist: torch-ema
33
38
  Requires-Dist: tqdm
34
- Project-URL: Documentation, https://chanzuckerberg.github.io/octopi/
35
- Project-URL: Homepage, https://github.com/chanzuckerberg/octopi
36
- Project-URL: Issues, https://github.com/chanzuckerberg/octopi/issues
39
+ Provides-Extra: dev
40
+ Requires-Dist: black>=24.8.0; extra == 'dev'
41
+ Requires-Dist: pre-commit>=3.8.0; extra == 'dev'
42
+ Requires-Dist: pytest>=6.2.3; extra == 'dev'
43
+ Requires-Dist: ruff>=0.6.4; extra == 'dev'
44
+ Provides-Extra: docs
45
+ Requires-Dist: mkdocs; extra == 'docs'
46
+ Requires-Dist: mkdocs-awesome-pages-plugin; extra == 'docs'
47
+ Requires-Dist: mkdocs-git-authors-plugin; extra == 'docs'
48
+ Requires-Dist: mkdocs-git-committers-plugin-2; extra == 'docs'
49
+ Requires-Dist: mkdocs-git-revision-date-localized-plugin; extra == 'docs'
50
+ Requires-Dist: mkdocs-material; extra == 'docs'
51
+ Requires-Dist: mkdocs-minify-plugin; extra == 'docs'
52
+ Requires-Dist: mkdocs-redirects; extra == 'docs'
37
53
  Description-Content-Type: text/markdown
38
54
 
39
55
  # OCTOPI 🐙🐙🐙
@@ -63,10 +79,16 @@ Our deep learning-based pipeline streamlines the training and execution of 3D au
63
79
 
64
80
  ### Installation
65
81
 
82
+ Octopi is availableon PyPI and can be installed using pip:
66
83
  ```bash
67
84
  pip install octopi
68
85
  ```
69
86
 
87
+ ⚠️ **Note**: One of the current dependencies is currently not working with pip 25. To temporarily reduce the pip version, run:
88
+ ```bash
89
+ pip install --upgrade "pip<25"
90
+ ```
91
+
70
92
  ### Basic Usage
71
93
 
72
94
  octopi provides two main command-line interfaces:
@@ -74,35 +96,25 @@ octopi provides two main command-line interfaces:
74
96
  ```bash
75
97
  # Main CLI for training, inference, and data processing
76
98
  octopi --help
77
- ```
78
99
 
79
- The main `octopi` command provides subcommands for:
80
- - Data import and preprocessing
81
- - Training label preparation
82
- - Model training and exploration
83
- - Inference and particle localization
84
-
85
- ```bash
86
100
  # HPC-specific CLI for submitting jobs to SLURM clusters
87
101
  octopi-slurm --help
88
102
  ```
89
103
 
90
- The `octopi-slurm` command provides utilities for:
91
- - Submitting training jobs to SLURM clusters
92
- - Managing distributed inference tasks
93
- - Handling batch processing on HPC systems
94
-
95
104
  ## 📚 Documentation
96
105
 
97
106
  For detailed documentation, tutorials, CLI and API reference, visit our [documentation](https://chanzuckerberg.github.io/octopi/).
98
107
 
99
108
  ## 🤝 Contributing
100
109
 
101
- This project adheres to the Contributor Covenant code of conduct. By participating, you are expected to uphold this code. Please report unacceptable behavior to opensource@chanzuckerberg.com.
110
+ ## Code of Conduct
111
+
112
+ This project adheres to the Contributor Covenant [code of conduct](https://github.com/chanzuckerberg/.github/blob/master/CODE_OF_CONDUCT.md).
113
+ By participating, you are expected to uphold this code.
114
+ Please report unacceptable behavior to [opensource@chanzuckerberg.com](mailto:opensource@chanzuckerberg.com).
102
115
 
103
116
  ## 🔒 Security
104
117
 
105
118
  If you believe you have found a security issue, please responsibly disclose by contacting us at security@chanzuckerberg.com.
106
119
 
107
120
 
108
-