waveorder 2.0.0rc3__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waveorder/_version.py CHANGED
@@ -1,4 +1,16 @@
1
1
  # file generated by setuptools_scm
2
2
  # don't change, don't track in version control
3
- __version__ = version = '2.0.0rc3'
4
- __version_tuple__ = version_tuple = (2, 0, 0)
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '2.2.0'
16
+ __version_tuple__ = version_tuple = (2, 2, 0)
@@ -0,0 +1,107 @@
1
+ """Background correction methods"""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, Size
6
+
7
+
8
+ def _sample_block_medians(image: Tensor, block_size) -> Tensor:
9
+ """
10
+ Sample densely tiled square blocks from a 2D image and return their medians.
11
+ Incomplete blocks (overhangs) will be ignored.
12
+
13
+ Parameters
14
+ ----------
15
+ image : Tensor
16
+ 2D image
17
+ block_size : int, optional
18
+ Width and height of the blocks
19
+
20
+ Returns
21
+ -------
22
+ Tensor
23
+ Median intensity values for each block, flattened
24
+ """
25
+ if not image.dtype.is_floating_point:
26
+ image.to(torch.float)
27
+ blocks = F.unfold(image[None, None], block_size, stride=block_size)[0]
28
+ return blocks.median(0)[0]
29
+
30
+
31
+ def _grid_coordinates(image: Tensor, block_size: int) -> Tensor:
32
+ """Build image coordinates from the center points of square blocks"""
33
+ coords = torch.meshgrid(
34
+ [
35
+ torch.arange(
36
+ 0 + block_size / 2,
37
+ boundary - block_size / 2 + 1,
38
+ block_size,
39
+ device=image.device,
40
+ )
41
+ for boundary in image.shape
42
+ ]
43
+ )
44
+ return torch.stack(coords, dim=-1).reshape(-1, 2)
45
+
46
+
47
+ def _fit_2d_polynomial_surface(
48
+ coords: Tensor, values: Tensor, order: int, surface_shape: Size
49
+ ) -> Tensor:
50
+ """Fit a 2D polynomial to a set of coordinates and their values,
51
+ and return the surface evaluated at every point."""
52
+ n_coeffs = int((order + 1) * (order + 2) / 2)
53
+ if n_coeffs >= len(values):
54
+ raise ValueError(
55
+ f"Cannot fit a {order} degree 2D polynomial "
56
+ f"with {len(values)} sampled values"
57
+ )
58
+ orders = torch.arange(order + 1, device=coords.device)
59
+ order_pairs = torch.stack(torch.meshgrid(orders, orders), -1)
60
+ order_pairs = order_pairs[order_pairs.sum(-1) <= order].reshape(-1, 2)
61
+ terms = torch.stack(
62
+ [coords[:, 0] ** i * coords[:, 1] ** j for i, j in order_pairs], -1
63
+ )
64
+ # use "gels" driver for precision and GPU consistency
65
+ coeffs = torch.linalg.lstsq(terms, values, driver="gels").solution
66
+ dense_coords = torch.meshgrid(
67
+ [
68
+ torch.arange(s, dtype=values.dtype, device=values.device)
69
+ for s in surface_shape
70
+ ]
71
+ )
72
+ dense_terms = torch.stack(
73
+ [dense_coords[0] ** i * dense_coords[1] ** j for i, j in order_pairs],
74
+ -1,
75
+ )
76
+ return torch.matmul(dense_terms, coeffs)
77
+
78
+
79
+ def estimate_background(image: Tensor, order: int = 2, block_size: int = 32):
80
+ """
81
+ Combine sampling and polynomial surface fit for background estimation.
82
+ To background correct an image, divide it by the background.
83
+
84
+ Parameters
85
+ ----------
86
+ image : Tensor
87
+ 2D image
88
+ order : int, optional
89
+ Order of polynomial, by default 2
90
+ block_size : int, optional
91
+ Width and height of the blocks, by default 32
92
+
93
+ Returns
94
+ -------
95
+ Tensor
96
+ Background image
97
+ """
98
+ if image.ndim != 2:
99
+ raise ValueError(f"Image must be 2D, got shape {image.shape}")
100
+ height, width = image.shape
101
+ if block_size > width:
102
+ raise ValueError("Block size larger than image height")
103
+ if block_size > height:
104
+ raise ValueError("Block size larger than image width")
105
+ medians = _sample_block_medians(image, block_size)
106
+ coords = _grid_coordinates(image, block_size)
107
+ return _fit_2d_polynomial_surface(coords, medians, order, image.shape)
waveorder/focus.py CHANGED
@@ -3,6 +3,7 @@ from typing import Literal, Optional
3
3
  from waveorder import util
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
+ import warnings
6
7
 
