waveorder 2.2.0rc0__py3-none-any.whl → 2.2.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waveorder/_version.py +9 -4
- waveorder/background_estimator.py +2 -2
- waveorder/correction.py +1 -1
- waveorder/filter.py +206 -0
- waveorder/focus.py +5 -3
- waveorder/models/inplane_oriented_thick_pol3d.py +16 -12
- waveorder/models/inplane_oriented_thick_pol3d_vector.py +320 -0
- waveorder/models/isotropic_fluorescent_thick_3d.py +85 -38
- waveorder/models/isotropic_thin_3d.py +96 -31
- waveorder/models/phase_thick_3d.py +106 -45
- waveorder/optics.py +233 -26
- waveorder/reconstruct.py +28 -0
- waveorder/sampling.py +94 -0
- waveorder/stokes.py +4 -3
- waveorder/util.py +61 -9
- waveorder/{visual.py → visuals/jupyter_visuals.py} +19 -26
- waveorder/visuals/matplotlib_visuals.py +335 -0
- waveorder/visuals/napari_visuals.py +76 -0
- waveorder/visuals/utils.py +30 -0
- waveorder/waveorder_reconstructor.py +18 -16
- waveorder/waveorder_simulator.py +6 -6
- waveorder-2.2.1b0.dist-info/METADATA +187 -0
- waveorder-2.2.1b0.dist-info/RECORD +27 -0
- {waveorder-2.2.0rc0.dist-info → waveorder-2.2.1b0.dist-info}/WHEEL +1 -1
- waveorder-2.2.0rc0.dist-info/METADATA +0 -147
- waveorder-2.2.0rc0.dist-info/RECORD +0 -20
- {waveorder-2.2.0rc0.dist-info → waveorder-2.2.1b0.dist-info}/LICENSE +0 -0
- {waveorder-2.2.0rc0.dist-info → waveorder-2.2.1b0.dist-info}/top_level.txt +0 -0
waveorder/_version.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
|
-
# file generated by
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
2
|
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
3
6
|
TYPE_CHECKING = False
|
|
4
7
|
if TYPE_CHECKING:
|
|
5
|
-
from typing import Tuple
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
6
11
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
7
12
|
else:
|
|
8
13
|
VERSION_TUPLE = object
|
|
@@ -12,5 +17,5 @@ __version__: str
|
|
|
12
17
|
__version_tuple__: VERSION_TUPLE
|
|
13
18
|
version_tuple: VERSION_TUPLE
|
|
14
19
|
|
|
15
|
-
__version__ = version = '2.2.
|
|
16
|
-
__version_tuple__ = version_tuple = (2, 2,
|
|
20
|
+
__version__ = version = '2.2.1b0'
|
|
21
|
+
__version_tuple__ = version_tuple = (2, 2, 1)
|
waveorder/correction.py
CHANGED
waveorder/filter.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def apply_filter_bank(
|
|
7
|
+
io_filter_bank: torch.Tensor,
|
|
8
|
+
i_input_array: torch.Tensor,
|
|
9
|
+
) -> torch.Tensor:
|
|
10
|
+
"""
|
|
11
|
+
Applies a filter bank to an input array.
|
|
12
|
+
|
|
13
|
+
io_filter_bank.shape must be smaller or equal to i_input_array.shape in all
|
|
14
|
+
dimensions. When io_filter_bank is smaller, it is effectively "stretched"
|
|
15
|
+
to apply the filter.
|
|
16
|
+
|
|
17
|
+
io_filter_bank is in "wrapped" format, i.e., the zero frequency is the
|
|
18
|
+
zeroth element.
|
|
19
|
+
|
|
20
|
+
i_input_array and io_filter_bank must have inverse sample spacing, i.e.,
|
|
21
|
+
is input_array contains samples spaced by dx, then io_filter_bank must
|
|
22
|
+
have extent 1/dx. Note that there is no need for io_filter_bank to have
|
|
23
|
+
sample spacing 1/(n*dx) because io_filter_bank will be stretched.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
io_filter_bank : torch.Tensor
|
|
28
|
+
The filter bank to be applied in the frequency domain.
|
|
29
|
+
The spatial extent of io_filter_bank must be 1/dx, where dx is the
|
|
30
|
+
sample spacing of i_input_array.
|
|
31
|
+
|
|
32
|
+
Leading dimensions are the input and output dimensions.
|
|
33
|
+
io_filter_bank.shape[:2] == (num_input_channels, num_output_channels)
|
|
34
|
+
|
|
35
|
+
Trailing dimensions are spatial frequency dimensions.
|
|
36
|
+
io_filter_bank.shape[2:] == (Z', Y', X') or (Y', X')
|
|
37
|
+
|
|
38
|
+
i_input_array : torch.Tensor
|
|
39
|
+
The real-valued input array with sample spacing dx to be filtered.
|
|
40
|
+
|
|
41
|
+
Leading dimension is the input dimension, matching the filter bank.
|
|
42
|
+
i_input_array.shape[0] == i
|
|
43
|
+
|
|
44
|
+
Trailing dimensions are spatial dimensions.
|
|
45
|
+
i_input_array.shape[1:] == (Z, Y, X) or (Y, X)
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
torch.Tensor
|
|
50
|
+
The filtered real-valued output array with shape
|
|
51
|
+
(num_output_channels, Z, Y, X) or (num_output_channels, Y, X).
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# Ensure all dimensions of transfer_function are smaller than or equal to input_array
|
|
56
|
+
if any(
|
|
57
|
+
t > i
|
|
58
|
+
for t, i in zip(io_filter_bank.shape[2:], i_input_array.shape[1:])
|
|
59
|
+
):
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"All spatial dimensions of io_filter_bank must be <= i_input_array."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Ensure the number of spatial dimensions match
|
|
65
|
+
if io_filter_bank.ndim - i_input_array.ndim != 1:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"io_filter_bank and i_input_array must have the same number of spatial dimensions."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Ensure the input dimensions match
|
|
71
|
+
if io_filter_bank.shape[0] != i_input_array.shape[0]:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"io_filter_bank.shape[0] and i_input_array.shape[0] must be the same."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
num_input_channels, num_output_channels = io_filter_bank.shape[:2]
|
|
77
|
+
spatial_dims = io_filter_bank.shape[2:]
|
|
78
|
+
|
|
79
|
+
# Pad input_array until each dimension is divisible by transfer_function
|
|
80
|
+
pad_sizes = [
|
|
81
|
+
(0, (t - (i % t)) % t)
|
|
82
|
+
for t, i in zip(
|
|
83
|
+
io_filter_bank.shape[2:][::-1], i_input_array.shape[1:][::-1]
|
|
84
|
+
)
|
|
85
|
+
]
|
|
86
|
+
flat_pad_sizes = list(itertools.chain(*pad_sizes))
|
|
87
|
+
padded_input_array = torch.nn.functional.pad(i_input_array, flat_pad_sizes)
|
|
88
|
+
|
|
89
|
+
# Apply the transfer function in the frequency domain
|
|
90
|
+
fft_dims = [d for d in range(1, i_input_array.ndim)]
|
|
91
|
+
padded_input_spectrum = torch.fft.fftn(padded_input_array, dim=fft_dims)
|
|
92
|
+
|
|
93
|
+
# Matrix-vector multiplication over f
|
|
94
|
+
# If this is a bottleneck, consider extending `stretched_multiply` to
|
|
95
|
+
# a `stretched_matrix_multiply` that uses an call like
|
|
96
|
+
# torch.einsum('io..., i... -> o...', io_filter_bank, padded_input_spectrum)
|
|
97
|
+
#
|
|
98
|
+
# Further optimization is likely with a combination of
|
|
99
|
+
# torch.baddbmm, torch.pixel_shuffle, torch.pixel_unshuffle.
|
|
100
|
+
padded_output_spectrum = torch.zeros(
|
|
101
|
+
(num_output_channels,) + spatial_dims,
|
|
102
|
+
dtype=padded_input_spectrum.dtype,
|
|
103
|
+
device=padded_input_spectrum.device,
|
|
104
|
+
)
|
|
105
|
+
for input_channel_idx in range(num_input_channels):
|
|
106
|
+
for output_channel_idx in range(num_output_channels):
|
|
107
|
+
padded_output_spectrum[output_channel_idx] += stretched_multiply(
|
|
108
|
+
io_filter_bank[input_channel_idx, output_channel_idx],
|
|
109
|
+
padded_input_spectrum[input_channel_idx],
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Cast to real, ignoring imaginary part
|
|
113
|
+
padded_result = torch.real(
|
|
114
|
+
torch.fft.ifftn(padded_output_spectrum, dim=fft_dims)
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Remove padding and return
|
|
118
|
+
slices = tuple(slice(0, i) for i in i_input_array.shape)
|
|
119
|
+
return padded_result[slices]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def stretched_multiply(
|
|
123
|
+
small_array: torch.Tensor, large_array: torch.Tensor
|
|
124
|
+
) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Effectively "stretches" small_array onto large_array before multiplying.
|
|
127
|
+
|
|
128
|
+
Each dimension of large_array must be divisible by each dimension of small_array.
|
|
129
|
+
|
|
130
|
+
Instead of upsampling small_array, this function uses a "block element-wise"
|
|
131
|
+
multiplication by breaking the large_array into blocks before element-wise
|
|
132
|
+
multiplication with the small_array.
|
|
133
|
+
|
|
134
|
+
For example, a `stretched_multiply` of a 3x3 array by a 99x99 array will
|
|
135
|
+
divide the 99x99 array into 33x33 blocks
|
|
136
|
+
[[33x33, 33x33, 33x33],
|
|
137
|
+
[33x33, 33x33, 33x33],
|
|
138
|
+
[33x33, 33x33, 33x33]]
|
|
139
|
+
and multiply each block by the corresponding element in the 3x3 array.
|
|
140
|
+
|
|
141
|
+
Returns an array with the same shape as large_array.
|
|
142
|
+
|
|
143
|
+
Works for arbitrary dimensions.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
small_array : torch.Tensor
|
|
148
|
+
A smaller array whose elements will be "stretched" onto blocks in the large array.
|
|
149
|
+
large_array : torch.Tensor
|
|
150
|
+
A larger array that will be divided into blocks and multiplied by the small array.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
torch.Tensor
|
|
155
|
+
Resulting tensor with shape matching large_array.
|
|
156
|
+
|
|
157
|
+
Example
|
|
158
|
+
-------
|
|
159
|
+
small_array = torch.tensor([[1, 2],
|
|
160
|
+
[3, 4]])
|
|
161
|
+
|
|
162
|
+
large_array = torch.tensor([[1, 2, 3, 4],
|
|
163
|
+
[5, 6, 7, 8],
|
|
164
|
+
[9, 10, 11, 12],
|
|
165
|
+
[13, 14, 15, 16]])
|
|
166
|
+
|
|
167
|
+
stretched_multiply(small_array, large_array) returns
|
|
168
|
+
|
|
169
|
+
[[ 1, 2, 6, 8],
|
|
170
|
+
[ 5, 6, 14, 16],
|
|
171
|
+
[ 27, 30, 44, 48],
|
|
172
|
+
[ 39, 42, 60, 64]]
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
# Ensure each dimension of large_array is divisible by each dimension of small_array
|
|
176
|
+
if any(l % s != 0 for s, l in zip(small_array.shape, large_array.shape)):
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"Each dimension of large_array must be divisible by each dimension of small_array"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Ensure the number of dimensions match
|
|
182
|
+
if small_array.ndim != large_array.ndim:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"small_array and large_array must have the same number of dimensions"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Get shapes
|
|
188
|
+
s_shape = small_array.shape
|
|
189
|
+
l_shape = large_array.shape
|
|
190
|
+
|
|
191
|
+
# Reshape both array into blocks
|
|
192
|
+
block_shape = tuple(p // s for p, s in zip(l_shape, s_shape))
|
|
193
|
+
new_large_shape = tuple(itertools.chain(*zip(s_shape, block_shape)))
|
|
194
|
+
new_small_shape = tuple(
|
|
195
|
+
itertools.chain(*zip(s_shape, small_array.ndim * (1,)))
|
|
196
|
+
)
|
|
197
|
+
reshaped_large_array = large_array.reshape(new_large_shape)
|
|
198
|
+
reshaped_small_array = small_array.reshape(new_small_shape)
|
|
199
|
+
|
|
200
|
+
# Multiply the reshaped arrays
|
|
201
|
+
reshaped_result = reshaped_large_array * reshaped_small_array
|
|
202
|
+
|
|
203
|
+
# Reshape the result back to the large array shape
|
|
204
|
+
result = reshaped_result.reshape(l_shape)
|
|
205
|
+
|
|
206
|
+
return result
|
waveorder/focus.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
2
|
from typing import Literal, Optional
|
|
3
|
-
|
|
3
|
+
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
|
-
import
|
|
6
|
+
from scipy.signal import peak_widths
|
|
7
|
+
|
|
8
|
+
from waveorder import util
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def focus_from_transverse_band(
|
|
@@ -7,7 +7,9 @@ from torch import Tensor
|
|
|
7
7
|
from waveorder import correction, stokes, util
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def generate_test_phantom(
|
|
10
|
+
def generate_test_phantom(
|
|
11
|
+
yx_shape: Tuple[int, int],
|
|
12
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
11
13
|
star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
|
|
12
14
|
retardance = 0.25 * star
|
|
13
15
|
orientation = (theta % np.pi) * (star > 1e-3)
|
|
@@ -17,13 +19,15 @@ def generate_test_phantom(yx_shape):
|
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
def calculate_transfer_function(
|
|
20
|
-
swing,
|
|
21
|
-
scheme,
|
|
22
|
-
):
|
|
22
|
+
swing: float,
|
|
23
|
+
scheme: str,
|
|
24
|
+
) -> Tensor:
|
|
23
25
|
return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
|
|
24
26
|
|
|
25
27
|
|
|
26
|
-
def visualize_transfer_function(
|
|
28
|
+
def visualize_transfer_function(
|
|
29
|
+
viewer, intensity_to_stokes_matrix: Tensor
|
|
30
|
+
) -> None:
|
|
27
31
|
viewer.add_image(
|
|
28
32
|
intensity_to_stokes_matrix.cpu().numpy(),
|
|
29
33
|
name="Intensity to stokes matrix",
|
|
@@ -31,12 +35,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
|
31
35
|
|
|
32
36
|
|
|
33
37
|
def apply_transfer_function(
|
|
34
|
-
retardance,
|
|
35
|
-
orientation,
|
|
36
|
-
transmittance,
|
|
37
|
-
depolarization,
|
|
38
|
-
intensity_to_stokes_matrix,
|
|
39
|
-
):
|
|
38
|
+
retardance: Tensor,
|
|
39
|
+
orientation: Tensor,
|
|
40
|
+
transmittance: Tensor,
|
|
41
|
+
depolarization: Tensor,
|
|
42
|
+
intensity_to_stokes_matrix: Tensor,
|
|
43
|
+
) -> Tensor:
|
|
40
44
|
stokes_params = stokes.stokes_after_adr(
|
|
41
45
|
retardance, orientation, transmittance, depolarization
|
|
42
46
|
)
|
|
@@ -59,7 +63,7 @@ def apply_inverse_transfer_function(
|
|
|
59
63
|
project_stokes_to_2d: bool = False,
|
|
60
64
|
flip_orientation: bool = False,
|
|
61
65
|
rotate_orientation: bool = False,
|
|
62
|
-
) -> Tuple[Tensor]:
|
|
66
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
63
67
|
"""Reconstructs retardance, orientation, transmittance, and depolarization
|
|
64
68
|
from czyx_data and an intensity_to_stokes_matrix, providing options for
|
|
65
69
|
background correction, projection, and orientation transformations.
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.nn.functional import avg_pool3d
|
|
7
|
+
|
|
8
|
+
from waveorder import optics, sampling, stokes, util
|
|
9
|
+
from waveorder.filter import apply_filter_bank
|
|
10
|
+
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
|
|
14
|
+
# Simulate
|
|
15
|
+
yx_star, yx_theta, _ = util.generate_star_target(
|
|
16
|
+
yx_shape=zyx_shape[1:],
|
|
17
|
+
blur_px=1,
|
|
18
|
+
margin=50,
|
|
19
|
+
)
|
|
20
|
+
c00 = yx_star
|
|
21
|
+
c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
|
|
22
|
+
c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
|
|
23
|
+
|
|
24
|
+
# Put in a center slices of a 3D object
|
|
25
|
+
center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
|
|
26
|
+
object = torch.zeros((3,) + zyx_shape)
|
|
27
|
+
object[:, zyx_shape[0] // 2, ...] = center_slice_object
|
|
28
|
+
return object
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def calculate_transfer_function(
|
|
32
|
+
swing: float,
|
|
33
|
+
scheme: str,
|
|
34
|
+
zyx_shape: tuple[int, int, int],
|
|
35
|
+
yx_pixel_size: float,
|
|
36
|
+
z_pixel_size: float,
|
|
37
|
+
wavelength_illumination: float,
|
|
38
|
+
z_padding: int,
|
|
39
|
+
index_of_refraction_media: float,
|
|
40
|
+
numerical_aperture_illumination: float,
|
|
41
|
+
numerical_aperture_detection: float,
|
|
42
|
+
invert_phase_contrast: bool = False,
|
|
43
|
+
fourier_oversample_factor: int = 1,
|
|
44
|
+
) -> tuple[
|
|
45
|
+
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
46
|
+
]:
|
|
47
|
+
if z_padding != 0:
|
|
48
|
+
raise NotImplementedError("Padding not implemented for this model")
|
|
49
|
+
|
|
50
|
+
transverse_nyquist = sampling.transverse_nyquist(
|
|
51
|
+
wavelength_illumination,
|
|
52
|
+
numerical_aperture_illumination,
|
|
53
|
+
numerical_aperture_detection,
|
|
54
|
+
)
|
|
55
|
+
axial_nyquist = sampling.axial_nyquist(
|
|
56
|
+
wavelength_illumination,
|
|
57
|
+
numerical_aperture_detection,
|
|
58
|
+
index_of_refraction_media,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
|
|
62
|
+
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
|
|
63
|
+
|
|
64
|
+
print("YX factor:", yx_factor)
|
|
65
|
+
print("Z factor:", z_factor)
|
|
66
|
+
|
|
67
|
+
tf_calculation_shape = (
|
|
68
|
+
zyx_shape[0] * z_factor * fourier_oversample_factor,
|
|
69
|
+
int(np.ceil(zyx_shape[1] * yx_factor * fourier_oversample_factor)),
|
|
70
|
+
int(np.ceil(zyx_shape[2] * yx_factor * fourier_oversample_factor)),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
(
|
|
74
|
+
sfZYX_transfer_function,
|
|
75
|
+
intensity_to_stokes_matrix,
|
|
76
|
+
) = _calculate_wrap_unsafe_transfer_function(
|
|
77
|
+
swing,
|
|
78
|
+
scheme,
|
|
79
|
+
tf_calculation_shape,
|
|
80
|
+
yx_pixel_size / yx_factor,
|
|
81
|
+
z_pixel_size / z_factor,
|
|
82
|
+
wavelength_illumination,
|
|
83
|
+
z_padding,
|
|
84
|
+
index_of_refraction_media,
|
|
85
|
+
numerical_aperture_illumination,
|
|
86
|
+
numerical_aperture_detection,
|
|
87
|
+
invert_phase_contrast=invert_phase_contrast,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# avg_pool3d does not support complex numbers
|
|
91
|
+
pooled_sfZYX_transfer_function_real = avg_pool3d(
|
|
92
|
+
sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
|
|
93
|
+
)
|
|
94
|
+
pooled_sfZYX_transfer_function_imag = avg_pool3d(
|
|
95
|
+
sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
|
|
96
|
+
)
|
|
97
|
+
pooled_sfZYX_transfer_function = (
|
|
98
|
+
pooled_sfZYX_transfer_function_real
|
|
99
|
+
+ 1j * pooled_sfZYX_transfer_function_imag
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Crop to original size
|
|
103
|
+
sfzyx_out_shape = (
|
|
104
|
+
pooled_sfZYX_transfer_function.shape[0],
|
|
105
|
+
pooled_sfZYX_transfer_function.shape[1],
|
|
106
|
+
zyx_shape[0] + 2 * z_padding,
|
|
107
|
+
) + zyx_shape[1:]
|
|
108
|
+
|
|
109
|
+
cropped = sampling.nd_fourier_central_cuboid(
|
|
110
|
+
pooled_sfZYX_transfer_function, sfzyx_out_shape
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Compute singular system on cropped and downsampled
|
|
114
|
+
singular_system = calculate_singular_system(cropped)
|
|
115
|
+
|
|
116
|
+
return (
|
|
117
|
+
cropped,
|
|
118
|
+
intensity_to_stokes_matrix,
|
|
119
|
+
singular_system,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _calculate_wrap_unsafe_transfer_function(
|
|
124
|
+
swing,
|
|
125
|
+
scheme,
|
|
126
|
+
zyx_shape,
|
|
127
|
+
yx_pixel_size,
|
|
128
|
+
z_pixel_size,
|
|
129
|
+
wavelength_illumination,
|
|
130
|
+
z_padding,
|
|
131
|
+
index_of_refraction_media,
|
|
132
|
+
numerical_aperture_illumination,
|
|
133
|
+
numerical_aperture_detection,
|
|
134
|
+
invert_phase_contrast=False,
|
|
135
|
+
):
|
|
136
|
+
print("Computing transfer function")
|
|
137
|
+
intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
|
|
138
|
+
swing, scheme=scheme
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
|
|
142
|
+
# input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
|
|
143
|
+
|
|
144
|
+
# Calculate frequencies
|
|
145
|
+
y_frequencies, x_frequencies = util.generate_frequencies(
|
|
146
|
+
zyx_shape[1:], yx_pixel_size
|
|
147
|
+
)
|
|
148
|
+
radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
|
|
149
|
+
|
|
150
|
+
z_total = zyx_shape[0] + 2 * z_padding
|
|
151
|
+
z_position_list = torch.fft.ifftshift(
|
|
152
|
+
(torch.arange(z_total) - z_total // 2) * z_pixel_size
|
|
153
|
+
)
|
|
154
|
+
if (
|
|
155
|
+
not invert_phase_contrast
|
|
156
|
+
): # opposite sign of direct phase reconstruction
|
|
157
|
+
z_position_list = torch.flip(z_position_list, dims=(0,))
|
|
158
|
+
z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
|
|
159
|
+
|
|
160
|
+
# 2D pupils
|
|
161
|
+
print("\tCalculating pupils...")
|
|
162
|
+
ill_pupil = optics.generate_pupil(
|
|
163
|
+
radial_frequencies,
|
|
164
|
+
numerical_aperture_illumination,
|
|
165
|
+
wavelength_illumination,
|
|
166
|
+
)
|
|
167
|
+
det_pupil = optics.generate_pupil(
|
|
168
|
+
radial_frequencies,
|
|
169
|
+
numerical_aperture_detection,
|
|
170
|
+
wavelength_illumination,
|
|
171
|
+
)
|
|
172
|
+
pupil = optics.generate_pupil(
|
|
173
|
+
radial_frequencies,
|
|
174
|
+
index_of_refraction_media, # largest possible NA
|
|
175
|
+
wavelength_illumination,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Defocus pupils
|
|
179
|
+
defocus_pupil = optics.generate_propagation_kernel(
|
|
180
|
+
radial_frequencies,
|
|
181
|
+
pupil,
|
|
182
|
+
wavelength_illumination / index_of_refraction_media,
|
|
183
|
+
z_position_list,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Calculate vector defocus pupils
|
|
187
|
+
S = optics.generate_vector_source_defocus_pupil(
|
|
188
|
+
x_frequencies,
|
|
189
|
+
y_frequencies,
|
|
190
|
+
z_position_list,
|
|
191
|
+
defocus_pupil,
|
|
192
|
+
input_jones,
|
|
193
|
+
ill_pupil,
|
|
194
|
+
wavelength_illumination / index_of_refraction_media,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Simplified scalar pupil
|
|
198
|
+
P = optics.generate_propagation_kernel(
|
|
199
|
+
radial_frequencies,
|
|
200
|
+
det_pupil,
|
|
201
|
+
wavelength_illumination / index_of_refraction_media,
|
|
202
|
+
z_position_list,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
|
|
206
|
+
S_3D = torch.fft.ifft(S, dim=-3)
|
|
207
|
+
|
|
208
|
+
print("\tCalculating greens tensor spectrum...")
|
|
209
|
+
G_3D = optics.generate_greens_tensor_spectrum(
|
|
210
|
+
zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
|
|
211
|
+
zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
|
|
212
|
+
wavelength=wavelength_illumination / index_of_refraction_media,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Main part
|
|
216
|
+
PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
|
|
217
|
+
PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
|
|
218
|
+
|
|
219
|
+
del P_3D, G_3D, S_3D
|
|
220
|
+
|
|
221
|
+
print("\tComputing pg and ps...")
|
|
222
|
+
pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
|
|
223
|
+
ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
|
|
224
|
+
|
|
225
|
+
del PG_3D, PS_3D
|
|
226
|
+
|
|
227
|
+
print("\tComputing H1 and H2...")
|
|
228
|
+
H1 = torch.fft.ifftn(
|
|
229
|
+
torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
|
|
230
|
+
dim=(-3, -2, -1),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
H2 = torch.fft.ifftn(
|
|
234
|
+
torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
|
|
235
|
+
dim=(-3, -2, -1),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
|
|
239
|
+
# H_im = 1j * (H1 - H2) # ignore absorptive terms
|
|
240
|
+
|
|
241
|
+
del H1, H2
|
|
242
|
+
|
|
243
|
+
H_re /= torch.amax(torch.abs(H_re))
|
|
244
|
+
|
|
245
|
+
s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
|
|
246
|
+
Y = util.gellmann()[[0, 4, 8]]
|
|
247
|
+
# select phase f00 and transverse linear isotropic terms 2-2, and f22
|
|
248
|
+
|
|
249
|
+
print("\tComputing final transfer function...")
|
|
250
|
+
sfZYX_transfer_function = torch.einsum(
|
|
251
|
+
"sik,ikpjzyx,lpj->slzyx", s, H_re, Y
|
|
252
|
+
)
|
|
253
|
+
return (
|
|
254
|
+
sfZYX_transfer_function,
|
|
255
|
+
intensity_to_stokes_matrix,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def calculate_singular_system(sfZYX_transfer_function):
|
|
260
|
+
# Compute regularized inverse filter
|
|
261
|
+
print("Computing SVD")
|
|
262
|
+
ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
|
|
263
|
+
U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
|
|
264
|
+
singular_system = (
|
|
265
|
+
U.permute(3, 4, 0, 1, 2),
|
|
266
|
+
S.permute(3, 0, 1, 2),
|
|
267
|
+
Vh.permute(3, 4, 0, 1, 2),
|
|
268
|
+
)
|
|
269
|
+
return singular_system
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def visualize_transfer_function(
|
|
273
|
+
viewer: "napari.Viewer",
|
|
274
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
275
|
+
zyx_scale: tuple[float, float, float],
|
|
276
|
+
) -> None:
|
|
277
|
+
add_transfer_function_to_viewer(
|
|
278
|
+
viewer,
|
|
279
|
+
sfZYX_transfer_function,
|
|
280
|
+
zyx_scale=zyx_scale,
|
|
281
|
+
layer_name="Transfer Function",
|
|
282
|
+
complex_rgb=True,
|
|
283
|
+
clim_factor=0.5,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def apply_transfer_function(
|
|
288
|
+
fzyx_object: torch.Tensor,
|
|
289
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
290
|
+
intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
|
|
291
|
+
) -> torch.Tensor:
|
|
292
|
+
fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
|
|
293
|
+
sZYX_data = torch.einsum(
|
|
294
|
+
"fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
|
|
295
|
+
)
|
|
296
|
+
szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
|
|
297
|
+
|
|
298
|
+
return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def apply_inverse_transfer_function(
|
|
302
|
+
szyx_data: Tensor,
|
|
303
|
+
singular_system: tuple[Tensor],
|
|
304
|
+
intensity_to_stokes_matrix: Tensor,
|
|
305
|
+
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
|
|
306
|
+
regularization_strength: float = 1e-3,
|
|
307
|
+
TV_rho_strength: float = 1e-3,
|
|
308
|
+
TV_iterations: int = 10,
|
|
309
|
+
):
|
|
310
|
+
# Key computation
|
|
311
|
+
print("Computing inverse filter")
|
|
312
|
+
U, S, Vh = singular_system
|
|
313
|
+
S_reg = S / (S**2 + regularization_strength)
|
|
314
|
+
sfzyx_inverse_filter = torch.einsum(
|
|
315
|
+
"sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
fzyx_recon = apply_filter_bank(sfzyx_inverse_filter, szyx_data)
|
|
319
|
+
|
|
320
|
+
return fzyx_recon
|