waveorder 3.0.0a2__py3-none-any.whl → 3.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waveorder/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '3.0.0a2'
32
- __version_tuple__ = version_tuple = (3, 0, 0, 'a2')
31
+ __version__ = version = '3.0.0a3'
32
+ __version_tuple__ = version_tuple = (3, 0, 0, 'a3')
33
33
 
34
34
  __commit_id__ = commit_id = None
Binary file
@@ -5,6 +5,7 @@ This module converts GUI-level reconstruction calls into library calls
5
5
  import numpy as np
6
6
  import torch
7
7
 
8
+ from waveorder.cli.settings import FluorescenceSettings, PhaseSettings
8
9
  from waveorder.models import (
9
10
  inplane_oriented_thick_pol3d,
10
11
  inplane_oriented_thick_pol3d_vector,
@@ -62,7 +63,7 @@ def birefringence(
62
63
  def phase(
63
64
  czyx_data,
64
65
  recon_dim,
65
- settings_phase,
66
+ settings_phase: PhaseSettings,
66
67
  transfer_function_dataset,
67
68
  ):
68
69
  # [phase only, 2]
@@ -83,7 +84,7 @@ def phase(
83
84
  ) = isotropic_thin_3d.apply_inverse_transfer_function(
84
85
  czyx_data[0],
85
86
  (U, S, Vh),
86
- **settings_phase.apply_inverse.dict(),
87
+ **settings_phase.apply_inverse.model_dump(),
87
88
  )
88
89
  # Stack to C1YX
89
90
  output = phase_yx[None, None]
@@ -108,7 +109,7 @@ def phase(
108
109
  real_potential_transfer_function,
109
110
  imaginary_potential_transfer_function,
110
111
  z_padding=settings_phase.transfer_function.z_padding,
111
- **settings_phase.apply_inverse.dict(),
112
+ **settings_phase.apply_inverse.model_dump(),
112
113
  )
113
114
 
114
115
  # Pad to CZYX
@@ -124,7 +125,7 @@ def birefringence_and_phase(
124
125
  wavelength_illumination,
125
126
  recon_dim,
126
127
  biref_inverse_dict,
127
- settings_phase,
128
+ settings_phase: PhaseSettings,
128
129
  transfer_function_dataset,
129
130
  ):
130
131
  # Load birefringence transfer function
@@ -174,7 +175,7 @@ def birefringence_and_phase(
174
175
  ) = isotropic_thin_3d.apply_inverse_transfer_function(
175
176
  brightfield_3d,
176
177
  (U, S, Vh),
177
- **settings_phase.apply_inverse.dict(),
178
+ **settings_phase.apply_inverse.model_dump(),
178
179
  )
179
180
 
180
181
  # Convert retardance
@@ -222,7 +223,7 @@ def birefringence_and_phase(
222
223
  real_potential_transfer_function,
223
224
  imaginary_potential_transfer_function,
224
225
  z_padding=settings_phase.transfer_function.z_padding,
225
- **settings_phase.apply_inverse.dict(),
226
+ **settings_phase.apply_inverse.model_dump(),
226
227
  )
227
228
 
228
229
  # Convert retardance
@@ -253,7 +254,7 @@ def birefringence_and_phase(
253
254
  szyx_data=stokes,
254
255
  singular_system=singular_system,
255
256
  intensity_to_stokes_matrix=None,
256
- **settings_phase.apply_inverse.dict(),
257
+ **settings_phase.apply_inverse.model_dump(),
257
258
  )
258
259
 
259
260
  new_ret = (
@@ -282,7 +283,10 @@ def birefringence_and_phase(
282
283
 
283
284
 
284
285
  def fluorescence(
285
- czyx_data, recon_dim, settings_fluorescence, transfer_function_dataset
286
+ czyx_data,
287
+ recon_dim,
288
+ settings_fluorescence: FluorescenceSettings,
289
+ transfer_function_dataset,
286
290
  ):
287
291
  # [fluo, 2]
288
292
  if recon_dim == 2:
@@ -299,7 +303,7 @@ def fluorescence(
299
303
  output = isotropic_fluorescent_thin_3d.apply_inverse_transfer_function(
300
304
  czyx_data[0],
301
305
  (U, S, Vh),
302
- **settings_fluorescence.apply_inverse.dict(),
306
+ **settings_fluorescence.apply_inverse.model_dump(),
303
307
  )
304
308
  # [fluo, 3]
305
309
  elif recon_dim == 3:
@@ -314,7 +318,7 @@ def fluorescence(
314
318
  czyx_data[0],
315
319
  optical_transfer_function,
316
320
  settings_fluorescence.transfer_function.z_padding,
317
- **settings_fluorescence.apply_inverse.dict(),
321
+ **settings_fluorescence.apply_inverse.model_dump(),
318
322
  )
319
323
  )
320
324
  # Pad to CZYX
@@ -47,7 +47,7 @@ def get_reconstruction_output_metadata(position_path: Path, config_path: Path):
47
47
  )
48
48
  plate_metadata = dict(input_plate.zattrs)
49
49
  plate_metadata.pop("plate")
50
- except RuntimeError:
50
+ except (RuntimeError, FileNotFoundError):
51
51
  warnings.warn(
52
52
  "Position is not part of a plate...no plate metadata will be copied."
53
53
  )
@@ -171,7 +171,7 @@ def apply_inverse_transfer_function_single_position(
171
171
  # so this section converts the settings to a dict and separates the
172
172
  # waveorder parameters (biref_inverse_dict) from the waveorder
173
173
  # parameters (cyx_no_sample_data, and wavelength_illumination)
174
- biref_inverse_dict = settings.birefringence.apply_inverse.dict()
174
+ biref_inverse_dict = settings.birefringence.apply_inverse.model_dump()
175
175
 
176
176
  # Resolve background path into array
177
177
  background_path = biref_inverse_dict.pop("background_path")
@@ -279,7 +279,7 @@ def apply_inverse_transfer_function_single_position(
279
279
  partial_apply_inverse_to_zyx_and_save(t_idx)
280
280
 
281
281
  # Save metadata at position level
282
- output_dataset.zattrs["settings"] = settings.dict()
282
+ output_dataset.zattrs["settings"] = settings.model_dump()
283
283
 
284
284
  echo_headline(f"Closing {output_position_dirpath}\n")
285
285
 
@@ -67,14 +67,14 @@ def generate_and_save_vector_birefringence_transfer_function(
67
67
  echo_headline(
68
68
  f"Downsampling transfer function in X and Y by {transverse_downsample_factor}x"
69
69
  )
70
- phase_settings_dict = settings.phase.transfer_function.dict()
70
+ phase_settings_dict = settings.phase.transfer_function.model_dump()
71
71
  phase_settings_dict.pop("z_focus_offset") # not used in 3D
72
72
 
73
73
  sfZYX_transfer_function, _, singular_system = (
74
74
  inplane_oriented_thick_pol3d_vector.calculate_transfer_function(
75
75
  zyx_shape=zyx_shape,
76
76
  scheme=str(len(settings.input_channel_names)) + "-State",
77
- **settings.birefringence.transfer_function.dict(),
77
+ **settings.birefringence.transfer_function.model_dump(),
78
78
  **phase_settings_dict,
79
79
  fourier_oversample_factor=int(transverse_downsample_factor),
80
80
  )
@@ -109,7 +109,9 @@ def generate_and_save_vector_birefringence_transfer_function(
109
109
  )
110
110
 
111
111
 
112
- def generate_and_save_birefringence_transfer_function(settings, dataset):
112
+ def generate_and_save_birefringence_transfer_function(
113
+ settings: ReconstructionSettings, dataset
114
+ ):
113
115
  """Generates and saves the birefringence transfer function to the dataset, based on the settings.
114
116
 
115
117
  Parameters
@@ -125,7 +127,7 @@ def generate_and_save_birefringence_transfer_function(settings, dataset):
125
127
  intensity_to_stokes_matrix = (
126
128
  inplane_oriented_thick_pol3d.calculate_transfer_function(
127
129
  scheme=str(len(settings.input_channel_names)) + "-State",
128
- **settings.birefringence.transfer_function.dict(),
130
+ **settings.birefringence.transfer_function.model_dump(),
129
131
  )
130
132
  )
131
133
  # Save
@@ -152,7 +154,7 @@ def generate_and_save_phase_transfer_function(
152
154
  echo_headline("Generating phase transfer function with settings:")
153
155
  echo_settings(settings.phase.transfer_function)
154
156
 
155
- settings_dict = settings.phase.transfer_function.dict()
157
+ settings_dict = settings.phase.transfer_function.model_dump()
156
158
  if settings.reconstruction_dimension == 2:
157
159
  # Convert zyx_shape and z_pixel_size into yx_shape and z_position_list
158
160
  settings_dict["yx_shape"] = [zyx_shape[1], zyx_shape[2]]
@@ -240,7 +242,7 @@ def generate_and_save_fluorescence_transfer_function(
240
242
  """
241
243
  echo_headline("Generating fluorescence transfer function with settings:")
242
244
  echo_settings(settings.fluorescence.transfer_function)
243
- settings_dict = settings.fluorescence.transfer_function.dict()
245
+ settings_dict = settings.fluorescence.transfer_function.model_dump()
244
246
 
245
247
  if settings.reconstruction_dimension == 2:
246
248
  # Convert zyx_shape and z_pixel_size into yx_shape and z_position_list
@@ -396,7 +398,7 @@ def compute_transfer_function_cli(
396
398
  )
397
399
 
398
400
  # Write settings to metadata
399
- output_dataset.zattrs["settings"] = settings.dict()
401
+ output_dataset.zattrs["settings"] = settings.model_dump()
400
402
 
401
403
  echo_headline(f"Closing {output_dirpath}\n")
402
404
  output_dataset.close()
waveorder/cli/printing.py CHANGED
@@ -1,10 +1,14 @@
1
1
  import click
2
2
  import yaml
3
3
 
4
+ from waveorder.cli.settings import MyBaseModel
4
5
 
5
- def echo_settings(settings):
6
+
7
+ def echo_settings(settings: MyBaseModel):
6
8
  click.echo(
7
- yaml.dump(settings.dict(), default_flow_style=False, sort_keys=False)
9
+ yaml.dump(
10
+ settings.model_dump(), default_flow_style=False, sort_keys=False
11
+ )
8
12
  )
9
13
 
10
14
 
waveorder/cli/settings.py CHANGED
@@ -3,14 +3,15 @@ import warnings
3
3
  from pathlib import Path
4
4
  from typing import List, Literal, Optional, Union
5
5
 
6
- from pydantic.v1 import (
6
+ from pydantic import (
7
7
  BaseModel,
8
+ ConfigDict,
8
9
  Extra,
9
10
  NonNegativeFloat,
10
11
  NonNegativeInt,
11
12
  PositiveFloat,
12
- root_validator,
13
- validator,
13
+ field_validator,
14
+ model_validator,
14
15
  )
15
16
 
16
17
  # This file defines the configuration settings for the CLI.
@@ -22,8 +23,8 @@ from pydantic.v1 import (
22
23
 
23
24
 
24
25
  # All settings classes inherit from MyBaseModel, which forbids extra parameters to guard against typos
25
- class MyBaseModel(BaseModel, extra=Extra.forbid):
26
- pass
26
+ class MyBaseModel(BaseModel):
27
+ model_config = ConfigDict(extra="forbid")
27
28
 
28
29
 
29
30
  # Bottom level settings
@@ -34,7 +35,8 @@ class WavelengthIllumination(MyBaseModel):
34
35
  class BirefringenceTransferFunctionSettings(MyBaseModel):
35
36
  swing: float = 0.1
36
37
 
37
- @validator("swing")
38
+ @field_validator("swing")
39
+ @classmethod
38
40
  def swing_range(cls, v):
39
41
  if v <= 0 or v >= 1.0:
40
42
  raise ValueError(f"swing = {v} should be between 0 and 1.")
@@ -43,11 +45,9 @@ class BirefringenceTransferFunctionSettings(MyBaseModel):
43
45
 
44
46
  class BirefringenceApplyInverseSettings(WavelengthIllumination):
45
47
  background_path: Union[str, Path] = ""
46
- remove_estimated_background: bool = False
47
- flip_orientation: bool = False
48
- rotate_orientation: bool = False
49
48
 
50
- @validator("background_path")
49
+ @field_validator("background_path")
50
+ @classmethod
51
51
  def check_background_path(cls, v):
52
52
  if v == "":
53
53
  return v
@@ -57,34 +57,36 @@ class BirefringenceApplyInverseSettings(WavelengthIllumination):
57
57
  raise ValueError(f"{v} is not a existing directory")
58
58
  return raw_dir
59
59
 
60
+ remove_estimated_background: bool = False
61
+ flip_orientation: bool = False
62
+ rotate_orientation: bool = False
63
+
60
64
 
61
65
  class FourierTransferFunctionSettings(MyBaseModel):
62
66
  yx_pixel_size: PositiveFloat = 6.5 / 20
63
67
  z_pixel_size: PositiveFloat = 2.0
64
68
  z_padding: NonNegativeInt = 0
65
- z_focus_offset: Union[int, Literal["auto"]] = 0
69
+ z_focus_offset: Union[float, Literal["auto"]] = 0
66
70
  index_of_refraction_media: PositiveFloat = 1.3
67
71
  numerical_aperture_detection: PositiveFloat = 1.2
68
72
 
69
- @validator("numerical_aperture_detection")
70
- def na_det(cls, v, values):
71
- n = values["index_of_refraction_media"]
72
- if v > n:
73
+ @model_validator(mode="after")
74
+ def validate_numerical_aperture_detection(self):
75
+ if self.numerical_aperture_detection > self.index_of_refraction_media:
73
76
  raise ValueError(
74
- f"numerical_aperture_detection = {v} must be less than or equal to index_of_refraction_media = {n}"
77
+ f"numerical_aperture_detection = {self.numerical_aperture_detection} must be less than or equal to index_of_refraction_media = {self.index_of_refraction_media}"
75
78
  )
76
- return v
79
+ return self
77
80
 
78
- @validator("z_pixel_size")
79
- def warn_unit_consistency(cls, v, values):
80
- yx_pixel_size = values["yx_pixel_size"]
81
- ratio = yx_pixel_size / v
81
+ @model_validator(mode="after")
82
+ def warn_unit_consistency(self):
83
+ ratio = self.yx_pixel_size / self.z_pixel_size
82
84
  if ratio < 1.0 / 20 or ratio > 20:
83
85
  warnings.warn(
84
- f"yx_pixel_size ({yx_pixel_size}) / z_pixel_size ({v}) = {ratio}. Did you use consistent units?",
86
+ f"yx_pixel_size ({self.yx_pixel_size}) / z_pixel_size ({self.z_pixel_size}) = {ratio}. Did you use consistent units?",
85
87
  UserWarning,
86
88
  )
87
- return v
89
+ return self
88
90
 
89
91
 
90
92
  class FourierApplyInverseSettings(MyBaseModel):
@@ -101,30 +103,31 @@ class PhaseTransferFunctionSettings(
101
103
  numerical_aperture_illumination: NonNegativeFloat = 0.5
102
104
  invert_phase_contrast: bool = False
103
105
 
104
- @validator("numerical_aperture_illumination")
105
- def na_ill(cls, v, values):
106
- n = values.get("index_of_refraction_media")
107
- if v > n:
106
+ @model_validator(mode="after")
107
+ def validate_numerical_aperture_illumination(self):
108
+ if (
109
+ self.numerical_aperture_illumination
110
+ > self.index_of_refraction_media
111
+ ):
108
112
  raise ValueError(
109
- f"numerical_aperture_illumination = {v} must be less than or equal to index_of_refraction_media = {n}"
113
+ f"numerical_aperture_illumination = {self.numerical_aperture_illumination} must be less than or equal to index_of_refraction_media = {self.index_of_refraction_media}"
110
114
  )
111
- return v
115
+ return self
112
116
 
113
117
 
114
118
  class FluorescenceTransferFunctionSettings(FourierTransferFunctionSettings):
115
119
  wavelength_emission: PositiveFloat = 0.507
116
120
  confocal_pinhole_diameter: Optional[PositiveFloat] = None
117
121
 
118
- @validator("wavelength_emission")
119
- def warn_unit_consistency(cls, v, values):
120
- yx_pixel_size = values.get("yx_pixel_size")
121
- ratio = yx_pixel_size / v
122
+ @model_validator(mode="after")
123
+ def warn_unit_consistency(self):
124
+ ratio = self.yx_pixel_size / self.wavelength_emission
122
125
  if ratio < 1.0 / 20 or ratio > 20:
123
126
  warnings.warn(
124
- f"yx_pixel_size ({yx_pixel_size}) / wavelength_illumination ({v}) = {ratio}. Did you use consistent units?",
127
+ f"yx_pixel_size ({self.yx_pixel_size}) / wavelength_illumination ({self.wavelength_emission}) = {ratio}. Did you use consistent units?",
125
128
  UserWarning,
126
129
  )
127
- return v
130
+ return self
128
131
 
129
132
 
130
133
  # Second level settings
@@ -158,24 +161,21 @@ class ReconstructionSettings(MyBaseModel):
158
161
  NonNegativeInt, List[NonNegativeInt], Literal["all"]
159
162
  ] = "all"
160
163
  reconstruction_dimension: Literal[2, 3] = 3
161
- birefringence: Optional[BirefringenceSettings]
162
- phase: Optional[PhaseSettings]
163
- fluorescence: Optional[FluorescenceSettings]
164
-
165
- @root_validator(pre=False)
166
- def validate_reconstruction_types(cls, values):
167
- if (values.get("birefringence") or values.get("phase")) and values.get(
168
- "fluorescence"
169
- ) is not None:
164
+ birefringence: Optional[BirefringenceSettings] = None
165
+ phase: Optional[PhaseSettings] = None
166
+ fluorescence: Optional[FluorescenceSettings] = None
167
+
168
+ @model_validator(mode="after")
169
+ def validate_reconstruction_types(self):
170
+ if (
171
+ self.birefringence or self.phase
172
+ ) and self.fluorescence is not None:
170
173
  raise ValueError(
171
174
  '"fluorescence" cannot be present alongside "birefringence" or "phase". Please use one configuration file for a "fluorescence" reconstruction and another configuration file for a "birefringence" and/or "phase" reconstructions.'
172
175
  )
173
- num_channel_names = len(values.get("input_channel_names"))
174
- if values.get("birefringence") is None:
175
- if (
176
- values.get("phase") is None
177
- and values.get("fluorescence") is None
178
- ):
176
+ num_channel_names = len(self.input_channel_names)
177
+ if self.birefringence is None:
178
+ if self.phase is None and self.fluorescence is None:
179
179
  raise ValueError(
180
180
  "Provide settings for either birefringence, phase, birefringence + phase, or fluorescence."
181
181
  )
@@ -184,4 +184,4 @@ class ReconstructionSettings(MyBaseModel):
184
184
  f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence/phase reconstructions."
185
185
  )
186
186
 
187
- return values
187
+ return self
waveorder/cli/utils.py CHANGED
@@ -37,7 +37,7 @@ def is_single_position_store(position_path: Path) -> bool:
37
37
  # Try to open as HCS plate 3 levels up
38
38
  open_ome_zarr(position_path.parent.parent.parent, mode="r")
39
39
  return False # Successfully opened as plate
40
- except RuntimeError:
40
+ except (RuntimeError, FileNotFoundError):
41
41
  return True # Not a plate structure
42
42
 
43
43
 
waveorder/focus.py CHANGED
@@ -60,6 +60,8 @@ def focus_from_transverse_band(
60
60
  polynomial_fit_order: Optional[int] = None,
61
61
  plot_path: Optional[str] = None,
62
62
  threshold_FWHM: float = 0,
63
+ return_statistics: bool = False,
64
+ enable_subpixel_precision: bool = False,
63
65
  ):
64
66
  """Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band.
65
67
 
@@ -91,14 +93,22 @@ def focus_from_transverse_band(
91
93
  The default value, 0, applies no threshold, and the maximum midband power is always considered in focus.
92
94
  For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus.
93
95
  If the peak does not meet this threshold, the function returns None.
96
+ return_statistics: bool, optional
97
+ If True, returns a tuple (in_focus_index, peak_stats) instead of just in_focus_index.
98
+ Default is False for backward compatibility.
99
+ enable_subpixel_precision: bool, optional
100
+ If True and polynomial_fit_order is provided, enables sub-pixel precision focus detection
101
+ by finding the continuous extremum of the polynomial fit. Default is False for backward compatibility.
94
102
 
95
103
  Returns
96
- ------
97
- slice : int or None
98
- If peak's FWHM > peak_width_threshold:
99
- return the index of the in-focus slice
100
- else:
101
- return None
104
+ -------
105
+ slice : int, float, None, or tuple
106
+ If return_statistics is False (default):
107
+ Returns in_focus_index (int if enable_subpixel_precision=False,
108
+ float if enable_subpixel_precision=True and polynomial_fit_order is not None, or None).
109
+ If return_statistics is True:
110
+ Returns tuple (in_focus_index, peak_stats) where peak_stats is a dict
111
+ containing 'peak_index' and 'peak_FWHM'.
102
112
 
103
113
  Example
104
114
  ------
@@ -109,6 +119,7 @@ def focus_from_transverse_band(
109
119
  >>> in_focus_data = data[slice,:,:]
110
120
  """
111
121
  minmaxfunc = _mode_to_minmaxfunc(mode)
122
+ peak_stats = {"peak_index": None, "peak_FWHM": None}
112
123
 
113
124
  _check_focus_inputs(
114
125
  zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
@@ -119,6 +130,8 @@ def focus_from_transverse_band(
119
130
  warnings.warn(
120
131
  "The dataset only contained a single slice. Returning trivial slice index = 0."
121
132
  )
133
+ if return_statistics:
134
+ return 0, peak_stats
122
135
  return 0
123
136
 
124
137
  # Calculate midband power for each slice
@@ -140,10 +153,48 @@ def focus_from_transverse_band(
140
153
  else:
141
154
  x = np.arange(len(midband_sum))
142
155
  coeffs = np.polyfit(x, midband_sum, polynomial_fit_order)
143
- peak_index = minmaxfunc(np.poly1d(coeffs)(x))
156
+ poly_func = np.poly1d(coeffs)
157
+
158
+ if enable_subpixel_precision:
159
+ # Find the continuous extremum using derivative
160
+ poly_deriv = np.polyder(coeffs)
161
+ # Find roots of the derivative (critical points)
162
+ critical_points = np.roots(poly_deriv)
144
163
 
145
- peak_results = peak_widths(midband_sum, [peak_index])
164
+ # Filter for real roots within the data range
165
+ real_critical_points = []
166
+ for cp in critical_points:
167
+ if np.isreal(cp) and 0 <= cp.real < len(midband_sum):
168
+ real_critical_points.append(cp.real)
169
+
170
+ if real_critical_points:
171
+ # Evaluate the polynomial at critical points to find extremum
172
+ critical_values = [
173
+ poly_func(cp) for cp in real_critical_points
174
+ ]
175
+ if mode == "max":
176
+ best_idx = np.argmax(critical_values)
177
+ else: # mode == "min"
178
+ best_idx = np.argmin(critical_values)
179
+ peak_index = real_critical_points[best_idx]
180
+ else:
181
+ # Fall back to discrete maximum if no valid critical points
182
+ peak_index = float(minmaxfunc(poly_func(x)))
183
+ else:
184
+ peak_index = minmaxfunc(poly_func(x))
185
+
186
+ # For peak width calculation, use integer peak index
187
+ if enable_subpixel_precision and polynomial_fit_order is not None:
188
+ # Use the closest integer index for peak width calculation
189
+ integer_peak_index = int(np.round(peak_index))
190
+ else:
191
+ integer_peak_index = int(peak_index)
192
+
193
+ peak_results = peak_widths(midband_sum, [integer_peak_index])
146
194
  peak_FWHM = peak_results[0][0]
195
+ peak_stats.update(
196
+ {"peak_index": int(peak_index), "peak_FWHM": float(peak_FWHM)}
197
+ )
147
198
 
148
199
  if peak_FWHM >= threshold_FWHM:
149
200
  in_focus_index = peak_index
@@ -161,6 +212,9 @@ def focus_from_transverse_band(
161
212
  threshold_FWHM,
162
213
  )
163
214
 
215
+ if return_statistics:
216
+ return in_focus_index, peak_stats
217
+
164
218
  return in_focus_index
165
219
 
166
220
 
@@ -215,9 +269,19 @@ def _plot_focus_metric(
215
269
  ):
216
270
  _, ax = plt.subplots(1, 1, figsize=(4, 4))
217
271
  ax.plot(midband_sum, "-k")
272
+
273
+ # Handle floating-point peak_index for plotting
274
+ if isinstance(peak_index, float) and not peak_index.is_integer():
275
+ # Use interpolation to get the y-value at the floating-point x-position
276
+ peak_y_value = np.interp(
277
+ peak_index, np.arange(len(midband_sum)), midband_sum
278
+ )
279
+ else:
280
+ peak_y_value = midband_sum[int(peak_index)]
281
+
218
282
  ax.plot(
219
283
  peak_index,
220
- midband_sum[peak_index],
284
+ peak_y_value,
221
285
  "go" if in_focus_index is not None else "ro",
222
286
  )
223
287
  ax.hlines(*peak_results[1:], color="k", linestyles="dashed")
waveorder/io/utils.py CHANGED
@@ -7,6 +7,8 @@ import torch
7
7
  import yaml
8
8
  from iohub import open_ome_zarr
9
9
 
10
+ from waveorder.cli.settings import MyBaseModel
11
+
10
12
 
11
13
  def add_index_to_path(path: Path):
12
14
  """Takes a path to a file or folder and appends the smallest index that does
@@ -76,13 +78,13 @@ def ram_message():
76
78
  return (is_warning, message)
77
79
 
78
80
 
79
- def model_to_yaml(model, yaml_path: Path) -> None:
81
+ def model_to_yaml(model: MyBaseModel, yaml_path: Path) -> None:
80
82
  """
81
83
  Save a model's dictionary representation to a YAML file.
82
84
 
83
85
  Parameters
84
86
  ----------
85
- model : object
87
+ model : MyBaseModel
86
88
  The model object to convert to YAML.
87
89
  yaml_path : Path
88
90
  The path to the output YAML file.
@@ -110,7 +112,7 @@ def model_to_yaml(model, yaml_path: Path) -> None:
110
112
  if not hasattr(model, "dict"):
111
113
  raise TypeError("The 'model' object does not have a 'dict()' method.")
112
114
 
113
- model_dict = model.dict()
115
+ model_dict = model.model_dump()
114
116
 
115
117
  # Remove None-valued fields
116
118
  clean_model_dict = {