waveorder 2.2.0rc0__py3-none-any.whl → 2.2.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,21 @@
1
1
  from typing import Literal
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
6
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
8
+ from waveorder.filter import apply_filter_bank
9
+ from waveorder.reconstruct import tikhonov_regularized_inverse_filter
10
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
7
11
 
8
12
 
9
13
  def generate_test_phantom(
10
- zyx_shape,
11
- yx_pixel_size,
12
- z_pixel_size,
13
- sphere_radius,
14
- ):
14
+ zyx_shape: tuple[int, int, int],
15
+ yx_pixel_size: float,
16
+ z_pixel_size: float,
17
+ sphere_radius: float,
18
+ ) -> Tensor:
15
19
  sphere, _, _ = util.generate_sphere_target(
16
20
  zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
17
21
  )
@@ -20,14 +24,56 @@ def generate_test_phantom(
20
24
 
21
25
 
22
26
  def calculate_transfer_function(
23
- zyx_shape,
24
- yx_pixel_size,
25
- z_pixel_size,
26
- wavelength_emission,
27
- z_padding,
28
- index_of_refraction_media,
29
- numerical_aperture_detection,
30
- ):
27
+ zyx_shape: tuple[int, int, int],
28
+ yx_pixel_size: float,
29
+ z_pixel_size: float,
30
+ wavelength_emission: float,
31
+ z_padding: int,
32
+ index_of_refraction_media: float,
33
+ numerical_aperture_detection: float,
34
+ ) -> Tensor:
35
+ transverse_nyquist = sampling.transverse_nyquist(
36
+ wavelength_emission,
37
+ numerical_aperture_detection, # ill = det for fluorescence
38
+ numerical_aperture_detection,
39
+ )
40
+ axial_nyquist = sampling.axial_nyquist(
41
+ wavelength_emission,
42
+ numerical_aperture_detection,
43
+ index_of_refraction_media,
44
+ )
45
+
46
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
47
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
48
+
49
+ optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
50
+ (
51
+ zyx_shape[0] * z_factor,
52
+ zyx_shape[1] * yx_factor,
53
+ zyx_shape[2] * yx_factor,
54
+ ),
55
+ yx_pixel_size / yx_factor,
56
+ z_pixel_size / z_factor,
57
+ wavelength_emission,
58
+ z_padding,
59
+ index_of_refraction_media,
60
+ numerical_aperture_detection,
61
+ )
62
+ zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
63
+ return sampling.nd_fourier_central_cuboid(
64
+ optical_transfer_function, zyx_out_shape
65
+ )
66
+
67
+
68
+ def _calculate_wrap_unsafe_transfer_function(
69
+ zyx_shape: tuple[int, int, int],
70
+ yx_pixel_size: float,
71
+ z_pixel_size: float,
72
+ wavelength_emission: float,
73
+ z_padding: int,
74
+ index_of_refraction_media: float,
75
+ numerical_aperture_detection: float,
76
+ ) -> Tensor:
31
77
  radial_frequencies = util.generate_radial_frequencies(
32
78
  zyx_shape[1:], yx_pixel_size
33
79
  )
@@ -63,27 +109,25 @@ def calculate_transfer_function(
63
109
  return optical_transfer_function
64
110
 
65
111
 
66
- def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
67
- arrays = [
68
- (torch.imag(optical_transfer_function), "Im(OTF)"),
69
- (torch.real(optical_transfer_function), "Re(OTF)"),
70
- ]
71
-
72
- for array in arrays:
73
- lim = 0.1 * torch.max(torch.abs(array[0]))
74
- viewer.add_image(
75
- torch.fft.ifftshift(array[0]).cpu().numpy(),
76
- name=array[1],
77
- colormap="bwr",
78
- contrast_limits=(-lim, lim),
79
- scale=1 / zyx_scale,
80
- )
81
- viewer.dims.order = (0, 1, 2)
112
+ def visualize_transfer_function(
113
+ viewer,
114
+ optical_transfer_function: Tensor,
115
+ zyx_scale: tuple[float, float, float],
116
+ ) -> None:
117
+ add_transfer_function_to_viewer(
118
+ viewer,
119
+ torch.real(optical_transfer_function),
120
+ zyx_scale,
121
+ clim_factor=0.05,
122
+ )
82
123
 
83
124
 
84
125
  def apply_transfer_function(
85
- zyx_object, optical_transfer_function, z_padding, background=10
86
- ):
126
+ zyx_object: Tensor,
127
+ optical_transfer_function: Tensor,
128
+ z_padding: int,
129
+ background: int = 10,
130
+ ) -> Tensor:
87
131
  """Simulate imaging by applying a transfer function
88
132
 
89
133
  Parameters
@@ -97,7 +141,7 @@ def apply_transfer_function(
97
141
  Returns
98
142
  -------
99
143
  Simulated data : torch.Tensor
100
-
144
+
101
145
  """
