waveorder 2.2.1b0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. waveorder/_version.py +16 -3
  2. waveorder/acq/__init__.py +0 -0
  3. waveorder/acq/acq_functions.py +166 -0
  4. waveorder/assets/HSV_legend.png +0 -0
  5. waveorder/assets/JCh_legend.png +0 -0
  6. waveorder/assets/waveorder_plugin_logo.png +0 -0
  7. waveorder/calib/Calibration.py +1512 -0
  8. waveorder/calib/Optimization.py +470 -0
  9. waveorder/calib/__init__.py +0 -0
  10. waveorder/calib/calibration_workers.py +464 -0
  11. waveorder/cli/apply_inverse_models.py +328 -0
  12. waveorder/cli/apply_inverse_transfer_function.py +379 -0
  13. waveorder/cli/compute_transfer_function.py +432 -0
  14. waveorder/cli/gui_widget.py +58 -0
  15. waveorder/cli/main.py +39 -0
  16. waveorder/cli/monitor.py +163 -0
  17. waveorder/cli/option_eat_all.py +47 -0
  18. waveorder/cli/parsing.py +122 -0
  19. waveorder/cli/printing.py +16 -0
  20. waveorder/cli/reconstruct.py +67 -0
  21. waveorder/cli/settings.py +187 -0
  22. waveorder/cli/utils.py +175 -0
  23. waveorder/filter.py +1 -2
  24. waveorder/focus.py +136 -25
  25. waveorder/io/__init__.py +0 -0
  26. waveorder/io/_reader.py +61 -0
  27. waveorder/io/core_functions.py +272 -0
  28. waveorder/io/metadata_reader.py +195 -0
  29. waveorder/io/utils.py +175 -0
  30. waveorder/io/visualization.py +160 -0
  31. waveorder/models/inplane_oriented_thick_pol3d_vector.py +3 -3
  32. waveorder/models/isotropic_fluorescent_thick_3d.py +92 -0
  33. waveorder/models/isotropic_fluorescent_thin_3d.py +331 -0
  34. waveorder/models/isotropic_thin_3d.py +73 -72
  35. waveorder/models/phase_thick_3d.py +103 -4
  36. waveorder/napari.yaml +36 -0
  37. waveorder/plugin/__init__.py +9 -0
  38. waveorder/plugin/gui.py +1094 -0
  39. waveorder/plugin/gui.ui +1440 -0
  40. waveorder/plugin/job_manager.py +42 -0
  41. waveorder/plugin/main_widget.py +1605 -0
  42. waveorder/plugin/tab_recon.py +3294 -0
  43. waveorder/scripts/__init__.py +0 -0
  44. waveorder/scripts/launch_napari.py +13 -0
  45. waveorder/scripts/repeat-cal-acq-rec.py +147 -0
  46. waveorder/scripts/repeat-calibration.py +31 -0
  47. waveorder/scripts/samples.py +85 -0
  48. waveorder/scripts/simulate_zarr_acq.py +204 -0
  49. waveorder/util.py +1 -1
  50. waveorder/visuals/napari_visuals.py +1 -1
  51. waveorder-3.0.0.dist-info/METADATA +350 -0
  52. waveorder-3.0.0.dist-info/RECORD +69 -0
  53. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info}/WHEEL +1 -1
  54. waveorder-3.0.0.dist-info/entry_points.txt +5 -0
  55. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info/licenses}/LICENSE +13 -1
  56. waveorder-2.2.1b0.dist-info/METADATA +0 -187
  57. waveorder-2.2.1b0.dist-info/RECORD +0 -27
  58. {waveorder-2.2.1b0.dist-info → waveorder-3.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ from typing import Literal, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from waveorder import optics, sampling, util
8
+ from waveorder.filter import apply_filter_bank
9
+
10
+
11
+ def generate_test_phantom(
12
+ yx_shape: tuple[int, int],
13
+ yx_pixel_size: float,
14
+ sphere_radius: float,
15
+ ) -> Tensor:
16
+ """Generate a test phantom for fluorescent thin object.
17
+
18
+ Parameters
19
+ ----------
20
+ yx_shape : tuple[int, int]
21
+ Shape of YX dimensions
22
+ yx_pixel_size : float
23
+ Pixel size in YX plane
24
+ wavelength_emission : float
25
+ Emission wavelength
26
+ sphere_radius : float
27
+ Radius of spherical phantom
28
+
29
+ Returns
30
+ -------
31
+ Tensor
32
+ YX fluorescence density map
33
+ """
34
+ sphere, _, _ = util.generate_sphere_target(
35
+ (3,) + yx_shape,
36
+ yx_pixel_size,
37
+ z_pixel_size=1.0,
38
+ radius=sphere_radius,
39
+ blur_size=2 * yx_pixel_size,
40
+ )
41
+
42
+ # Use middle slice as thin fluorescent object
43
+ yx_fluorescence_density = sphere[1]
44
+
45
+ return yx_fluorescence_density
46
+
47
+
48
+ def calculate_transfer_function(
49
+ yx_shape: tuple[int, int],
50
+ yx_pixel_size: float,
51
+ z_position_list: list,
52
+ wavelength_emission: float,
53
+ index_of_refraction_media: float,
54
+ numerical_aperture_detection: float,
55
+ confocal_pinhole_diameter: float | None = None,
56
+ ) -> Tensor:
57
+ """Calculate transfer function for fluorescent thin object imaging.
58
+
59
+ Parameters
60
+ ----------
61
+ yx_shape : tuple[int, int]
62
+ Shape of YX dimensions
63
+ yx_pixel_size : float
64
+ Pixel size in YX plane
65
+ z_position_list : list
66
+ List of Z positions for defocus stack
67
+ wavelength_emission : float
68
+ Emission wavelength
69
+ index_of_refraction_media : float
70
+ Refractive index of imaging medium
71
+ numerical_aperture_detection : float
72
+ Numerical aperture of detection objective
73
+ confocal_pinhole_diameter : float | None, optional
74
+ Diameter of confocal pinhole. Not implemented for 2D fluorescence.
75
+
76
+ Returns
77
+ -------
78
+ Tensor
79
+ Fluorescent 2D-to-3D transfer function
80
+
81
+ Raises
82
+ ------
83
+ NotImplementedError
84
+ If confocal_pinhole_diameter is not None
85
+ """
86
+ if confocal_pinhole_diameter is not None:
87
+ raise NotImplementedError(
88
+ "Confocal reconstruction is not implemented for 2D fluorescence"
89
+ )
90
+
91
+ transverse_nyquist = sampling.transverse_nyquist(
92
+ wavelength_emission,
93
+ numerical_aperture_detection, # ill = det for fluorescence
94
+ numerical_aperture_detection,
95
+ )
96
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
97
+
98
+ fluorescent_2d_to_3d_transfer_function = (
99
+ _calculate_wrap_unsafe_transfer_function(
100
+ (
101
+ yx_shape[0] * yx_factor,
102
+ yx_shape[1] * yx_factor,
103
+ ),
104
+ yx_pixel_size / yx_factor,
105
+ z_position_list,
106
+ wavelength_emission,
107
+ index_of_refraction_media,
108
+ numerical_aperture_detection,
109
+ )
110
+ )
111
+
112
+ fluorescent_2d_to_3d_transfer_function_out = torch.zeros(
113
+ (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
114
+ )
115
+
116
+ for z in range(len(z_position_list)):
117
+ fluorescent_2d_to_3d_transfer_function_out[z] = (
118
+ sampling.nd_fourier_central_cuboid(
119
+ fluorescent_2d_to_3d_transfer_function[z], yx_shape
120
+ )
121
+ )
122
+
123
+ return fluorescent_2d_to_3d_transfer_function_out
124
+
125
+
126
+ def calculate_singular_system(
127
+ fluorescent_2d_to_3d_transfer_function: Tensor,
128
+ ) -> Tuple[Tensor, Tensor, Tensor]:
129
+ """Calculates the singular system of the fluorescent transfer function.
130
+
131
+ The transfer function has shape (Z, Vy, Vx), where (Z,) is the data-space
132
+ dimension, and (Vy, Vx) are the spatial frequency dimensions.
133
+
134
+ The SVD is computed over the (Z,) dimension.
135
+
136
+ Parameters
137
+ ----------
138
+ fluorescent_2d_to_3d_transfer_function : Tensor
139
+ ZYX transfer function for fluorescence
140
+
141
+ Returns
142
+ -------
143
+ Tuple[Tensor, Tensor, Tensor]
144
+ U, S, Vh components of the SVD
145
+ """
146
+ # For fluorescence, we have only one object property (fluorescence density)
147
+ # Input shape: (Z, Vy, Vx)
148
+
149
+ # We need to create the format: (1, Z, Vy, Vx) where 1 represents single object type
150
+ sfYX_transfer_function = fluorescent_2d_to_3d_transfer_function[None]
151
+
152
+ # Permute to: (Vy, Vx, 1, Z) for SVD
153
+ YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1)
154
+ Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False)
155
+ # SVD gives us: Up: (Vy, Vx, 1, min(1,Z)), Sp: (Vy, Vx, min(1,Z)), Vhp: (Vy, Vx, min(1,Z), Z)
156
+
157
+ # Permute back to match expected format:
158
+ U = Up.permute(2, 3, 0, 1) # (1, min(1,Z), Vy, Vx) -> (1, Z, Vy, Vx)
159
+ S = Sp.permute(2, 0, 1) # (min(1,Z), Vy, Vx) -> (1, Vy, Vx)
160
+ Vh = Vhp.permute(2, 3, 0, 1) # (min(1,Z), Z, Vy, Vx) -> (1, 1, Vy, Vx)
161
+ return U, S, Vh
162
+
163
+
164
+ def _calculate_wrap_unsafe_transfer_function(
165
+ yx_shape: tuple[int, int],
166
+ yx_pixel_size: float,
167
+ z_position_list: list,
168
+ wavelength_emission: float,
169
+ index_of_refraction_media: float,
170
+ numerical_aperture_detection: float,
171
+ ) -> Tensor:
172
+ """Calculate wrap-unsafe transfer function for fluorescent imaging."""
173
+ radial_frequencies = util.generate_radial_frequencies(
174
+ yx_shape, yx_pixel_size
175
+ )
176
+
177
+ det_pupil = optics.generate_pupil(
178
+ radial_frequencies,
179
+ numerical_aperture_detection,
180
+ wavelength_emission,
181
+ )
182
+
183
+ propagation_kernel = optics.generate_propagation_kernel(
184
+ radial_frequencies,
185
+ det_pupil,
186
+ wavelength_emission / index_of_refraction_media,
187
+ torch.tensor(z_position_list),
188
+ )
189
+
190
+ zyx_shape = (len(z_position_list),) + tuple(yx_shape)
191
+ fluorescent_2d_to_3d_transfer_function = torch.zeros(
192
+ zyx_shape, dtype=torch.complex64
193
+ )
194
+
195
+ for z in range(len(z_position_list)):
196
+ # For fluorescent imaging, the transfer function is the squared magnitude
197
+ # of the coherent transfer function (incoherent imaging)
198
+ point_spread_function = (
199
+ torch.abs(torch.fft.ifft2(propagation_kernel[z], dim=(0, 1))) ** 2
200
+ )
201
+ fluorescent_2d_to_3d_transfer_function[z] = torch.fft.fft2(
202
+ point_spread_function
203
+ )
204
+
205
+ # Normalize
206
+ max_val = torch.max(torch.abs(fluorescent_2d_to_3d_transfer_function))
207
+ if max_val > 0:
208
+ fluorescent_2d_to_3d_transfer_function /= max_val
209
+
210
+ return fluorescent_2d_to_3d_transfer_function
211
+
212
+
213
+ def visualize_transfer_function(
214
+ viewer,
215
+ fluorescent_2d_to_3d_transfer_function: Tensor,
216
+ zyx_scale: tuple[float, float, float],
217
+ ) -> None:
218
+ """Visualize the fluorescent transfer function in napari."""
219
+ arrays = [
220
+ (
221
+ torch.imag(fluorescent_2d_to_3d_transfer_function),
222
+ "Im(fluorescent TF)",
223
+ ),
224
+ (
225
+ torch.real(fluorescent_2d_to_3d_transfer_function),
226
+ "Re(fluorescent TF)",
227
+ ),
228
+ ]
229
+
230
+ for array in arrays:
231
+ lim = (0.5 * torch.max(torch.abs(array[0]))).item()
232
+ viewer.add_image(
233
+ torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
234
+ name=array[1],
235
+ colormap="bwr",
236
+ contrast_limits=(-lim, lim),
237
+ scale=zyx_scale,
238
+ )
239
+ viewer.dims.order = (2, 0, 1)
240
+
241
+
242
+ def apply_transfer_function(
243
+ yx_fluorescence_density: Tensor,
244
+ fluorescent_2d_to_3d_transfer_function: Tensor,
245
+ background: int = 10,
246
+ ) -> Tensor:
247
+ """Simulate fluorescent imaging by applying the transfer function.
248
+
249
+ Parameters
250
+ ----------
251
+ yx_fluorescence_density : Tensor
252
+ 2D fluorescence density map
253
+ fluorescent_2d_to_3d_transfer_function : Tensor
254
+ 3D transfer function
255
+ background : int, optional
256
+ Background counts, by default 10
257
+
258
+ Returns
259
+ -------
260
+ Tensor
261
+ Simulated 3D fluorescent data stack
262
+ """
263
+ # Simulate fluorescent object imaging
264
+ yx_fluorescence_hat = torch.fft.fftn(yx_fluorescence_density)
265
+ zyx_fluorescence_data_hat = yx_fluorescence_hat[None] * torch.real(
266
+ fluorescent_2d_to_3d_transfer_function
267
+ )
268
+ zyx_fluorescence_data = torch.real(
269
+ torch.fft.ifftn(zyx_fluorescence_data_hat, dim=(1, 2))
270
+ )
271
+
272
+ # Add background
273
+ data = zyx_fluorescence_data + background
274
+ return data
275
+
276
+
277
+ def apply_inverse_transfer_function(
278
+ zyx_data: Tensor,
279
+ singular_system: Tuple[Tensor, Tensor, Tensor],
280
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
281
+ regularization_strength: float = 1e-3,
282
+ TV_rho_strength: float = 1e-3,
283
+ TV_iterations: int = 10,
284
+ ) -> Tensor:
285
+ """Reconstruct fluorescence density from zyx_data and singular system.
286
+
287
+ Parameters
288
+ ----------
289
+ zyx_data : Tensor
290
+ 3D raw data, fluorescence defocus stack
291
+ singular_system : Tuple[Tensor, Tensor, Tensor]
292
+ Singular system of the fluorescent transfer function
293
+ reconstruction_algorithm : Literal["Tikhonov", "TV"], optional
294
+ Reconstruction algorithm, by default "Tikhonov"
295
+ "TV" is not implemented
296
+ regularization_strength : float, optional
297
+ Regularization parameter, by default 1e-3
298
+ TV_rho_strength : float, optional
299
+ TV-specific regularization parameter, by default 1e-3
300
+ "TV" is not implemented
301
+ TV_iterations : int, optional
302
+ TV-specific number of iterations, by default 10
303
+ "TV" is not implemented
304
+
305
+ Returns
306
+ -------
307
+ Tensor
308
+ YX fluorescence density reconstruction
309
+
310
+ Raises
311
+ ------
312
+ NotImplementedError
313
+ TV is not implemented
314
+ """
315
+ if reconstruction_algorithm == "Tikhonov":
316
+ print("Computing inverse filter")
317
+ U, S, Vh = singular_system
318
+ S_reg = S / (S**2 + regularization_strength)
319
+ sfyx_inverse_filter = torch.einsum(
320
+ "sj...,j...,jf...->fs...", U, S_reg, Vh
321
+ )
322
+
323
+ # Apply filter bank - returns tuple but we only have one object type
324
+ yx_fluorescence_density = apply_filter_bank(
325
+ sfyx_inverse_filter, zyx_data
326
+ )[0]
327
+
328
+ elif reconstruction_algorithm == "TV":
329
+ raise NotImplementedError("TV reconstruction is not implemented")
330
+
331
+ return yx_fluorescence_density
@@ -5,6 +5,7 @@ import torch
5
5
  from torch import Tensor
