waveorder 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waveorder/_version.py +2 -2
- waveorder/focus.py +36 -18
- waveorder/models/inplane_oriented_thick_pol3d.py +12 -12
- waveorder/models/inplane_oriented_thick_pol3d_vector.py +351 -0
- waveorder/models/isotropic_fluorescent_thick_3d.py +86 -33
- waveorder/models/isotropic_thin_3d.py +94 -32
- waveorder/models/phase_thick_3d.py +107 -63
- waveorder/optics.py +242 -28
- waveorder/sampling.py +94 -0
- waveorder/util.py +54 -2
- waveorder/{visual.py → visuals/jupyter_visuals.py} +2 -6
- waveorder/visuals/matplotlib_visuals.py +335 -0
- waveorder/visuals/napari_visuals.py +77 -0
- waveorder/visuals/utils.py +31 -0
- waveorder/waveorder_reconstructor.py +8 -7
- waveorder-2.2.0.dist-info/METADATA +186 -0
- waveorder-2.2.0.dist-info/RECORD +25 -0
- {waveorder-2.1.0.dist-info → waveorder-2.2.0.dist-info}/WHEEL +1 -1
- waveorder-2.1.0.dist-info/METADATA +0 -124
- waveorder-2.1.0.dist-info/RECORD +0 -20
- {waveorder-2.1.0.dist-info → waveorder-2.2.0.dist-info}/LICENSE +0 -0
- {waveorder-2.1.0.dist-info → waveorder-2.2.0.dist-info}/top_level.txt +0 -0
waveorder/_version.py
CHANGED
waveorder/focus.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Literal, Optional
|
|
|
3
3
|
from waveorder import util
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
|
+
import warnings
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def focus_from_transverse_band(
|
|
@@ -60,10 +61,19 @@ def focus_from_transverse_band(
|
|
|
60
61
|
>>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
|
|
61
62
|
>>> in_focus_data = data[slice,:,:]
|
|
62
63
|
"""
|
|
63
|
-
minmaxfunc =
|
|
64
|
-
|
|
64
|
+
minmaxfunc = _mode_to_minmaxfunc(mode)
|
|
65
|
+
|
|
66
|
+
_check_focus_inputs(
|
|
67
|
+
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
65
68
|
)
|
|
66
69
|
|
|
70
|
+
# Check for single slice
|
|
71
|
+
if zyx_array.shape[0] == 1:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
"The dataset only contained a single slice. Returning trivial slice index = 0."
|
|
74
|
+
)
|
|
75
|
+
return 0
|
|
76
|
+
|
|
67
77
|
# Calculate coordinates
|
|
68
78
|
_, Y, X = zyx_array.shape
|
|
69
79
|
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
|
|
@@ -94,25 +104,35 @@ def focus_from_transverse_band(
|
|
|
94
104
|
# Plot
|
|
95
105
|
if plot_path is not None:
|
|
96
106
|
_plot_focus_metric(
|
|
97
|
-
plot_path,
|
|
107
|
+
plot_path,
|
|
108
|
+
midband_sum,
|
|
109
|
+
peak_index,
|
|
110
|
+
in_focus_index,
|
|
111
|
+
peak_results,
|
|
112
|
+
threshold_FWHM,
|
|
98
113
|
)
|
|
99
114
|
|
|
100
115
|
return in_focus_index
|
|
101
116
|
|
|
102
117
|
|
|
118
|
+
def _mode_to_minmaxfunc(mode):
|
|
119
|
+
if mode == "min":
|
|
120
|
+
minmaxfunc = np.argmin
|
|
121
|
+
elif mode == "max":
|
|
122
|
+
minmaxfunc = np.argmax
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("mode must be either `min` or `max`")
|
|
125
|
+
return minmaxfunc
|
|
126
|
+
|
|
127
|
+
|
|
103
128
|
def _check_focus_inputs(
|
|
104
|
-
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
129
|
+
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
105
130
|
):
|
|
106
131
|
N = len(zyx_array.shape)
|
|
107
132
|
if N != 3:
|
|
108
133
|
raise ValueError(
|
|
109
134
|
f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays."
|
|
110
135
|
)
|
|
111
|
-
if zyx_array.shape[0] == 1:
|
|
112
|
-
print(
|
|
113
|
-
"WARNING: The dataset only contained a single slice. Returning trivial slice index = 0."
|
|
114
|
-
)
|
|
115
|
-
return 0
|
|
116
136
|
|
|
117
137
|
if NA_det < 0:
|
|
118
138
|
raise ValueError("NA must be > 0")
|
|
@@ -121,7 +141,7 @@ def _check_focus_inputs(
|
|
|
121
141
|
if pixel_size < 0:
|
|
122
142
|
raise ValueError("pixel_size must be > 0")
|
|
123
143
|
if not 0.4 < lambda_ill / pixel_size < 10:
|
|
124
|
-
|
|
144
|
+
warnings.warn(
|
|
125
145
|
f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}."
|
|
126
146
|
f"Did you use the same units?"
|
|
127
147
|
f"Did you enter the pixel size in (demagnified) object-space units?"
|
|
@@ -134,17 +154,15 @@ def _check_focus_inputs(
|
|
|
134
154
|
raise ValueError("midband_fractions[0] must be between 0 and 1")
|
|
135
155
|
if not (0 <= midband_fractions[1] <= 1):
|
|
136
156
|
raise ValueError("midband_fractions[1] must be between 0 and 1")
|
|
137
|
-
if mode == "min":
|
|
138
|
-
minmaxfunc = np.argmin
|
|
139
|
-
elif mode == "max":
|
|
140
|
-
minmaxfunc = np.argmax
|
|
141
|
-
else:
|
|
142
|
-
raise ValueError("mode must be either `min` or `max`")
|
|
143
|
-
return minmaxfunc
|
|
144
157
|
|
|
145
158
|
|
|
146
159
|
def _plot_focus_metric(
|
|
147
|
-
plot_path,
|
|
160
|
+
plot_path,
|
|
161
|
+
midband_sum,
|
|
162
|
+
peak_index,
|
|
163
|
+
in_focus_index,
|
|
164
|
+
peak_results,
|
|
165
|
+
threshold_FWHM,
|
|
148
166
|
):
|
|
149
167
|
_, ax = plt.subplots(1, 1, figsize=(4, 4))
|
|
150
168
|
ax.plot(midband_sum, "-k")
|
|
@@ -7,7 +7,7 @@ from torch import Tensor
|
|
|
7
7
|
from waveorder import correction, stokes, util
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def generate_test_phantom(yx_shape):
|
|
10
|
+
def generate_test_phantom(yx_shape: Tuple[int, int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
11
11
|
star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
|
|
12
12
|
retardance = 0.25 * star
|
|
13
13
|
orientation = (theta % np.pi) * (star > 1e-3)
|
|
@@ -17,13 +17,13 @@ def generate_test_phantom(yx_shape):
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def calculate_transfer_function(
|
|
20
|
-
swing,
|
|
21
|
-
scheme,
|
|
22
|
-
):
|
|
20
|
+
swing: float,
|
|
21
|
+
scheme: str,
|
|
22
|
+
) -> Tensor:
|
|
23
23
|
return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
26
|
+
def visualize_transfer_function(viewer, intensity_to_stokes_matrix: Tensor) -> None:
|
|
27
27
|
viewer.add_image(
|
|
28
28
|
intensity_to_stokes_matrix.cpu().numpy(),
|
|
29
29
|
name="Intensity to stokes matrix",
|
|
@@ -31,12 +31,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def apply_transfer_function(
|
|
34
|
-
retardance,
|
|
35
|
-
orientation,
|
|
36
|
-
transmittance,
|
|
37
|
-
depolarization,
|
|
38
|
-
intensity_to_stokes_matrix,
|
|
39
|
-
):
|
|
34
|
+
retardance: Tensor,
|
|
35
|
+
orientation: Tensor,
|
|
36
|
+
transmittance: Tensor,
|
|
37
|
+
depolarization: Tensor,
|
|
38
|
+
intensity_to_stokes_matrix: Tensor,
|
|
39
|
+
) -> Tensor:
|
|
40
40
|
stokes_params = stokes.stokes_after_adr(
|
|
41
41
|
retardance, orientation, transmittance, depolarization
|
|
42
42
|
)
|
|
@@ -59,7 +59,7 @@ def apply_inverse_transfer_function(
|
|
|
59
59
|
project_stokes_to_2d: bool = False,
|
|
60
60
|
flip_orientation: bool = False,
|
|
61
61
|
rotate_orientation: bool = False,
|
|
62
|
-
) -> Tuple[Tensor]:
|
|
62
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
63
63
|
"""Reconstructs retardance, orientation, transmittance, and depolarization
|
|
64
64
|
from czyx_data and an intensity_to_stokes_matrix, providing options for
|
|
65
65
|
background correction, projection, and orientation transformations.
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from typing import Literal
|
|
6
|
+
from torch.nn.functional import avg_pool3d, interpolate
|
|
7
|
+
from waveorder import optics, sampling, stokes, util
|
|
8
|
+
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
|
|
12
|
+
# Simulate
|
|
13
|
+
yx_star, yx_theta, _ = util.generate_star_target(
|
|
14
|
+
yx_shape=zyx_shape[1:],
|
|
15
|
+
blur_px=1,
|
|
16
|
+
margin=50,
|
|
17
|
+
)
|
|
18
|
+
c00 = yx_star
|
|
19
|
+
c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
|
|
20
|
+
c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
|
|
21
|
+
|
|
22
|
+
# Put in a center slices of a 3D object
|
|
23
|
+
center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
|
|
24
|
+
object = torch.zeros((3,) + zyx_shape)
|
|
25
|
+
object[:, zyx_shape[0] // 2, ...] = center_slice_object
|
|
26
|
+
return object
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def calculate_transfer_function(
|
|
30
|
+
swing: float,
|
|
31
|
+
scheme: str,
|
|
32
|
+
zyx_shape: tuple[int, int, int],
|
|
33
|
+
yx_pixel_size: float,
|
|
34
|
+
z_pixel_size: float,
|
|
35
|
+
wavelength_illumination: float,
|
|
36
|
+
z_padding: int,
|
|
37
|
+
index_of_refraction_media: float,
|
|
38
|
+
numerical_aperture_illumination: float,
|
|
39
|
+
numerical_aperture_detection: float,
|
|
40
|
+
invert_phase_contrast: bool = False,
|
|
41
|
+
fourier_oversample_factor: int = 1,
|
|
42
|
+
transverse_downsample_factor: int = 1,
|
|
43
|
+
) -> tuple[
|
|
44
|
+
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
45
|
+
]:
|
|
46
|
+
if z_padding != 0:
|
|
47
|
+
raise NotImplementedError("Padding not implemented for this model")
|
|
48
|
+
|
|
49
|
+
transverse_nyquist = sampling.transverse_nyquist(
|
|
50
|
+
wavelength_illumination,
|
|
51
|
+
numerical_aperture_illumination,
|
|
52
|
+
numerical_aperture_detection,
|
|
53
|
+
)
|
|
54
|
+
axial_nyquist = sampling.axial_nyquist(
|
|
55
|
+
wavelength_illumination,
|
|
56
|
+
numerical_aperture_detection,
|
|
57
|
+
index_of_refraction_media,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
|
|
61
|
+
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
|
|
62
|
+
|
|
63
|
+
print("YX factor:", yx_factor)
|
|
64
|
+
print("Z factor:", z_factor)
|
|
65
|
+
|
|
66
|
+
tf_calculation_shape = (
|
|
67
|
+
zyx_shape[0] * z_factor * fourier_oversample_factor,
|
|
68
|
+
int(
|
|
69
|
+
np.ceil(
|
|
70
|
+
zyx_shape[1]
|
|
71
|
+
* yx_factor
|
|
72
|
+
* fourier_oversample_factor
|
|
73
|
+
/ transverse_downsample_factor
|
|
74
|
+
)
|
|
75
|
+
),
|
|
76
|
+
int(
|
|
77
|
+
np.ceil(
|
|
78
|
+
zyx_shape[2]
|
|
79
|
+
* yx_factor
|
|
80
|
+
* fourier_oversample_factor
|
|
81
|
+
/ transverse_downsample_factor
|
|
82
|
+
)
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
sfZYX_transfer_function, intensity_to_stokes_matrix = (
|
|
87
|
+
_calculate_wrap_unsafe_transfer_function(
|
|
88
|
+
swing,
|
|
89
|
+
scheme,
|
|
90
|
+
tf_calculation_shape,
|
|
91
|
+
yx_pixel_size / yx_factor,
|
|
92
|
+
z_pixel_size / z_factor,
|
|
93
|
+
wavelength_illumination,
|
|
94
|
+
z_padding,
|
|
95
|
+
index_of_refraction_media,
|
|
96
|
+
numerical_aperture_illumination,
|
|
97
|
+
numerical_aperture_detection,
|
|
98
|
+
invert_phase_contrast=invert_phase_contrast,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# avg_pool3d does not support complex numbers
|
|
103
|
+
pooled_sfZYX_transfer_function_real = avg_pool3d(
|
|
104
|
+
sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
|
|
105
|
+
)
|
|
106
|
+
pooled_sfZYX_transfer_function_imag = avg_pool3d(
|
|
107
|
+
sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
|
|
108
|
+
)
|
|
109
|
+
pooled_sfZYX_transfer_function = (
|
|
110
|
+
pooled_sfZYX_transfer_function_real
|
|
111
|
+
+ 1j * pooled_sfZYX_transfer_function_imag
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Crop to original size
|
|
115
|
+
sfzyx_out_shape = (
|
|
116
|
+
pooled_sfZYX_transfer_function.shape[0],
|
|
117
|
+
pooled_sfZYX_transfer_function.shape[1],
|
|
118
|
+
zyx_shape[0] + 2 * z_padding,
|
|
119
|
+
) + zyx_shape[1:]
|
|
120
|
+
|
|
121
|
+
cropped = sampling.nd_fourier_central_cuboid(
|
|
122
|
+
pooled_sfZYX_transfer_function, sfzyx_out_shape
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Compute singular system on cropped and downsampled
|
|
126
|
+
U, S, Vh = calculate_singular_system(cropped)
|
|
127
|
+
|
|
128
|
+
# Interpolate to final size in YX
|
|
129
|
+
def complex_interpolate(
|
|
130
|
+
tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
interpolated_real = interpolate(tensor.real, size=zyx_shape)
|
|
133
|
+
interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
|
|
134
|
+
return interpolated_real + 1j * interpolated_imag
|
|
135
|
+
|
|
136
|
+
full_cropped = complex_interpolate(cropped, zyx_shape)
|
|
137
|
+
full_U = complex_interpolate(U, zyx_shape)
|
|
138
|
+
full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
|
|
139
|
+
full_Vh = complex_interpolate(Vh, zyx_shape)
|
|
140
|
+
|
|
141
|
+
return (
|
|
142
|
+
full_cropped,
|
|
143
|
+
intensity_to_stokes_matrix,
|
|
144
|
+
(full_U, full_S, full_Vh),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _calculate_wrap_unsafe_transfer_function(
|
|
149
|
+
swing,
|
|
150
|
+
scheme,
|
|
151
|
+
zyx_shape,
|
|
152
|
+
yx_pixel_size,
|
|
153
|
+
z_pixel_size,
|
|
154
|
+
wavelength_illumination,
|
|
155
|
+
z_padding,
|
|
156
|
+
index_of_refraction_media,
|
|
157
|
+
numerical_aperture_illumination,
|
|
158
|
+
numerical_aperture_detection,
|
|
159
|
+
invert_phase_contrast=False,
|
|
160
|
+
):
|
|
161
|
+
print("Computing transfer function")
|
|
162
|
+
intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
|
|
163
|
+
swing, scheme=scheme
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
|
|
167
|
+
# input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
|
|
168
|
+
|
|
169
|
+
# Calculate frequencies
|
|
170
|
+
y_frequencies, x_frequencies = util.generate_frequencies(
|
|
171
|
+
zyx_shape[1:], yx_pixel_size
|
|
172
|
+
)
|
|
173
|
+
radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
|
|
174
|
+
|
|
175
|
+
z_total = zyx_shape[0] + 2 * z_padding
|
|
176
|
+
z_position_list = torch.fft.ifftshift(
|
|
177
|
+
(torch.arange(z_total) - z_total // 2) * z_pixel_size
|
|
178
|
+
)
|
|
179
|
+
if (
|
|
180
|
+
not invert_phase_contrast
|
|
181
|
+
): # opposite sign of direct phase reconstruction
|
|
182
|
+
z_position_list = torch.flip(z_position_list, dims=(0,))
|
|
183
|
+
z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
|
|
184
|
+
|
|
185
|
+
# 2D pupils
|
|
186
|
+
print("\tCalculating pupils...")
|
|
187
|
+
ill_pupil = optics.generate_pupil(
|
|
188
|
+
radial_frequencies,
|
|
189
|
+
numerical_aperture_illumination,
|
|
190
|
+
wavelength_illumination,
|
|
191
|
+
)
|
|
192
|
+
det_pupil = optics.generate_pupil(
|
|
193
|
+
radial_frequencies,
|
|
194
|
+
numerical_aperture_detection,
|
|
195
|
+
wavelength_illumination,
|
|
196
|
+
)
|
|
197
|
+
pupil = optics.generate_pupil(
|
|
198
|
+
radial_frequencies,
|
|
199
|
+
index_of_refraction_media, # largest possible NA
|
|
200
|
+
wavelength_illumination,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Defocus pupils
|
|
204
|
+
defocus_pupil = optics.generate_propagation_kernel(
|
|
205
|
+
radial_frequencies,
|
|
206
|
+
pupil,
|
|
207
|
+
wavelength_illumination / index_of_refraction_media,
|
|
208
|
+
z_position_list,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Calculate vector defocus pupils
|
|
212
|
+
S = optics.generate_vector_source_defocus_pupil(
|
|
213
|
+
x_frequencies,
|
|
214
|
+
y_frequencies,
|
|
215
|
+
z_position_list,
|
|
216
|
+
defocus_pupil,
|
|
217
|
+
input_jones,
|
|
218
|
+
ill_pupil,
|
|
219
|
+
wavelength_illumination / index_of_refraction_media,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Simplified scalar pupil
|
|
223
|
+
P = optics.generate_propagation_kernel(
|
|
224
|
+
radial_frequencies,
|
|
225
|
+
det_pupil,
|
|
226
|
+
wavelength_illumination / index_of_refraction_media,
|
|
227
|
+
z_position_list,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
|
|
231
|
+
S_3D = torch.fft.ifft(S, dim=-3)
|
|
232
|
+
|
|
233
|
+
print("\tCalculating greens tensor spectrum...")
|
|
234
|
+
G_3D = optics.generate_greens_tensor_spectrum(
|
|
235
|
+
zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
|
|
236
|
+
zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
|
|
237
|
+
wavelength=wavelength_illumination / index_of_refraction_media,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Main part
|
|
241
|
+
PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
|
|
242
|
+
PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
|
|
243
|
+
|
|
244
|
+
del P_3D, G_3D, S_3D
|
|
245
|
+
|
|
246
|
+
print("\tComputing pg and ps...")
|
|
247
|
+
pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
|
|
248
|
+
ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
|
|
249
|
+
|
|
250
|
+
del PG_3D, PS_3D
|
|
251
|
+
|
|
252
|
+
print("\tComputing H1 and H2...")
|
|
253
|
+
H1 = torch.fft.ifftn(
|
|
254
|
+
torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
|
|
255
|
+
dim=(-3, -2, -1),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
H2 = torch.fft.ifftn(
|
|
259
|
+
torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
|
|
260
|
+
dim=(-3, -2, -1),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
|
|
264
|
+
# H_im = 1j * (H1 - H2) # ignore absorptive terms
|
|
265
|
+
|
|
266
|
+
del H1, H2
|
|
267
|
+
|
|
268
|
+
H_re /= torch.amax(torch.abs(H_re))
|
|
269
|
+
|
|
270
|
+
s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
|
|
271
|
+
Y = util.gellmann()[[0, 4, 8]]
|
|
272
|
+
# select phase f00 and transverse linear isotropic terms 2-2, and f22
|
|
273
|
+
|
|
274
|
+
print("\tComputing final transfer function...")
|
|
275
|
+
sfZYX_transfer_function = torch.einsum(
|
|
276
|
+
"sik,ikpjzyx,lpj->slzyx", s, H_re, Y
|
|
277
|
+
)
|
|
278
|
+
return (
|
|
279
|
+
sfZYX_transfer_function,
|
|
280
|
+
intensity_to_stokes_matrix,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def calculate_singular_system(sfZYX_transfer_function):
|
|
285
|
+
# Compute regularized inverse filter
|
|
286
|
+
print("Computing SVD")
|
|
287
|
+
ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
|
|
288
|
+
U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
|
|
289
|
+
singular_system = (
|
|
290
|
+
U.permute(3, 4, 0, 1, 2),
|
|
291
|
+
S.permute(3, 0, 1, 2),
|
|
292
|
+
Vh.permute(3, 4, 0, 1, 2),
|
|
293
|
+
)
|
|
294
|
+
return singular_system
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def visualize_transfer_function(
|
|
298
|
+
viewer: "napari.Viewer",
|
|
299
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
300
|
+
zyx_scale: tuple[float, float, float],
|
|
301
|
+
) -> None:
|
|
302
|
+
add_transfer_function_to_viewer(
|
|
303
|
+
viewer,
|
|
304
|
+
sfZYX_transfer_function,
|
|
305
|
+
zyx_scale=zyx_scale,
|
|
306
|
+
layer_name="Transfer Function",
|
|
307
|
+
complex_rgb=True,
|
|
308
|
+
clim_factor=0.5,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def apply_transfer_function(
|
|
313
|
+
fzyx_object: torch.Tensor,
|
|
314
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
315
|
+
intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
|
|
316
|
+
) -> torch.Tensor:
|
|
317
|
+
fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
|
|
318
|
+
sZYX_data = torch.einsum(
|
|
319
|
+
"fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
|
|
320
|
+
)
|
|
321
|
+
szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
|
|
322
|
+
|
|
323
|
+
return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def apply_inverse_transfer_function(
|
|
327
|
+
szyx_data: Tensor,
|
|
328
|
+
singular_system: tuple[Tensor],
|
|
329
|
+
intensity_to_stokes_matrix: Tensor,
|
|
330
|
+
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
|
|
331
|
+
regularization_strength: float = 1e-3,
|
|
332
|
+
TV_rho_strength: float = 1e-3,
|
|
333
|
+
TV_iterations: int = 10,
|
|
334
|
+
):
|
|
335
|
+
sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
|
|
336
|
+
|
|
337
|
+
# Key computation
|
|
338
|
+
print("Computing inverse filter")
|
|
339
|
+
U, S, Vh = singular_system
|
|
340
|
+
S_reg = S / (S**2 + regularization_strength)
|
|
341
|
+
|
|
342
|
+
ZYXsf_inverse_filter = torch.einsum(
|
|
343
|
+
"sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Apply inverse filter
|
|
347
|
+
fZYX_reconstructed = torch.einsum(
|
|
348
|
+
"szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
+
import numpy as np
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor
|
|
5
6
|
|
|
6
|
-
from waveorder import optics, util
|
|
7
|
+
from waveorder import optics, sampling, util
|
|
8
|
+
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def generate_test_phantom(
|
|
10
|
-
zyx_shape,
|
|
11
|
-
yx_pixel_size,
|
|
12
|
-
z_pixel_size,
|
|
13
|
-
sphere_radius,
|
|
14
|
-
):
|
|
12
|
+
zyx_shape: tuple[int, int, int],
|
|
13
|
+
yx_pixel_size: float,
|
|
14
|
+
z_pixel_size: float,
|
|
15
|
+
sphere_radius: float,
|
|
16
|
+
) -> Tensor:
|
|
15
17
|
sphere, _, _ = util.generate_sphere_target(
|
|
16
18
|
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
|
|
17
19
|
)
|
|
@@ -20,14 +22,57 @@ def generate_test_phantom(
|
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
def calculate_transfer_function(
|
|
23
|
-
zyx_shape,
|
|
24
|
-
yx_pixel_size,
|
|
25
|
-
z_pixel_size,
|
|
26
|
-
wavelength_emission,
|
|
27
|
-
z_padding,
|
|
28
|
-
index_of_refraction_media,
|
|
29
|
-
numerical_aperture_detection,
|
|
30
|
-
):
|
|
25
|
+
zyx_shape: tuple[int, int, int],
|
|
26
|
+
yx_pixel_size: float,
|
|
27
|
+
z_pixel_size: float,
|
|
28
|
+
wavelength_emission: float,
|
|
29
|
+
z_padding: int,
|
|
30
|
+
index_of_refraction_media: float,
|
|
31
|
+
numerical_aperture_detection: float,
|
|
32
|
+
) -> Tensor:
|
|
33
|
+
|
|
34
|
+
transverse_nyquist = sampling.transverse_nyquist(
|
|
35
|
+
wavelength_emission,
|
|
36
|
+
numerical_aperture_detection, # ill = det for fluorescence
|
|
37
|
+
numerical_aperture_detection,
|
|
38
|
+
)
|
|
39
|
+
axial_nyquist = sampling.axial_nyquist(
|
|
40
|
+
wavelength_emission,
|
|
41
|
+
numerical_aperture_detection,
|
|
42
|
+
index_of_refraction_media,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
|
|
46
|
+
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
|
|
47
|
+
|
|
48
|
+
optical_transfer_function = _calculate_wrap_unsafe_transfer_function(
|
|
49
|
+
(
|
|
50
|
+
zyx_shape[0] * z_factor,
|
|
51
|
+
zyx_shape[1] * yx_factor,
|
|
52
|
+
zyx_shape[2] * yx_factor,
|
|
53
|
+
),
|
|
54
|
+
yx_pixel_size / yx_factor,
|
|
55
|
+
z_pixel_size / z_factor,
|
|
56
|
+
wavelength_emission,
|
|
57
|
+
z_padding,
|
|
58
|
+
index_of_refraction_media,
|
|
59
|
+
numerical_aperture_detection,
|
|
60
|
+
)
|
|
61
|
+
zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:]
|
|
62
|
+
return sampling.nd_fourier_central_cuboid(
|
|
63
|
+
optical_transfer_function, zyx_out_shape
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _calculate_wrap_unsafe_transfer_function(
|
|
68
|
+
zyx_shape: tuple[int, int, int],
|
|
69
|
+
yx_pixel_size: float,
|
|
70
|
+
z_pixel_size: float,
|
|
71
|
+
wavelength_emission: float,
|
|
72
|
+
z_padding: int,
|
|
73
|
+
index_of_refraction_media: float,
|
|
74
|
+
numerical_aperture_detection: float,
|
|
75
|
+
) -> Tensor:
|
|
31
76
|
radial_frequencies = util.generate_radial_frequencies(
|
|
32
77
|
zyx_shape[1:], yx_pixel_size
|
|
33
78
|
)
|
|
@@ -63,25 +108,33 @@ def calculate_transfer_function(
|
|
|
63
108
|
return optical_transfer_function
|
|
64
109
|
|
|
65
110
|
|
|
66
|
-
def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
lim = 0.1 * torch.max(torch.abs(array[0]))
|
|
74
|
-
viewer.add_image(
|
|
75
|
-
torch.fft.ifftshift(array[0]).cpu().numpy(),
|
|
76
|
-
name=array[1],
|
|
77
|
-
colormap="bwr",
|
|
78
|
-
contrast_limits=(-lim, lim),
|
|
79
|
-
scale=1 / zyx_scale,
|
|
80
|
-
)
|
|
81
|
-
viewer.dims.order = (0, 1, 2)
|
|
111
|
+
def visualize_transfer_function(viewer, optical_transfer_function: Tensor, zyx_scale: tuple[float, float, float]) -> None:
|
|
112
|
+
add_transfer_function_to_viewer(
|
|
113
|
+
viewer,
|
|
114
|
+
torch.real(optical_transfer_function),
|
|
115
|
+
zyx_scale,
|
|
116
|
+
clim_factor=0.05,
|
|
117
|
+
)
|
|
82
118
|
|
|
83
119
|
|
|
84
|
-
def apply_transfer_function(
|
|
120
|
+
def apply_transfer_function(
|
|
121
|
+
zyx_object: Tensor, optical_transfer_function: Tensor, z_padding: int, background: int = 10
|
|
122
|
+
) -> Tensor:
|
|
123
|
+
"""Simulate imaging by applying a transfer function
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
zyx_object : torch.Tensor
|
|
128
|
+
optical_transfer_function : torch.Tensor
|
|
129
|
+
z_padding : int
|
|
130
|
+
background : int, optional
|
|
131
|
+
constant background counts added to each voxel, by default 10
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
Simulated data : torch.Tensor
|
|
136
|
+
|
|
137
|
+
"""
|
|
85
138
|
if (
|
|
86
139
|
zyx_object.shape[0] + 2 * z_padding
|
|
87
140
|
!= optical_transfer_function.shape[0]
|
|
@@ -99,7 +152,7 @@ def apply_transfer_function(zyx_object, optical_transfer_function, z_padding):
|
|
|
99
152
|
zyx_data = zyx_obj_hat * optical_transfer_function
|
|
100
153
|
data = torch.real(torch.fft.ifftn(zyx_data))
|
|
101
154
|
|
|
102
|
-
data +=
|
|
155
|
+
data += background # Add a direct background
|
|
103
156
|
return data
|
|
104
157
|
|
|
105
158
|
|
|
@@ -111,7 +164,7 @@ def apply_inverse_transfer_function(
|
|
|
111
164
|
regularization_strength: float = 1e-3,
|
|
112
165
|
TV_rho_strength: float = 1e-3,
|
|
113
166
|
TV_iterations: int = 10,
|
|
114
|
-
):
|
|
167
|
+
) -> Tensor:
|
|
115
168
|
"""Reconstructs fluorescence density from zyx_data and
|
|
116
169
|
an optical_transfer_function, providing options for z padding and
|
|
117
170
|
reconstruction algorithms.
|