102
146
  if (
103
147
  zyx_object.shape[0] + 2 * z_padding
@@ -128,7 +172,7 @@ def apply_inverse_transfer_function(
128
172
  regularization_strength: float = 1e-3,
129
173
  TV_rho_strength: float = 1e-3,
130
174
  TV_iterations: int = 10,
131
- ):
175
+ ) -> Tensor:
132
176
  """Reconstructs fluorescence density from zyx_data and
133
177
  an optical_transfer_function, providing options for z padding and
134
178
  reconstruction algorithms.
@@ -169,12 +213,15 @@ def apply_inverse_transfer_function(
169
213
 
170
214
  # Reconstruct
171
215
  if reconstruction_algorithm == "Tikhonov":
172
- f_real = util.single_variable_tikhonov_deconvolution_3D(
173
- zyx_padded,
174
- optical_transfer_function,
175
- reg_re=regularization_strength,
216
+ inverse_filter = tikhonov_regularized_inverse_filter(
217
+ optical_transfer_function, regularization_strength
176
218
  )
177
219
 
220
+ # [None]s and [0] are for applying a 1x1 "bank" of filters.
221
+ # For further uniformity, consider returning (1, Z, Y, X)
222
+ f_real = apply_filter_bank(
223
+ inverse_filter[None, None], zyx_padded[None]
224
+ )[0]
178
225
  elif reconstruction_algorithm == "TV":
179
226
  raise NotImplementedError
180
227
  f_real = util.single_variable_admm_tv_deconvolution_3D(
@@ -1,19 +1,20 @@
1
1
  from typing import Literal, Tuple
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
6
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
7
8
 
8
9
 
9
10
  def generate_test_phantom(
10
- yx_shape,
11
- yx_pixel_size,
12
- wavelength_illumination,
13
- index_of_refraction_media,
14
- index_of_refraction_sample,
15
- sphere_radius,
16
- ):
11
+ yx_shape: Tuple[int, int],
12
+ yx_pixel_size: float,
13
+ wavelength_illumination: float,
14
+ index_of_refraction_media: float,
15
+ index_of_refraction_sample: float,
16
+ sphere_radius: float,
17
+ ) -> Tuple[Tensor, Tensor]:
17
18
  sphere, _, _ = util.generate_sphere_target(
18
19
  (3,) + yx_shape,
19
20
  yx_pixel_size,
@@ -34,15 +35,74 @@ def generate_test_phantom(
34
35
 
35
36
 
36
37
  def calculate_transfer_function(
37
- yx_shape,
38
- yx_pixel_size,
39
- z_position_list,
40
- wavelength_illumination,
41
- index_of_refraction_media,
42
- numerical_aperture_illumination,
43
- numerical_aperture_detection,
44
- invert_phase_contrast=False,
45
- ):
38
+ yx_shape: Tuple[int, int],
39
+ yx_pixel_size: float,
40
+ z_position_list: list,
41
+ wavelength_illumination: float,
42
+ index_of_refraction_media: float,
43
+ numerical_aperture_illumination: float,
44
+ numerical_aperture_detection: float,
45
+ invert_phase_contrast: bool = False,
46
+ ) -> Tuple[Tensor, Tensor]:
47
+ transverse_nyquist = sampling.transverse_nyquist(
48
+ wavelength_illumination,
49
+ numerical_aperture_illumination,
50
+ numerical_aperture_detection,
51
+ )
52
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
53
+
54
+ (
55
+ absorption_2d_to_3d_transfer_function,
56
+ phase_2d_to_3d_transfer_function,
57
+ ) = _calculate_wrap_unsafe_transfer_function(
58
+ (
59
+ yx_shape[0] * yx_factor,
60
+ yx_shape[1] * yx_factor,
61
+ ),
62
+ yx_pixel_size / yx_factor,
63
+ z_position_list,
64
+ wavelength_illumination,
65
+ index_of_refraction_media,
66
+ numerical_aperture_illumination,
67
+ numerical_aperture_detection,
68
+ invert_phase_contrast=invert_phase_contrast,
69
+ )
70
+
71
+ absorption_2d_to_3d_transfer_function_out = torch.zeros(
72
+ (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
73
+ )
74
+ phase_2d_to_3d_transfer_function_out = torch.zeros(
75
+ (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
76
+ )
77
+
78
+ for z in range(len(z_position_list)):
79
+ absorption_2d_to_3d_transfer_function_out[z] = (
80
+ sampling.nd_fourier_central_cuboid(
81
+ absorption_2d_to_3d_transfer_function[z], yx_shape
82
+ )
83
+ )
84
+ phase_2d_to_3d_transfer_function_out[z] = (
85
+ sampling.nd_fourier_central_cuboid(
86
+ phase_2d_to_3d_transfer_function[z], yx_shape
87
+ )
88
+ )
89
+
90
+ return (
91
+ absorption_2d_to_3d_transfer_function_out,
92
+ phase_2d_to_3d_transfer_function_out,
93
+ )
94
+
95
+
96
+ def _calculate_wrap_unsafe_transfer_function(
97
+ yx_shape: Tuple[int, int],
98
+ yx_pixel_size: float,
99
+ z_position_list: list,
100
+ wavelength_illumination: float,
101
+ index_of_refraction_media: float,
102
+ numerical_aperture_illumination: float,
103
+ numerical_aperture_detection: float,
104
+ invert_phase_contrast: bool = False,
105
+ ) -> Tuple[Tensor, Tensor]:
46
106
  if invert_phase_contrast:
47
107
  z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
48
108
 
@@ -90,10 +150,14 @@ def calculate_transfer_function(
90
150
 
91
151
  def visualize_transfer_function(
92
152
  viewer,
93
- absorption_2d_to_3d_transfer_function,
94
- phase_2d_to_3d_transfer_function,
95
- ):
96
- # TODO: consider generalizing w/ phase_thick_3d.visualize_transfer_function
153
+ absorption_2d_to_3d_transfer_function: Tensor,
154
+ phase_2d_to_3d_transfer_function: Tensor,
155
+ ) -> None:
156
+ """Note: unlike other `visualize_transfer_function` calls, this transfer
157
+ function is a mixed 3D-to-2D transfer function, so it cannot reuse
158
+ util.add_transfer_function_to_viewer. If more 3D-to-2D transfer functions
159
+ are added, consider refactoring.
160
+ """
97
161
  arrays = [
98
162
  (torch.imag(absorption_2d_to_3d_transfer_function), "Im(absorb TF)"),
99
163
  (torch.real(absorption_2d_to_3d_transfer_function), "Re(absorb TF)"),
@@ -110,14 +174,14 @@ def visualize_transfer_function(
110
174
  contrast_limits=(-lim, lim),
111
175
  scale=(1, 1, 1),
112
176
  )
113
- viewer.dims.order = (0, 1, 2)
177
+ viewer.dims.order = (2, 0, 1)
114
178
 
115
179
 
116
180
  def visualize_point_spread_function(
117
181
  viewer,
118
- absorption_2d_to_3d_transfer_function,
119
- phase_2d_to_3d_transfer_function,
120
- ):
182
+ absorption_2d_to_3d_transfer_function: Tensor,
183
+ phase_2d_to_3d_transfer_function: Tensor,
184
+ ) -> None:
121
185
  arrays = [
122
186
  (torch.fft.ifftn(absorption_2d_to_3d_transfer_function), "absorb PSF"),
123
187
  (torch.fft.ifftn(phase_2d_to_3d_transfer_function), "phase PSF"),
@@ -136,11 +200,11 @@ def visualize_point_spread_function(
136
200
 
137
201
 
138
202
  def apply_transfer_function(
139
- yx_absorption,
140
- yx_phase,
141
- phase_2d_to_3d_transfer_function,
142
- absorption_2d_to_3d_transfer_function,
143
- ):
203
+ yx_absorption: Tensor,
204
+ yx_phase: Tensor,
205
+ phase_2d_to_3d_transfer_function: Tensor,
206
+ absorption_2d_to_3d_transfer_function: Tensor,
207
+ ) -> Tensor:
144
208
  # Very simple simulation, consider adding noise and bkg knobs
145
209
 
146
210
  # simulate absorbing object
@@ -177,7 +241,7 @@ def apply_inverse_transfer_function(
177
241
  TV_rho_strength: float = 1e-3,
178
242
  TV_iterations: int = 10,
179
243
  bg_filter: bool = True,
180
- ) -> Tuple[Tensor]:
244
+ ) -> Tuple[Tensor, Tensor]:
181
245
  """Reconstructs absorption and phase from zyx_data and a pair of
