waveorder 2.0.0rc3__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,19 @@
1
1
  from typing import Literal
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
6
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
8
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
7
9
 
8
10
 
9
11
  def generate_test_phantom(
10
- zyx_shape,
11
- yx_pixel_size,
12
- z_pixel_size,
13
- sphere_radius,
14
- ):
12
+ zyx_shape: tuple[int, int, int],
13
+ yx_pixel_size: float,
14
+ z_pixel_size: float,
15
+ sphere_radius: float,
16
+ ) -> Tensor:
15
17
  sphere, _, _ = util.generate_sphere_target(
16
18
  zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
17
19
  )
@@ -20,14 +22,57 @@ def generate_test_phantom(
20
22
 
21
23
 
22
24
  def calculate_transfer_function(
23
- zyx_shape,
24
- yx_pixel_size,
25
- z_pixel_size,
26
- wavelength_emission,
27
- z_padding,
28
- index_of_refraction_media,
29
- numerical_aperture_detection,
30
- ):
25
+ zyx_shape: tuple[int, int, int],
26
+ yx_pixel_size: float,
27
+ z_pixel_size: float,
28
+ wavelength_emission: float,
29
+ z_padding: int,
30
+ index_of_refraction_media: float,
31
+ numerical_aperture_detection: float,
32
+ ) -> Tensor:
33
+
34
+ transverse_nyquist = sampling.transverse_nyquist(
35
+ wavelength_emission,
36
+ numerical_aperture_detection, # ill = det for fluorescence
37
+ numerical_aperture_detection,
38
+ )
39
+ axial_nyquist = sampling.axial_nyquist(
40
+ wavelength_emission,
41
+ numerical_aperture_detection,
42
+ index_of_refraction_media,
43
+ )
44
+
45
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
46
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
47
+
48
+ optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
49
+ (
50
+ zyx_shape[0] * z_factor,
51
+ zyx_shape[1] * yx_factor,
52
+ zyx_shape[2] * yx_factor,
53
+ ),
54
+ yx_pixel_size / yx_factor,
55
+ z_pixel_size / z_factor,
56
+ wavelength_emission,
57
+ z_padding,
58
+ index_of_refraction_media,
59
+ numerical_aperture_detection,
60
+ )
61
+ zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
62
+ return sampling.nd_fourier_central_cuboid(
63
+ optical_transfer_function, zyx_out_shape
64
+ )
65
+
66
+
67
+ def _calculate_wrap_unsafe_transfer_function(
68
+ zyx_shape: tuple[int, int, int],
69
+ yx_pixel_size: float,
70
+ z_pixel_size: float,
71
+ wavelength_emission: float,
72
+ z_padding: int,
73
+ index_of_refraction_media: float,
74
+ numerical_aperture_detection: float,
75
+ ) -> Tensor:
31
76
  radial_frequencies = util.generate_radial_frequencies(
32
77
  zyx_shape[1:], yx_pixel_size
33
78
  )
