waveorder 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waveorder/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '2.1.0'
16
- __version_tuple__ = version_tuple = (2, 1, 0)
15
+ __version__ = version = '2.2.0'
16
+ __version_tuple__ = version_tuple = (2, 2, 0)
waveorder/focus.py CHANGED
@@ -3,6 +3,7 @@ from typing import Literal, Optional
3
3
  from waveorder import util
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
+ import warnings
6
7
 
7
8
 
8
9
  def focus_from_transverse_band(
@@ -60,10 +61,19 @@ def focus_from_transverse_band(
60
61
  >>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
61
62
  >>> in_focus_data = data[slice,:,:]
62
63
  """
63
- minmaxfunc = _check_focus_inputs(
64
- zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
64
+ minmaxfunc = _mode_to_minmaxfunc(mode)
65
+
66
+ _check_focus_inputs(
67
+ zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
65
68
  )
66
69
 
70
+ # Check for single slice
71
+ if zyx_array.shape[0] == 1:
72
+ warnings.warn(
73
+ "The dataset only contained a single slice. Returning trivial slice index = 0."
74
+ )
75
+ return 0
76
+
67
77
  # Calculate coordinates
68
78
  _, Y, X = zyx_array.shape
69
79
  _, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
@@ -94,25 +104,35 @@ def focus_from_transverse_band(
94
104
  # Plot
95
105
  if plot_path is not None:
96
106
  _plot_focus_metric(
97
- plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
107
+ plot_path,
108
+ midband_sum,
109
+ peak_index,
110
+ in_focus_index,
111
+ peak_results,
112
+ threshold_FWHM,
98
113
  )
99
114
 
100
115
  return in_focus_index
101
116
 
102
117
 
118
+ def _mode_to_minmaxfunc(mode):
119
+ if mode == "min":
120
+ minmaxfunc = np.argmin
121
+ elif mode == "max":
122
+ minmaxfunc = np.argmax
123
+ else:
124
+ raise ValueError("mode must be either `min` or `max`")
125
+ return minmaxfunc
126
+
127
+
103
128
  def _check_focus_inputs(
104
- zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
129
+ zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
105
130
  ):
106
131
  N = len(zyx_array.shape)
107
132
  if N != 3:
108
133
  raise ValueError(
109
134
  f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays."
110
135
  )
111
- if zyx_array.shape[0] == 1:
112
- print(
113
- "WARNING: The dataset only contained a single slice. Returning trivial slice index = 0."
114
- )
115
- return 0
116
136
 
117
137
  if NA_det < 0:
118
138
  raise ValueError("NA must be > 0")
@@ -121,7 +141,7 @@ def _check_focus_inputs(
121
141
  if pixel_size < 0:
122
142
  raise ValueError("pixel_size must be > 0")
123
143
  if not 0.4 < lambda_ill / pixel_size < 10:
124
- print(
144
+ warnings.warn(
125
145
  f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}."
126
146
  f"Did you use the same units?"
127
147
  f"Did you enter the pixel size in (demagnified) object-space units?"
@@ -134,17 +154,15 @@ def _check_focus_inputs(
134
154
  raise ValueError("midband_fractions[0] must be between 0 and 1")
135
155
  if not (0 <= midband_fractions[1] <= 1):
136
156
  raise ValueError("midband_fractions[1] must be between 0 and 1")
137
- if mode == "min":
138
- minmaxfunc = np.argmin
139
- elif mode == "max":
140
- minmaxfunc = np.argmax
141
- else:
142
- raise ValueError("mode must be either `min` or `max`")
143
- return minmaxfunc
144
157
 
145
158
 
146
159
  def _plot_focus_metric(
147
- plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
160
+ plot_path,
161
+ midband_sum,
162
+ peak_index,
163
+ in_focus_index,
164
+ peak_results,
165
+ threshold_FWHM,
148
166
  ):
149
167
  _, ax = plt.subplots(1, 1, figsize=(4, 4))
150
168
  ax.plot(midband_sum, "-k")
@@ -7,7 +7,7 @@ from torch import Tensor
7
7
  from waveorder import correction, stokes, util
8
8
 
9
9
 
10
- def generate_test_phantom(yx_shape):
10
+ def generate_test_phantom(yx_shape: Tuple[int, int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
11
11
  star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
12
12
  retardance = 0.25 * star
13
13
  orientation = (theta % np.pi) * (star > 1e-3)
@@ -17,13 +17,13 @@ def generate_test_phantom(yx_shape):
17
17
 
18
18
 
19
19
  def calculate_transfer_function(
20
- swing,
21
- scheme,
22
- ):
20
+ swing: float,
21
+ scheme: str,
22
+ ) -> Tensor:
23
23
  return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
24
24
 
25
25
 
26
- def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
26
+ def visualize_transfer_function(viewer, intensity_to_stokes_matrix: Tensor) -> None:
27
27
  viewer.add_image(
28
28
  intensity_to_stokes_matrix.cpu().numpy(),
29
29
  name="Intensity to stokes matrix",
@@ -31,12 +31,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
31
31
 
32
32
 
33
33
  def apply_transfer_function(
34
- retardance,
35
- orientation,
36
- transmittance,
37
- depolarization,
38
- intensity_to_stokes_matrix,
39
- ):
34
+ retardance: Tensor,
35
+ orientation: Tensor,
36
+ transmittance: Tensor,
37
+ depolarization: Tensor,
38
+ intensity_to_stokes_matrix: Tensor,
39
+ ) -> Tensor:
40
40
  stokes_params = stokes.stokes_after_adr(
41
41
  retardance, orientation, transmittance, depolarization
42
42
  )
@@ -59,7 +59,7 @@ def apply_inverse_transfer_function(
59
59
  project_stokes_to_2d: bool = False,
60
60
  flip_orientation: bool = False,
61
61
  rotate_orientation: bool = False,
62
- ) -> Tuple[Tensor]:
62
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
63
63
  """Reconstructs retardance, orientation, transmittance, and depolarization
64
64
  from czyx_data and an intensity_to_stokes_matrix, providing options for
65
65
  background correction, projection, and orientation transformations.
@@ -0,0 +1,351 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ from torch import Tensor
5
+ from typing import Literal
6
+ from torch.nn.functional import avg_pool3d, interpolate
7
+ from waveorder import optics, sampling, stokes, util
8
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
9
+
10
+
11
+ def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
12
+ # Simulate
13
+ yx_star, yx_theta, _ = util.generate_star_target(
14
+ yx_shape=zyx_shape[1:],
15
+ blur_px=1,
16
+ margin=50,
17
+ )
18
+ c00 = yx_star
19
+ c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
20
+ c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
21
+
22
+ # Put in a center slices of a 3D object
23
+ center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
24
+ object = torch.zeros((3,) + zyx_shape)
25
+ object[:, zyx_shape[0] // 2, ...] = center_slice_object
26
+ return object
27
+
28
+
29
+ def calculate_transfer_function(
30
+ swing: float,
31
+ scheme: str,
32
+ zyx_shape: tuple[int, int, int],
33
+ yx_pixel_size: float,
34
+ z_pixel_size: float,
35
+ wavelength_illumination: float,
36
+ z_padding: int,
37
+ index_of_refraction_media: float,
38
+ numerical_aperture_illumination: float,
39
+ numerical_aperture_detection: float,
40
+ invert_phase_contrast: bool = False,
41
+ fourier_oversample_factor: int = 1,
42
+ transverse_downsample_factor: int = 1,
43
+ ) -> tuple[
44
+ torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
45
+ ]:
46
+ if z_padding != 0:
47
+ raise NotImplementedError("Padding not implemented for this model")
48
+
49
+ transverse_nyquist = sampling.transverse_nyquist(
50
+ wavelength_illumination,
51
+ numerical_aperture_illumination,
52
+ numerical_aperture_detection,
53
+ )
54
+ axial_nyquist = sampling.axial_nyquist(
55
+ wavelength_illumination,
56
+ numerical_aperture_detection,
57
+ index_of_refraction_media,
58
+ )
59
+
60
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
61
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
62
+
63
+ print("YX factor:", yx_factor)
64
+ print("Z factor:", z_factor)
65
+
66
+ tf_calculation_shape = (
67
+ zyx_shape[0] * z_factor * fourier_oversample_factor,
68
+ int(
69
+ np.ceil(
70
+ zyx_shape[1]
71
+ * yx_factor
72
+ * fourier_oversample_factor
73
+ / transverse_downsample_factor
74
+ )
75
+ ),
76
+ int(
77
+ np.ceil(
78
+ zyx_shape[2]
79
+ * yx_factor
80
+ * fourier_oversample_factor
81
+ / transverse_downsample_factor
82
+ )
83
+ ),
84
+ )
85
+
86
+ sfZYX_transfer_function, intensity_to_stokes_matrix = (
87
+ _calculate_wrap_unsafe_transfer_function(
88
+ swing,
89
+ scheme,
90
+ tf_calculation_shape,
91
+ yx_pixel_size / yx_factor,
92
+ z_pixel_size / z_factor,
93
+ wavelength_illumination,
94
+ z_padding,
95
+ index_of_refraction_media,
96
+ numerical_aperture_illumination,
97
+ numerical_aperture_detection,
98
+ invert_phase_contrast=invert_phase_contrast,
99
+ )
100
+ )
101
+
102
+ # avg_pool3d does not support complex numbers
103
+ pooled_sfZYX_transfer_function_real = avg_pool3d(
104
+ sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
105
+ )
106
+ pooled_sfZYX_transfer_function_imag = avg_pool3d(
107
+ sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
108
+ )
109
+ pooled_sfZYX_transfer_function = (
110
+ pooled_sfZYX_transfer_function_real
111
+ + 1j * pooled_sfZYX_transfer_function_imag
112
+ )
113
+
114
+ # Crop to original size
115
+ sfzyx_out_shape = (
116
+ pooled_sfZYX_transfer_function.shape[0],
117
+ pooled_sfZYX_transfer_function.shape[1],
118
+ zyx_shape[0] + 2 * z_padding,
119
+ ) + zyx_shape[1:]
120
+
121
+ cropped = sampling.nd_fourier_central_cuboid(
122
+ pooled_sfZYX_transfer_function, sfzyx_out_shape
123
+ )
124
+
125
+ # Compute singular system on cropped and downsampled
126
+ U, S, Vh = calculate_singular_system(cropped)
127
+
128
+ # Interpolate to final size in YX
129
+ def complex_interpolate(
130
+ tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
131
+ ) -> torch.Tensor:
132
+ interpolated_real = interpolate(tensor.real, size=zyx_shape)
133
+ interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
134
+ return interpolated_real + 1j * interpolated_imag
135
+
136
+ full_cropped = complex_interpolate(cropped, zyx_shape)
137
+ full_U = complex_interpolate(U, zyx_shape)
138
+ full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
139
+ full_Vh = complex_interpolate(Vh, zyx_shape)
140
+
141
+ return (
142
+ full_cropped,
143
+ intensity_to_stokes_matrix,
144
+ (full_U, full_S, full_Vh),
145
+ )
146
+
147
+
148
+ def _calculate_wrap_unsafe_transfer_function(
149
+ swing,
150
+ scheme,
151
+ zyx_shape,
152
+ yx_pixel_size,
153
+ z_pixel_size,
154
+ wavelength_illumination,
155
+ z_padding,
156
+ index_of_refraction_media,
157
+ numerical_aperture_illumination,
158
+ numerical_aperture_detection,
159
+ invert_phase_contrast=False,
160
+ ):
161
+ print("Computing transfer function")
162
+ intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
163
+ swing, scheme=scheme
164
+ )
165
+
166
+ input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
167
+ # input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
168
+
169
+ # Calculate frequencies
170
+ y_frequencies, x_frequencies = util.generate_frequencies(
171
+ zyx_shape[1:], yx_pixel_size
172
+ )
173
+ radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
174
+
175
+ z_total = zyx_shape[0] + 2 * z_padding
176
+ z_position_list = torch.fft.ifftshift(
177
+ (torch.arange(z_total) - z_total // 2) * z_pixel_size
178
+ )
179
+ if (
180
+ not invert_phase_contrast
181
+ ): # opposite sign of direct phase reconstruction
182
+ z_position_list = torch.flip(z_position_list, dims=(0,))
183
+ z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
184
+
185
+ # 2D pupils
186
+ print("\tCalculating pupils...")
187
+ ill_pupil = optics.generate_pupil(
188
+ radial_frequencies,
189
+ numerical_aperture_illumination,
190
+ wavelength_illumination,
191
+ )
192
+ det_pupil = optics.generate_pupil(
193
+ radial_frequencies,
194
+ numerical_aperture_detection,
195
+ wavelength_illumination,
196
+ )
197
+ pupil = optics.generate_pupil(
198
+ radial_frequencies,
199
+ index_of_refraction_media, # largest possible NA
200
+ wavelength_illumination,
201
+ )
202
+
203
+ # Defocus pupils
204
+ defocus_pupil = optics.generate_propagation_kernel(
205
+ radial_frequencies,
206
+ pupil,
207
+ wavelength_illumination / index_of_refraction_media,
208
+ z_position_list,
209
+ )
210
+
211
+ # Calculate vector defocus pupils
212
+ S = optics.generate_vector_source_defocus_pupil(
213
+ x_frequencies,
214
+ y_frequencies,
215
+ z_position_list,
216
+ defocus_pupil,
217
+ input_jones,
218
+ ill_pupil,
219
+ wavelength_illumination / index_of_refraction_media,
220
+ )
221
+
222
+ # Simplified scalar pupil
223
+ P = optics.generate_propagation_kernel(
224
+ radial_frequencies,
225
+ det_pupil,
226
+ wavelength_illumination / index_of_refraction_media,
227
+ z_position_list,
228
+ )
229
+
230
+ P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
231
+ S_3D = torch.fft.ifft(S, dim=-3)
232
+
233
+ print("\tCalculating greens tensor spectrum...")
234
+ G_3D = optics.generate_greens_tensor_spectrum(
235
+ zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
236
+ zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
237
+ wavelength=wavelength_illumination / index_of_refraction_media,
238
+ )
239
+
240
+ # Main part
241
+ PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
242
+ PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
243
+
244
+ del P_3D, G_3D, S_3D
245
+
246
+ print("\tComputing pg and ps...")
247
+ pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
248
+ ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
249
+
250
+ del PG_3D, PS_3D
251
+
252
+ print("\tComputing H1 and H2...")
253
+ H1 = torch.fft.ifftn(
254
+ torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
255
+ dim=(-3, -2, -1),
256
+ )
257
+
258
+ H2 = torch.fft.ifftn(
259
+ torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
260
+ dim=(-3, -2, -1),
261
+ )
262
+
263
+ H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
264
+ # H_im = 1j * (H1 - H2) # ignore absorptive terms
265
+
266
+ del H1, H2
267
+
268
+ H_re /= torch.amax(torch.abs(H_re))
269
+
270
+ s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
271
+ Y = util.gellmann()[[0, 4, 8]]
272
+ # select phase f00 and transverse linear isotropic terms 2-2, and f22
273
+
274
+ print("\tComputing final transfer function...")
275
+ sfZYX_transfer_function = torch.einsum(
276
+ "sik,ikpjzyx,lpj->slzyx", s, H_re, Y
277
+ )
278
+ return (
279
+ sfZYX_transfer_function,
280
+ intensity_to_stokes_matrix,
281
+ )
282
+
283
+
284
+ def calculate_singular_system(sfZYX_transfer_function):
285
+ # Compute regularized inverse filter
286
+ print("Computing SVD")
287
+ ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
288
+ U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
289
+ singular_system = (
290
+ U.permute(3, 4, 0, 1, 2),
291
+ S.permute(3, 0, 1, 2),
292
+ Vh.permute(3, 4, 0, 1, 2),
293
+ )
294
+ return singular_system
295
+
296
+
297
+ def visualize_transfer_function(
298
+ viewer: "napari.Viewer",
299
+ sfZYX_transfer_function: torch.Tensor,
300
+ zyx_scale: tuple[float, float, float],
301
+ ) -> None:
302
+ add_transfer_function_to_viewer(
303
+ viewer,
304
+ sfZYX_transfer_function,
305
+ zyx_scale=zyx_scale,
306
+ layer_name="Transfer Function",
307
+ complex_rgb=True,
308
+ clim_factor=0.5,
309
+ )
310
+
311
+
312
+ def apply_transfer_function(
313
+ fzyx_object: torch.Tensor,
314
+ sfZYX_transfer_function: torch.Tensor,
315
+ intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
316
+ ) -> torch.Tensor:
317
+ fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
318
+ sZYX_data = torch.einsum(
319
+ "fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
320
+ )
321
+ szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
322
+
323
+ return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
324
+
325
+
326
+ def apply_inverse_transfer_function(
327
+ szyx_data: Tensor,
328
+ singular_system: tuple[Tensor],
329
+ intensity_to_stokes_matrix: Tensor,
330
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
331
+ regularization_strength: float = 1e-3,
332
+ TV_rho_strength: float = 1e-3,
333
+ TV_iterations: int = 10,
334
+ ):
335
+ sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
336
+
337
+ # Key computation
338
+ print("Computing inverse filter")
339
+ U, S, Vh = singular_system
340
+ S_reg = S / (S**2 + regularization_strength)
341
+
342
+ ZYXsf_inverse_filter = torch.einsum(
343
+ "sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
344
+ )
345
+
346
+ # Apply inverse filter
347
+ fZYX_reconstructed = torch.einsum(
348
+ "szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
349
+ )
350
+
351
+ return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))
@@ -1,17 +1,19 @@
1
1
  from typing import Literal
2
2
 
3
+ import numpy as np
3
4
  import torch
4
5
  from torch import Tensor
5
6
 
6
- from waveorder import optics, util
7
+ from waveorder import optics, sampling, util
8
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
7
9
 
8
10
 
9
11
  def generate_test_phantom(
10
- zyx_shape,
11
- yx_pixel_size,
12
- z_pixel_size,
13
- sphere_radius,
14
- ):
12
+ zyx_shape: tuple[int, int, int],
13
+ yx_pixel_size: float,
14
+ z_pixel_size: float,
15
+ sphere_radius: float,
16
+ ) -> Tensor:
15
17
  sphere, _, _ = util.generate_sphere_target(
16
18
  zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
17
19
  )