182
246
  3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
183
247
  phase_2d_to_3d_transfer_function, providing options for reconstruction
@@ -224,6 +288,7 @@ def apply_inverse_transfer_function(
224
288
  zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2))
225
289
 
226
290
  # TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations
291
+ # TODO Reformulate to use filter.apply_filter_bank
227
292
  AHA = [
228
293
  torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0)
229
294
  + regularization_strength,
@@ -4,18 +4,21 @@ import numpy as np
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
8
+ from waveorder.filter import apply_filter_bank
8
9
  from waveorder.models import isotropic_fluorescent_thick_3d
10
+ from waveorder.reconstruct import tikhonov_regularized_inverse_filter
11
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
9
12
 
10
13
 
11
14
  def generate_test_phantom(
12
- zyx_shape,
13
- yx_pixel_size,
14
- z_pixel_size,
15
- index_of_refraction_media,
16
- index_of_refraction_sample,
17
- sphere_radius,
18
- ):
15
+ zyx_shape: tuple[int, int, int],
16
+ yx_pixel_size: float,
17
+ z_pixel_size: float,
18
+ index_of_refraction_media: float,
19
+ index_of_refraction_sample: float,
20
+ sphere_radius: float,
21
+ ) -> np.ndarray:
19
22
  sphere, _, _ = util.generate_sphere_target(
20
23
  zyx_shape,
21
24
  yx_pixel_size,
@@ -31,16 +34,71 @@ def generate_test_phantom(
31
34
 
32
35
 
33
36
  def calculate_transfer_function(
34
- zyx_shape,
35
- yx_pixel_size,
36
- z_pixel_size,
37
- wavelength_illumination,
38
- z_padding,
39
- index_of_refraction_media,
40
- numerical_aperture_illumination,
41
- numerical_aperture_detection,
42
- invert_phase_contrast=False,
43
- ):
37
+ zyx_shape: tuple[int, int, int],
38
+ yx_pixel_size: float,
39
+ z_pixel_size: float,
40
+ wavelength_illumination: float,
41
+ z_padding: int,
42
+ index_of_refraction_media: float,
43
+ numerical_aperture_illumination: float,
44
+ numerical_aperture_detection: float,
45
+ invert_phase_contrast: bool = False,
46
+ ) -> tuple[np.ndarray, np.ndarray]:
47
+ transverse_nyquist = sampling.transverse_nyquist(
48
+ wavelength_illumination,
49
+ numerical_aperture_illumination,
50
+ numerical_aperture_detection,
51
+ )
52
+ axial_nyquist = sampling.axial_nyquist(
53
+ wavelength_illumination,
54
+ numerical_aperture_detection,
55
+ index_of_refraction_media,
56
+ )
57
+
58
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
59
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
60
+
61
+ (
62
+ real_potential_transfer_function,
63
+ imag_potential_transfer_function,
64
+ ) = _calculate_wrap_unsafe_transfer_function(
65
+ (
66
+ zyx_shape[0] * z_factor,
67
+ zyx_shape[1] * yx_factor,
68
+ zyx_shape[2] * yx_factor,
69
+ ),
70
+ yx_pixel_size / yx_factor,
71
+ z_pixel_size / z_factor,
72
+ wavelength_illumination,
73
+ z_padding,
74
+ index_of_refraction_media,
75
+ numerical_aperture_illumination,
76
+ numerical_aperture_detection,
77
+ invert_phase_contrast=invert_phase_contrast,
78
+ )
79
+
80
+ zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
81
+ return (
82
+ sampling.nd_fourier_central_cuboid(
83
+ real_potential_transfer_function, zyx_out_shape
84
+ ),
85
+ sampling.nd_fourier_central_cuboid(
86
+ imag_potential_transfer_function, zyx_out_shape
87
+ ),
88
+ )
89
+
90
+
91
+ def _calculate_wrap_unsafe_transfer_function(
92
+ zyx_shape: tuple[int, int, int],
93
+ yx_pixel_size: float,
94
+ z_pixel_size: float,
95
+ wavelength_illumination: float,
96
+ z_padding: int,
97
+ index_of_refraction_media: float,
98
+ numerical_aperture_illumination: float,
99
+ numerical_aperture_detection: float,
100
+ invert_phase_contrast: bool = False,
101
+ ) -> tuple[np.ndarray, np.ndarray]:
44
102
  radial_frequencies = util.generate_radial_frequencies(
45
103
  zyx_shape[1:], yx_pixel_size
46
104
  )
@@ -72,6 +130,7 @@ def calculate_transfer_function(
72
130
  det_pupil,
73
131
  wavelength_illumination / index_of_refraction_media,
74
132
  z_position_list,
133
+ axially_even=False,
75
134
  )
76
135
 
77
136
  (
@@ -91,33 +150,31 @@ def calculate_transfer_function(
91
150
 
92
151
  def visualize_transfer_function(
93
152
  viewer,
94
- real_potential_transfer_function,
95
- imag_potential_transfer_function,
96
- zyx_scale,
97
- ):
98
- # TODO: consider generalizing w/ phase2Dto3D.visualize_TF
99
- arrays = [
100
- (torch.real(imag_potential_transfer_function), "Re(imag pot. TF)"),
101
- (torch.imag(imag_potential_transfer_function), "Im(imag pot. TF)"),
102
- (torch.real(real_potential_transfer_function), "Re(real pot. TF)"),
103
- (torch.imag(real_potential_transfer_function), "Im(real pot. TF)"),
104
- ]
105
-
106
- for array in arrays:
107
- lim = 0.5 * torch.max(torch.abs(array[0]))
108
- viewer.add_image(
109
- torch.fft.ifftshift(array[0]).cpu().numpy(),
110
- name=array[1],
111
- colormap="bwr",
112
- contrast_limits=(-lim, lim),
113
- scale=1 / zyx_scale,
114
- )
115
- viewer.dims.order = (0, 1, 2)
153
+ real_potential_transfer_function: np.ndarray,
154
+ imag_potential_transfer_function: np.ndarray,
155
+ zyx_scale: tuple[float, float, float],
156
+ ) -> None:
157
+ add_transfer_function_to_viewer(
158
+ viewer,
159
+ imag_potential_transfer_function,
160
+ zyx_scale,
161
+ layer_name="Imag pot. TF",
162
+ )
163
+
164
+ add_transfer_function_to_viewer(
165
+ viewer,
166
+ real_potential_transfer_function,
167
+ zyx_scale,
168
+ layer_name="Real pot. TF",
169
+ )
116
170
 
117
171
 
118
172
  def apply_transfer_function(
119
- zyx_object, real_potential_transfer_function, z_padding, brightness
120
- ):
173
+ zyx_object: np.ndarray,
174
+ real_potential_transfer_function: np.ndarray,
175
+ z_padding: int,
176
+ brightness: float,
177
+ ) -> np.ndarray:
121
178
  # This simplified forward model only handles phase, so it resuses the fluorescence forward model
122
179
  # TODO: extend to absorption
123
180
  return (
@@ -142,7 +199,7 @@ def apply_inverse_transfer_function(
142
199
  regularization_strength: float = 1e-3,
143
200
  TV_rho_strength: float = 1e-3,
144
201
  TV_iterations: int = 10,
145
- ):
202
+ ) -> Tensor:
146
203
  """Reconstructs 3D phase from labelfree defocus zyx_data and a pair of
147
204
  complex 3D transfer functions real_potential_transfer_function and
148
205
  imag_potential_transfer_function, providing options for reconstruction
@@ -198,10 +255,14 @@ def apply_inverse_transfer_function(
198
255
 
199
256
  # Reconstruct
200
257
  if reconstruction_algorithm == "Tikhonov":
201
- f_real = util.single_variable_tikhonov_deconvolution_3D(
202
- zyx, effective_transfer_function, reg_re=regularization_strength
258
+ inverse_filter = tikhonov_regularized_inverse_filter(
259
+ effective_transfer_function, regularization_strength
203
260
  )
204
261
 
262
+ # [None]s and [0] are for applying a 1x1 "bank" of filters.
263
+ # For further uniformity, consider returning (1, Z, Y, X)
264
+ f_real = apply_filter_bank(inverse_filter[None, None], zyx[None])[0]
265
+
205
266
  elif reconstruction_algorithm == "TV":
206
267
  raise NotImplementedError
207
268
  f_real = util.single_variable_admm_tv_deconvolution_3D(