@@ -63,25 +108,33 @@ def calculate_transfer_function(
63
108
  return optical_transfer_function
64
109
 
65
110
 
66
- def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
67
- arrays = [
68
- (torch.imag(optical_transfer_function), "Im(OTF)"),
69
- (torch.real(optical_transfer_function), "Re(OTF)"),
70
- ]
71
-
72
- for array in arrays:
73
- lim = 0.1 * torch.max(torch.abs(array[0]))
74
- viewer.add_image(
75
- torch.fft.ifftshift(array[0]).cpu().numpy(),
76
- name=array[1],
77
- colormap="bwr",
78
- contrast_limits=(-lim, lim),
79
- scale=1 / zyx_scale,
80
- )
81
- viewer.dims.order = (0, 1, 2)
111
+ def visualize_transfer_function(viewer, optical_transfer_function: Tensor, zyx_scale: tuple[float, float, float]) -> None:
112
+ add_transfer_function_to_viewer(
113
+ viewer,
114
+ torch.real(optical_transfer_function),
115
+ zyx_scale,
116
+ clim_factor=0.05,
117
+ )
82
118
 
83
119
 
84
- def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
120
+ def apply_transfer_function(
121
+ zyx_object: Tensor, optical_transfer_function: Tensor, z_padding: int, background: int = 10
122
+ ) -> Tensor:
123
+ """Simulate imaging by applying a transfer function
124
+
125
+ Parameters
126
+ ----------
127
+ zyx_object : torch.Tensor
128
+ optical_transfer_function : torch.Tensor
129
+ z_padding : int
130
+ background : int, optional
131
+ constant background counts added to each voxel, by default 10
132
+
133
+ Returns
134
+ -------
135
+ Simulated data : torch.Tensor
136
+
137
+ """
85
138
  if (
86
139
  zyx_object.shape[0] + 2 * z_padding
87
140
  != optical_transfer_function.shape[0]
@@ -99,7 +152,7 @@ def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
99
152
  zyx_data = zyx_obj_hat * optical_transfer_function
100
153
  data = torch.real(torch.fft.ifftn(zyx_data))
101
154
 
102
- data += 10 # Add a direct background
155
+ data += background # Add a direct background
103
156
  return data
104
157
 
105
158
 
@@ -111,7 +164,7 @@ def apply_inverse_transfer_function(
111
164
  regularization_strength: float = 1e-3,
112
165
  TV_rho_strength: float = 1e-3,
113
166
  TV_iterations: int = 10,
114
- ):
167
+ ) -> Tensor:
115
168
  """Reconstructs fluorescence density from zyx_data and
116
169
  an optical_transfer_function, providing options for z padding and
117
170
  reconstruction algorithms.
@@ -1,19 +1,19 @@
1
1
  from typing import Literal, Tuple
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
6
- from waveorder import optics, util
7
-
7
+ from waveorder import optics, sampling, util
8
8
 
9
9
  def generate_test_phantom(
10
- yx_shape,
11
- yx_pixel_size,
12
- wavelength_illumination,
13
- index_of_refraction_media,
14
- index_of_refraction_sample,
15
- sphere_radius,
16
- ):
10
+ yx_shape: Tuple[int, int],
11
+ yx_pixel_size: float,
12
+ wavelength_illumination: float,
13
+ index_of_refraction_media: float,
14
+ index_of_refraction_sample: float,
15
+ sphere_radius: float,
16
+ ) -> Tuple[Tensor, Tensor]:
17
17
  sphere, _, _ = util.generate_sphere_target(
18
18
  (3,) + yx_shape,
19
19
  yx_pixel_size,
@@ -28,21 +28,79 @@ def generate_test_phantom(
28
28
  / wavelength_illumination
29
29
  ) # phase in radians
30
30
 
31
- yx_absorption = 0.99 * sphere[1]
31
+ yx_absorption = 0.02 * sphere[1]
32
32
 
33
33
  return yx_absorption, yx_phase
34
34
 
35
35
 
36
36
  def calculate_transfer_function(
37
- yx_shape,
38
- yx_pixel_size,
39
- z_position_list,
40
- wavelength_illumination,
41
- index_of_refraction_media,
42
- numerical_aperture_illumination,
43
- numerical_aperture_detection,
44
- invert_phase_contrast=False,
45
- ):
37
+ yx_shape: Tuple[int, int],
38
+ yx_pixel_size: float,
39
+ z_position_list: list,
40
+ wavelength_illumination: float,
41
+ index_of_refraction_media: float,
42
+ numerical_aperture_illumination: float,
43
+ numerical_aperture_detection: float,
44
+ invert_phase_contrast: bool = False,
45
+ ) -> Tuple[Tensor, Tensor]:
46
+ transverse_nyquist = sampling.transverse_nyquist(
47
+ wavelength_illumination,
48
+ numerical_aperture_illumination,
49
+ numerical_aperture_detection,
50
+ )
51
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
52
+
53
+ absorption_2d_to_3d_transfer_function, phase_2d_to_3d_transfer_function = (
54
+ _calculate_wrap_unsafe_transfer_function(
55
+ (
56
+ yx_shape[0] * yx_factor,
57
+ yx_shape[1] * yx_factor,
58
+ ),
59
+ yx_pixel_size / yx_factor,
60
+ z_position_list,
61
+ wavelength_illumination,
62
+ index_of_refraction_media,
63
+ numerical_aperture_illumination,
64
+ numerical_aperture_detection,
65
+ invert_phase_contrast=invert_phase_contrast,
66
+ )
67
+ )
68
+
69
+ absorption_2d_to_3d_transfer_function_out = torch.zeros(
70
+ (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
71
+ )
72
+ phase_2d_to_3d_transfer_function_out = torch.zeros(
73
+ (len(z_position_list),) + tuple(yx_shape), dtype=torch.complex64
74
+ )
75
+
76
+ for z in range(len(z_position_list)):
77
+ absorption_2d_to_3d_transfer_function_out[z] = (
78
+ sampling.nd_fourier_central_cuboid(
79
+ absorption_2d_to_3d_transfer_function[z], yx_shape
80
+ )
81
+ )
82
+ phase_2d_to_3d_transfer_function_out[z] = (
83
+ sampling.nd_fourier_central_cuboid(
84
+ phase_2d_to_3d_transfer_function[z], yx_shape
85
+ )
86
+ )
87
+
88
+ return (
89
+ absorption_2d_to_3d_transfer_function_out,
90
+ phase_2d_to_3d_transfer_function_out,
91
+ )
92
+
93
+
94
+ def _calculate_wrap_unsafe_transfer_function(
95
+ yx_shape: Tuple[int, int],
96
+ yx_pixel_size: float,
97
+ z_position_list: list,
98
+ wavelength_illumination: float,
99
+ index_of_refraction_media: float,
100
+ numerical_aperture_illumination: float,
101
+ numerical_aperture_detection: float,
102
+ invert_phase_contrast: bool = False,
103
+ ) -> Tuple[Tensor, Tensor]:
46
104
  if invert_phase_contrast:
47
105
  z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
48
106
 
@@ -90,10 +148,14 @@ def calculate_transfer_function(
90
148
 
91
149
  def visualize_transfer_function(
92
150
  viewer,
93
- absorption_2d_to_3d_transfer_function,
94
- phase_2d_to_3d_transfer_function,
95
- ):
96
- # TODO: consider generalizing w/ phase_thick_3d.visualize_transfer_function
151
+ absorption_2d_to_3d_transfer_function: Tensor,
152
+ phase_2d_to_3d_transfer_function: Tensor,
153
+ ) -> None:
154
+ """Note: unlike other `visualize_transfer_function` calls, this transfer
155
+ function is a mixed 3D-to-2D transfer function, so it cannot reuse
156
+ util.add_transfer_function_to_viewer. If more 3D-to-2D transfer functions
157
+ are added, consider refactoring.
158
+ """
97
159
  arrays = [
98
160
  (torch.imag(absorption_2d_to_3d_transfer_function), "Im(absorb TF)"),
99
161
  (torch.real(absorption_2d_to_3d_transfer_function), "Re(absorb TF)"),
@@ -101,6 +163,28 @@ def visualize_transfer_function(
101
163
  (torch.real(phase_2d_to_3d_transfer_function), "Re(phase TF)"),
102
164
  ]
103
165
 
166
+ for array in arrays:
167
+ lim = 0.5 * torch.max(torch.abs(array[0]))
168
+ viewer.add_image(
169
+ torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
170
+ name=array[1],
171
+ colormap="bwr",
172
+ contrast_limits=(-lim, lim),
173
+ scale=(1, 1, 1),
174
+ )
175
+ viewer.dims.order = (2, 0, 1)
176
+
177
+
178
+ def visualize_point_spread_function(
179
+ viewer,
180
+ absorption_2d_to_3d_transfer_function: Tensor,
181
+ phase_2d_to_3d_transfer_function: Tensor,
182
+ ) -> None:
183
+ arrays = [
184
+ (torch.fft.ifftn(absorption_2d_to_3d_transfer_function), "absorb PSF"),
185
+ (torch.fft.ifftn(phase_2d_to_3d_transfer_function), "phase PSF"),
186
+ ]
187
+
104
188
  for array in arrays:
105
189
  lim = 0.5 * torch.max(torch.abs(array[0]))
106
190
  viewer.add_image(
@@ -114,11 +198,11 @@ def visualize_transfer_function(
114
198
 
115
199
 
116
200
  def apply_transfer_function(
117
- yx_absorption,
118
- yx_phase,
119
- phase_2d_to_3d_transfer_function,
120
- absorption_2d_to_3d_transfer_function,
121
- ):
201
+ yx_absorption: Tensor,
202
+ yx_phase: Tensor,
203
+ phase_2d_to_3d_transfer_function: Tensor,
204
+ absorption_2d_to_3d_transfer_function: Tensor,
205
+ ) -> Tensor:
122
206
  # Very simple simulation, consider adding noise and bkg knobs
123
207
 
124
208
  # simulate absorbing object
@@ -155,7 +239,7 @@ def apply_inverse_transfer_function(
155
239
  TV_rho_strength: float = 1e-3,
156
240
  TV_iterations: int = 10,
157
241
  bg_filter: bool = True,
158
- ) -> Tuple[Tensor]:
242
+ ) -> Tuple[Tensor, Tensor]:
159
243
  """Reconstructs absorption and phase from zyx_data and a pair of
160
244
  3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
161
245
  phase_2d_to_3d_transfer_function, providing options for reconstruction
@@ -4,19 +4,19 @@ import numpy as np
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
8
8
  from waveorder.models import isotropic_fluorescent_thick_3d
9
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
9
10
 
10
11
 
11
12
  def generate_test_phantom(
12
- zyx_shape,
13
- yx_pixel_size,
14
- z_pixel_size,
15
- wavelength_illumination,
16
- index_of_refraction_media,
17
- index_of_refraction_sample,
18
- sphere_radius,
19
- ):
13
+ zyx_shape: tuple[int, int, int],
14
+ yx_pixel_size: float,
15
+ z_pixel_size: float,
16
+ index_of_refraction_media: float,
17
+ index_of_refraction_sample: float,
18
+ sphere_radius: float,
19
+ ) -> np.ndarray:
20
20
  sphere, _, _ = util.generate_sphere_target(
21
21
  zyx_shape,
22
22
  yx_pixel_size,
@@ -24,27 +24,78 @@ def generate_test_phantom(
24
24
  radius=sphere_radius,
25
25
  blur_size=2 * yx_pixel_size,
26
26
  )
27
- zyx_phase = (
28
- sphere
29
- * (index_of_refraction_sample - index_of_refraction_media)
30
- * z_pixel_size
31
- / wavelength_illumination
32
- ) # phase in radians
27
+ zyx_phase = sphere * (
28
+ index_of_refraction_sample - index_of_refraction_media
29
+ ) # refractive index increment
33
30
 
34
31
  return zyx_phase
35
32
 
36
33
 
37
34
  def calculate_transfer_function(
38
- zyx_shape,
39
- yx_pixel_size,
40
- z_pixel_size,
41
- wavelength_illumination,
42
- z_padding,
43
- index_of_refraction_media,
44
- numerical_aperture_illumination,
45
- numerical_aperture_detection,
46
- invert_phase_contrast=False,
47
- ):
35
+ zyx_shape: tuple[int, int, int],
36
+ yx_pixel_size: float,
37
+ z_pixel_size: float,
38
+ wavelength_illumination: float,
39
+ z_padding: int,
40
+ index_of_refraction_media: float,
41
+ numerical_aperture_illumination: float,
42
+ numerical_aperture_detection: float,
43
+ invert_phase_contrast: bool = False,
44
+ ) -> tuple[np.ndarray, np.ndarray]:
45
+ transverse_nyquist = sampling.transverse_nyquist(
46
+ wavelength_illumination,
47
+ numerical_aperture_illumination,
48
+ numerical_aperture_detection,
49
+ )
50
+ axial_nyquist = sampling.axial_nyquist(
51
+ wavelength_illumination,
52
+ numerical_aperture_detection,
53
+ index_of_refraction_media,
54
+ )
55
+
56
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
57
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
58
+
59
+ real_potential_transfer_function, imag_potential_transfer_function = (
60
+ _calculate_wrap_unsafe_transfer_function(
61
+ (
62
+ zyx_shape[0] * z_factor,
63
+ zyx_shape[1] * yx_factor,
64
+ zyx_shape[2] * yx_factor,
65
+ ),
66
+ yx_pixel_size / yx_factor,
67
+ z_pixel_size / z_factor,
68
+ wavelength_illumination,
69
+ z_padding,
70
+ index_of_refraction_media,
71
+ numerical_aperture_illumination,
72
+ numerical_aperture_detection,
73
+ invert_phase_contrast=invert_phase_contrast,
74
+ )
75
+ )
76
+
77
+ zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
78
+ return (
79
+ sampling.nd_fourier_central_cuboid(
80
+ real_potential_transfer_function, zyx_out_shape
81
+ ),
82
+ sampling.nd_fourier_central_cuboid(
83
+ imag_potential_transfer_function, zyx_out_shape
84
+ ),
85
+ )
86
+
87
+
88
+ def _calculate_wrap_unsafe_transfer_function(
89
+ zyx_shape: tuple[int, int, int],
90
+ yx_pixel_size: float,
91
+ z_pixel_size: float,
92
+ wavelength_illumination: float,
93
+ z_padding: int,
94
+ index_of_refraction_media: float,
95
+ numerical_aperture_illumination: float,
96
+ numerical_aperture_detection: float,
97
+ invert_phase_contrast: bool = False,
98
+ ) -> tuple[np.ndarray, np.ndarray]:
48
99
  radial_frequencies = util.generate_radial_frequencies(
49
100
  zyx_shape[1:], yx_pixel_size
50
101
  )