@@ -20,14 +22,57 @@ def generate_test_phantom(
20
22
 
21
23
 
22
24
  def calculate_transfer_function(
23
- zyx_shape,
24
- yx_pixel_size,
25
- z_pixel_size,
26
- wavelength_emission,
27
- z_padding,
28
- index_of_refraction_media,
29
- numerical_aperture_detection,
30
- ):
25
+ zyx_shape: tuple[int, int, int],
26
+ yx_pixel_size: float,
27
+ z_pixel_size: float,
28
+ wavelength_emission: float,
29
+ z_padding: int,
30
+ index_of_refraction_media: float,
31
+ numerical_aperture_detection: float,
32
+ ) -> Tensor:
33
+
34
+ transverse_nyquist = sampling.transverse_nyquist(
35
+ wavelength_emission,
36
+ numerical_aperture_detection, # ill = det for fluorescence
37
+ numerical_aperture_detection,
38
+ )
39
+ axial_nyquist = sampling.axial_nyquist(
40
+ wavelength_emission,
41
+ numerical_aperture_detection,
42
+ index_of_refraction_media,
43
+ )
44
+
45
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
46
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
47
+
48
+ optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
49
+ (
50
+ zyx_shape[0] * z_factor,
51
+ zyx_shape[1] * yx_factor,
52
+ zyx_shape[2] * yx_factor,
53
+ ),
54
+ yx_pixel_size / yx_factor,
55
+ z_pixel_size / z_factor,
56
+ wavelength_emission,
57
+ z_padding,
58
+ index_of_refraction_media,
59
+ numerical_aperture_detection,
60
+ )
61
+ zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
62
+ return sampling.nd_fourier_central_cuboid(
63
+ optical_transfer_function, zyx_out_shape
64
+ )
65
+
66
+
67
+ def _calculate_wrap_unsafe_transfer_function(
68
+ zyx_shape: tuple[int, int, int],
69
+ yx_pixel_size: float,
70
+ z_pixel_size: float,
71
+ wavelength_emission: float,
72
+ z_padding: int,
73
+ index_of_refraction_media: float,
74
+ numerical_aperture_detection: float,
75
+ ) -> Tensor:
31
76
  radial_frequencies = util.generate_radial_frequencies(
32
77
  zyx_shape[1:], yx_pixel_size
33
78
  )
