waveorder 2.0.0rc3__py3-none-any.whl → 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- waveorder/_version.py +14 -2
- waveorder/correction.py +107 -0
- waveorder/focus.py +36 -18
- waveorder/models/inplane_oriented_thick_pol3d.py +15 -17
- waveorder/models/inplane_oriented_thick_pol3d_vector.py +351 -0
- waveorder/models/isotropic_fluorescent_thick_3d.py +86 -33
- waveorder/models/isotropic_thin_3d.py +113 -29
- waveorder/models/phase_thick_3d.py +107 -63
- waveorder/optics.py +243 -29
- waveorder/sampling.py +94 -0
- waveorder/stokes.py +2 -2
- waveorder/util.py +54 -2
- waveorder/{visual.py → visuals/jupyter_visuals.py} +2 -6
- waveorder/visuals/matplotlib_visuals.py +335 -0
- waveorder/visuals/napari_visuals.py +77 -0
- waveorder/visuals/utils.py +31 -0
- waveorder/waveorder_reconstructor.py +8 -7
- waveorder-2.2.0.dist-info/METADATA +186 -0
- waveorder-2.2.0.dist-info/RECORD +25 -0
- {waveorder-2.0.0rc3.dist-info → waveorder-2.2.0.dist-info}/WHEEL +1 -1
- waveorder-2.0.0rc3.dist-info/METADATA +0 -129
- waveorder-2.0.0rc3.dist-info/RECORD +0 -19
- {waveorder-2.0.0rc3.dist-info → waveorder-2.2.0.dist-info}/LICENSE +0 -0
- {waveorder-2.0.0rc3.dist-info → waveorder-2.2.0.dist-info}/top_level.txt +0 -0
waveorder/_version.py
CHANGED
|
@@ -1,4 +1,16 @@
|
|
|
1
1
|
# file generated by setuptools_scm
|
|
2
2
|
# don't change, don't track in version control
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
TYPE_CHECKING = False
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from typing import Tuple, Union
|
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
7
|
+
else:
|
|
8
|
+
VERSION_TUPLE = object
|
|
9
|
+
|
|
10
|
+
version: str
|
|
11
|
+
__version__: str
|
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
|
13
|
+
version_tuple: VERSION_TUPLE
|
|
14
|
+
|
|
15
|
+
__version__ = version = '2.2.0'
|
|
16
|
+
__version_tuple__ = version_tuple = (2, 2, 0)
|
waveorder/correction.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Background correction methods"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor, Size
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _sample_block_medians(image: Tensor, block_size) -> Tensor:
|
|
9
|
+
"""
|
|
10
|
+
Sample densely tiled square blocks from a 2D image and return their medians.
|
|
11
|
+
Incomplete blocks (overhangs) will be ignored.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
image : Tensor
|
|
16
|
+
2D image
|
|
17
|
+
block_size : int, optional
|
|
18
|
+
Width and height of the blocks
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Tensor
|
|
23
|
+
Median intensity values for each block, flattened
|
|
24
|
+
"""
|
|
25
|
+
if not image.dtype.is_floating_point:
|
|
26
|
+
image.to(torch.float)
|
|
27
|
+
blocks = F.unfold(image[None, None], block_size, stride=block_size)[0]
|
|
28
|
+
return blocks.median(0)[0]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _grid_coordinates(image: Tensor, block_size: int) -> Tensor:
|
|
32
|
+
"""Build image coordinates from the center points of square blocks"""
|
|
33
|
+
coords = torch.meshgrid(
|
|
34
|
+
[
|
|
35
|
+
torch.arange(
|
|
36
|
+
0 + block_size / 2,
|
|
37
|
+
boundary - block_size / 2 + 1,
|
|
38
|
+
block_size,
|
|
39
|
+
device=image.device,
|
|
40
|
+
)
|
|
41
|
+
for boundary in image.shape
|
|
42
|
+
]
|
|
43
|
+
)
|
|
44
|
+
return torch.stack(coords, dim=-1).reshape(-1, 2)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _fit_2d_polynomial_surface(
|
|
48
|
+
coords: Tensor, values: Tensor, order: int, surface_shape: Size
|
|
49
|
+
) -> Tensor:
|
|
50
|
+
"""Fit a 2D polynomial to a set of coordinates and their values,
|
|
51
|
+
and return the surface evaluated at every point."""
|
|
52
|
+
n_coeffs = int((order + 1) * (order + 2) / 2)
|
|
53
|
+
if n_coeffs >= len(values):
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Cannot fit a {order} degree 2D polynomial "
|
|
56
|
+
f"with {len(values)} sampled values"
|
|
57
|
+
)
|
|
58
|
+
orders = torch.arange(order + 1, device=coords.device)
|
|
59
|
+
order_pairs = torch.stack(torch.meshgrid(orders, orders), -1)
|
|
60
|
+
order_pairs = order_pairs[order_pairs.sum(-1) <= order].reshape(-1, 2)
|
|
61
|
+
terms = torch.stack(
|
|
62
|
+
[coords[:, 0] ** i * coords[:, 1] ** j for i, j in order_pairs], -1
|
|
63
|
+
)
|
|
64
|
+
# use "gels" driver for precision and GPU consistency
|
|
65
|
+
coeffs = torch.linalg.lstsq(terms, values, driver="gels").solution
|
|
66
|
+
dense_coords = torch.meshgrid(
|
|
67
|
+
[
|
|
68
|
+
torch.arange(s, dtype=values.dtype, device=values.device)
|
|
69
|
+
for s in surface_shape
|
|
70
|
+
]
|
|
71
|
+
)
|
|
72
|
+
dense_terms = torch.stack(
|
|
73
|
+
[dense_coords[0] ** i * dense_coords[1] ** j for i, j in order_pairs],
|
|
74
|
+
-1,
|
|
75
|
+
)
|
|
76
|
+
return torch.matmul(dense_terms, coeffs)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def estimate_background(image: Tensor, order: int = 2, block_size: int = 32):
|
|
80
|
+
"""
|
|
81
|
+
Combine sampling and polynomial surface fit for background estimation.
|
|
82
|
+
To background correct an image, divide it by the background.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
image : Tensor
|
|
87
|
+
2D image
|
|
88
|
+
order : int, optional
|
|
89
|
+
Order of polynomial, by default 2
|
|
90
|
+
block_size : int, optional
|
|
91
|
+
Width and height of the blocks, by default 32
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tensor
|
|
96
|
+
Background image
|
|
97
|
+
"""
|
|
98
|
+
if image.ndim != 2:
|
|
99
|
+
raise ValueError(f"Image must be 2D, got shape {image.shape}")
|
|
100
|
+
height, width = image.shape
|
|
101
|
+
if block_size > width:
|
|
102
|
+
raise ValueError("Block size larger than image height")
|
|
103
|
+
if block_size > height:
|
|
104
|
+
raise ValueError("Block size larger than image width")
|
|
105
|
+
medians = _sample_block_medians(image, block_size)
|
|
106
|
+
coords = _grid_coordinates(image, block_size)
|
|
107
|
+
return _fit_2d_polynomial_surface(coords, medians, order, image.shape)
|
waveorder/focus.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Literal, Optional
|
|
|
3
3
|
from waveorder import util
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import numpy as np
|
|
6
|
+
import warnings
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def focus_from_transverse_band(
|
|
@@ -60,10 +61,19 @@ def focus_from_transverse_band(
|
|
|
60
61
|
>>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
|
|
61
62
|
>>> in_focus_data = data[slice,:,:]
|
|
62
63
|
"""
|
|
63
|
-
minmaxfunc =
|
|
64
|
-
|
|
64
|
+
minmaxfunc = _mode_to_minmaxfunc(mode)
|
|
65
|
+
|
|
66
|
+
_check_focus_inputs(
|
|
67
|
+
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
65
68
|
)
|
|
66
69
|
|
|
70
|
+
# Check for single slice
|
|
71
|
+
if zyx_array.shape[0] == 1:
|
|
72
|
+
warnings.warn(
|
|
73
|
+
"The dataset only contained a single slice. Returning trivial slice index = 0."
|
|
74
|
+
)
|
|
75
|
+
return 0
|
|
76
|
+
|
|
67
77
|
# Calculate coordinates
|
|
68
78
|
_, Y, X = zyx_array.shape
|
|
69
79
|
_, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
|
|
@@ -94,25 +104,35 @@ def focus_from_transverse_band(
|
|
|
94
104
|
# Plot
|
|
95
105
|
if plot_path is not None:
|
|
96
106
|
_plot_focus_metric(
|
|
97
|
-
plot_path,
|
|
107
|
+
plot_path,
|
|
108
|
+
midband_sum,
|
|
109
|
+
peak_index,
|
|
110
|
+
in_focus_index,
|
|
111
|
+
peak_results,
|
|
112
|
+
threshold_FWHM,
|
|
98
113
|
)
|
|
99
114
|
|
|
100
115
|
return in_focus_index
|
|
101
116
|
|
|
102
117
|
|
|
118
|
+
def _mode_to_minmaxfunc(mode):
|
|
119
|
+
if mode == "min":
|
|
120
|
+
minmaxfunc = np.argmin
|
|
121
|
+
elif mode == "max":
|
|
122
|
+
minmaxfunc = np.argmax
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("mode must be either `min` or `max`")
|
|
125
|
+
return minmaxfunc
|
|
126
|
+
|
|
127
|
+
|
|
103
128
|
def _check_focus_inputs(
|
|
104
|
-
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
129
|
+
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
|
|
105
130
|
):
|
|
106
131
|
N = len(zyx_array.shape)
|
|
107
132
|
if N != 3:
|
|
108
133
|
raise ValueError(
|
|
109
134
|
f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays."
|
|
110
135
|
)
|
|
111
|
-
if zyx_array.shape[0] == 1:
|
|
112
|
-
print(
|
|
113
|
-
"WARNING: The dataset only contained a single slice. Returning trivial slice index = 0."
|
|
114
|
-
)
|
|
115
|
-
return 0
|
|
116
136
|
|
|
117
137
|
if NA_det < 0:
|
|
118
138
|
raise ValueError("NA must be > 0")
|
|
@@ -121,7 +141,7 @@ def _check_focus_inputs(
|
|
|
121
141
|
if pixel_size < 0:
|
|
122
142
|
raise ValueError("pixel_size must be > 0")
|
|
123
143
|
if not 0.4 < lambda_ill / pixel_size < 10:
|
|
124
|
-
|
|
144
|
+
warnings.warn(
|
|
125
145
|
f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}."
|
|
126
146
|
f"Did you use the same units?"
|
|
127
147
|
f"Did you enter the pixel size in (demagnified) object-space units?"
|
|
@@ -134,17 +154,15 @@ def _check_focus_inputs(
|
|
|
134
154
|
raise ValueError("midband_fractions[0] must be between 0 and 1")
|
|
135
155
|
if not (0 <= midband_fractions[1] <= 1):
|
|
136
156
|
raise ValueError("midband_fractions[1] must be between 0 and 1")
|
|
137
|
-
if mode == "min":
|
|
138
|
-
minmaxfunc = np.argmin
|
|
139
|
-
elif mode == "max":
|
|
140
|
-
minmaxfunc = np.argmax
|
|
141
|
-
else:
|
|
142
|
-
raise ValueError("mode must be either `min` or `max`")
|
|
143
|
-
return minmaxfunc
|
|
144
157
|
|
|
145
158
|
|
|
146
159
|
def _plot_focus_metric(
|
|
147
|
-
plot_path,
|
|
160
|
+
plot_path,
|
|
161
|
+
midband_sum,
|
|
162
|
+
peak_index,
|
|
163
|
+
in_focus_index,
|
|
164
|
+
peak_results,
|
|
165
|
+
threshold_FWHM,
|
|
148
166
|
):
|
|
149
167
|
_, ax = plt.subplots(1, 1, figsize=(4, 4))
|
|
150
168
|
ax.plot(midband_sum, "-k")
|
|
@@ -4,10 +4,10 @@ import numpy as np
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
|
-
from waveorder import
|
|
7
|
+
from waveorder import correction, stokes, util
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def generate_test_phantom(yx_shape):
|
|
10
|
+
def generate_test_phantom(yx_shape: Tuple[int, int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
11
11
|
star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
|
|
12
12
|
retardance = 0.25 * star
|
|
13
13
|
orientation = (theta % np.pi) * (star > 1e-3)
|
|
@@ -17,13 +17,13 @@ def generate_test_phantom(yx_shape):
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def calculate_transfer_function(
|
|
20
|
-
swing,
|
|
21
|
-
scheme,
|
|
22
|
-
):
|
|
20
|
+
swing: float,
|
|
21
|
+
scheme: str,
|
|
22
|
+
) -> Tensor:
|
|
23
23
|
return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
26
|
+
def visualize_transfer_function(viewer, intensity_to_stokes_matrix: Tensor) -> None:
|
|
27
27
|
viewer.add_image(
|
|
28
28
|
intensity_to_stokes_matrix.cpu().numpy(),
|
|
29
29
|
name="Intensity to stokes matrix",
|
|
@@ -31,12 +31,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def apply_transfer_function(
|
|
34
|
-
retardance,
|
|
35
|
-
orientation,
|
|
36
|
-
transmittance,
|
|
37
|
-
depolarization,
|
|
38
|
-
intensity_to_stokes_matrix,
|
|
39
|
-
):
|
|
34
|
+
retardance: Tensor,
|
|
35
|
+
orientation: Tensor,
|
|
36
|
+
transmittance: Tensor,
|
|
37
|
+
depolarization: Tensor,
|
|
38
|
+
intensity_to_stokes_matrix: Tensor,
|
|
39
|
+
) -> Tensor:
|
|
40
40
|
stokes_params = stokes.stokes_after_adr(
|
|
41
41
|
retardance, orientation, transmittance, depolarization
|
|
42
42
|
)
|
|
@@ -59,7 +59,7 @@ def apply_inverse_transfer_function(
|
|
|
59
59
|
project_stokes_to_2d: bool = False,
|
|
60
60
|
flip_orientation: bool = False,
|
|
61
61
|
rotate_orientation: bool = False,
|
|
62
|
-
) -> Tuple[Tensor]:
|
|
62
|
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
63
63
|
"""Reconstructs retardance, orientation, transmittance, and depolarization
|
|
64
64
|
from czyx_data and an intensity_to_stokes_matrix, providing options for
|
|
65
65
|
background correction, projection, and orientation transformations.
|
|
@@ -125,7 +125,6 @@ def apply_inverse_transfer_function(
|
|
|
125
125
|
|
|
126
126
|
# Apply an "Estimated" background correction
|
|
127
127
|
if remove_estimated_background:
|
|
128
|
-
estimator = background_estimator.BackgroundEstimator2D()
|
|
129
128
|
for stokes_index in range(background_corrected_stokes.shape[0]):
|
|
130
129
|
# Project to 2D
|
|
131
130
|
z_projection = torch.mean(
|
|
@@ -134,9 +133,8 @@ def apply_inverse_transfer_function(
|
|
|
134
133
|
# Estimate the background and subtract
|
|
135
134
|
background_corrected_stokes[
|
|
136
135
|
stokes_index
|
|
137
|
-
] -=
|
|
138
|
-
z_projection,
|
|
139
|
-
normalize=False,
|
|
136
|
+
] -= correction.estimate_background(
|
|
137
|
+
z_projection, order=2, block_size=32
|
|
140
138
|
)
|
|
141
139
|
|
|
142
140
|
# Project to 2D (typically for SNR reasons)
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from typing import Literal
|
|
6
|
+
from torch.nn.functional import avg_pool3d, interpolate
|
|
7
|
+
from waveorder import optics, sampling, stokes, util
|
|
8
|
+
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
|
|
12
|
+
# Simulate
|
|
13
|
+
yx_star, yx_theta, _ = util.generate_star_target(
|
|
14
|
+
yx_shape=zyx_shape[1:],
|
|
15
|
+
blur_px=1,
|
|
16
|
+
margin=50,
|
|
17
|
+
)
|
|
18
|
+
c00 = yx_star
|
|
19
|
+
c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
|
|
20
|
+
c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
|
|
21
|
+
|
|
22
|
+
# Put in a center slices of a 3D object
|
|
23
|
+
center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
|
|
24
|
+
object = torch.zeros((3,) + zyx_shape)
|
|
25
|
+
object[:, zyx_shape[0] // 2, ...] = center_slice_object
|
|
26
|
+
return object
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def calculate_transfer_function(
|
|
30
|
+
swing: float,
|
|
31
|
+
scheme: str,
|
|
32
|
+
zyx_shape: tuple[int, int, int],
|
|
33
|
+
yx_pixel_size: float,
|
|
34
|
+
z_pixel_size: float,
|
|
35
|
+
wavelength_illumination: float,
|
|
36
|
+
z_padding: int,
|
|
37
|
+
index_of_refraction_media: float,
|
|
38
|
+
numerical_aperture_illumination: float,
|
|
39
|
+
numerical_aperture_detection: float,
|
|
40
|
+
invert_phase_contrast: bool = False,
|
|
41
|
+
fourier_oversample_factor: int = 1,
|
|
42
|
+
transverse_downsample_factor: int = 1,
|
|
43
|
+
) -> tuple[
|
|
44
|
+
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
45
|
+
]:
|
|
46
|
+
if z_padding != 0:
|
|
47
|
+
raise NotImplementedError("Padding not implemented for this model")
|
|
48
|
+
|
|
49
|
+
transverse_nyquist = sampling.transverse_nyquist(
|
|
50
|
+
wavelength_illumination,
|
|
51
|
+
numerical_aperture_illumination,
|
|
52
|
+
numerical_aperture_detection,
|
|
53
|
+
)
|
|
54
|
+
axial_nyquist = sampling.axial_nyquist(
|
|
55
|
+
wavelength_illumination,
|
|
56
|
+
numerical_aperture_detection,
|
|
57
|
+
index_of_refraction_media,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
|
|
61
|
+
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
|
|
62
|
+
|
|
63
|
+
print("YX factor:", yx_factor)
|
|
64
|
+
print("Z factor:", z_factor)
|
|
65
|
+
|
|
66
|
+
tf_calculation_shape = (
|
|
67
|
+
zyx_shape[0] * z_factor * fourier_oversample_factor,
|
|
68
|
+
int(
|
|
69
|
+
np.ceil(
|
|
70
|
+
zyx_shape[1]
|
|
71
|
+
* yx_factor
|
|
72
|
+
* fourier_oversample_factor
|
|
73
|
+
/ transverse_downsample_factor
|
|
74
|
+
)
|
|
75
|
+
),
|
|
76
|
+
int(
|
|
77
|
+
np.ceil(
|
|
78
|
+
zyx_shape[2]
|
|
79
|
+
* yx_factor
|
|
80
|
+
* fourier_oversample_factor
|
|
81
|
+
/ transverse_downsample_factor
|
|
82
|
+
)
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
sfZYX_transfer_function, intensity_to_stokes_matrix = (
|
|
87
|
+
_calculate_wrap_unsafe_transfer_function(
|
|
88
|
+
swing,
|
|
89
|
+
scheme,
|
|
90
|
+
tf_calculation_shape,
|
|
91
|
+
yx_pixel_size / yx_factor,
|
|
92
|
+
z_pixel_size / z_factor,
|
|
93
|
+
wavelength_illumination,
|
|
94
|
+
z_padding,
|
|
95
|
+
index_of_refraction_media,
|
|
96
|
+
numerical_aperture_illumination,
|
|
97
|
+
numerical_aperture_detection,
|
|
98
|
+
invert_phase_contrast=invert_phase_contrast,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# avg_pool3d does not support complex numbers
|
|
103
|
+
pooled_sfZYX_transfer_function_real = avg_pool3d(
|
|
104
|
+
sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
|
|
105
|
+
)
|
|
106
|
+
pooled_sfZYX_transfer_function_imag = avg_pool3d(
|
|
107
|
+
sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
|
|
108
|
+
)
|
|
109
|
+
pooled_sfZYX_transfer_function = (
|
|
110
|
+
pooled_sfZYX_transfer_function_real
|
|
111
|
+
+ 1j * pooled_sfZYX_transfer_function_imag
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Crop to original size
|
|
115
|
+
sfzyx_out_shape = (
|
|
116
|
+
pooled_sfZYX_transfer_function.shape[0],
|
|
117
|
+
pooled_sfZYX_transfer_function.shape[1],
|
|
118
|
+
zyx_shape[0] + 2 * z_padding,
|
|
119
|
+
) + zyx_shape[1:]
|
|
120
|
+
|
|
121
|
+
cropped = sampling.nd_fourier_central_cuboid(
|
|
122
|
+
pooled_sfZYX_transfer_function, sfzyx_out_shape
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Compute singular system on cropped and downsampled
|
|
126
|
+
U, S, Vh = calculate_singular_system(cropped)
|
|
127
|
+
|
|
128
|
+
# Interpolate to final size in YX
|
|
129
|
+
def complex_interpolate(
|
|
130
|
+
tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
interpolated_real = interpolate(tensor.real, size=zyx_shape)
|
|
133
|
+
interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
|
|
134
|
+
return interpolated_real + 1j * interpolated_imag
|
|
135
|
+
|
|
136
|
+
full_cropped = complex_interpolate(cropped, zyx_shape)
|
|
137
|
+
full_U = complex_interpolate(U, zyx_shape)
|
|
138
|
+
full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
|
|
139
|
+
full_Vh = complex_interpolate(Vh, zyx_shape)
|
|
140
|
+
|
|
141
|
+
return (
|
|
142
|
+
full_cropped,
|
|
143
|
+
intensity_to_stokes_matrix,
|
|
144
|
+
(full_U, full_S, full_Vh),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _calculate_wrap_unsafe_transfer_function(
|
|
149
|
+
swing,
|
|
150
|
+
scheme,
|
|
151
|
+
zyx_shape,
|
|
152
|
+
yx_pixel_size,
|
|
153
|
+
z_pixel_size,
|
|
154
|
+
wavelength_illumination,
|
|
155
|
+
z_padding,
|
|
156
|
+
index_of_refraction_media,
|
|
157
|
+
numerical_aperture_illumination,
|
|
158
|
+
numerical_aperture_detection,
|
|
159
|
+
invert_phase_contrast=False,
|
|
160
|
+
):
|
|
161
|
+
print("Computing transfer function")
|
|
162
|
+
intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
|
|
163
|
+
swing, scheme=scheme
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
|
|
167
|
+
# input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
|
|
168
|
+
|
|
169
|
+
# Calculate frequencies
|
|
170
|
+
y_frequencies, x_frequencies = util.generate_frequencies(
|
|
171
|
+
zyx_shape[1:], yx_pixel_size
|
|
172
|
+
)
|
|
173
|
+
radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
|
|
174
|
+
|
|
175
|
+
z_total = zyx_shape[0] + 2 * z_padding
|
|
176
|
+
z_position_list = torch.fft.ifftshift(
|
|
177
|
+
(torch.arange(z_total) - z_total // 2) * z_pixel_size
|
|
178
|
+
)
|
|
179
|
+
if (
|
|
180
|
+
not invert_phase_contrast
|
|
181
|
+
): # opposite sign of direct phase reconstruction
|
|
182
|
+
z_position_list = torch.flip(z_position_list, dims=(0,))
|
|
183
|
+
z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
|
|
184
|
+
|
|
185
|
+
# 2D pupils
|
|
186
|
+
print("\tCalculating pupils...")
|
|
187
|
+
ill_pupil = optics.generate_pupil(
|
|
188
|
+
radial_frequencies,
|
|
189
|
+
numerical_aperture_illumination,
|
|
190
|
+
wavelength_illumination,
|
|
191
|
+
)
|
|
192
|
+
det_pupil = optics.generate_pupil(
|
|
193
|
+
radial_frequencies,
|
|
194
|
+
numerical_aperture_detection,
|
|
195
|
+
wavelength_illumination,
|
|
196
|
+
)
|
|
197
|
+
pupil = optics.generate_pupil(
|
|
198
|
+
radial_frequencies,
|
|
199
|
+
index_of_refraction_media, # largest possible NA
|
|
200
|
+
wavelength_illumination,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Defocus pupils
|
|
204
|
+
defocus_pupil = optics.generate_propagation_kernel(
|
|
205
|
+
radial_frequencies,
|
|
206
|
+
pupil,
|
|
207
|
+
wavelength_illumination / index_of_refraction_media,
|
|
208
|
+
z_position_list,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Calculate vector defocus pupils
|
|
212
|
+
S = optics.generate_vector_source_defocus_pupil(
|
|
213
|
+
x_frequencies,
|
|
214
|
+
y_frequencies,
|
|
215
|
+
z_position_list,
|
|
216
|
+
defocus_pupil,
|
|
217
|
+
input_jones,
|
|
218
|
+
ill_pupil,
|
|
219
|
+
wavelength_illumination / index_of_refraction_media,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Simplified scalar pupil
|
|
223
|
+
P = optics.generate_propagation_kernel(
|
|
224
|
+
radial_frequencies,
|
|
225
|
+
det_pupil,
|
|
226
|
+
wavelength_illumination / index_of_refraction_media,
|
|
227
|
+
z_position_list,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
|
|
231
|
+
S_3D = torch.fft.ifft(S, dim=-3)
|
|
232
|
+
|
|
233
|
+
print("\tCalculating greens tensor spectrum...")
|
|
234
|
+
G_3D = optics.generate_greens_tensor_spectrum(
|
|
235
|
+
zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
|
|
236
|
+
zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
|
|
237
|
+
wavelength=wavelength_illumination / index_of_refraction_media,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Main part
|
|
241
|
+
PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
|
|
242
|
+
PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
|
|
243
|
+
|
|
244
|
+
del P_3D, G_3D, S_3D
|
|
245
|
+
|
|
246
|
+
print("\tComputing pg and ps...")
|
|
247
|
+
pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
|
|
248
|
+
ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
|
|
249
|
+
|
|
250
|
+
del PG_3D, PS_3D
|
|
251
|
+
|
|
252
|
+
print("\tComputing H1 and H2...")
|
|
253
|
+
H1 = torch.fft.ifftn(
|
|
254
|
+
torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
|
|
255
|
+
dim=(-3, -2, -1),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
H2 = torch.fft.ifftn(
|
|
259
|
+
torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
|
|
260
|
+
dim=(-3, -2, -1),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
|
|
264
|
+
# H_im = 1j * (H1 - H2) # ignore absorptive terms
|
|
265
|
+
|
|
266
|
+
del H1, H2
|
|
267
|
+
|
|
268
|
+
H_re /= torch.amax(torch.abs(H_re))
|
|
269
|
+
|
|
270
|
+
s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
|
|
271
|
+
Y = util.gellmann()[[0, 4, 8]]
|
|
272
|
+
# select phase f00 and transverse linear isotropic terms 2-2, and f22
|
|
273
|
+
|
|
274
|
+
print("\tComputing final transfer function...")
|
|
275
|
+
sfZYX_transfer_function = torch.einsum(
|
|
276
|
+
"sik,ikpjzyx,lpj->slzyx", s, H_re, Y
|
|
277
|
+
)
|
|
278
|
+
return (
|
|
279
|
+
sfZYX_transfer_function,
|
|
280
|
+
intensity_to_stokes_matrix,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def calculate_singular_system(sfZYX_transfer_function):
|
|
285
|
+
# Compute regularized inverse filter
|
|
286
|
+
print("Computing SVD")
|
|
287
|
+
ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
|
|
288
|
+
U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
|
|
289
|
+
singular_system = (
|
|
290
|
+
U.permute(3, 4, 0, 1, 2),
|
|
291
|
+
S.permute(3, 0, 1, 2),
|
|
292
|
+
Vh.permute(3, 4, 0, 1, 2),
|
|
293
|
+
)
|
|
294
|
+
return singular_system
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def visualize_transfer_function(
|
|
298
|
+
viewer: "napari.Viewer",
|
|
299
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
300
|
+
zyx_scale: tuple[float, float, float],
|
|
301
|
+
) -> None:
|
|
302
|
+
add_transfer_function_to_viewer(
|
|
303
|
+
viewer,
|
|
304
|
+
sfZYX_transfer_function,
|
|
305
|
+
zyx_scale=zyx_scale,
|
|
306
|
+
layer_name="Transfer Function",
|
|
307
|
+
complex_rgb=True,
|
|
308
|
+
clim_factor=0.5,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def apply_transfer_function(
|
|
313
|
+
fzyx_object: torch.Tensor,
|
|
314
|
+
sfZYX_transfer_function: torch.Tensor,
|
|
315
|
+
intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
|
|
316
|
+
) -> torch.Tensor:
|
|
317
|
+
fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
|
|
318
|
+
sZYX_data = torch.einsum(
|
|
319
|
+
"fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
|
|
320
|
+
)
|
|
321
|
+
szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
|
|
322
|
+
|
|
323
|
+
return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def apply_inverse_transfer_function(
|
|
327
|
+
szyx_data: Tensor,
|
|
328
|
+
singular_system: tuple[Tensor],
|
|
329
|
+
intensity_to_stokes_matrix: Tensor,
|
|
330
|
+
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
|
|
331
|
+
regularization_strength: float = 1e-3,
|
|
332
|
+
TV_rho_strength: float = 1e-3,
|
|
333
|
+
TV_iterations: int = 10,
|
|
334
|
+
):
|
|
335
|
+
sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
|
|
336
|
+
|
|
337
|
+
# Key computation
|
|
338
|
+
print("Computing inverse filter")
|
|
339
|
+
U, S, Vh = singular_system
|
|
340
|
+
S_reg = S / (S**2 + regularization_strength)
|
|
341
|
+
|
|
342
|
+
ZYXsf_inverse_filter = torch.einsum(
|
|
343
|
+
"sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Apply inverse filter
|
|
347
|
+
fZYX_reconstructed = torch.einsum(
|
|
348
|
+
"szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))
|