@@ -76,6 +127,7 @@ def calculate_transfer_function(
76
127
  det_pupil,
77
128
  wavelength_illumination / index_of_refraction_media,
78
129
  z_position_list,
130
+ axially_even=False,
79
131
  )
80
132
 
81
133
  (
@@ -95,37 +147,39 @@ def calculate_transfer_function(
95
147
 
96
148
  def visualize_transfer_function(
97
149
  viewer,
98
- real_potential_transfer_function,
99
- imag_potential_transfer_function,
100
- zyx_scale,
101
- ):
102
- # TODO: consider generalizing w/ phase2Dto3D.visualize_TF
103
- arrays = [
104
- (torch.real(imag_potential_transfer_function), "Re(imag pot. TF)"),
105
- (torch.imag(imag_potential_transfer_function), "Im(imag pot. TF)"),
106
- (torch.real(real_potential_transfer_function), "Re(real pot. TF)"),
107
- (torch.imag(real_potential_transfer_function), "Im(real pot. TF)"),
108
- ]
109
-
110
- for array in arrays:
111
- lim = 0.5 * torch.max(torch.abs(array[0]))
112
- viewer.add_image(
113
- torch.fft.ifftshift(array[0]).cpu().numpy(),
114
- name=array[1],
115
- colormap="bwr",
116
- contrast_limits=(-lim, lim),
117
- scale=1 / zyx_scale,
118
- )
119
- viewer.dims.order = (0, 1, 2)
150
+ real_potential_transfer_function: np.ndarray,
151
+ imag_potential_transfer_function: np.ndarray,
152
+ zyx_scale: tuple[float, float, float],
153
+ ) -> None:
154
+ add_transfer_function_to_viewer(
155
+ viewer,
156
+ imag_potential_transfer_function,
157
+ zyx_scale,
158
+ layer_name="Imag pot. TF",
159
+ )
160
+
161
+ add_transfer_function_to_viewer(
162
+ viewer,
163
+ real_potential_transfer_function,
164
+ zyx_scale,
165
+ layer_name="Real pot. TF",
166
+ )
120
167
 
121
168
 
122
169
  def apply_transfer_function(
123
- zyx_object, real_potential_transfer_function, z_padding
124
- ):
170
+ zyx_object: np.ndarray, real_potential_transfer_function: np.ndarray, z_padding: int, brightness: float
171
+ ) -> np.ndarray:
125
172
  # This simplified forward model only handles phase, so it resuses the fluorescence forward model