@@ -63,25 +108,33 @@ def calculate_transfer_function(
63
108
  return optical_transfer_function
64
109
 
65
110
 
66
- def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
67
- arrays = [
68
- (torch.imag(optical_transfer_function), "Im(OTF)"),
69
- (torch.real(optical_transfer_function), "Re(OTF)"),
70
- ]
71
-
72
- for array in arrays:
73
- lim = 0.1 * torch.max(torch.abs(array[0]))
74
- viewer.add_image(
75
- torch.fft.ifftshift(array[0]).cpu().numpy(),
76
- name=array[1],
77
- colormap="bwr",
78
- contrast_limits=(-lim, lim),
79
- scale=1 / zyx_scale,
80
- )
81
- viewer.dims.order = (0, 1, 2)
111
+ def visualize_transfer_function(viewer, optical_transfer_function: Tensor, zyx_scale: tuple[float, float, float]) -> None:
112
+ add_transfer_function_to_viewer(
113
+ viewer,
114
+ torch.real(optical_transfer_function),
115
+ zyx_scale,
116
+ clim_factor=0.05,
117
+ )
82
118
 
83
119
 
84
- def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
120
+ def apply_transfer_function(
121
+ zyx_object: Tensor, optical_transfer_function: Tensor, z_padding: int, background: int = 10
122
+ ) -> Tensor:
123
+ """Simulate imaging by applying a transfer function
124
+
125
+ Parameters
126
+ ----------
127
+ zyx_object : torch.Tensor
128
+ optical_transfer_function : torch.Tensor
129
+ z_padding : int
130
+ background : int, optional
131
+ constant background counts added to each voxel, by default 10
132
+
133
+ Returns
134
+ -------
135
+ Simulated data : torch.Tensor
136
+
137
+ """
85
138
  if (
86
139
  zyx_object.shape[0] + 2 * z_padding
87
140
  != optical_transfer_function.shape[0]
@@ -99,7 +152,7 @@ def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
99
152
  zyx_data = zyx_obj_hat * optical_transfer_function
100
153
  data = torch.real(torch.fft.ifftn(zyx_data))
101
154
 
102
- data += 10 # Add a direct background
155
+ data += background # Add a direct background
103
156
  return data
104
157
 
105
158
 
@@ -111,7 +164,7 @@ def apply_inverse_transfer_function(
111
164
  regularization_strength: float = 1e-3,
112
165
  TV_rho_strength: float = 1e-3,
113
166
  TV_iterations: int = 10,
114
- ):
167
+ ) -> Tensor:
115
168
  """Reconstructs fluorescence density from zyx_data and
116
169
  an optical_transfer_function, providing options for z padding and
117
170
  reconstruction algorithms.