7
8
 
8
9
  def focus_from_transverse_band(
@@ -60,10 +61,19 @@ def focus_from_transverse_band(
60
61
  >>> slice = focus_from_transverse_band(zyx_array, NA_det=0.55, lambda_ill=0.532, pixel_size=6.5/20)
61
62
  >>> in_focus_data = data[slice,:,:]
62
63
  """
63
- minmaxfunc = _check_focus_inputs(
64
- zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
64
+ minmaxfunc = _mode_to_minmaxfunc(mode)
65
+
66
+ _check_focus_inputs(
67
+ zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
65
68
  )
66
69
 
70
+ # Check for single slice
71
+ if zyx_array.shape[0] == 1:
72
+ warnings.warn(
73
+ "The dataset only contained a single slice. Returning trivial slice index = 0."
74
+ )
75
+ return 0
76
+
67
77
  # Calculate coordinates
68
78
  _, Y, X = zyx_array.shape
69
79
  _, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size)
@@ -94,25 +104,35 @@ def focus_from_transverse_band(
94
104
  # Plot
95
105
  if plot_path is not None:
96
106
  _plot_focus_metric(
97
- plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
107
+ plot_path,
108
+ midband_sum,
109
+ peak_index,
110
+ in_focus_index,
111
+ peak_results,
112
+ threshold_FWHM,
98
113
  )
99
114
 
100
115
  return in_focus_index
101
116
 
102
117
 
118
+ def _mode_to_minmaxfunc(mode):
119
+ if mode == "min":
120
+ minmaxfunc = np.argmin
121
+ elif mode == "max":
122
+ minmaxfunc = np.argmax
123
+ else:
124
+ raise ValueError("mode must be either `min` or `max`")
125
+ return minmaxfunc
126
+
127
+
103
128
  def _check_focus_inputs(
104
- zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions, mode
129
+ zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
105
130
  ):
106
131
  N = len(zyx_array.shape)
107
132
  if N != 3:
108
133
  raise ValueError(
109
134
  f"{N}D array supplied. `focus_from_transverse_band` only accepts 3D arrays."
110
135
  )
111
- if zyx_array.shape[0] == 1:
112
- print(
113
- "WARNING: The dataset only contained a single slice. Returning trivial slice index = 0."
114
- )
115
- return 0
116
136
 
117
137
  if NA_det < 0:
118
138
  raise ValueError("NA must be > 0")
@@ -121,7 +141,7 @@ def _check_focus_inputs(
121
141
  if pixel_size < 0:
122
142
  raise ValueError("pixel_size must be > 0")
123
143
  if not 0.4 < lambda_ill / pixel_size < 10:
124
- print(
144
+ warnings.warn(
125
145
  f"WARNING: lambda_ill/pixel_size = {lambda_ill/pixel_size}."
126
146
  f"Did you use the same units?"
127
147
  f"Did you enter the pixel size in (demagnified) object-space units?"
@@ -134,17 +154,15 @@ def _check_focus_inputs(
134
154
  raise ValueError("midband_fractions[0] must be between 0 and 1")
135
155
  if not (0 <= midband_fractions[1] <= 1):
136
156
  raise ValueError("midband_fractions[1] must be between 0 and 1")
137
- if mode == "min":
138
- minmaxfunc = np.argmin
139
- elif mode == "max":
140
- minmaxfunc = np.argmax
141
- else:
142
- raise ValueError("mode must be either `min` or `max`")
143
- return minmaxfunc
144
157
 
145
158
 
146
159
  def _plot_focus_metric(
147
- plot_path, midband_sum, peak_index, in_focus_index, peak_results, threshold_FWHM
160
+ plot_path,
161
+ midband_sum,
162
+ peak_index,
163
+ in_focus_index,
164
+ peak_results,
165
+ threshold_FWHM,
148
166
  ):
149
167
  _, ax = plt.subplots(1, 1, figsize=(4, 4))
150
168
  ax.plot(midband_sum, "-k")
@@ -4,10 +4,10 @@ import numpy as np
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
- from waveorder import background_estimator, stokes, util
7
+ from waveorder import correction, stokes, util
8
8
 
9
9
 
10
- def generate_test_phantom(yx_shape):
10
+ def generate_test_phantom(yx_shape: Tuple[int, int]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
11
11
  star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
12
12
  retardance = 0.25 * star
13
13
  orientation = (theta % np.pi) * (star > 1e-3)
@@ -17,13 +17,13 @@ def generate_test_phantom(yx_shape):
17
17
 
18
18
 
19
19
  def calculate_transfer_function(
20
- swing,
21
- scheme,
22
- ):
20
+ swing: float,
21
+ scheme: str,
22
+ ) -> Tensor:
23
23
  return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
24
24
 
25
25
 
26
- def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
26
+ def visualize_transfer_function(viewer, intensity_to_stokes_matrix: Tensor) -> None:
27
27
  viewer.add_image(
28
28
  intensity_to_stokes_matrix.cpu().numpy(),
29
29
  name="Intensity to stokes matrix",
@@ -31,12 +31,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
31
31
 
32
32
 
33
33
  def apply_transfer_function(
34
- retardance,
35
- orientation,
36
- transmittance,
37
- depolarization,
38
- intensity_to_stokes_matrix,
39
- ):
34
+ retardance: Tensor,
35
+ orientation: Tensor,
36
+ transmittance: Tensor,
37
+ depolarization: Tensor,
38
+ intensity_to_stokes_matrix: Tensor,
39
+ ) -> Tensor:
40
40
  stokes_params = stokes.stokes_after_adr(
41
41
  retardance, orientation, transmittance, depolarization
42
42
  )
@@ -59,7 +59,7 @@ def apply_inverse_transfer_function(
59
59
  project_stokes_to_2d: bool = False,
60
60
  flip_orientation: bool = False,
61
61
  rotate_orientation: bool = False,
62
- ) -> Tuple[Tensor]:
62
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
63
63
  """Reconstructs retardance, orientation, transmittance, and depolarization
64
64
  from czyx_data and an intensity_to_stokes_matrix, providing options for
65
65
  background correction, projection, and orientation transformations.
@@ -125,7 +125,6 @@ def apply_inverse_transfer_function(
125
125
 
126
126
  # Apply an "Estimated" background correction
127
127
  if remove_estimated_background:
128
- estimator = background_estimator.BackgroundEstimator2D()
129
128
  for stokes_index in range(background_corrected_stokes.shape[0]):
130
129
  # Project to 2D
131
130
  z_projection = torch.mean(
@@ -134,9 +133,8 @@ def apply_inverse_transfer_function(
134
133
  # Estimate the background and subtract
135
134
  background_corrected_stokes[
136
135
  stokes_index
137
- ] -= estimator.get_background(
138
- z_projection,
139
- normalize=False,
136
+ ] -= correction.estimate_background(
137
+ z_projection, order=2, block_size=32
140
138
  )
141
139
 
142
140
  # Project to 2D (typically for SNR reasons)
@@ -0,0 +1,351 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ from torch import Tensor
5
+ from typing import Literal
6
+ from torch.nn.functional import avg_pool3d, interpolate
7
+ from waveorder import optics, sampling, stokes, util
8
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
9
+
10
+
11
+ def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
12
+ # Simulate
13
+ yx_star, yx_theta, _ = util.generate_star_target(
14
+ yx_shape=zyx_shape[1:],
15
+ blur_px=1,
16
+ margin=50,
17
+ )
18
+ c00 = yx_star
19
+ c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
20
+ c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
21
+
22
+ # Put in a center slices of a 3D object
23
+ center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
24
+ object = torch.zeros((3,) + zyx_shape)
25
+ object[:, zyx_shape[0] // 2, ...] = center_slice_object
26
+ return object
27
+
28
+
29
+ def calculate_transfer_function(
30
+ swing: float,
31
+ scheme: str,
32
+ zyx_shape: tuple[int, int, int],
33
+ yx_pixel_size: float,
34
+ z_pixel_size: float,
35
+ wavelength_illumination: float,
36
+ z_padding: int,
37
+ index_of_refraction_media: float,
38
+ numerical_aperture_illumination: float,
39
+ numerical_aperture_detection: float,
40
+ invert_phase_contrast: bool = False,
41
+ fourier_oversample_factor: int = 1,
42
+ transverse_downsample_factor: int = 1,
43
+ ) -> tuple[
44
+ torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
45
+ ]:
46
+ if z_padding != 0:
47
+ raise NotImplementedError("Padding not implemented for this model")
48
+
49
+ transverse_nyquist = sampling.transverse_nyquist(
50
+ wavelength_illumination,
51
+ numerical_aperture_illumination,
52
+ numerical_aperture_detection,
53
+ )
54
+ axial_nyquist = sampling.axial_nyquist(
55
+ wavelength_illumination,
56
+ numerical_aperture_detection,
57
+ index_of_refraction_media,
58
+ )
59
+
60
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
61
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
62
+
63
+ print("YX factor:", yx_factor)
64
+ print("Z factor:", z_factor)
65
+
66
+ tf_calculation_shape = (
67
+ zyx_shape[0] * z_factor * fourier_oversample_factor,
68
+ int(
69
+ np.ceil(
70
+ zyx_shape[1]
71
+ * yx_factor
72
+ * fourier_oversample_factor
73
+ / transverse_downsample_factor
74
+ )
75
+ ),
76
+ int(
77
+ np.ceil(
78
+ zyx_shape[2]
79
+ * yx_factor
80
+ * fourier_oversample_factor
81
+ / transverse_downsample_factor
82
+ )
83
+ ),
84
+ )
85
+
86
+ sfZYX_transfer_function, intensity_to_stokes_matrix = (
87
+ _calculate_wrap_unsafe_transfer_function(
88
+ swing,
89
+ scheme,
90
+ tf_calculation_shape,
91
+ yx_pixel_size / yx_factor,
92
+ z_pixel_size / z_factor,
93
+ wavelength_illumination,
94
+ z_padding,
95
+ index_of_refraction_media,
96
+ numerical_aperture_illumination,
97
+ numerical_aperture_detection,
98
+ invert_phase_contrast=invert_phase_contrast,
99
+ )
100
+ )
101
+
102
+ # avg_pool3d does not support complex numbers
103
+ pooled_sfZYX_transfer_function_real = avg_pool3d(
104
+ sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
105
+ )
106
+ pooled_sfZYX_transfer_function_imag = avg_pool3d(
107
+ sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
108
+ )
109
+ pooled_sfZYX_transfer_function = (
110
+ pooled_sfZYX_transfer_function_real
111
+ + 1j * pooled_sfZYX_transfer_function_imag
112
+ )
113
+
114
+ # Crop to original size
115
+ sfzyx_out_shape = (
116
+ pooled_sfZYX_transfer_function.shape[0],
117
+ pooled_sfZYX_transfer_function.shape[1],
118
+ zyx_shape[0] + 2 * z_padding,
119
+ ) + zyx_shape[1:]
120
+
121
+ cropped = sampling.nd_fourier_central_cuboid(
122
+ pooled_sfZYX_transfer_function, sfzyx_out_shape
123
+ )
124
+
125
+ # Compute singular system on cropped and downsampled
126
+ U, S, Vh = calculate_singular_system(cropped)
127
+
128
+ # Interpolate to final size in YX
129
+ def complex_interpolate(
130
+ tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
131
+ ) -> torch.Tensor:
132
+ interpolated_real = interpolate(tensor.real, size=zyx_shape)
133
+ interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
134
+ return interpolated_real + 1j * interpolated_imag
135
+
136
+ full_cropped = complex_interpolate(cropped, zyx_shape)
137
+ full_U = complex_interpolate(U, zyx_shape)
138
+ full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
139
+ full_Vh = complex_interpolate(Vh, zyx_shape)
140
+
141
+ return (
142
+ full_cropped,
143
+ intensity_to_stokes_matrix,
144
+ (full_U, full_S, full_Vh),
145
+ )
146
+
147
+
148
+ def _calculate_wrap_unsafe_transfer_function(
149
+ swing,
150
+ scheme,
151
+ zyx_shape,
152
+ yx_pixel_size,
153
+ z_pixel_size,
154
+ wavelength_illumination,
155
+ z_padding,
156
+ index_of_refraction_media,
157
+ numerical_aperture_illumination,
158
+ numerical_aperture_detection,
159
+ invert_phase_contrast=False,
160
+ ):
161
+ print("Computing transfer function")
162
+ intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
163
+ swing, scheme=scheme
164
+ )
165
+
166
+ input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
167
+ # input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
168
+
169
+ # Calculate frequencies
170
+ y_frequencies, x_frequencies = util.generate_frequencies(
171
+ zyx_shape[1:], yx_pixel_size
172
+ )
173
+ radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
174
+
175
+ z_total = zyx_shape[0] + 2 * z_padding
176
+ z_position_list = torch.fft.ifftshift(
177
+ (torch.arange(z_total) - z_total // 2) * z_pixel_size
178
+ )
179
+ if (
180
+ not invert_phase_contrast
181
+ ): # opposite sign of direct phase reconstruction
182
+ z_position_list = torch.flip(z_position_list, dims=(0,))
183
+ z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
184
+
185
+ # 2D pupils
186
+ print("\tCalculating pupils...")
187
+ ill_pupil = optics.generate_pupil(
188
+ radial_frequencies,
189
+ numerical_aperture_illumination,
190
+ wavelength_illumination,
191
+ )
192
+ det_pupil = optics.generate_pupil(
193
+ radial_frequencies,
194
+ numerical_aperture_detection,
195
+ wavelength_illumination,
196
+ )
197
+ pupil = optics.generate_pupil(
198
+ radial_frequencies,
199
+ index_of_refraction_media, # largest possible NA
200
+ wavelength_illumination,
201
+ )
202
+
203
+ # Defocus pupils
204
+ defocus_pupil = optics.generate_propagation_kernel(
205
+ radial_frequencies,
206
+ pupil,
207
+ wavelength_illumination / index_of_refraction_media,
208
+ z_position_list,
209
+ )
210
+
211
+ # Calculate vector defocus pupils
212
+ S = optics.generate_vector_source_defocus_pupil(
213
+ x_frequencies,
214
+ y_frequencies,
215
+ z_position_list,
216
+ defocus_pupil,
217
+ input_jones,
218
+ ill_pupil,
219
+ wavelength_illumination / index_of_refraction_media,
220
+ )
221
+
222
+ # Simplified scalar pupil
223
+ P = optics.generate_propagation_kernel(
224
+ radial_frequencies,
225
+ det_pupil,
226
+ wavelength_illumination / index_of_refraction_media,
227
+ z_position_list,
228
+ )
229
+
230
+ P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
231
+ S_3D = torch.fft.ifft(S, dim=-3)
232
+
233
+ print("\tCalculating greens tensor spectrum...")
234
+ G_3D = optics.generate_greens_tensor_spectrum(
235
+ zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
236
+ zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
237
+ wavelength=wavelength_illumination / index_of_refraction_media,
238
+ )
239
+
240
+ # Main part
241
+ PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
242
+ PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
243
+
244
+ del P_3D, G_3D, S_3D
245
+
246
+ print("\tComputing pg and ps...")
247
+ pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
248
+ ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
249
+
250
+ del PG_3D, PS_3D
251
+
252
+ print("\tComputing H1 and H2...")
253
+ H1 = torch.fft.ifftn(
254
+ torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
255
+ dim=(-3, -2, -1),
256
+ )
257
+
258
+ H2 = torch.fft.ifftn(
259
+ torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
260
+ dim=(-3, -2, -1),
261
+ )
262
+
263
+ H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
264
+ # H_im = 1j * (H1 - H2) # ignore absorptive terms
265
+
266
+ del H1, H2
267
+
268
+ H_re /= torch.amax(torch.abs(H_re))
269
+
270
+ s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
271
+ Y = util.gellmann()[[0, 4, 8]]
272
+ # select phase f00 and transverse linear isotropic terms 2-2, and f22
273
+
274
+ print("\tComputing final transfer function...")
275
+ sfZYX_transfer_function = torch.einsum(
276
+ "sik,ikpjzyx,lpj->slzyx", s, H_re, Y
277
+ )
278
+ return (
279
+ sfZYX_transfer_function,
280
+ intensity_to_stokes_matrix,
281
+ )
282
+
283
+
284
+ def calculate_singular_system(sfZYX_transfer_function):
285
+ # Compute regularized inverse filter
286
+ print("Computing SVD")
287
+ ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
288
+ U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
289
+ singular_system = (
290
+ U.permute(3, 4, 0, 1, 2),
291
+ S.permute(3, 0, 1, 2),
292
+ Vh.permute(3, 4, 0, 1, 2),
293
+ )
294
+ return singular_system
295
+
296
+
297
+ def visualize_transfer_function(
298
+ viewer: "napari.Viewer",
299
+ sfZYX_transfer_function: torch.Tensor,
300
+ zyx_scale: tuple[float, float, float],
301
+ ) -> None:
302
+ add_transfer_function_to_viewer(
303
+ viewer,
304
+ sfZYX_transfer_function,
305
+ zyx_scale=zyx_scale,
306
+ layer_name="Transfer Function",
307
+ complex_rgb=True,
308
+ clim_factor=0.5,
309
+ )
310
+
311
+
312
+ def apply_transfer_function(
313
+ fzyx_object: torch.Tensor,
314
+ sfZYX_transfer_function: torch.Tensor,
315
+ intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
316
+ ) -> torch.Tensor:
317
+ fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
318
+ sZYX_data = torch.einsum(
319
+ "fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
320
+ )
321
+ szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
322
+
323
+ return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
324
+
325
+
326
+ def apply_inverse_transfer_function(
327
+ szyx_data: Tensor,
328
+ singular_system: tuple[Tensor],
329
+ intensity_to_stokes_matrix: Tensor,
330
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
331
+ regularization_strength: float = 1e-3,
332
+ TV_rho_strength: float = 1e-3,
333
+ TV_iterations: int = 10,
334
+ ):
335
+ sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
336
+
337
+ # Key computation
338
+ print("Computing inverse filter")
339
+ U, S, Vh = singular_system
340
+ S_reg = S / (S**2 + regularization_strength)
341
+
342
+ ZYXsf_inverse_filter = torch.einsum(
343
+ "sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
344
+ )
345
+
346
+ # Apply inverse filter
347
+ fZYX_reconstructed = torch.einsum(
348
+ "szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
349
+ )
350
+
351
+ return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))