octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (48) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +9 -9
  7. octopi/entry_points/create_slurm_submission.py +16 -8
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +22 -45
  11. octopi/entry_points/run_localize.py +37 -54
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +19 -12
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/main.py +1 -1
  19. octopi/models/common.py +1 -1
  20. octopi/processing/create_targets_from_picks.py +11 -5
  21. octopi/processing/downsample.py +6 -10
  22. octopi/processing/evaluate.py +24 -11
  23. octopi/processing/importers.py +4 -4
  24. octopi/pytorch/hyper_search.py +2 -3
  25. octopi/pytorch/model_search_submitter.py +15 -15
  26. octopi/pytorch/segmentation.py +147 -192
  27. octopi/pytorch/segmentation_multigpu.py +162 -0
  28. octopi/pytorch/trainer.py +9 -3
  29. octopi/utils/__init__.py +0 -0
  30. octopi/utils/config.py +57 -0
  31. octopi/utils/io.py +128 -0
  32. octopi/{utils.py → utils/parsers.py} +10 -84
  33. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  34. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  35. octopi/workflows.py +236 -0
  36. octopi-1.2.0.dist-info/METADATA +120 -0
  37. octopi-1.2.0.dist-info/RECORD +62 -0
  38. {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  39. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  40. {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  41. octopi/io.py +0 -457
  42. octopi/processing/my_metrics.py +0 -26
  43. octopi/processing/writers.py +0 -102
  44. octopi-1.0.dist-info/METADATA +0 -209
  45. octopi-1.0.dist-info/RECORD +0 -59
  46. octopi-1.0.dist-info/entry_points.txt +0 -4
  47. /octopi/{losses.py → utils/losses.py} +0 -0
  48. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
@@ -1,7 +1,7 @@
1
1
  from octopi.pytorch.model_search_submitter import ModelSearchSubmit
2
2
  from octopi.entry_points import common
3
+ from octopi.utils import parsers, io
3
4
  import argparse, os, pprint
4
- from octopi import utils
5
5
 
6
6
  def optuna_parser(parser_description, add_slurm: bool = False):
7
7
  """
@@ -20,22 +20,22 @@ def optuna_parser(parser_description, add_slurm: bool = False):
20
20
  # Input Arguments
21
21
  input_group = parser.add_argument_group("Input Arguments")
22
22
  common.add_config(input_group, single_config=False)
23
- input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
23
+ input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
24
24
  help="Target information, e.g., 'name' or 'name,user_id,session_id'")
25
25
  input_group.add_argument("--tomo-alg", default='wbp',
26
26
  help="Tomogram algorithm used for training")
27
27
  input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
28
28
  help="Name of the MLflow experiment (default: 'model-search').")
29
- input_group.add_argument("--trainRunIDs", type=utils.parse_list, default=None, required=False,
29
+ input_group.add_argument("--trainRunIDs", type=parsers.parse_list, default=None, required=False,
30
30
  help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
31
- input_group.add_argument("--validateRunIDs", type=utils.parse_list, default=None, required=False,
31
+ input_group.add_argument("--validateRunIDs", type=parsers.parse_list, default=None, required=False,
32
32
  help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
33
33
  input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
34
34
  "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
35
35
 
36
36
  model_group = parser.add_argument_group("Model Arguments")
37
37
  model_group.add_argument("--model-type", type=str, default='Unet', required=False,
38
- choices=['Unet', 'AttentionUnet'],
38
+ choices=['Unet', 'AttentionUnet', 'MedNeXt', 'SegResNet'],
39
39
  help="Model type to use for training")
40
40
  model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
41
41
 
@@ -61,7 +61,7 @@ def cli():
61
61
  args = optuna_parser(description)
62
62
 
63
63
  # Parse the CoPick configuration paths
64
- if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
64
+ if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
65
65
  else: copick_configs = args.config[0]
66
66
 
67
67
  # Create the model exploration directory
@@ -133,7 +133,7 @@ def save_parameters(args: argparse.Namespace,
133
133
  pprint.pprint(params); print()
134
134
 
135
135
  # Save to YAML file
136
- utils.save_parameters_yaml(params, output_path)
136
+ io.save_parameters_yaml(params, output_path)
137
137
 
138
138
  if __name__ == "__main__":
139
139
  cli()
@@ -1,8 +1,8 @@
1
+ import torch, argparse, json, pprint, yaml, os
1
2
  from octopi.pytorch import segmentation
2
3
  from octopi.entry_points import common
3
- import torch, argparse, json, pprint, yaml, os
4
- from octopi import utils
5
4
  from typing import List, Tuple
5
+ from octopi.utils import io
6
6
 
7
7
  def inference(
8
8
  copick_config_path: str,
@@ -136,7 +136,7 @@ def save_parameters(args: argparse.Namespace,
136
136
  output_path: str):
137
137
 
138
138
  # Load the model config
139
- model_config = utils.load_yaml(args.model_config)
139
+ model_config = io.load_yaml(args.model_config)
140
140
 
141
141
  # Create parameters dictionary
142
142
  params = {
@@ -160,7 +160,7 @@ def save_parameters(args: argparse.Namespace,
160
160
  pprint.pprint(params); print()
161
161
 
162
162
  # Save to YAML file
163
- utils.save_parameters_yaml(params, output_path)
163
+ io.save_parameters_yaml(params, output_path)
164
164
 
165
165
  if __name__ == "__main__":
166
166
  cli()
@@ -2,12 +2,11 @@ from octopi.datasets import generators, multi_config_generator
2
2
  from monai.losses import DiceLoss, FocalLoss, TverskyLoss
3
3
  from octopi.models import common as builder
4
4
  from monai.metrics import ConfusionMatrixMetric
5
+ from octopi.utils import parsers, io
5
6
  from octopi.entry_points import common
6
7
  from octopi.pytorch import trainer
7
- from octopi import io, utils
8
8
  import torch, os, argparse
9
9
  from typing import List, Optional, Tuple
10
- import pprint
11
10
 
12
11
  def train_model(
13
12
  copick_config_path: str,
@@ -56,7 +55,7 @@ def train_model(
56
55
 
57
56
 
58
57
  # Get the data splits
59
- ratios = utils.parse_data_split(data_split)
58
+ ratios = parsers.parse_data_split(data_split)
60
59
  data_generator.get_data_splits(
61
60
  trainRunIDs = trainRunIDs,
62
61
  validateRunIDs = validateRunIDs,
@@ -114,11 +113,11 @@ def train_model_parser(parser_description, add_slurm: bool = False):
114
113
  # Input Arguments
115
114
  input_group = parser.add_argument_group("Input Arguments")
116
115
  common.add_config(input_group, single_config=False)
117
- input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
116
+ input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
118
117
  help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
119
118
  input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
120
- input_group.add_argument("--trainRunIDs", type=utils.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
121
- input_group.add_argument("--validateRunIDs", type=utils.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
119
+ input_group.add_argument("--trainRunIDs", type=parsers.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
120
+ input_group.add_argument("--validateRunIDs", type=parsers.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
122
121
  input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
123
122
  "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
124
123
 
@@ -153,11 +152,11 @@ def cli():
153
152
  args = train_model_parser(parser_description)
154
153
 
155
154
  # Parse the CoPick configuration paths
156
- if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
155
+ if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
157
156
  else: copick_configs = args.config[0]
158
157
 
159
158
  if args.model_config:
160
- model_config = utils.load_yaml(args.model_config)
159
+ model_config = io.load_yaml(args.model_config)
161
160
  else:
162
161
  model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
163
162
 
@@ -3,15 +3,15 @@ from scipy.cluster.hierarchy import fcluster, linkage
3
3
  from skimage.segmentation import watershed
4
4
  from typing import List, Optional, Tuple
5
5
  from skimage.measure import regionprops
6
+ from copick_utils.io import readers
6
7
  from scipy.spatial import distance
7
8
  from dataclasses import dataclass
8
- from octopi import io
9
9
  import scipy.ndimage as ndi
10
10
  from tqdm import tqdm
11
11
  import numpy as np
12
- import math
12
+ import gc
13
13
 
14
- def processs_localization(run,
14
+ def process_localization(run,
15
15
  objects,
16
16
  seg_info: Tuple[str, str, str],
17
17
  method: str = 'com',
@@ -27,12 +27,12 @@ def processs_localization(run,
27
27
  raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
28
28
 
29
29
  # Get Segmentation
30
- seg = io.get_segmentation_array(run,
31
- voxel_size,
32
- seg_info[0],
33
- user_id=seg_info[1],
34
- session_id=seg_info[2],
35
- raise_error=False)
30
+ seg = readers.segmentation(
31
+ run, voxel_size,
32
+ seg_info[0],
33
+ user_id=seg_info[1],
34
+ session_id=seg_info[2],
35
+ raise_error=False)
36
36
 
37
37
  # Preprocess Segmentation
38
38
  # seg = preprocess_segmentation(seg, voxel_size, objects)
@@ -99,15 +99,15 @@ def extract_particle_centroids_via_watershed(
99
99
  max_particle_size (int): Maximum size threshold for particles.
100
100
  """
101
101
 
102
- if maxima_filter_size is None or maxima_filter_size < 0:
103
- AssertionError('Enter a Non-Zero Filter Size!')
102
+ if maxima_filter_size is None or maxima_filter_size <= 0:
103
+ raise ValueError('Enter a Non-Zero Filter Size!')
104
104
 
105
105
  # Calculate minimum and maximum particle volumes based on the given radii
106
106
  min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
107
107
  max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
108
108
 
109
109
  # Create a binary mask for the specific segmentation label
110
- binary_mask = (segmentation == segmentation_idx).astype(int)
110
+ binary_mask = (segmentation == segmentation_idx).astype(np.uint8)
111
111
 
112
112
  # Skip if the segmentation label is not present
113
113
  if np.sum(binary_mask) == 0:
@@ -117,6 +117,7 @@ def extract_particle_centroids_via_watershed(
117
117
  # Structuring element for erosion and dilation
118
118
  struct_elem = ball(1)
119
119
  eroded = binary_erosion(binary_mask, struct_elem)
120
+
120
121
  dilated = binary_dilation(eroded, struct_elem)
121
122
 
122
123
  # Distance transform and local maxima detection
@@ -125,7 +126,13 @@ def extract_particle_centroids_via_watershed(
125
126
 
126
127
  # Watershed segmentation
127
128
  markers, _ = ndi.label(local_max)
129
+ del local_max
130
+ gc.collect()
131
+
128
132
  watershed_labels = watershed(-distance, markers, mask=dilated)
133
+ distance, markers, dilated = None, None, None
134
+ del distance, markers, dilated
135
+ gc.collect()
129
136
 
130
137
  # Extract region properties and filter based on particle size
131
138
  all_centroids = []
@@ -1,5 +1,5 @@
1
1
  from scipy.spatial.transform import Rotation as R
2
- from octopi import utils, io
2
+ from copick_utils.io import readers
3
3
  import scipy.ndimage as ndi
4
4
  from typing import Tuple
5
5
  import numpy as np
@@ -36,7 +36,7 @@ def process_membrane_bound_extract(run,
36
36
  new_session_id = str(int(save_session_id) + 1) # Convert to string after increment
37
37
 
38
38
  # Need Better Error Handing for Missing Picks
39
- coordinates = io.get_copick_coordinates(
39
+ coordinates = readers.coordinates(
40
40
  run,
41
41
  picks_info[0], picks_info[1], picks_info[2],
42
42
  voxel_size,
@@ -54,12 +54,13 @@ def process_membrane_bound_extract(run,
54
54
  if membrane_info is None:
55
55
  # Flag to distinguish between organelle and membrane segmentation
56
56
  membranes_provided = False
57
- seg = io.get_segmentation_array(run,
58
- voxel_size,
59
- organelle_info[0],
60
- user_id=organelle_info[1],
61
- session_id=organelle_info[2],
62
- raise_error=False)
57
+ seg = readers.segmentation(
58
+ run,
59
+ voxel_size,
60
+ organelle_info[0],
61
+ user_id=organelle_info[1],
62
+ session_id=organelle_info[2],
63
+ raise_error=False)
63
64
  # If No Segmentation is Found, Return
64
65
  if seg is None: return
65
66
  elif nPoints == 0 or np.unique(seg).max() == 0:
@@ -68,7 +69,7 @@ def process_membrane_bound_extract(run,
68
69
  else:
69
70
  # Read both Organelle and Membrane Segmentations
70
71
  membranes_provided = True
71
- seg = io.get_segmentation_array(
72
+ seg = readers.segmentation(
72
73
  run,
73
74
  voxel_size,
74
75
  membrane_info[0],
@@ -76,7 +77,7 @@ def process_membrane_bound_extract(run,
76
77
  session_id=membrane_info[2],
77
78
  raise_error=False)
78
79
 
79
- organelle_seg = io.get_segmentation_array(
80
+ organelle_seg = readers.segmentation(
80
81
  run,
81
82
  voxel_size,
82
83
  organelle_info[0],
@@ -1,6 +1,6 @@
1
1
  from octopi.extract import membranebound_extract as extract
2
2
  from scipy.spatial.transform import Rotation as R
3
- from octopi import io
3
+ from copick_utils.io import readers
4
4
  from scipy.spatial import cKDTree
5
5
  from typing import Tuple
6
6
  import numpy as np
@@ -28,7 +28,7 @@ def process_midpoint_extract(
28
28
  """
29
29
 
30
30
  # Pull Picks that Are used for Midpoint Extraction
31
- coordinates = io.get_copick_coordinates(
31
+ coordinates = readers.coordinates(
32
32
  run,
33
33
  picks_info[0], picks_info[1], picks_info[2],
34
34
  voxel_size
@@ -40,7 +40,7 @@ def process_midpoint_extract(
40
40
  save_picks_info[2] = save_session_id
41
41
 
42
42
  # Get Organelle Segmentation
43
- seg = io.get_segmentation_array(
43
+ seg = readers.segmentation(
44
44
  run,
45
45
  voxel_size,
46
46
  organelle_info[0],
octopi/main.py CHANGED
@@ -33,7 +33,7 @@ def cli_main():
33
33
  "create-targets": (create_targets, "Generate segmentation targets from coordinates."),
34
34
  "train": (train_model, "Train a single U-Net model."),
35
35
  "model-explore": (model_explore, "Explore model architectures with Optuna / Bayesian Optimization."),
36
- "inference": (inference, "Perform segmentation inference on tomograms."),
36
+ "segment": (inference, "Perform segmentation inference on tomograms."),
37
37
  "localize": (localize, "Perform localization of particles in tomograms."),
38
38
  "extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms."),
39
39
  "evaluate": (evaluate, "Evaluate the performance of a model."),
octopi/models/common.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from monai.losses import FocalLoss, TverskyLoss
2
- from octopi import losses
2
+ from octopi.utils import losses
3
3
  from octopi.models import (
4
4
  Unet, AttentionUnet, MedNeXt, SegResNet
5
5
  )
@@ -1,6 +1,5 @@
1
1
  from octopi.processing.segmentation_from_picks import from_picks
2
- import octopi.processing.writers as write
3
- from octopi import io
2
+ from copick_utils.io import readers, writers
4
3
  from typing import List
5
4
  from tqdm import tqdm
6
5
  import numpy as np
@@ -42,7 +41,11 @@ def generate_targets(
42
41
 
43
42
  # If runIDs are not provided, load all runs
44
43
  if run_ids is None:
45
- run_ids = [run.name for run in root.runs]
44
+ run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is not None]
45
+ skipped_run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is None]
46
+
47
+ if skipped_run_ids:
48
+ print(f"Warning: skipping runs with no voxel spacing {voxel_size}: {skipped_run_ids}")
46
49
 
47
50
  # Iterate Over All Runs
48
51
  for runID in tqdm(run_ids):
@@ -52,7 +55,7 @@ def generate_targets(
52
55
  run = root.get_run(runID)
53
56
 
54
57
  # Get Tomogram
55
- tomo = io.get_tomogram_array(run, voxel_size, tomo_algorithm)
58
+ tomo = readers.tomogram(run, voxel_size, tomo_algorithm)
56
59
 
57
60
  # Initialize Target Volume
58
61
  target = np.zeros(tomo.shape, dtype=np.uint8)
@@ -87,6 +90,9 @@ def generate_targets(
87
90
  session_id=train_targets[target_name]["session_id"],
88
91
  )
89
92
 
93
+ # Filter out empty picks
94
+ query = [pick for pick in query if pick.points is not None]
95
+
90
96
  # Add Picks to Target
91
97
  for pick in query:
92
98
  numPicks += len(pick.points)
@@ -100,7 +106,7 @@ def generate_targets(
100
106
  # Write Segmentation for non-empty targets
101
107
  if target.max() > 0 and numPicks > 0:
102
108
  tqdm.write(f'Annotating {numPicks} picks in {runID}...')
103
- write.segmentation(run, target, target_user_name,
109
+ writers.segmentation(run, target, target_user_name,
104
110
  name = target_segmentation_name, session_id= target_session_id,
105
111
  voxel_size = voxel_size)
106
112
  print('Creation of targets complete!')
@@ -102,11 +102,6 @@ class FourierRescale:
102
102
  """
103
103
  in_depth, in_height, in_width = volume.shape[-3:]
104
104
 
105
- # Check if dimensions are odd
106
- d_is_odd = in_depth % 2
107
- h_is_odd = in_height % 2
108
- w_is_odd = in_width % 2
109
-
110
105
  # Calculate new dimensions
111
106
  extent_depth = in_depth * self.input_voxel_size[0]
112
107
  extent_height = in_height * self.input_voxel_size[1]
@@ -121,9 +116,10 @@ class FourierRescale:
121
116
  new_height = new_height - (new_height % 2)
122
117
  new_width = new_width - (new_width % 2)
123
118
 
124
- # Calculate starting points with odd/even correction
125
- start_d = (in_depth - new_depth) // 2 + (d_is_odd)
126
- start_h = (in_height - new_height) // 2 + (h_is_odd)
127
- start_w = (in_width - new_width) // 2 + (w_is_odd)
119
+ # Calculate starting points - properly centered around DC component
120
+ # No odd/even correction needed - just center the crop
121
+ start_d = (in_depth - new_depth) // 2
122
+ start_h = (in_height - new_height) // 2
123
+ start_w = (in_width - new_width) // 2
128
124
 
129
- return start_d, start_h, start_w, new_depth, new_height, new_width
125
+ return start_d, start_h, start_w, new_depth, new_height, new_width
@@ -1,7 +1,7 @@
1
- from octopi import utils, io
1
+ from copick_utils.io import readers
2
2
  from scipy.spatial import distance
3
+ import copick, json, os, yaml
3
4
  from typing import List
4
- import copick, json, os
5
5
  import numpy as np
6
6
 
7
7
  class evaluator:
@@ -95,12 +95,12 @@ class evaluator:
95
95
  for name, radius in self.objects:
96
96
 
97
97
  # Get Ground Truth and Predicted Coordinates
98
- gt_coordinates = io.get_copick_coordinates(
98
+ gt_coordinates = readers.coordinates(
99
99
  run, name,
100
100
  self.ground_truth_user_id, self.ground_truth_session_id,
101
101
  self.voxel_size, raise_error=False
102
102
  )
103
- pred_coordinates = io.get_copick_coordinates(
103
+ pred_coordinates = readers.coordinates(
104
104
  run, name,
105
105
  self.prediction_user_id, self.predict_session_id,
106
106
  self.voxel_size, raise_error=False
@@ -202,14 +202,27 @@ class evaluator:
202
202
  }
203
203
 
204
204
  os.makedirs(save_path, exist_ok=True)
205
- summary_metrics = { "input": self.input_params, "parameters": self.parameters,
206
- "summary_metrics": final_summary_metrics }
207
- with open(os.path.join(save_path, 'average_metrics.json'), 'w') as f:
208
- json.dump(summary_metrics, f, indent=4)
209
- print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.json")}')
205
+ summary_metrics = { "input": self.input_params,
206
+ "final_fbeta_score": final_fbeta,
207
+ "aggregated_particle_scores": { # Optionally add per-particle details
208
+ name: {
209
+ "tp": counts['total_tp'],
210
+ "fp": counts['total_fp'],
211
+ "fn": counts['total_fn'],
212
+ "weight": self.weights.get(name, 1)
213
+ } for name, counts in aggregated_counts.items()
214
+ },
215
+ "summary_metrics": final_summary_metrics,
216
+ "parameters": self.parameters, }
217
+
218
+ # Save average metrics to YAML file
219
+ with open(os.path.join(save_path, 'average_metrics.yaml'), 'w') as f:
220
+ yaml.dump(summary_metrics, f, indent=4, default_flow_style=False, sort_keys=False)
221
+ print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.yaml")}')
210
222
 
211
- detailed_metrics = { "input": self.input_params, "parameters": self.parameters,
212
- "metrics": metrics }
223
+ detailed_metrics = { "input": self.input_params,
224
+ "metrics": metrics,
225
+ "parameters": self.parameters, }
213
226
  with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
214
227
  json.dump(detailed_metrics, f, indent=4)
215
228
  print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
@@ -1,7 +1,7 @@
1
1
  from octopi.processing.downsample import FourierRescale
2
2
  import copick, argparse, mrcfile, glob, os
3
- import octopi.processing.writers as write
4
3
  from octopi.entry_points import common
4
+ from copick_utils.io import writers
5
5
  from tqdm import tqdm
6
6
 
7
7
  def from_dataportal(
@@ -57,10 +57,10 @@ def from_dataportal(
57
57
 
58
58
  # If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
59
59
  if output_voxel_size is None:
60
- write.tomogram(run, vol, input_voxel_size, target_tomo_type)
60
+ writers.tomogram(run, vol, input_voxel_size, target_tomo_type)
61
61
  else:
62
62
  vol = rescale.run(vol)
63
- write.tomogram(run, vol, output_voxel_size, target_tomo_type)
63
+ writers.tomogram(run, vol, output_voxel_size, target_tomo_type)
64
64
 
65
65
  print(f'Downloading Complete!! Downloaded {len(root.runs)} runs')
66
66
 
@@ -168,7 +168,7 @@ def from_mrcs(
168
168
  voxel_size_to_write = input_voxel_size
169
169
 
170
170
  # Write the tomogram
171
- write.tomogram(run, vol, voxel_size_to_write, target_tomo_type)
171
+ writers.tomogram(run, vol, voxel_size_to_write, target_tomo_type)
172
172
  print(f"Processed {len(mrc_files)} files from {mrcs_path}")
173
173
 
174
174
 
@@ -1,10 +1,9 @@
1
- from monai.losses import FocalLoss, TverskyLoss
2
1
  from monai.metrics import ConfusionMatrixMetric
3
2
  from octopi.pytorch import trainer
4
3
  from mlflow.tracking import MlflowClient
5
4
  from octopi.models import common
6
- from octopi import io, losses
7
5
  import torch, mlflow, optuna, gc
6
+ from octopi.utils import io
8
7
 
9
8
  class BayesianModelSearch:
10
9
 
@@ -207,7 +206,7 @@ class BayesianModelSearch:
207
206
  if score > best_score_so_far:
208
207
  torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
209
208
  io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
210
- f'{self.results_dir}/best_model_config.yaml')
209
+ f'{self.results_dir}/model_config.yaml')
211
210
 
212
211
  def get_best_score(self, trial):
213
212
  """Retrieve the best score from the trial."""
@@ -1,7 +1,7 @@
1
1
  from octopi.datasets import generators, multi_config_generator
2
+ from octopi.utils import config, parsers
2
3
  from octopi.pytorch import hyper_search
3
4
  import torch, mlflow, optuna
4
- from octopi import utils
5
5
  from typing import List
6
6
  import pandas as pd
7
7
 
@@ -16,16 +16,16 @@ class ModelSearchSubmit:
16
16
  voxel_size: float,
17
17
  Nclass: int,
18
18
  model_type: str,
19
- mlflow_experiment_name: str,
20
- random_seed: int,
21
- num_epochs: int,
22
- num_trials: int,
23
- tomo_batch_size: int,
24
- best_metric: str,
25
- val_interval: int,
26
- trainRunIDs: List[str],
27
- validateRunIDs: List[str],
28
- data_split: str
19
+ best_metric: str = 'avg_f1',
20
+ num_epochs: int = 1000,
21
+ num_trials: int = 100,
22
+ data_split: str = 0.8,
23
+ random_seed: int = 42,
24
+ val_interval: int = 10,
25
+ tomo_batch_size: int = 15,
26
+ trainRunIDs: List[str] = None,
27
+ validateRunIDs: List[str] = None,
28
+ mlflow_experiment_name: str = 'explore',
29
29
  ):
30
30
  """
31
31
  Initialize the ModelSearch class for architecture search with Optuna.
@@ -75,7 +75,7 @@ class ModelSearchSubmit:
75
75
  self.data_generator = None
76
76
 
77
77
  # Set random seed for reproducibility
78
- utils.set_seed(self.random_seed)
78
+ config.set_seed(self.random_seed)
79
79
 
80
80
  # Initialize dataset generator
81
81
  self._initialize_data_generator()
@@ -108,7 +108,7 @@ class ModelSearchSubmit:
108
108
  )
109
109
 
110
110
  # Split datasets into training and validation
111
- ratios = utils.parse_data_split(self.data_split)
111
+ ratios = parsers.parse_data_split(self.data_split)
112
112
  self.data_generator.get_data_splits(
113
113
  trainRunIDs=self.trainRunIDs,
114
114
  validateRunIDs=self.validateRunIDs,
@@ -134,7 +134,7 @@ class ModelSearchSubmit:
134
134
 
135
135
  # Set up MLflow tracking
136
136
  try:
137
- tracking_uri = utils.mlflow_setup()
137
+ tracking_uri = config.mlflow_setup()
138
138
  mlflow.set_tracking_uri(tracking_uri)
139
139
  except Exception as e:
140
140
  print(f'Failed to set up MLflow tracking: {e}')
@@ -207,7 +207,7 @@ class ModelSearchSubmit:
207
207
  # Run multi-GPU optimization
208
208
  study = self.get_optuna_study()
209
209
  study.optimize(
210
- lambda trial: BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
210
+ lambda trial: hyper_search.BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
211
211
  parent_run, trial,
212
212
  self.num_epochs,
213
213
  best_metric=self.best_metric,