126
173
  # TODO: extend to absorption
127
- return isotropic_fluorescent_thick_3d.apply_transfer_function(
128
- zyx_object, real_potential_transfer_function, z_padding
174
+ return (
175
+ isotropic_fluorescent_thick_3d.apply_transfer_function(
176
+ zyx_object,
177
+ real_potential_transfer_function,
178
+ z_padding,
179
+ background=0,
180
+ )
181
+ * brightness
182
+ + brightness
129
183
  )
130
184
 
131
185
 
@@ -134,14 +188,12 @@ def apply_inverse_transfer_function(
134
188
  real_potential_transfer_function: Tensor,
135
189
  imaginary_potential_transfer_function: Tensor,
136
190
  z_padding: int,
137
- z_pixel_size: float, # TODO: MOVE THIS PARAM TO OTF? (leaky param)
138
- wavelength_illumination: float, # TOOD: MOVE THIS PARAM TO OTF? (leaky param)
139
191
  absorption_ratio: float = 0.0,
140
192
  reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
141
193
  regularization_strength: float = 1e-3,
142
194
  TV_rho_strength: float = 1e-3,
143
195
  TV_iterations: int = 10,
144
- ):
196
+ ) -> Tensor:
145
197
  """Reconstructs 3D phase from labelfree defocus zyx_data and a pair of
146
198
  complex 3D transfer functions real_potential_transfer_function and
147
199
  imag_potential_transfer_function, providing options for reconstruction
@@ -158,14 +210,6 @@ def apply_inverse_transfer_function(
158
210
  z_padding : int
159
211
  Padding for axial dimension. Use zero for defocus stacks that
160
212
  extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise.
161
- z_pixel_size : float
162
- spacing between axial samples in sample space
163
- units must be consistent with wavelength_illumination
164
- TODO: move this leaky parameter to calculate_transfer_function
165
- wavelength_illumination : float,
166
- illumination wavelength
167
- units must be consistent with z_pixel_size
168
- TODO: move this leaky parameter to calculate_transfer_function
169
213
  absorption_ratio : float, optional,
170
214
  Absorption-to-phase ratio in the sample.
171
215
  Use default 0 for purely phase objects.
@@ -223,4 +267,4 @@ def apply_inverse_transfer_function(
223
267
  if z_padding != 0:
224
268
  f_real = f_real[z_padding:-z_padding]
225
269
 
226
- return f_real * z_pixel_size / 4 / np.pi * wavelength_illumination
270
+ return f_real