6
6
 
7
7
  from waveorder import optics, sampling, util
8
+ from waveorder.filter import apply_filter_bank
8
9
 
9
10
 
10
11
  def generate_test_phantom(
@@ -29,7 +30,7 @@ def generate_test_phantom(
29
30
  / wavelength_illumination
30
31
  ) # phase in radians
31
32
 
32
- yx_absorption = 0.02 * sphere[1]
33
+ yx_absorption = torch.clone(yx_phase)
33
34
 
34
35
  return yx_absorption, yx_phase
35
36
 
@@ -103,9 +104,17 @@ def _calculate_wrap_unsafe_transfer_function(
103
104
  numerical_aperture_detection: float,
104
105
  invert_phase_contrast: bool = False,
105
106
  ) -> Tuple[Tensor, Tensor]:
106
- if invert_phase_contrast:
107
- z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
107
+ if numerical_aperture_illumination >= numerical_aperture_detection:
108
+ print(
109
+ "Warning: numerical_aperture_illumination is >= "
110
+ "numerical_aperture_detection. Setting "
111
+ "numerical_aperture_illumination to 0.9 * "
112
+ "numerical_aperture_detection to avoid singularities."
113
+ )
114
+ numerical_aperture_illumination = 0.9 * numerical_aperture_detection
108
115
 
116
+ if invert_phase_contrast:
117
+ z_position_list = [-1 * x for x in z_position_list]
109
118
  radial_frequencies = util.generate_radial_frequencies(
110
119
  yx_shape, yx_pixel_size
111
120
  )
@@ -148,6 +157,45 @@ def _calculate_wrap_unsafe_transfer_function(
148
157
  )
149
158
 
150
159
 
160
+ def calculate_singular_system(
161
+ absorption_2d_to_3d_transfer_function: Tensor,
162
+ phase_2d_to_3d_transfer_function: Tensor,
163
+ ) -> Tuple[Tensor, Tensor, Tensor]:
164
+ """Calculates the singular system of the absoprtion and phase transfer
165
+ functions.
166
+
167
+ Together, the transfer functions form a (2, Z, Vy, Vx) tensor, where
168
+ (2,) is the object-space dimension (abs, phase), (Z,) is the data-space
169
+ dimension, and (Vy, Vx) are the spatial frequency dimensions.
170
+
171
+ The SVD is computed over the (2, Z) dimensions.
172
+
173
+ Parameters
174
+ ----------
175
+ absorption_2d_to_3d_transfer_function : Tensor
176
+ ZYX transfer function for absorption
177
+ phase_2d_to_3d_transfer_function : Tensor
178
+ ZYX transfer function for phase
179
+
180
+ Returns
181
+ -------
182
+ Tuple[Tensor, Tensor, Tensor]
183
+ """
184
+ sfYX_transfer_function = torch.stack(
185
+ (
186
+ absorption_2d_to_3d_transfer_function,
187
+ phase_2d_to_3d_transfer_function,
188
+ ),
189
+ dim=0,
190
+ )
191
+ YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1)
192
+ Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False)
193
+ U = Up.permute(2, 3, 0, 1)
194
+ S = Sp.permute(2, 0, 1)
195
+ Vh = Vhp.permute(2, 3, 0, 1)
196
+ return U, S, Vh
197
+
198
+
151
199
  def visualize_transfer_function(
152
200
  viewer,
153
201
  absorption_2d_to_3d_transfer_function: Tensor,
@@ -166,7 +214,7 @@ def visualize_transfer_function(
166
214
  ]
167
215
 
168
216
  for array in arrays:
169
- lim = 0.5 * torch.max(torch.abs(array[0]))
217
+ lim = (0.5 * torch.max(torch.abs(array[0]))).item()
170
218
  viewer.add_image(
171
219
  torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
172
220
  name=array[1],
@@ -188,7 +236,7 @@ def visualize_point_spread_function(
188
236
  ]
189
237
 
190
238
  for array in arrays:
191
- lim = 0.5 * torch.max(torch.abs(array[0]))
239
+ lim = (0.5 * torch.max(torch.abs(array[0]))).item()
192
240
  viewer.add_image(
193
241
  torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
194
242
  name=array[1],
@@ -202,8 +250,8 @@ def visualize_point_spread_function(
202
250
  def apply_transfer_function(
203
251
  yx_absorption: Tensor,
204
252
  yx_phase: Tensor,
205
- phase_2d_to_3d_transfer_function: Tensor,
206
253
  absorption_2d_to_3d_transfer_function: Tensor,
254
+ phase_2d_to_3d_transfer_function: Tensor,
207
255
  ) -> Tensor:
208
256
  # Very simple simulation, consider adding noise and bkg knobs
209
257
 
@@ -233,14 +281,13 @@ def apply_transfer_function(
233
281
 
234
282
  def apply_inverse_transfer_function(
235
283
  zyx_data: Tensor,
236
- absorption_2d_to_3d_transfer_function: Tensor,
237
- phase_2d_to_3d_transfer_function: Tensor,
284
+ singular_system: Tuple[Tensor, Tensor, Tensor],
238
285
  reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
239
- regularization_strength: float = 1e-6,
286
+ regularization_strength: float = 1e-3,
240
287
  reg_p: float = 1e-6, # TODO: use this parameter
241
288
  TV_rho_strength: float = 1e-3,
242
289
  TV_iterations: int = 10,
243
- bg_filter: bool = True,
290
+ bg_filter: bool = False,
244
291
  ) -> Tuple[Tensor, Tensor]:
245
292
  """Reconstructs absorption and phase from zyx_data and a pair of
246
293
  3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
@@ -251,15 +298,13 @@ def apply_inverse_transfer_function(
251
298
  ----------
252
299
  zyx_data : Tensor
253
300
  3D raw data, label-free defocus stack
254
- absorption_2d_to_3d_transfer_function : Tensor
255
- 3D-to-2D absorption transfer function, see calculate_transfer_function above
256
- phase_2d_to_3d_transfer_function : Tensor
257
- 3D-to-2D phase transfer function, see calculate_transfer_function above
258
- reconstruction_algorithm : Literal["Tikhonov", "TV"], optional
301
+ singular_system : Tuple[Tensor, Tensor, Tensor]
302
+ singular system of the transfer function bank
303
+ reconstruction_algorithm : Literal["Tikhonov";, "TV";], optional
259
304
  "Tikhonov" or "TV", by default "Tikhonov"
260
305
  "TV" is not implemented.
261
306
  regularization_strength : float, optional
262
- regularization parameter, by default 1e-6
307
+ regularization parameter, by default 1e-3
263
308
  reg_p : float, optional
264
309
  TV-specific phase regularization parameter, by default 1e-6
265
310
  "TV" is not implemented.
@@ -268,7 +313,7 @@ def apply_inverse_transfer_function(
268
313
  "TV" is not implemented.
269
314
  bg_filter : bool, optional
270
315
  option for slow-varying 2D background normalization with uniform filter
271
- by default True
316
+ by default False
272
317
 
273
318
  Returns
274
319
  -------
@@ -281,66 +326,22 @@ def apply_inverse_transfer_function(
281
326
  NotImplementedError
282
327
  TV is not implemented
283
328
  """
284
- zyx_data_normalized = util.inten_normalization(
285
- zyx_data, bg_filter=bg_filter
286
- )
329
+ # Normalize
330
+ zyx = util.inten_normalization(zyx_data, bg_filter=bg_filter)
287
331
 
288
- zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2))
289
-
290
- # TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations
291
- # TODO Reformulate to use filter.apply_filter_bank
292
- AHA = [
293
- torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0)
294
- + regularization_strength,
295
- torch.sum(
296
- torch.conj(absorption_2d_to_3d_transfer_function)
297
- * phase_2d_to_3d_transfer_function,
298
- dim=0,
299
- ),
300
- torch.sum(
301
- torch.conj(
302
- phase_2d_to_3d_transfer_function,
303
- )
304
- * absorption_2d_to_3d_transfer_function,
305
- dim=0,
306
- ),
307
- torch.sum(
308
- torch.abs(
309
- phase_2d_to_3d_transfer_function,
310
- )
311
- ** 2,
312
- dim=0,
313
- )
314
- + reg_p,
315
- ]
316
-
317
- b_vec = [
318
- torch.sum(
319
- torch.conj(absorption_2d_to_3d_transfer_function) * zyx_data_hat,
320
- dim=0,
321
- ),
322
- torch.sum(
323
- torch.conj(
324
- phase_2d_to_3d_transfer_function,
325
- )
326
- * zyx_data_hat,
327
- dim=0,
328
- ),
329
- ]
330
-
331
- # Deconvolution with Tikhonov regularization
332
+ # TODO Consider refactoring with vectorial transfer function SVD
332
333
  if reconstruction_algorithm == "Tikhonov":
333
- absorption, phase = util.dual_variable_tikhonov_deconvolution_2d(
334
- AHA, b_vec
334
+ print("Computing inverse filter")
335
+ U, S, Vh = singular_system
336
+ S_reg = S / (S**2 + regularization_strength)
337
+ sfyx_inverse_filter = torch.einsum(
338
+ "sj...,j...,jf...->fs...", U, S_reg, Vh
335
339
  )
336
340
 
341
+ absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx)
342
+
337
343
  # ADMM deconvolution with anisotropic TV regularization
338
344
  elif reconstruction_algorithm == "TV":
339
345
  raise NotImplementedError
340
- absorption, phase = util.dual_variable_admm_tv_deconv_2d(
341
- AHA, b_vec, rho=TV_rho_strength, itr=TV_iterations
342
- )
343
-
344
- phase -= torch.mean(phase)
345
346
 
346
- return absorption, phase
347
+ return absorption_yx, phase_yx
@@ -10,15 +10,95 @@ from waveorder.models import isotropic_fluorescent_thick_3d
10
10
  from waveorder.reconstruct import tikhonov_regularized_inverse_filter
11
11
  from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
12
12
 
13
+ """
14
+ Phase Thick 3D Model - Units and Conventions
15
+ =============================================
16
+
17
+ This module implements phase-from-defocus optical diffraction tomography (ODT)
18
+ for thick phase objects using the weak object transfer function (first Born
19
+ approximation).
20
+
21
+ Units Convention
22
+ ----------------
23
+ This model uses "cycles" as the fundamental unit for phase:
24
+ - 1 cycle = 2π radians = 1 wavelength of optical path difference
25
+
26
+ Phantom (input):
27
+ Phase in cycles per voxel = (Δn × z_pixel_size) / λ_medium
28
+ where:
29
+ - Δn = n_sample - n_media (refractive index difference)
30
+ - z_pixel_size = voxel thickness
31
+ - λ_medium = λ_vacuum / n_media (wavelength in medium)
32
+
33
+ Reconstruction (output):
34
+ Phase in cycles per voxel (same units as phantom)
35
+
36
+ Converting Between Units
37
+ ------------------------
38
+ From cycles to radians:
39
+ phase_radians = 2 * np.pi * phase_cycles
40
+
41
+ From cycles to refractive index difference:
42
+ wavelength_medium = wavelength_vacuum / n_media
43
+ delta_n = phase_cycles * wavelength_medium / z_pixel_size
44
+
45
+ From cycles to optical path length:
46
+ optical_path_length = phase_cycles * wavelength_medium
47
+
48
+ Physics Background
49
+ ------------------
50
+ The weak object approximation (first Born approximation) assumes:
51
+ 1. Small refractive index variations: |Δn| << n_media
52
+ 2. Weak scattering: no multiple scattering
53
+ 3. Linear relationship between object and measured intensity
54
+
55
+ Reference
56
+ ---------
57
+ J. M. Soto, J. A. Rodrigo, and T. Alieva, "Label-free quantitative 3D
58
+ tomographic imaging for partially coherent light microscopy,"
59
+ Opt. Express 25, 15699-15712 (2017)
60
+ """
61
+
13
62
 
14
63
  def generate_test_phantom(
15
64
  zyx_shape: tuple[int, int, int],
16
65
  yx_pixel_size: float,
17
66
  z_pixel_size: float,
67
+ wavelength_illumination: float,
18
68
  index_of_refraction_media: float,
19
69
  index_of_refraction_sample: float,
20
70
  sphere_radius: float,
21
71
  ) -> np.ndarray:
72
+ """
73
+ Generate a spherical phantom with phase in cycles per voxel.
74
+
75
+ Parameters
76
+ ----------
77
+ zyx_shape : tuple[int, int, int]
78
+ Shape of the 3D volume (Z, Y, X)
79
+ yx_pixel_size : float
80
+ Pixel size in transverse (Y, X) dimensions (length)
81
+ z_pixel_size : float
82
+ Pixel size in axial (Z) dimension (length)
83
+ wavelength_illumination : float
84
+ Wavelength of illumination light (length, same units as pixel sizes)
85
+ index_of_refraction_media : float
86
+ Refractive index of the surrounding medium
87
+ index_of_refraction_sample : float
88
+ Refractive index of the sphere
89
+ sphere_radius : float
90
+ Radius of the sphere (length, same units as pixel sizes)
91
+
92
+ Returns
93
+ -------
94
+ np.ndarray
95
+ 3D array of phase in cycles per voxel.
96
+ Units: (n_sample - n_media) × z_pixel_size / λ_medium [cycles/voxel]
97
+
98
+ Each voxel value represents the phase shift (in cycles) that light
99
+ acquires when passing through that voxel. This matches the units
100
+ returned by apply_inverse_transfer_function().
101
+ """
22
102
  sphere, _, _ = util.generate_sphere_target(
23
103
  zyx_shape,
24
104
  yx_pixel_size,
@@ -26,9 +106,13 @@ def generate_test_phantom(
26
106
  radius=sphere_radius,
27
107
  blur_size=2 * yx_pixel_size,
28
108
  )
29
- zyx_phase = sphere * (
30
- index_of_refraction_sample - index_of_refraction_media
31
- ) # refractive index increment
109
+
110
+ # Compute refractive index difference
111
+ delta_n = sphere * (index_of_refraction_sample - index_of_refraction_media)
112
+
113
+ # Convert to phase in cycles per voxel
114
+ wavelength_medium = wavelength_illumination / index_of_refraction_media
115
+ zyx_phase = delta_n * z_pixel_size / wavelength_medium
32
116
 
33
117
  return zyx_phase
34
118
 
@@ -234,7 +318,22 @@ def apply_inverse_transfer_function(
234
318
  Returns
235
319
  -------
236
320
  Tensor
237
- zyx_phase (radians)
321
+ zyx_phase : Phase in cycles per voxel
322
+ Units: (Δn × z_pixel_size) / λ_medium [cycles/voxel]
323
+
324
+ Each voxel represents the phase shift (in cycles) that light acquires
325
+ when passing through that voxel. This matches the units of the input
326
+ phantom from generate_test_phantom().
327
+
328
+ To convert to phase in radians:
329
+ phase_radians = 2 * np.pi * zyx_phase
330
+
331
+ To convert to refractive index difference:
332
+ wavelength_medium = wavelength_illumination / index_of_refraction_media
333
+ delta_n = zyx_phase * wavelength_medium / z_pixel_size
334
+
335
+ Note: One cycle corresponds to 2π radians of phase shift, or one
336
+ wavelength of optical path length difference.
238
337
 
239
338
  Raises
240
339
  ------