waveorder 2.2.0__py3-none-any.whl → 2.2.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,351 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
- from torch import Tensor
5
- from typing import Literal
6
- from torch.nn.functional import avg_pool3d, interpolate
7
- from waveorder import optics, sampling, stokes, util
8
- from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
9
-
10
-
11
- def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
12
- # Simulate
13
- yx_star, yx_theta, _ = util.generate_star_target(
14
- yx_shape=zyx_shape[1:],
15
- blur_px=1,
16
- margin=50,
17
- )
18
- c00 = yx_star
19
- c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
20
- c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
21
-
22
- # Put in a center slices of a 3D object
23
- center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
24
- object = torch.zeros((3,) + zyx_shape)
25
- object[:, zyx_shape[0] // 2, ...] = center_slice_object
26
- return object
27
-
28
-
29
- def calculate_transfer_function(
30
- swing: float,
31
- scheme: str,
32
- zyx_shape: tuple[int, int, int],
33
- yx_pixel_size: float,
34
- z_pixel_size: float,
35
- wavelength_illumination: float,
36
- z_padding: int,
37
- index_of_refraction_media: float,
38
- numerical_aperture_illumination: float,
39
- numerical_aperture_detection: float,
40
- invert_phase_contrast: bool = False,
41
- fourier_oversample_factor: int = 1,
42
- transverse_downsample_factor: int = 1,
43
- ) -> tuple[
44
- torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
45
- ]:
46
- if z_padding != 0:
47
- raise NotImplementedError("Padding not implemented for this model")
48
-
49
- transverse_nyquist = sampling.transverse_nyquist(
50
- wavelength_illumination,
51
- numerical_aperture_illumination,
52
- numerical_aperture_detection,
53
- )
54
- axial_nyquist = sampling.axial_nyquist(
55
- wavelength_illumination,
56
- numerical_aperture_detection,
57
- index_of_refraction_media,
58
- )
59
-
60
- yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
61
- z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
62
-
63
- print("YX factor:", yx_factor)
64
- print("Z factor:", z_factor)
65
-
66
- tf_calculation_shape = (
67
- zyx_shape[0] * z_factor * fourier_oversample_factor,
68
- int(
69
- np.ceil(
70
- zyx_shape[1]
71
- * yx_factor
72
- * fourier_oversample_factor
73
- / transverse_downsample_factor
74
- )
75
- ),
76
- int(
77
- np.ceil(
78
- zyx_shape[2]
79
- * yx_factor
80
- * fourier_oversample_factor
81
- / transverse_downsample_factor
82
- )
83
- ),
84
- )
85
-
86
- sfZYX_transfer_function, intensity_to_stokes_matrix = (
87
- _calculate_wrap_unsafe_transfer_function(
88
- swing,
89
- scheme,
90
- tf_calculation_shape,
91
- yx_pixel_size / yx_factor,
92
- z_pixel_size / z_factor,
93
- wavelength_illumination,
94
- z_padding,
95
- index_of_refraction_media,
96
- numerical_aperture_illumination,
97
- numerical_aperture_detection,
98
- invert_phase_contrast=invert_phase_contrast,
99
- )
100
- )
101
-
102
- # avg_pool3d does not support complex numbers
103
- pooled_sfZYX_transfer_function_real = avg_pool3d(
104
- sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
105
- )
106
- pooled_sfZYX_transfer_function_imag = avg_pool3d(
107
- sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
108
- )
109
- pooled_sfZYX_transfer_function = (
110
- pooled_sfZYX_transfer_function_real
111
- + 1j * pooled_sfZYX_transfer_function_imag
112
- )
113
-
114
- # Crop to original size
115
- sfzyx_out_shape = (
116
- pooled_sfZYX_transfer_function.shape[0],
117
- pooled_sfZYX_transfer_function.shape[1],
118
- zyx_shape[0] + 2 * z_padding,
119
- ) + zyx_shape[1:]
120
-
121
- cropped = sampling.nd_fourier_central_cuboid(
122
- pooled_sfZYX_transfer_function, sfzyx_out_shape
123
- )
124
-
125
- # Compute singular system on cropped and downsampled
126
- U, S, Vh = calculate_singular_system(cropped)
127
-
128
- # Interpolate to final size in YX
129
- def complex_interpolate(
130
- tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
131
- ) -> torch.Tensor:
132
- interpolated_real = interpolate(tensor.real, size=zyx_shape)
133
- interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
134
- return interpolated_real + 1j * interpolated_imag
135
-
136
- full_cropped = complex_interpolate(cropped, zyx_shape)
137
- full_U = complex_interpolate(U, zyx_shape)
138
- full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
139
- full_Vh = complex_interpolate(Vh, zyx_shape)
140
-
141
- return (
142
- full_cropped,
143
- intensity_to_stokes_matrix,
144
- (full_U, full_S, full_Vh),
145
- )
146
-
147
-
148
- def _calculate_wrap_unsafe_transfer_function(
149
- swing,
150
- scheme,
151
- zyx_shape,
152
- yx_pixel_size,
153
- z_pixel_size,
154
- wavelength_illumination,
155
- z_padding,
156
- index_of_refraction_media,
157
- numerical_aperture_illumination,
158
- numerical_aperture_detection,
159
- invert_phase_contrast=False,
160
- ):
161
- print("Computing transfer function")
162
- intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
163
- swing, scheme=scheme
164
- )
165
-
166
- input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
167
- # input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
168
-
169
- # Calculate frequencies
170
- y_frequencies, x_frequencies = util.generate_frequencies(
171
- zyx_shape[1:], yx_pixel_size
172
- )
173
- radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
174
-
175
- z_total = zyx_shape[0] + 2 * z_padding
176
- z_position_list = torch.fft.ifftshift(
177
- (torch.arange(z_total) - z_total // 2) * z_pixel_size
178
- )
179
- if (
180
- not invert_phase_contrast
181
- ): # opposite sign of direct phase reconstruction
182
- z_position_list = torch.flip(z_position_list, dims=(0,))
183
- z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
184
-
185
- # 2D pupils
186
- print("\tCalculating pupils...")
187
- ill_pupil = optics.generate_pupil(
188
- radial_frequencies,
189
- numerical_aperture_illumination,
190
- wavelength_illumination,
191
- )
192
- det_pupil = optics.generate_pupil(
193
- radial_frequencies,
194
- numerical_aperture_detection,
195
- wavelength_illumination,
196
- )
197
- pupil = optics.generate_pupil(
198
- radial_frequencies,
199
- index_of_refraction_media, # largest possible NA
200
- wavelength_illumination,
201
- )
202
-
203
- # Defocus pupils
204
- defocus_pupil = optics.generate_propagation_kernel(
205
- radial_frequencies,
206
- pupil,
207
- wavelength_illumination / index_of_refraction_media,
208
- z_position_list,
209
- )
210
-
211
- # Calculate vector defocus pupils
212
- S = optics.generate_vector_source_defocus_pupil(
213
- x_frequencies,
214
- y_frequencies,
215
- z_position_list,
216
- defocus_pupil,
217
- input_jones,
218
- ill_pupil,
219
- wavelength_illumination / index_of_refraction_media,
220
- )
221
-
222
- # Simplified scalar pupil
223
- P = optics.generate_propagation_kernel(
224
- radial_frequencies,
225
- det_pupil,
226
- wavelength_illumination / index_of_refraction_media,
227
- z_position_list,
228
- )
229
-
230
- P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
231
- S_3D = torch.fft.ifft(S, dim=-3)
232
-
233
- print("\tCalculating greens tensor spectrum...")
234
- G_3D = optics.generate_greens_tensor_spectrum(
235
- zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
236
- zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
237
- wavelength=wavelength_illumination / index_of_refraction_media,
238
- )
239
-
240
- # Main part
241
- PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
242
- PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
243
-
244
- del P_3D, G_3D, S_3D
245
-
246
- print("\tComputing pg and ps...")
247
- pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
248
- ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
249
-
250
- del PG_3D, PS_3D
251
-
252
- print("\tComputing H1 and H2...")
253
- H1 = torch.fft.ifftn(
254
- torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
255
- dim=(-3, -2, -1),
256
- )
257
-
258
- H2 = torch.fft.ifftn(
259
- torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
260
- dim=(-3, -2, -1),
261
- )
262
-
263
- H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
264
- # H_im = 1j * (H1 - H2) # ignore absorptive terms
265
-
266
- del H1, H2
267
-
268
- H_re /= torch.amax(torch.abs(H_re))
269
-
270
- s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
271
- Y = util.gellmann()[[0, 4, 8]]
272
- # select phase f00 and transverse linear isotropic terms 2-2, and f22
273
-
274
- print("\tComputing final transfer function...")
275
- sfZYX_transfer_function = torch.einsum(
276
- "sik,ikpjzyx,lpj->slzyx", s, H_re, Y
277
- )
278
- return (
279
- sfZYX_transfer_function,
280
- intensity_to_stokes_matrix,
281
- )
282
-
283
-
284
- def calculate_singular_system(sfZYX_transfer_function):
285
- # Compute regularized inverse filter
286
- print("Computing SVD")
287
- ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
288
- U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
289
- singular_system = (
290
- U.permute(3, 4, 0, 1, 2),
291
- S.permute(3, 0, 1, 2),
292
- Vh.permute(3, 4, 0, 1, 2),
293
- )
294
- return singular_system
295
-
296
-
297
- def visualize_transfer_function(
298
- viewer: "napari.Viewer",
299
- sfZYX_transfer_function: torch.Tensor,
300
- zyx_scale: tuple[float, float, float],
301
- ) -> None:
302
- add_transfer_function_to_viewer(
303
- viewer,
304
- sfZYX_transfer_function,
305
- zyx_scale=zyx_scale,
306
- layer_name="Transfer Function",
307
- complex_rgb=True,
308
- clim_factor=0.5,
309
- )
310
-
311
-
312
- def apply_transfer_function(
313
- fzyx_object: torch.Tensor,
314
- sfZYX_transfer_function: torch.Tensor,
315
- intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
316
- ) -> torch.Tensor:
317
- fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
318
- sZYX_data = torch.einsum(
319
- "fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
320
- )
321
- szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
322
-
323
- return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
324
-
325
-
326
- def apply_inverse_transfer_function(
327
- szyx_data: Tensor,
328
- singular_system: tuple[Tensor],
329
- intensity_to_stokes_matrix: Tensor,
330
- reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
331
- regularization_strength: float = 1e-3,
332
- TV_rho_strength: float = 1e-3,
333
- TV_iterations: int = 10,
334
- ):
335
- sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
336
-
337
- # Key computation
338
- print("Computing inverse filter")
339
- U, S, Vh = singular_system
340
- S_reg = S / (S**2 + regularization_strength)
341
-
342
- ZYXsf_inverse_filter = torch.einsum(
343
- "sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
344
- )
345
-
346
- # Apply inverse filter
347
- fZYX_reconstructed = torch.einsum(
348
- "szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
349
- )
350
-
351
- return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))
waveorder/sampling.py DELETED
@@ -1,94 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
-
5
- def transverse_nyquist(
6
- wavelength_emission,
7
- numerical_aperture_illumination,
8
- numerical_aperture_detection,
9
- ):
10
- """Transverse Nyquist sample spacing in `wavelength_emission` units.
11
-
12
- For widefield label-free imaging, the transverse Nyquist sample spacing is
13
- lambda / (2 * (NA_ill + NA_det)).
14
-
15
- Perhaps surprisingly, the transverse Nyquist sample spacing for widefield
16
- fluorescence is lambda / (4 * NA), which is equivalent to the above formula
17
- when NA_ill = NA_det.
18
-
19
- Parameters
20
- ----------
21
- wavelength_emission : float
22
- Output units match these units
23
- numerical_aperture_illumination : float
24
- For widefield fluorescence, set to numerical_aperture_detection
25
- numerical_aperture_detection : float
26
-
27
- Returns
28
- -------
29
- float
30
- Transverse Nyquist sample spacing
31
-
32
- """
33
- return wavelength_emission / (
34
- 2 * (numerical_aperture_detection + numerical_aperture_illumination)
35
- )
36
-
37
-
38
- def axial_nyquist(
39
- wavelength_emission,
40
- numerical_aperture_detection,
41
- index_of_refraction_media,
42
- ):
43
- """Axial Nyquist sample spacing in `wavelength_emission` units.
44
-
45
- For widefield microscopes, the axial Nyquist cutoff frequency is:
46
-
47
- (n/lambda) - sqrt( (n/lambda)^2 - (NA_det/lambda)^2 ),
48
-
49
- and the axial Nyquist sample spacing is 1 / (2 * cutoff_frequency).
50
-
51
- Perhaps surprisingly, the axial Nyquist sample spacing is independent of
52
- the illumination numerical aperture.
53
-
54
- Parameters
55
- ----------
56
- wavelength_emission : float
57
- Output units match these units
58
- numerical_aperture_detection : float
59
- index_of_refraction_media: float
60
-
61
- Returns
62
- -------
63
- float
64
- Axial Nyquist sample spacing
65
-
66
- """
67
- n_on_lambda = index_of_refraction_media / wavelength_emission
68
- cutoff_frequency = n_on_lambda - np.sqrt(
69
- n_on_lambda**2
70
- - (numerical_aperture_detection / wavelength_emission) ** 2
71
- )
72
- return 1 / (2 * cutoff_frequency)
73
-
74
-
75
- def nd_fourier_central_cuboid(source, target_shape):
76
- """Central cuboid of an N-D Fourier transform.
77
-
78
- Parameters
79
- ----------
80
- source : torch.Tensor
81
- Source tensor
82
- target_shape : tuple of int
83
-
84
- Returns
85
- -------
86
- torch.Tensor
87
- Center cuboid in Fourier space
88
-
89
- """
90
- center_slices = tuple(
91
- slice((s - o) // 2, (s - o) // 2 + o)
92
- for s, o in zip(source.shape, target_shape)
93
- )
94
- return torch.fft.ifftshift(torch.fft.fftshift(source)[center_slices])