waveorder 0.2.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of waveorder might be problematic. Click here for more details.
- waveorder/__init__.py +0 -0
- waveorder/_version.py +16 -0
- waveorder/background_estimator.py +319 -0
- waveorder/correction.py +107 -0
- waveorder/focus.py +198 -0
- waveorder/models/inplane_oriented_thick_pol3d.py +159 -0
- waveorder/models/isotropic_fluorescent_thick_3d.py +192 -0
- waveorder/models/isotropic_thin_3d.py +281 -0
- waveorder/models/phase_thick_3d.py +219 -0
- waveorder/optics.py +1196 -0
- waveorder/stokes.py +458 -0
- waveorder/util.py +2241 -0
- waveorder/visual.py +1931 -0
- waveorder/waveorder_reconstructor.py +4031 -0
- waveorder/waveorder_simulator.py +1217 -0
- waveorder-0.2.2rc0.dist-info/LICENSE +28 -0
- waveorder-0.2.2rc0.dist-info/METADATA +147 -0
- waveorder-0.2.2rc0.dist-info/RECORD +20 -0
- waveorder-0.2.2rc0.dist-info/WHEEL +5 -0
- waveorder-0.2.2rc0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from waveorder import correction, stokes, util
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def generate_test_phantom(yx_shape):
|
|
11
|
+
star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
|
|
12
|
+
retardance = 0.25 * star
|
|
13
|
+
orientation = (theta % np.pi) * (star > 1e-3)
|
|
14
|
+
transmittance = 0.9 * torch.ones_like(retardance)
|
|
15
|
+
depolarization = 0.9 * torch.ones_like(retardance)
|
|
16
|
+
return retardance, orientation, transmittance, depolarization
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def calculate_transfer_function(
|
|
20
|
+
swing,
|
|
21
|
+
scheme,
|
|
22
|
+
):
|
|
23
|
+
return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
|
|
27
|
+
viewer.add_image(
|
|
28
|
+
intensity_to_stokes_matrix.cpu().numpy(),
|
|
29
|
+
name="Intensity to stokes matrix",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def apply_transfer_function(
|
|
34
|
+
retardance,
|
|
35
|
+
orientation,
|
|
36
|
+
transmittance,
|
|
37
|
+
depolarization,
|
|
38
|
+
intensity_to_stokes_matrix,
|
|
39
|
+
):
|
|
40
|
+
stokes_params = stokes.stokes_after_adr(
|
|
41
|
+
retardance, orientation, transmittance, depolarization
|
|
42
|
+
)
|
|
43
|
+
stokes_to_intensity_matrix = torch.linalg.pinv(intensity_to_stokes_matrix)
|
|
44
|
+
|
|
45
|
+
cyx_intensities = stokes.mmul(
|
|
46
|
+
stokes_to_intensity_matrix, torch.stack(stokes_params)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Return in czyx shape
|
|
50
|
+
# TODO: make this simulation more realistic with defocussed data
|
|
51
|
+
return cyx_intensities[:, None, ...] + 0.1
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def apply_inverse_transfer_function(
|
|
55
|
+
czyx_data: Tensor,
|
|
56
|
+
intensity_to_stokes_matrix: Tensor,
|
|
57
|
+
cyx_no_sample_data: Optional[Tensor] = None,
|
|
58
|
+
remove_estimated_background: bool = False,
|
|
59
|
+
project_stokes_to_2d: bool = False,
|
|
60
|
+
flip_orientation: bool = False,
|
|
61
|
+
rotate_orientation: bool = False,
|
|
62
|
+
) -> Tuple[Tensor]:
|
|
63
|
+
"""Reconstructs retardance, orientation, transmittance, and depolarization
|
|
64
|
+
from czyx_data and an intensity_to_stokes_matrix, providing options for
|
|
65
|
+
background correction, projection, and orientation transformations.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
czyx_data : Tensor
|
|
70
|
+
4D raw data, first dimension is the polarization dimension, remaining
|
|
71
|
+
dimensions are spatial
|
|
72
|
+
intensity_to_stokes_matrix : Tensor
|
|
73
|
+
Forward model, see calculate_transfer_function above
|
|
74
|
+
cyx_no_sample_data : Tensor, optional
|
|
75
|
+
3D raw background data, by default None
|
|
76
|
+
First dimension is the polarization dimension, remaining dimensions are spatial.
|
|
77
|
+
cyx shape must match in this parameter and czxy_data
|
|
78
|
+
If provided, this background will be removed.
|
|
79
|
+
If None, no background will be removed.
|
|
80
|
+
remove_estimated_background : bool, optional
|
|
81
|
+
Estimate a background from the data and remove it, by default False
|
|
82
|
+
project_stokes_to_2d : bool, optional
|
|
83
|
+
Project stokes to 2D for SNR improvement in thin samples, by default False
|
|
84
|
+
flip_orientation : bool, optional
|
|
85
|
+
Flip the reconstructed orientation about the x axis, by default False
|
|
86
|
+
rotate_orientation : bool, optional
|
|
87
|
+
Add 90 degrees to the reconstructed orientation, by default False
|
|
88
|
+
|
|
89
|
+
Notes
|
|
90
|
+
-----
|
|
91
|
+
cyx_no_sample_data and remove_estimated_background provide background correction options
|
|
92
|
+
|
|
93
|
+
flip_orientation and rotate_orientation modify the reconstructed orientation.
|
|
94
|
+
We recommend using these parameters when a test target with a known orientation
|
|
95
|
+
is available.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
Tuple[Tensor]
|
|
100
|
+
zyx_retardance (radians)
|
|
101
|
+
zyx_orientation (radians)
|
|
102
|
+
zyx_transmittance (unitless)
|
|
103
|
+
zyx_depolarization (unitless)
|
|
104
|
+
"""
|
|
105
|
+
data_stokes = stokes.mmul(intensity_to_stokes_matrix, czyx_data)
|
|
106
|
+
|
|
107
|
+
# Apply a "Measured" background correction
|
|
108
|
+
if cyx_no_sample_data is None:
|
|
109
|
+
background_corrected_stokes = data_stokes
|
|
110
|
+
else:
|
|
111
|
+
# Find the no-sample Stokes parameters from the background data
|
|
112
|
+
measured_no_sample_stokes = stokes.mmul(
|
|
113
|
+
intensity_to_stokes_matrix, cyx_no_sample_data
|
|
114
|
+
)
|
|
115
|
+
# Estimate the attenuating, depolarizing, retarder's inverse Mueller
|
|
116
|
+
# matrix that caused this background data
|
|
117
|
+
inverse_background_mueller = stokes.mueller_from_stokes(
|
|
118
|
+
*measured_no_sample_stokes, model="adr", direction="inverse"
|
|
119
|
+
)
|
|
120
|
+
# Apply this background-correction Mueller matrix to the data to remove
|
|
121
|
+
# the background contribution
|
|
122
|
+
background_corrected_stokes = stokes.mmul(
|
|
123
|
+
inverse_background_mueller, data_stokes
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Apply an "Estimated" background correction
|
|
127
|
+
if remove_estimated_background:
|
|
128
|
+
for stokes_index in range(background_corrected_stokes.shape[0]):
|
|
129
|
+
# Project to 2D
|
|
130
|
+
z_projection = torch.mean(
|
|
131
|
+
background_corrected_stokes[stokes_index], dim=0
|
|
132
|
+
)
|
|
133
|
+
# Estimate the background and subtract
|
|
134
|
+
background_corrected_stokes[
|
|
135
|
+
stokes_index
|
|
136
|
+
] -= correction.estimate_background(
|
|
137
|
+
z_projection, order=2, block_size=32
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Project to 2D (typically for SNR reasons)
|
|
141
|
+
if project_stokes_to_2d:
|
|
142
|
+
background_corrected_stokes = torch.mean(
|
|
143
|
+
background_corrected_stokes, dim=1
|
|
144
|
+
)[:, None, ...]
|
|
145
|
+
|
|
146
|
+
# Estimate an attenuating, depolarizing, retarder's parameters,
|
|
147
|
+
# i.e. (retardance, orientation, transmittance, depolarization)
|
|
148
|
+
# from the background-corrected Stokes values
|
|
149
|
+
adr_parameters = stokes.estimate_adr_from_stokes(
|
|
150
|
+
*background_corrected_stokes
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Apply orientation transformations
|
|
154
|
+
orientation = stokes.apply_orientation_offset(
|
|
155
|
+
adr_parameters[1], rotate=rotate_orientation, flip=flip_orientation
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Return (retardance, orientation, transmittance, depolarization)
|
|
159
|
+
return adr_parameters[0], orientation, adr_parameters[2], adr_parameters[3]
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from waveorder import optics, util
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_test_phantom(
|
|
10
|
+
zyx_shape,
|
|
11
|
+
yx_pixel_size,
|
|
12
|
+
z_pixel_size,
|
|
13
|
+
sphere_radius,
|
|
14
|
+
):
|
|
15
|
+
sphere, _, _ = util.generate_sphere_target(
|
|
16
|
+
zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
return sphere
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def calculate_transfer_function(
|
|
23
|
+
zyx_shape,
|
|
24
|
+
yx_pixel_size,
|
|
25
|
+
z_pixel_size,
|
|
26
|
+
wavelength_emission,
|
|
27
|
+
z_padding,
|
|
28
|
+
index_of_refraction_media,
|
|
29
|
+
numerical_aperture_detection,
|
|
30
|
+
):
|
|
31
|
+
radial_frequencies = util.generate_radial_frequencies(
|
|
32
|
+
zyx_shape[1:], yx_pixel_size
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
z_total = zyx_shape[0] + 2 * z_padding
|
|
36
|
+
z_position_list = torch.fft.ifftshift(
|
|
37
|
+
(torch.arange(z_total) - z_total // 2) * z_pixel_size
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
det_pupil = optics.generate_pupil(
|
|
41
|
+
radial_frequencies,
|
|
42
|
+
numerical_aperture_detection,
|
|
43
|
+
wavelength_emission,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
propagation_kernel = optics.generate_propagation_kernel(
|
|
47
|
+
radial_frequencies,
|
|
48
|
+
det_pupil,
|
|
49
|
+
wavelength_emission / index_of_refraction_media,
|
|
50
|
+
z_position_list,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
point_spread_function = (
|
|
54
|
+
torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2
|
|
55
|
+
)
|
|
56
|
+
optical_transfer_function = torch.fft.fftn(
|
|
57
|
+
point_spread_function, dim=(0, 1, 2)
|
|
58
|
+
)
|
|
59
|
+
optical_transfer_function /= torch.max(
|
|
60
|
+
torch.abs(optical_transfer_function)
|
|
61
|
+
) # normalize
|
|
62
|
+
|
|
63
|
+
return optical_transfer_function
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
|
|
67
|
+
arrays = [
|
|
68
|
+
(torch.imag(optical_transfer_function), "Im(OTF)"),
|
|
69
|
+
(torch.real(optical_transfer_function), "Re(OTF)"),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
for array in arrays:
|
|
73
|
+
lim = 0.1 * torch.max(torch.abs(array[0]))
|
|
74
|
+
viewer.add_image(
|
|
75
|
+
torch.fft.ifftshift(array[0]).cpu().numpy(),
|
|
76
|
+
name=array[1],
|
|
77
|
+
colormap="bwr",
|
|
78
|
+
contrast_limits=(-lim, lim),
|
|
79
|
+
scale=1 / zyx_scale,
|
|
80
|
+
)
|
|
81
|
+
viewer.dims.order = (0, 1, 2)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def apply_transfer_function(
|
|
85
|
+
zyx_object, optical_transfer_function, z_padding, background=10
|
|
86
|
+
):
|
|
87
|
+
"""Simulate imaging by applying a transfer function
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
zyx_object : torch.Tensor
|
|
92
|
+
optical_transfer_function : torch.Tensor
|
|
93
|
+
z_padding : int
|
|
94
|
+
background : int, optional
|
|
95
|
+
constant background counts added to each voxel, by default 10
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
Simulated data : torch.Tensor
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
if (
|
|
103
|
+
zyx_object.shape[0] + 2 * z_padding
|
|
104
|
+
!= optical_transfer_function.shape[0]
|
|
105
|
+
):
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"Please check padding: ZYX_obj.shape[0] + 2 * Z_pad != H_re.shape[0]"
|
|
108
|
+
)
|
|
109
|
+
if z_padding > 0:
|
|
110
|
+
optical_transfer_function = optical_transfer_function[
|
|
111
|
+
z_padding:-z_padding
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
# Very simple simulation, consider adding noise and bkg knobs
|
|
115
|
+
zyx_obj_hat = torch.fft.fftn(zyx_object)
|
|
116
|
+
zyx_data = zyx_obj_hat * optical_transfer_function
|
|
117
|
+
data = torch.real(torch.fft.ifftn(zyx_data))
|
|
118
|
+
|
|
119
|
+
data += background # Add a direct background
|
|
120
|
+
return data
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def apply_inverse_transfer_function(
|
|
124
|
+
zyx_data: Tensor,
|
|
125
|
+
optical_transfer_function: Tensor,
|
|
126
|
+
z_padding: int,
|
|
127
|
+
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
|
|
128
|
+
regularization_strength: float = 1e-3,
|
|
129
|
+
TV_rho_strength: float = 1e-3,
|
|
130
|
+
TV_iterations: int = 10,
|
|
131
|
+
):
|
|
132
|
+
"""Reconstructs fluorescence density from zyx_data and
|
|
133
|
+
an optical_transfer_function, providing options for z padding and
|
|
134
|
+
reconstruction algorithms.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
zyx_data : Tensor
|
|
139
|
+
3D raw data, fluorescence defocus stack
|
|
140
|
+
optical_transfer_function : Tensor
|
|
141
|
+
3D optical transfer function, see calculate_transfer_function above
|
|
142
|
+
z_padding : int
|
|
143
|
+
Padding for axial dimension. Use zero for defocus stacks that
|
|
144
|
+
extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise.
|
|
145
|
+
reconstruction_algorithm : str, optional
|
|
146
|
+
"Tikhonov" or "TV", by default "Tikhonov"
|
|
147
|
+
"TV" is not implemented.
|
|
148
|
+
regularization_strength : float, optional
|
|
149
|
+
regularization parameter, by default 1e-3
|
|
150
|
+
TV_rho_strength : _type_, optional
|
|
151
|
+
TV-specific regularization parameter, by default 1e-3
|
|
152
|
+
"TV" is not implemented.
|
|
153
|
+
TV_iterations : int, optional
|
|
154
|
+
TV-specific number of iterations, by default 10
|
|
155
|
+
"TV" is not implemented.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
Tensor
|
|
160
|
+
zyx_fluorescence_density (fluorophores per volumes)
|
|
161
|
+
|
|
162
|
+
Raises
|
|
163
|
+
------
|
|
164
|
+
NotImplementedError
|
|
165
|
+
TV is not implemented
|
|
166
|
+
"""
|
|
167
|
+
# Handle padding
|
|
168
|
+
zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding)
|
|
169
|
+
|
|
170
|
+
# Reconstruct
|
|
171
|
+
if reconstruction_algorithm == "Tikhonov":
|
|
172
|
+
f_real = util.single_variable_tikhonov_deconvolution_3D(
|
|
173
|
+
zyx_padded,
|
|
174
|
+
optical_transfer_function,
|
|
175
|
+
reg_re=regularization_strength,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
elif reconstruction_algorithm == "TV":
|
|
179
|
+
raise NotImplementedError
|
|
180
|
+
f_real = util.single_variable_admm_tv_deconvolution_3D(
|
|
181
|
+
zyx_padded,
|
|
182
|
+
optical_transfer_function,
|
|
183
|
+
reg_re=regularization_strength,
|
|
184
|
+
rho=TV_rho_strength,
|
|
185
|
+
itr=TV_iterations,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Unpad
|
|
189
|
+
if z_padding != 0:
|
|
190
|
+
f_real = f_real[z_padding:-z_padding]
|
|
191
|
+
|
|
192
|
+
return f_real
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from typing import Literal, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from waveorder import optics, util
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_test_phantom(
|
|
10
|
+
yx_shape,
|
|
11
|
+
yx_pixel_size,
|
|
12
|
+
wavelength_illumination,
|
|
13
|
+
index_of_refraction_media,
|
|
14
|
+
index_of_refraction_sample,
|
|
15
|
+
sphere_radius,
|
|
16
|
+
):
|
|
17
|
+
sphere, _, _ = util.generate_sphere_target(
|
|
18
|
+
(3,) + yx_shape,
|
|
19
|
+
yx_pixel_size,
|
|
20
|
+
z_pixel_size=1.0,
|
|
21
|
+
radius=sphere_radius,
|
|
22
|
+
blur_size=2 * yx_pixel_size,
|
|
23
|
+
)
|
|
24
|
+
yx_phase = (
|
|
25
|
+
sphere[1]
|
|
26
|
+
* (index_of_refraction_sample - index_of_refraction_media)
|
|
27
|
+
* 0.1
|
|
28
|
+
/ wavelength_illumination
|
|
29
|
+
) # phase in radians
|
|
30
|
+
|
|
31
|
+
yx_absorption = 0.02 * sphere[1]
|
|
32
|
+
|
|
33
|
+
return yx_absorption, yx_phase
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def calculate_transfer_function(
|
|
37
|
+
yx_shape,
|
|
38
|
+
yx_pixel_size,
|
|
39
|
+
z_position_list,
|
|
40
|
+
wavelength_illumination,
|
|
41
|
+
index_of_refraction_media,
|
|
42
|
+
numerical_aperture_illumination,
|
|
43
|
+
numerical_aperture_detection,
|
|
44
|
+
invert_phase_contrast=False,
|
|
45
|
+
):
|
|
46
|
+
if invert_phase_contrast:
|
|
47
|
+
z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
|
|
48
|
+
|
|
49
|
+
radial_frequencies = util.generate_radial_frequencies(
|
|
50
|
+
yx_shape, yx_pixel_size
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
illumination_pupil = optics.generate_pupil(
|
|
54
|
+
radial_frequencies,
|
|
55
|
+
numerical_aperture_illumination,
|
|
56
|
+
wavelength_illumination,
|
|
57
|
+
)
|
|
58
|
+
detection_pupil = optics.generate_pupil(
|
|
59
|
+
radial_frequencies,
|
|
60
|
+
numerical_aperture_detection,
|
|
61
|
+
wavelength_illumination,
|
|
62
|
+
)
|
|
63
|
+
propagation_kernel = optics.generate_propagation_kernel(
|
|
64
|
+
radial_frequencies,
|
|
65
|
+
detection_pupil,
|
|
66
|
+
wavelength_illumination / index_of_refraction_media,
|
|
67
|
+
torch.tensor(z_position_list),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
zyx_shape = (len(z_position_list),) + tuple(yx_shape)
|
|
71
|
+
absorption_2d_to_3d_transfer_function = torch.zeros(
|
|
72
|
+
zyx_shape, dtype=torch.complex64
|
|
73
|
+
)
|
|
74
|
+
phase_2d_to_3d_transfer_function = torch.zeros(
|
|
75
|
+
zyx_shape, dtype=torch.complex64
|
|
76
|
+
)
|
|
77
|
+
for z in range(len(z_position_list)):
|
|
78
|
+
(
|
|
79
|
+
absorption_2d_to_3d_transfer_function[z],
|
|
80
|
+
phase_2d_to_3d_transfer_function[z],
|
|
81
|
+
) = optics.compute_weak_object_transfer_function_2d(
|
|
82
|
+
illumination_pupil, detection_pupil * propagation_kernel[z]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return (
|
|
86
|
+
absorption_2d_to_3d_transfer_function,
|
|
87
|
+
phase_2d_to_3d_transfer_function,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def visualize_transfer_function(
|
|
92
|
+
viewer,
|
|
93
|
+
absorption_2d_to_3d_transfer_function,
|
|
94
|
+
phase_2d_to_3d_transfer_function,
|
|
95
|
+
):
|
|
96
|
+
# TODO: consider generalizing w/ phase_thick_3d.visualize_transfer_function
|
|
97
|
+
arrays = [
|
|
98
|
+
(torch.imag(absorption_2d_to_3d_transfer_function), "Im(absorb TF)"),
|
|
99
|
+
(torch.real(absorption_2d_to_3d_transfer_function), "Re(absorb TF)"),
|
|
100
|
+
(torch.imag(phase_2d_to_3d_transfer_function), "Im(phase TF)"),
|
|
101
|
+
(torch.real(phase_2d_to_3d_transfer_function), "Re(phase TF)"),
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
for array in arrays:
|
|
105
|
+
lim = 0.5 * torch.max(torch.abs(array[0]))
|
|
106
|
+
viewer.add_image(
|
|
107
|
+
torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
|
|
108
|
+
name=array[1],
|
|
109
|
+
colormap="bwr",
|
|
110
|
+
contrast_limits=(-lim, lim),
|
|
111
|
+
scale=(1, 1, 1),
|
|
112
|
+
)
|
|
113
|
+
viewer.dims.order = (0, 1, 2)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def visualize_point_spread_function(
|
|
117
|
+
viewer,
|
|
118
|
+
absorption_2d_to_3d_transfer_function,
|
|
119
|
+
phase_2d_to_3d_transfer_function,
|
|
120
|
+
):
|
|
121
|
+
arrays = [
|
|
122
|
+
(torch.fft.ifftn(absorption_2d_to_3d_transfer_function), "absorb PSF"),
|
|
123
|
+
(torch.fft.ifftn(phase_2d_to_3d_transfer_function), "phase PSF"),
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
for array in arrays:
|
|
127
|
+
lim = 0.5 * torch.max(torch.abs(array[0]))
|
|
128
|
+
viewer.add_image(
|
|
129
|
+
torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
|
|
130
|
+
name=array[1],
|
|
131
|
+
colormap="bwr",
|
|
132
|
+
contrast_limits=(-lim, lim),
|
|
133
|
+
scale=(1, 1, 1),
|
|
134
|
+
)
|
|
135
|
+
viewer.dims.order = (0, 1, 2)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def apply_transfer_function(
|
|
139
|
+
yx_absorption,
|
|
140
|
+
yx_phase,
|
|
141
|
+
phase_2d_to_3d_transfer_function,
|
|
142
|
+
absorption_2d_to_3d_transfer_function,
|
|
143
|
+
):
|
|
144
|
+
# Very simple simulation, consider adding noise and bkg knobs
|
|
145
|
+
|
|
146
|
+
# simulate absorbing object
|
|
147
|
+
yx_absorption_hat = torch.fft.fftn(yx_absorption)
|
|
148
|
+
zyx_absorption_data_hat = yx_absorption_hat[None, ...] * torch.real(
|
|
149
|
+
absorption_2d_to_3d_transfer_function
|
|
150
|
+
)
|
|
151
|
+
zyx_absorption_data = torch.real(
|
|
152
|
+
torch.fft.ifftn(zyx_absorption_data_hat, dim=(1, 2))
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# simulate phase object
|
|
156
|
+
yx_phase_hat = torch.fft.fftn(yx_phase)
|
|
157
|
+
zyx_phase_data_hat = yx_phase_hat[None, ...] * torch.real(
|
|
158
|
+
phase_2d_to_3d_transfer_function
|
|
159
|
+
)
|
|
160
|
+
zyx_phase_data = torch.real(
|
|
161
|
+
torch.fft.ifftn(zyx_phase_data_hat, dim=(1, 2))
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# sum and add background
|
|
165
|
+
data = zyx_absorption_data + zyx_phase_data
|
|
166
|
+
data += 10 # Add a direct background
|
|
167
|
+
return data
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def apply_inverse_transfer_function(
|
|
171
|
+
zyx_data: Tensor,
|
|
172
|
+
absorption_2d_to_3d_transfer_function: Tensor,
|
|
173
|
+
phase_2d_to_3d_transfer_function: Tensor,
|
|
174
|
+
reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
|
|
175
|
+
regularization_strength: float = 1e-6,
|
|
176
|
+
reg_p: float = 1e-6, # TODO: use this parameter
|
|
177
|
+
TV_rho_strength: float = 1e-3,
|
|
178
|
+
TV_iterations: int = 10,
|
|
179
|
+
bg_filter: bool = True,
|
|
180
|
+
) -> Tuple[Tensor]:
|
|
181
|
+
"""Reconstructs absorption and phase from zyx_data and a pair of
|
|
182
|
+
3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
|
|
183
|
+
phase_2d_to_3d_transfer_function, providing options for reconstruction
|
|
184
|
+
algorithms.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
zyx_data : Tensor
|
|
189
|
+
3D raw data, label-free defocus stack
|
|
190
|
+
absorption_2d_to_3d_transfer_function : Tensor
|
|
191
|
+
3D-to-2D absorption transfer function, see calculate_transfer_function above
|
|
192
|
+
phase_2d_to_3d_transfer_function : Tensor
|
|
193
|
+
3D-to-2D phase transfer function, see calculate_transfer_function above
|
|
194
|
+
reconstruction_algorithm : Literal["Tikhonov", "TV"], optional
|
|
195
|
+
"Tikhonov" or "TV", by default "Tikhonov"
|
|
196
|
+
"TV" is not implemented.
|
|
197
|
+
regularization_strength : float, optional
|
|
198
|
+
regularization parameter, by default 1e-6
|
|
199
|
+
reg_p : float, optional
|
|
200
|
+
TV-specific phase regularization parameter, by default 1e-6
|
|
201
|
+
"TV" is not implemented.
|
|
202
|
+
TV_iterations : int, optional
|
|
203
|
+
TV-specific number of iterations, by default 10
|
|
204
|
+
"TV" is not implemented.
|
|
205
|
+
bg_filter : bool, optional
|
|
206
|
+
option for slow-varying 2D background normalization with uniform filter
|
|
207
|
+
by default True
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
Tuple[Tensor]
|
|
212
|
+
yx_absorption (unitless)
|
|
213
|
+
yx_phase (radians)
|
|
214
|
+
|
|
215
|
+
Raises
|
|
216
|
+
------
|
|
217
|
+
NotImplementedError
|
|
218
|
+
TV is not implemented
|
|
219
|
+
"""
|
|
220
|
+
zyx_data_normalized = util.inten_normalization(
|
|
221
|
+
zyx_data, bg_filter=bg_filter
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2))
|
|
225
|
+
|
|
226
|
+
# TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations
|
|
227
|
+
AHA = [
|
|
228
|
+
torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0)
|
|
229
|
+
+ regularization_strength,
|
|
230
|
+
torch.sum(
|
|
231
|
+
torch.conj(absorption_2d_to_3d_transfer_function)
|
|
232
|
+
* phase_2d_to_3d_transfer_function,
|
|
233
|
+
dim=0,
|
|
234
|
+
),
|
|
235
|
+
torch.sum(
|
|
236
|
+
torch.conj(
|
|
237
|
+
phase_2d_to_3d_transfer_function,
|
|
238
|
+
)
|
|
239
|
+
* absorption_2d_to_3d_transfer_function,
|
|
240
|
+
dim=0,
|
|
241
|
+
),
|
|
242
|
+
torch.sum(
|
|
243
|
+
torch.abs(
|
|
244
|
+
phase_2d_to_3d_transfer_function,
|
|
245
|
+
)
|
|
246
|
+
** 2,
|
|
247
|
+
dim=0,
|
|
248
|
+
)
|
|
249
|
+
+ reg_p,
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
b_vec = [
|
|
253
|
+
torch.sum(
|
|
254
|
+
torch.conj(absorption_2d_to_3d_transfer_function) * zyx_data_hat,
|
|
255
|
+
dim=0,
|
|
256
|
+
),
|
|
257
|
+
torch.sum(
|
|
258
|
+
torch.conj(
|
|
259
|
+
phase_2d_to_3d_transfer_function,
|
|
260
|
+
)
|
|
261
|
+
* zyx_data_hat,
|
|
262
|
+
dim=0,
|
|
263
|
+
),
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
# Deconvolution with Tikhonov regularization
|
|
267
|
+
if reconstruction_algorithm == "Tikhonov":
|
|
268
|
+
absorption, phase = util.dual_variable_tikhonov_deconvolution_2d(
|
|
269
|
+
AHA, b_vec
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# ADMM deconvolution with anisotropic TV regularization
|
|
273
|
+
elif reconstruction_algorithm == "TV":
|
|
274
|
+
raise NotImplementedError
|
|
275
|
+
absorption, phase = util.dual_variable_admm_tv_deconv_2d(
|
|
276
|
+
AHA, b_vec, rho=TV_rho_strength, itr=TV_iterations
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
phase -= torch.mean(phase)
|
|
280
|
+
|
|
281
|
+
return absorption, phase
|