waveorder 0.2.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of waveorder might be problematic. Click here for more details.

@@ -0,0 +1,159 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from waveorder import correction, stokes, util
8
+
9
+
10
+ def generate_test_phantom(yx_shape):
11
+ star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
12
+ retardance = 0.25 * star
13
+ orientation = (theta % np.pi) * (star > 1e-3)
14
+ transmittance = 0.9 * torch.ones_like(retardance)
15
+ depolarization = 0.9 * torch.ones_like(retardance)
16
+ return retardance, orientation, transmittance, depolarization
17
+
18
+
19
+ def calculate_transfer_function(
20
+ swing,
21
+ scheme,
22
+ ):
23
+ return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
24
+
25
+
26
+ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
27
+ viewer.add_image(
28
+ intensity_to_stokes_matrix.cpu().numpy(),
29
+ name="Intensity to stokes matrix",
30
+ )
31
+
32
+
33
+ def apply_transfer_function(
34
+ retardance,
35
+ orientation,
36
+ transmittance,
37
+ depolarization,
38
+ intensity_to_stokes_matrix,
39
+ ):
40
+ stokes_params = stokes.stokes_after_adr(
41
+ retardance, orientation, transmittance, depolarization
42
+ )
43
+ stokes_to_intensity_matrix = torch.linalg.pinv(intensity_to_stokes_matrix)
44
+
45
+ cyx_intensities = stokes.mmul(
46
+ stokes_to_intensity_matrix, torch.stack(stokes_params)
47
+ )
48
+
49
+ # Return in czyx shape
50
+ # TODO: make this simulation more realistic with defocussed data
51
+ return cyx_intensities[:, None, ...] + 0.1
52
+
53
+
54
+ def apply_inverse_transfer_function(
55
+ czyx_data: Tensor,
56
+ intensity_to_stokes_matrix: Tensor,
57
+ cyx_no_sample_data: Optional[Tensor] = None,
58
+ remove_estimated_background: bool = False,
59
+ project_stokes_to_2d: bool = False,
60
+ flip_orientation: bool = False,
61
+ rotate_orientation: bool = False,
62
+ ) -> Tuple[Tensor]:
63
+ """Reconstructs retardance, orientation, transmittance, and depolarization
64
+ from czyx_data and an intensity_to_stokes_matrix, providing options for
65
+ background correction, projection, and orientation transformations.
66
+
67
+ Parameters
68
+ ----------
69
+ czyx_data : Tensor
70
+ 4D raw data, first dimension is the polarization dimension, remaining
71
+ dimensions are spatial
72
+ intensity_to_stokes_matrix : Tensor
73
+ Forward model, see calculate_transfer_function above
74
+ cyx_no_sample_data : Tensor, optional
75
+ 3D raw background data, by default None
76
+ First dimension is the polarization dimension, remaining dimensions are spatial.
77
+ cyx shape must match in this parameter and czxy_data
78
+ If provided, this background will be removed.
79
+ If None, no background will be removed.
80
+ remove_estimated_background : bool, optional
81
+ Estimate a background from the data and remove it, by default False
82
+ project_stokes_to_2d : bool, optional
83
+ Project stokes to 2D for SNR improvement in thin samples, by default False
84
+ flip_orientation : bool, optional
85
+ Flip the reconstructed orientation about the x axis, by default False
86
+ rotate_orientation : bool, optional
87
+ Add 90 degrees to the reconstructed orientation, by default False
88
+
89
+ Notes
90
+ -----
91
+ cyx_no_sample_data and remove_estimated_background provide background correction options
92
+
93
+ flip_orientation and rotate_orientation modify the reconstructed orientation.
94
+ We recommend using these parameters when a test target with a known orientation
95
+ is available.
96
+
97
+ Returns
98
+ -------
99
+ Tuple[Tensor]
100
+ zyx_retardance (radians)
101
+ zyx_orientation (radians)
102
+ zyx_transmittance (unitless)
103
+ zyx_depolarization (unitless)
104
+ """
105
+ data_stokes = stokes.mmul(intensity_to_stokes_matrix, czyx_data)
106
+
107
+ # Apply a "Measured" background correction
108
+ if cyx_no_sample_data is None:
109
+ background_corrected_stokes = data_stokes
110
+ else:
111
+ # Find the no-sample Stokes parameters from the background data
112
+ measured_no_sample_stokes = stokes.mmul(
113
+ intensity_to_stokes_matrix, cyx_no_sample_data
114
+ )
115
+ # Estimate the attenuating, depolarizing, retarder's inverse Mueller
116
+ # matrix that caused this background data
117
+ inverse_background_mueller = stokes.mueller_from_stokes(
118
+ *measured_no_sample_stokes, model="adr", direction="inverse"
119
+ )
120
+ # Apply this background-correction Mueller matrix to the data to remove
121
+ # the background contribution
122
+ background_corrected_stokes = stokes.mmul(
123
+ inverse_background_mueller, data_stokes
124
+ )
125
+
126
+ # Apply an "Estimated" background correction
127
+ if remove_estimated_background:
128
+ for stokes_index in range(background_corrected_stokes.shape[0]):
129
+ # Project to 2D
130
+ z_projection = torch.mean(
131
+ background_corrected_stokes[stokes_index], dim=0
132
+ )
133
+ # Estimate the background and subtract
134
+ background_corrected_stokes[
135
+ stokes_index
136
+ ] -= correction.estimate_background(
137
+ z_projection, order=2, block_size=32
138
+ )
139
+
140
+ # Project to 2D (typically for SNR reasons)
141
+ if project_stokes_to_2d:
142
+ background_corrected_stokes = torch.mean(
143
+ background_corrected_stokes, dim=1
144
+ )[:, None, ...]
145
+
146
+ # Estimate an attenuating, depolarizing, retarder's parameters,
147
+ # i.e. (retardance, orientation, transmittance, depolarization)
148
+ # from the background-corrected Stokes values
149
+ adr_parameters = stokes.estimate_adr_from_stokes(
150
+ *background_corrected_stokes
151
+ )
152
+
153
+ # Apply orientation transformations
154
+ orientation = stokes.apply_orientation_offset(
155
+ adr_parameters[1], rotate=rotate_orientation, flip=flip_orientation
156
+ )
157
+
158
+ # Return (retardance, orientation, transmittance, depolarization)
159
+ return adr_parameters[0], orientation, adr_parameters[2], adr_parameters[3]
@@ -0,0 +1,192 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from waveorder import optics, util
7
+
8
+
9
+ def generate_test_phantom(
10
+ zyx_shape,
11
+ yx_pixel_size,
12
+ z_pixel_size,
13
+ sphere_radius,
14
+ ):
15
+ sphere, _, _ = util.generate_sphere_target(
16
+ zyx_shape, yx_pixel_size, z_pixel_size, sphere_radius
17
+ )
18
+
19
+ return sphere
20
+
21
+
22
+ def calculate_transfer_function(
23
+ zyx_shape,
24
+ yx_pixel_size,
25
+ z_pixel_size,
26
+ wavelength_emission,
27
+ z_padding,
28
+ index_of_refraction_media,
29
+ numerical_aperture_detection,
30
+ ):
31
+ radial_frequencies = util.generate_radial_frequencies(
32
+ zyx_shape[1:], yx_pixel_size
33
+ )
34
+
35
+ z_total = zyx_shape[0] + 2 * z_padding
36
+ z_position_list = torch.fft.ifftshift(
37
+ (torch.arange(z_total) - z_total // 2) * z_pixel_size
38
+ )
39
+
40
+ det_pupil = optics.generate_pupil(
41
+ radial_frequencies,
42
+ numerical_aperture_detection,
43
+ wavelength_emission,
44
+ )
45
+
46
+ propagation_kernel = optics.generate_propagation_kernel(
47
+ radial_frequencies,
48
+ det_pupil,
49
+ wavelength_emission / index_of_refraction_media,
50
+ z_position_list,
51
+ )
52
+
53
+ point_spread_function = (
54
+ torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2
55
+ )
56
+ optical_transfer_function = torch.fft.fftn(
57
+ point_spread_function, dim=(0, 1, 2)
58
+ )
59
+ optical_transfer_function /= torch.max(
60
+ torch.abs(optical_transfer_function)
61
+ ) # normalize
62
+
63
+ return optical_transfer_function
64
+
65
+
66
+ def visualize_transfer_function(viewer, optical_transfer_function, zyx_scale):
67
+ arrays = [
68
+ (torch.imag(optical_transfer_function), "Im(OTF)"),
69
+ (torch.real(optical_transfer_function), "Re(OTF)"),
70
+ ]
71
+
72
+ for array in arrays:
73
+ lim = 0.1 * torch.max(torch.abs(array[0]))
74
+ viewer.add_image(
75
+ torch.fft.ifftshift(array[0]).cpu().numpy(),
76
+ name=array[1],
77
+ colormap="bwr",
78
+ contrast_limits=(-lim, lim),
79
+ scale=1 / zyx_scale,
80
+ )
81
+ viewer.dims.order = (0, 1, 2)
82
+
83
+
84
+ def apply_transfer_function(
85
+ zyx_object, optical_transfer_function, z_padding, background=10
86
+ ):
87
+ """Simulate imaging by applying a transfer function
88
+
89
+ Parameters
90
+ ----------
91
+ zyx_object : torch.Tensor
92
+ optical_transfer_function : torch.Tensor
93
+ z_padding : int
94
+ background : int, optional
95
+ constant background counts added to each voxel, by default 10
96
+
97
+ Returns
98
+ -------
99
+ Simulated data : torch.Tensor
100
+
101
+ """
102
+ if (
103
+ zyx_object.shape[0] + 2 * z_padding
104
+ != optical_transfer_function.shape[0]
105
+ ):
106
+ raise ValueError(
107
+ "Please check padding: ZYX_obj.shape[0] + 2 * Z_pad != H_re.shape[0]"
108
+ )
109
+ if z_padding > 0:
110
+ optical_transfer_function = optical_transfer_function[
111
+ z_padding:-z_padding
112
+ ]
113
+
114
+ # Very simple simulation, consider adding noise and bkg knobs
115
+ zyx_obj_hat = torch.fft.fftn(zyx_object)
116
+ zyx_data = zyx_obj_hat * optical_transfer_function
117
+ data = torch.real(torch.fft.ifftn(zyx_data))
118
+
119
+ data += background # Add a direct background
120
+ return data
121
+
122
+
123
+ def apply_inverse_transfer_function(
124
+ zyx_data: Tensor,
125
+ optical_transfer_function: Tensor,
126
+ z_padding: int,
127
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
128
+ regularization_strength: float = 1e-3,
129
+ TV_rho_strength: float = 1e-3,
130
+ TV_iterations: int = 10,
131
+ ):
132
+ """Reconstructs fluorescence density from zyx_data and
133
+ an optical_transfer_function, providing options for z padding and
134
+ reconstruction algorithms.
135
+
136
+ Parameters
137
+ ----------
138
+ zyx_data : Tensor
139
+ 3D raw data, fluorescence defocus stack
140
+ optical_transfer_function : Tensor
141
+ 3D optical transfer function, see calculate_transfer_function above
142
+ z_padding : int
143
+ Padding for axial dimension. Use zero for defocus stacks that
144
+ extend ~3 PSF widths beyond the sample. Pad by ~3 PSF widths otherwise.
145
+ reconstruction_algorithm : str, optional
146
+ "Tikhonov" or "TV", by default "Tikhonov"
147
+ "TV" is not implemented.
148
+ regularization_strength : float, optional
149
+ regularization parameter, by default 1e-3
150
+ TV_rho_strength : _type_, optional
151
+ TV-specific regularization parameter, by default 1e-3
152
+ "TV" is not implemented.
153
+ TV_iterations : int, optional
154
+ TV-specific number of iterations, by default 10
155
+ "TV" is not implemented.
156
+
157
+ Returns
158
+ -------
159
+ Tensor
160
+ zyx_fluorescence_density (fluorophores per volumes)
161
+
162
+ Raises
163
+ ------
164
+ NotImplementedError
165
+ TV is not implemented
166
+ """
167
+ # Handle padding
168
+ zyx_padded = util.pad_zyx_along_z(zyx_data, z_padding)
169
+
170
+ # Reconstruct
171
+ if reconstruction_algorithm == "Tikhonov":
172
+ f_real = util.single_variable_tikhonov_deconvolution_3D(
173
+ zyx_padded,
174
+ optical_transfer_function,
175
+ reg_re=regularization_strength,
176
+ )
177
+
178
+ elif reconstruction_algorithm == "TV":
179
+ raise NotImplementedError
180
+ f_real = util.single_variable_admm_tv_deconvolution_3D(
181
+ zyx_padded,
182
+ optical_transfer_function,
183
+ reg_re=regularization_strength,
184
+ rho=TV_rho_strength,
185
+ itr=TV_iterations,
186
+ )
187
+
188
+ # Unpad
189
+ if z_padding != 0:
190
+ f_real = f_real[z_padding:-z_padding]
191
+
192
+ return f_real
@@ -0,0 +1,281 @@
1
+ from typing import Literal, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from waveorder import optics, util
7
+
8
+
9
+ def generate_test_phantom(
10
+ yx_shape,
11
+ yx_pixel_size,
12
+ wavelength_illumination,
13
+ index_of_refraction_media,
14
+ index_of_refraction_sample,
15
+ sphere_radius,
16
+ ):
17
+ sphere, _, _ = util.generate_sphere_target(
18
+ (3,) + yx_shape,
19
+ yx_pixel_size,
20
+ z_pixel_size=1.0,
21
+ radius=sphere_radius,
22
+ blur_size=2 * yx_pixel_size,
23
+ )
24
+ yx_phase = (
25
+ sphere[1]
26
+ * (index_of_refraction_sample - index_of_refraction_media)
27
+ * 0.1
28
+ / wavelength_illumination
29
+ ) # phase in radians
30
+
31
+ yx_absorption = 0.02 * sphere[1]
32
+
33
+ return yx_absorption, yx_phase
34
+
35
+
36
+ def calculate_transfer_function(
37
+ yx_shape,
38
+ yx_pixel_size,
39
+ z_position_list,
40
+ wavelength_illumination,
41
+ index_of_refraction_media,
42
+ numerical_aperture_illumination,
43
+ numerical_aperture_detection,
44
+ invert_phase_contrast=False,
45
+ ):
46
+ if invert_phase_contrast:
47
+ z_position_list = torch.flip(torch.tensor(z_position_list), dims=(0,))
48
+
49
+ radial_frequencies = util.generate_radial_frequencies(
50
+ yx_shape, yx_pixel_size
51
+ )
52
+
53
+ illumination_pupil = optics.generate_pupil(
54
+ radial_frequencies,
55
+ numerical_aperture_illumination,
56
+ wavelength_illumination,
57
+ )
58
+ detection_pupil = optics.generate_pupil(
59
+ radial_frequencies,
60
+ numerical_aperture_detection,
61
+ wavelength_illumination,
62
+ )
63
+ propagation_kernel = optics.generate_propagation_kernel(
64
+ radial_frequencies,
65
+ detection_pupil,
66
+ wavelength_illumination / index_of_refraction_media,
67
+ torch.tensor(z_position_list),
68
+ )
69
+
70
+ zyx_shape = (len(z_position_list),) + tuple(yx_shape)
71
+ absorption_2d_to_3d_transfer_function = torch.zeros(
72
+ zyx_shape, dtype=torch.complex64
73
+ )
74
+ phase_2d_to_3d_transfer_function = torch.zeros(
75
+ zyx_shape, dtype=torch.complex64
76
+ )
77
+ for z in range(len(z_position_list)):
78
+ (
79
+ absorption_2d_to_3d_transfer_function[z],
80
+ phase_2d_to_3d_transfer_function[z],
81
+ ) = optics.compute_weak_object_transfer_function_2d(
82
+ illumination_pupil, detection_pupil * propagation_kernel[z]
83
+ )
84
+
85
+ return (
86
+ absorption_2d_to_3d_transfer_function,
87
+ phase_2d_to_3d_transfer_function,
88
+ )
89
+
90
+
91
+ def visualize_transfer_function(
92
+ viewer,
93
+ absorption_2d_to_3d_transfer_function,
94
+ phase_2d_to_3d_transfer_function,
95
+ ):
96
+ # TODO: consider generalizing w/ phase_thick_3d.visualize_transfer_function
97
+ arrays = [
98
+ (torch.imag(absorption_2d_to_3d_transfer_function), "Im(absorb TF)"),
99
+ (torch.real(absorption_2d_to_3d_transfer_function), "Re(absorb TF)"),
100
+ (torch.imag(phase_2d_to_3d_transfer_function), "Im(phase TF)"),
101
+ (torch.real(phase_2d_to_3d_transfer_function), "Re(phase TF)"),
102
+ ]
103
+
104
+ for array in arrays:
105
+ lim = 0.5 * torch.max(torch.abs(array[0]))
106
+ viewer.add_image(
107
+ torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
108
+ name=array[1],
109
+ colormap="bwr",
110
+ contrast_limits=(-lim, lim),
111
+ scale=(1, 1, 1),
112
+ )
113
+ viewer.dims.order = (0, 1, 2)
114
+
115
+
116
+ def visualize_point_spread_function(
117
+ viewer,
118
+ absorption_2d_to_3d_transfer_function,
119
+ phase_2d_to_3d_transfer_function,
120
+ ):
121
+ arrays = [
122
+ (torch.fft.ifftn(absorption_2d_to_3d_transfer_function), "absorb PSF"),
123
+ (torch.fft.ifftn(phase_2d_to_3d_transfer_function), "phase PSF"),
124
+ ]
125
+
126
+ for array in arrays:
127
+ lim = 0.5 * torch.max(torch.abs(array[0]))
128
+ viewer.add_image(
129
+ torch.fft.ifftshift(array[0], dim=(1, 2)).cpu().numpy(),
130
+ name=array[1],
131
+ colormap="bwr",
132
+ contrast_limits=(-lim, lim),
133
+ scale=(1, 1, 1),
134
+ )
135
+ viewer.dims.order = (0, 1, 2)
136
+
137
+
138
+ def apply_transfer_function(
139
+ yx_absorption,
140
+ yx_phase,
141
+ phase_2d_to_3d_transfer_function,
142
+ absorption_2d_to_3d_transfer_function,
143
+ ):
144
+ # Very simple simulation, consider adding noise and bkg knobs
145
+
146
+ # simulate absorbing object
147
+ yx_absorption_hat = torch.fft.fftn(yx_absorption)
148
+ zyx_absorption_data_hat = yx_absorption_hat[None, ...] * torch.real(
149
+ absorption_2d_to_3d_transfer_function
150
+ )
151
+ zyx_absorption_data = torch.real(
152
+ torch.fft.ifftn(zyx_absorption_data_hat, dim=(1, 2))
153
+ )
154
+
155
+ # simulate phase object
156
+ yx_phase_hat = torch.fft.fftn(yx_phase)
157
+ zyx_phase_data_hat = yx_phase_hat[None, ...] * torch.real(
158
+ phase_2d_to_3d_transfer_function
159
+ )
160
+ zyx_phase_data = torch.real(
161
+ torch.fft.ifftn(zyx_phase_data_hat, dim=(1, 2))
162
+ )
163
+
164
+ # sum and add background
165
+ data = zyx_absorption_data + zyx_phase_data
166
+ data += 10 # Add a direct background
167
+ return data
168
+
169
+
170
+ def apply_inverse_transfer_function(
171
+ zyx_data: Tensor,
172
+ absorption_2d_to_3d_transfer_function: Tensor,
173
+ phase_2d_to_3d_transfer_function: Tensor,
174
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
175
+ regularization_strength: float = 1e-6,
176
+ reg_p: float = 1e-6, # TODO: use this parameter
177
+ TV_rho_strength: float = 1e-3,
178
+ TV_iterations: int = 10,
179
+ bg_filter: bool = True,
180
+ ) -> Tuple[Tensor]:
181
+ """Reconstructs absorption and phase from zyx_data and a pair of
182
+ 3D-to-2D transfer functions named absorption_2d_to_3d_transfer_function and
183
+ phase_2d_to_3d_transfer_function, providing options for reconstruction
184
+ algorithms.
185
+
186
+ Parameters
187
+ ----------
188
+ zyx_data : Tensor
189
+ 3D raw data, label-free defocus stack
190
+ absorption_2d_to_3d_transfer_function : Tensor
191
+ 3D-to-2D absorption transfer function, see calculate_transfer_function above
192
+ phase_2d_to_3d_transfer_function : Tensor
193
+ 3D-to-2D phase transfer function, see calculate_transfer_function above
194
+ reconstruction_algorithm : Literal["Tikhonov", "TV"], optional
195
+ "Tikhonov" or "TV", by default "Tikhonov"
196
+ "TV" is not implemented.
197
+ regularization_strength : float, optional
198
+ regularization parameter, by default 1e-6
199
+ reg_p : float, optional
200
+ TV-specific phase regularization parameter, by default 1e-6
201
+ "TV" is not implemented.
202
+ TV_iterations : int, optional
203
+ TV-specific number of iterations, by default 10
204
+ "TV" is not implemented.
205
+ bg_filter : bool, optional
206
+ option for slow-varying 2D background normalization with uniform filter
207
+ by default True
208
+
209
+ Returns
210
+ -------
211
+ Tuple[Tensor]
212
+ yx_absorption (unitless)
213
+ yx_phase (radians)
214
+
215
+ Raises
216
+ ------
217
+ NotImplementedError
218
+ TV is not implemented
219
+ """
220
+ zyx_data_normalized = util.inten_normalization(
221
+ zyx_data, bg_filter=bg_filter
222
+ )
223
+
224
+ zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2))
225
+
226
+ # TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations
227
+ AHA = [
228
+ torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0)
229
+ + regularization_strength,
230
+ torch.sum(
231
+ torch.conj(absorption_2d_to_3d_transfer_function)
232
+ * phase_2d_to_3d_transfer_function,
233
+ dim=0,
234
+ ),
235
+ torch.sum(
236
+ torch.conj(
237
+ phase_2d_to_3d_transfer_function,
238
+ )
239
+ * absorption_2d_to_3d_transfer_function,
240
+ dim=0,
241
+ ),
242
+ torch.sum(
243
+ torch.abs(
244
+ phase_2d_to_3d_transfer_function,
245
+ )
246
+ ** 2,
247
+ dim=0,
248
+ )
249
+ + reg_p,
250
+ ]
251
+
252
+ b_vec = [
253
+ torch.sum(
254
+ torch.conj(absorption_2d_to_3d_transfer_function) * zyx_data_hat,
255
+ dim=0,
256
+ ),
257
+ torch.sum(
258
+ torch.conj(
259
+ phase_2d_to_3d_transfer_function,
260
+ )
261
+ * zyx_data_hat,
262
+ dim=0,
263
+ ),
264
+ ]
265
+
266
+ # Deconvolution with Tikhonov regularization
267
+ if reconstruction_algorithm == "Tikhonov":
268
+ absorption, phase = util.dual_variable_tikhonov_deconvolution_2d(
269
+ AHA, b_vec
270
+ )
271
+
272
+ # ADMM deconvolution with anisotropic TV regularization
273
+ elif reconstruction_algorithm == "TV":
274
+ raise NotImplementedError
275
+ absorption, phase = util.dual_variable_admm_tv_deconv_2d(
276
+ AHA, b_vec, rho=TV_rho_strength, itr=TV_iterations
277
+ )
278
+
279
+ phase -= torch.mean(phase)
280
+
281
+ return absorption, phase