waveorder 2.2.0rc0__py3-none-any.whl → 2.2.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waveorder/_version.py CHANGED
@@ -1,8 +1,13 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
3
6
  TYPE_CHECKING = False
4
7
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
8
+ from typing import Tuple
9
+ from typing import Union
10
+
6
11
  VERSION_TUPLE = Tuple[Union[int, str], ...]
7
12
  else:
8
13
  VERSION_TUPLE = object
@@ -12,5 +17,5 @@ __version__: str
12
17
  __version_tuple__: VERSION_TUPLE
13
18
  version_tuple: VERSION_TUPLE
14
19
 
15
- __version__ = version = '2.2.0rc0'
16
- __version_tuple__ = version_tuple = (2, 2, 0)
20
+ __version__ = version = '2.2.1b0'
21
+ __version_tuple__ = version_tuple = (2, 2, 1)
@@ -1,12 +1,12 @@
1
1
  """Estimate flat field images"""
2
2
 
3
- import numpy as np
4
3
  import itertools
5
4
 
5
+ import numpy as np
6
6
 
7
7
  """
8
8
 
9
- This script is adopted from
9
+ This script is adopted from
10
10
 
11
11
  https://github.com/mehta-lab/reconstruct-order
12
12
 
waveorder/correction.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import torch
4
4
  import torch.nn.functional as F
5
- from torch import Tensor, Size
5
+ from torch import Size, Tensor
6
6
 
7
7
 
8
8
  def _sample_block_medians(image: Tensor, block_size) -> Tensor:
waveorder/filter.py ADDED
@@ -0,0 +1,206 @@
1
+ import itertools
2
+
3
+ import torch
4
+
5
+
6
+ def apply_filter_bank(
7
+ io_filter_bank: torch.Tensor,
8
+ i_input_array: torch.Tensor,
9
+ ) -> torch.Tensor:
10
+ """
11
+ Applies a filter bank to an input array.
12
+
13
+ io_filter_bank.shape must be smaller or equal to i_input_array.shape in all
14
+ dimensions. When io_filter_bank is smaller, it is effectively "stretched"
15
+ to apply the filter.
16
+
17
+ io_filter_bank is in "wrapped" format, i.e., the zero frequency is the
18
+ zeroth element.
19
+
20
+ i_input_array and io_filter_bank must have inverse sample spacing, i.e.,
21
+ is input_array contains samples spaced by dx, then io_filter_bank must
22
+ have extent 1/dx. Note that there is no need for io_filter_bank to have
23
+ sample spacing 1/(n*dx) because io_filter_bank will be stretched.
24
+
25
+ Parameters
26
+ ----------
27
+ io_filter_bank : torch.Tensor
28
+ The filter bank to be applied in the frequency domain.
29
+ The spatial extent of io_filter_bank must be 1/dx, where dx is the
30
+ sample spacing of i_input_array.
31
+
32
+ Leading dimensions are the input and output dimensions.
33
+ io_filter_bank.shape[:2] == (num_input_channels, num_output_channels)
34
+
35
+ Trailing dimensions are spatial frequency dimensions.
36
+ io_filter_bank.shape[2:] == (Z', Y', X') or (Y', X')
37
+
38
+ i_input_array : torch.Tensor
39
+ The real-valued input array with sample spacing dx to be filtered.
40
+
41
+ Leading dimension is the input dimension, matching the filter bank.
42
+ i_input_array.shape[0] == i
43
+
44
+ Trailing dimensions are spatial dimensions.
45
+ i_input_array.shape[1:] == (Z, Y, X) or (Y, X)
46
+
47
+ Returns
48
+ -------
49
+ torch.Tensor
50
+ The filtered real-valued output array with shape
51
+ (num_output_channels, Z, Y, X) or (num_output_channels, Y, X).
52
+
53
+ """
54
+
55
+ # Ensure all dimensions of transfer_function are smaller than or equal to input_array
56
+ if any(
57
+ t > i
58
+ for t, i in zip(io_filter_bank.shape[2:], i_input_array.shape[1:])
59
+ ):
60
+ raise ValueError(
61
+ "All spatial dimensions of io_filter_bank must be <= i_input_array."
62
+ )
63
+
64
+ # Ensure the number of spatial dimensions match
65
+ if io_filter_bank.ndim - i_input_array.ndim != 1:
66
+ raise ValueError(
67
+ "io_filter_bank and i_input_array must have the same number of spatial dimensions."
68
+ )
69
+
70
+ # Ensure the input dimensions match
71
+ if io_filter_bank.shape[0] != i_input_array.shape[0]:
72
+ raise ValueError(
73
+ "io_filter_bank.shape[0] and i_input_array.shape[0] must be the same."
74
+ )
75
+
76
+ num_input_channels, num_output_channels = io_filter_bank.shape[:2]
77
+ spatial_dims = io_filter_bank.shape[2:]
78
+
79
+ # Pad input_array until each dimension is divisible by transfer_function
80
+ pad_sizes = [
81
+ (0, (t - (i % t)) % t)
82
+ for t, i in zip(
83
+ io_filter_bank.shape[2:][::-1], i_input_array.shape[1:][::-1]
84
+ )
85
+ ]
86
+ flat_pad_sizes = list(itertools.chain(*pad_sizes))
87
+ padded_input_array = torch.nn.functional.pad(i_input_array, flat_pad_sizes)
88
+
89
+ # Apply the transfer function in the frequency domain
90
+ fft_dims = [d for d in range(1, i_input_array.ndim)]
91
+ padded_input_spectrum = torch.fft.fftn(padded_input_array, dim=fft_dims)
92
+
93
+ # Matrix-vector multiplication over f
94
+ # If this is a bottleneck, consider extending `stretched_multiply` to
95
+ # a `stretched_matrix_multiply` that uses an call like
96
+ # torch.einsum('io..., i... -> o...', io_filter_bank, padded_input_spectrum)
97
+ #
98
+ # Further optimization is likely with a combination of
99
+ # torch.baddbmm, torch.pixel_shuffle, torch.pixel_unshuffle.
100
+ padded_output_spectrum = torch.zeros(
101
+ (num_output_channels,) + spatial_dims,
102
+ dtype=padded_input_spectrum.dtype,
103
+ device=padded_input_spectrum.device,
104
+ )
105
+ for input_channel_idx in range(num_input_channels):
106
+ for output_channel_idx in range(num_output_channels):
107
+ padded_output_spectrum[output_channel_idx] += stretched_multiply(
108
+ io_filter_bank[input_channel_idx, output_channel_idx],
109
+ padded_input_spectrum[input_channel_idx],
110
+ )
111
+
112
+ # Cast to real, ignoring imaginary part
113
+ padded_result = torch.real(
114
+ torch.fft.ifftn(padded_output_spectrum, dim=fft_dims)
115
+ )
116
+
117
+ # Remove padding and return
118
+ slices = tuple(slice(0, i) for i in i_input_array.shape)
119
+ return padded_result[slices]
120
+
121
+
122
+ def stretched_multiply(
123
+ small_array: torch.Tensor, large_array: torch.Tensor
124
+ ) -> torch.Tensor:
125
+ """
126
+ Effectively "stretches" small_array onto large_array before multiplying.
127
+
128
+ Each dimension of large_array must be divisible by each dimension of small_array.
129
+
130
+ Instead of upsampling small_array, this function uses a "block element-wise"
131
+ multiplication by breaking the large_array into blocks before element-wise
132
+ multiplication with the small_array.
133
+
134
+ For example, a `stretched_multiply` of a 3x3 array by a 99x99 array will
135
+ divide the 99x99 array into 33x33 blocks
136
+ [[33x33, 33x33, 33x33],
137
+ [33x33, 33x33, 33x33],
138
+ [33x33, 33x33, 33x33]]
139
+ and multiply each block by the corresponding element in the 3x3 array.
140
+
141
+ Returns an array with the same shape as large_array.
142
+
143
+ Works for arbitrary dimensions.
144
+
145
+ Parameters
146
+ ----------
147
+ small_array : torch.Tensor
148
+ A smaller array whose elements will be "stretched" onto blocks in the large array.
149
+ large_array : torch.Tensor
150
+ A larger array that will be divided into blocks and multiplied by the small array.
151
+
152
+ Returns
153
+ -------
154
+ torch.Tensor
155
+ Resulting tensor with shape matching large_array.
156
+
157
+ Example
158
+ -------
159
+ small_array = torch.tensor([[1, 2],
160
+ [3, 4]])
161
+
162
+ large_array = torch.tensor([[1, 2, 3, 4],
163
+ [5, 6, 7, 8],
164
+ [9, 10, 11, 12],
165
+ [13, 14, 15, 16]])
166
+
167
+ stretched_multiply(small_array, large_array) returns
168
+
169
+ [[ 1, 2, 6, 8],
170
+ [ 5, 6, 14, 16],
171
+ [ 27, 30, 44, 48],
172
+ [ 39, 42, 60, 64]]
173
+ """
174
+
175
+ # Ensure each dimension of large_array is divisible by each dimension of small_array
176
+ if any(l % s != 0 for s, l in zip(small_array.shape, large_array.shape)):
177
+ raise ValueError(
178
+ "Each dimension of large_array must be divisible by each dimension of small_array"
179
+ )
180
+
181
+ # Ensure the number of dimensions match
182
+ if small_array.ndim != large_array.ndim:
183
+ raise ValueError(
184
+ "small_array and large_array must have the same number of dimensions"
185
+ )
186
+
187
+ # Get shapes
188
+ s_shape = small_array.shape
189
+ l_shape = large_array.shape
190
+
191
+ # Reshape both array into blocks
192
+ block_shape = tuple(p // s for p, s in zip(l_shape, s_shape))
193
+ new_large_shape = tuple(itertools.chain(*zip(s_shape, block_shape)))
194
+ new_small_shape = tuple(
195
+ itertools.chain(*zip(s_shape, small_array.ndim * (1,)))
196
+ )
197
+ reshaped_large_array = large_array.reshape(new_large_shape)
198
+ reshaped_small_array = small_array.reshape(new_small_shape)
199
+
200
+ # Multiply the reshaped arrays
201
+ reshaped_result = reshaped_large_array * reshaped_small_array
202
+
203
+ # Reshape the result back to the large array shape
204
+ result = reshaped_result.reshape(l_shape)
205
+
206
+ return result
waveorder/focus.py CHANGED
@@ -1,9 +1,11 @@
1
- from scipy.signal import peak_widths
1
+ import warnings
2
2
  from typing import Literal, Optional
3
- from waveorder import util
3
+
4
4
  import matplotlib.pyplot as plt
5
5
  import numpy as np
6
- import warnings
6
+ from scipy.signal import peak_widths
7
+
8
+ from waveorder import util
7
9
 
8
10
 
9
11
  def focus_from_transverse_band(
@@ -7,7 +7,9 @@ from torch import Tensor
7
7
  from waveorder import correction, stokes, util
8
8
 
9
9
 
10
- def generate_test_phantom(yx_shape):
10
+ def generate_test_phantom(
11
+ yx_shape: Tuple[int, int],
12
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
11
13
  star, theta, _ = util.generate_star_target(yx_shape, blur_px=0.1)
12
14
  retardance = 0.25 * star
13
15
  orientation = (theta % np.pi) * (star > 1e-3)
@@ -17,13 +19,15 @@ def generate_test_phantom(yx_shape):
17
19
 
18
20
 
19
21
  def calculate_transfer_function(
20
- swing,
21
- scheme,
22
- ):
22
+ swing: float,
23
+ scheme: str,
24
+ ) -> Tensor:
23
25
  return stokes.calculate_intensity_to_stokes_matrix(swing, scheme=scheme)
24
26
 
25
27
 
26
- def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
28
+ def visualize_transfer_function(
29
+ viewer, intensity_to_stokes_matrix: Tensor
30
+ ) -> None:
27
31
  viewer.add_image(
28
32
  intensity_to_stokes_matrix.cpu().numpy(),
29
33
  name="Intensity to stokes matrix",
@@ -31,12 +35,12 @@ def visualize_transfer_function(viewer, intensity_to_stokes_matrix):
31
35
 
32
36
 
33
37
  def apply_transfer_function(
34
- retardance,
35
- orientation,
36
- transmittance,
37
- depolarization,
38
- intensity_to_stokes_matrix,
39
- ):
38
+ retardance: Tensor,
39
+ orientation: Tensor,
40
+ transmittance: Tensor,
41
+ depolarization: Tensor,
42
+ intensity_to_stokes_matrix: Tensor,
43
+ ) -> Tensor:
40
44
  stokes_params = stokes.stokes_after_adr(
41
45
  retardance, orientation, transmittance, depolarization
42
46
  )
@@ -59,7 +63,7 @@ def apply_inverse_transfer_function(
59
63
  project_stokes_to_2d: bool = False,
60
64
  flip_orientation: bool = False,
61
65
  rotate_orientation: bool = False,
62
- ) -> Tuple[Tensor]:
66
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
63
67
  """Reconstructs retardance, orientation, transmittance, and depolarization
64
68
  from czyx_data and an intensity_to_stokes_matrix, providing options for
65
69
  background correction, projection, and orientation transformations.
@@ -0,0 +1,320 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn.functional import avg_pool3d
7
+
8
+ from waveorder import optics, sampling, stokes, util
9
+ from waveorder.filter import apply_filter_bank
10
+ from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
11
+
12
+
13
+ def generate_test_phantom(zyx_shape: tuple[int, int, int]) -> torch.Tensor:
14
+ # Simulate
15
+ yx_star, yx_theta, _ = util.generate_star_target(
16
+ yx_shape=zyx_shape[1:],
17
+ blur_px=1,
18
+ margin=50,
19
+ )
20
+ c00 = yx_star
21
+ c2_2 = -torch.sin(2 * yx_theta) * yx_star # torch.zeros_like(c00)
22
+ c22 = -torch.cos(2 * yx_theta) * yx_star # torch.zeros_like(c00) #
23
+
24
+ # Put in a center slices of a 3D object
25
+ center_slice_object = torch.stack((c00, c2_2, c22), dim=0)
26
+ object = torch.zeros((3,) + zyx_shape)
27
+ object[:, zyx_shape[0] // 2, ...] = center_slice_object
28
+ return object
29
+
30
+
31
+ def calculate_transfer_function(
32
+ swing: float,
33
+ scheme: str,
34
+ zyx_shape: tuple[int, int, int],
35
+ yx_pixel_size: float,
36
+ z_pixel_size: float,
37
+ wavelength_illumination: float,
38
+ z_padding: int,
39
+ index_of_refraction_media: float,
40
+ numerical_aperture_illumination: float,
41
+ numerical_aperture_detection: float,
42
+ invert_phase_contrast: bool = False,
43
+ fourier_oversample_factor: int = 1,
44
+ ) -> tuple[
45
+ torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
46
+ ]:
47
+ if z_padding != 0:
48
+ raise NotImplementedError("Padding not implemented for this model")
49
+
50
+ transverse_nyquist = sampling.transverse_nyquist(
51
+ wavelength_illumination,
52
+ numerical_aperture_illumination,
53
+ numerical_aperture_detection,
54
+ )
55
+ axial_nyquist = sampling.axial_nyquist(
56
+ wavelength_illumination,
57
+ numerical_aperture_detection,
58
+ index_of_refraction_media,
59
+ )
60
+
61
+ yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
62
+ z_factor = int(np.ceil(z_pixel_size / axial_nyquist))
63
+
64
+ print("YX factor:", yx_factor)
65
+ print("Z factor:", z_factor)
66
+
67
+ tf_calculation_shape = (
68
+ zyx_shape[0] * z_factor * fourier_oversample_factor,
69
+ int(np.ceil(zyx_shape[1] * yx_factor * fourier_oversample_factor)),
70
+ int(np.ceil(zyx_shape[2] * yx_factor * fourier_oversample_factor)),
71
+ )
72
+
73
+ (
74
+ sfZYX_transfer_function,
75
+ intensity_to_stokes_matrix,
76
+ ) = _calculate_wrap_unsafe_transfer_function(
77
+ swing,
78
+ scheme,
79
+ tf_calculation_shape,
80
+ yx_pixel_size / yx_factor,
81
+ z_pixel_size / z_factor,
82
+ wavelength_illumination,
83
+ z_padding,
84
+ index_of_refraction_media,
85
+ numerical_aperture_illumination,
86
+ numerical_aperture_detection,
87
+ invert_phase_contrast=invert_phase_contrast,
88
+ )
89
+
90
+ # avg_pool3d does not support complex numbers
91
+ pooled_sfZYX_transfer_function_real = avg_pool3d(
92
+ sfZYX_transfer_function.real, (fourier_oversample_factor,) * 3
93
+ )
94
+ pooled_sfZYX_transfer_function_imag = avg_pool3d(
95
+ sfZYX_transfer_function.imag, (fourier_oversample_factor,) * 3
96
+ )
97
+ pooled_sfZYX_transfer_function = (
98
+ pooled_sfZYX_transfer_function_real
99
+ + 1j * pooled_sfZYX_transfer_function_imag
100
+ )
101
+
102
+ # Crop to original size
103
+ sfzyx_out_shape = (
104
+ pooled_sfZYX_transfer_function.shape[0],
105
+ pooled_sfZYX_transfer_function.shape[1],
106
+ zyx_shape[0] + 2 * z_padding,
107
+ ) + zyx_shape[1:]
108
+
109
+ cropped = sampling.nd_fourier_central_cuboid(
110
+ pooled_sfZYX_transfer_function, sfzyx_out_shape
111
+ )
112
+
113
+ # Compute singular system on cropped and downsampled
114
+ singular_system = calculate_singular_system(cropped)
115
+
116
+ return (
117
+ cropped,
118
+ intensity_to_stokes_matrix,
119
+ singular_system,
120
+ )
121
+
122
+
123
+ def _calculate_wrap_unsafe_transfer_function(
124
+ swing,
125
+ scheme,
126
+ zyx_shape,
127
+ yx_pixel_size,
128
+ z_pixel_size,
129
+ wavelength_illumination,
130
+ z_padding,
131
+ index_of_refraction_media,
132
+ numerical_aperture_illumination,
133
+ numerical_aperture_detection,
134
+ invert_phase_contrast=False,
135
+ ):
136
+ print("Computing transfer function")
137
+ intensity_to_stokes_matrix = stokes.calculate_intensity_to_stokes_matrix(
138
+ swing, scheme=scheme
139
+ )
140
+
141
+ input_jones = torch.tensor([0.0 - 1.0j, 1.0 + 0j]) # circular
142
+ # input_jones = torch.tensor([0 + 0j, 1 + 0j]) # linear
143
+
144
+ # Calculate frequencies
145
+ y_frequencies, x_frequencies = util.generate_frequencies(
146
+ zyx_shape[1:], yx_pixel_size
147
+ )
148
+ radial_frequencies = torch.sqrt(x_frequencies**2 + y_frequencies**2)
149
+
150
+ z_total = zyx_shape[0] + 2 * z_padding
151
+ z_position_list = torch.fft.ifftshift(
152
+ (torch.arange(z_total) - z_total // 2) * z_pixel_size
153
+ )
154
+ if (
155
+ not invert_phase_contrast
156
+ ): # opposite sign of direct phase reconstruction
157
+ z_position_list = torch.flip(z_position_list, dims=(0,))
158
+ z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)
159
+
160
+ # 2D pupils
161
+ print("\tCalculating pupils...")
162
+ ill_pupil = optics.generate_pupil(
163
+ radial_frequencies,
164
+ numerical_aperture_illumination,
165
+ wavelength_illumination,
166
+ )
167
+ det_pupil = optics.generate_pupil(
168
+ radial_frequencies,
169
+ numerical_aperture_detection,
170
+ wavelength_illumination,
171
+ )
172
+ pupil = optics.generate_pupil(
173
+ radial_frequencies,
174
+ index_of_refraction_media, # largest possible NA
175
+ wavelength_illumination,
176
+ )
177
+
178
+ # Defocus pupils
179
+ defocus_pupil = optics.generate_propagation_kernel(
180
+ radial_frequencies,
181
+ pupil,
182
+ wavelength_illumination / index_of_refraction_media,
183
+ z_position_list,
184
+ )
185
+
186
+ # Calculate vector defocus pupils
187
+ S = optics.generate_vector_source_defocus_pupil(
188
+ x_frequencies,
189
+ y_frequencies,
190
+ z_position_list,
191
+ defocus_pupil,
192
+ input_jones,
193
+ ill_pupil,
194
+ wavelength_illumination / index_of_refraction_media,
195
+ )
196
+
197
+ # Simplified scalar pupil
198
+ P = optics.generate_propagation_kernel(
199
+ radial_frequencies,
200
+ det_pupil,
201
+ wavelength_illumination / index_of_refraction_media,
202
+ z_position_list,
203
+ )
204
+
205
+ P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
206
+ S_3D = torch.fft.ifft(S, dim=-3)
207
+
208
+ print("\tCalculating greens tensor spectrum...")
209
+ G_3D = optics.generate_greens_tensor_spectrum(
210
+ zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
211
+ zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
212
+ wavelength=wavelength_illumination / index_of_refraction_media,
213
+ )
214
+
215
+ # Main part
216
+ PG_3D = torch.einsum("zyx,ipzyx->ipzyx", P_3D, G_3D)
217
+ PS_3D = torch.einsum("zyx,jzyx,kzyx->jkzyx", P_3D, S_3D, torch.conj(S_3D))
218
+
219
+ del P_3D, G_3D, S_3D
220
+
221
+ print("\tComputing pg and ps...")
222
+ pg = torch.fft.fftn(PG_3D, dim=(-3, -2, -1))
223
+ ps = torch.fft.fftn(PS_3D, dim=(-3, -2, -1))
224
+
225
+ del PG_3D, PS_3D
226
+
227
+ print("\tComputing H1 and H2...")
228
+ H1 = torch.fft.ifftn(
229
+ torch.einsum("ipzyx,jkzyx->ijpkzyx", pg, torch.conj(ps)),
230
+ dim=(-3, -2, -1),
231
+ )
232
+
233
+ H2 = torch.fft.ifftn(
234
+ torch.einsum("ikzyx,jpzyx->ijpkzyx", ps, torch.conj(pg)),
235
+ dim=(-3, -2, -1),
236
+ )
237
+
238
+ H_re = H1[1:, 1:] + H2[1:, 1:] # drop data-side z components
239
+ # H_im = 1j * (H1 - H2) # ignore absorptive terms
240
+
241
+ del H1, H2
242
+
243
+ H_re /= torch.amax(torch.abs(H_re))
244
+
245
+ s = util.pauli()[[0, 1, 2, 3]] # select s0, s1, and s2
246
+ Y = util.gellmann()[[0, 4, 8]]
247
+ # select phase f00 and transverse linear isotropic terms 2-2, and f22
248
+
249
+ print("\tComputing final transfer function...")
250
+ sfZYX_transfer_function = torch.einsum(
251
+ "sik,ikpjzyx,lpj->slzyx", s, H_re, Y
252
+ )
253
+ return (
254
+ sfZYX_transfer_function,
255
+ intensity_to_stokes_matrix,
256
+ )
257
+
258
+
259
+ def calculate_singular_system(sfZYX_transfer_function):
260
+ # Compute regularized inverse filter
261
+ print("Computing SVD")
262
+ ZYXsf_transfer_function = sfZYX_transfer_function.permute(2, 3, 4, 0, 1)
263
+ U, S, Vh = torch.linalg.svd(ZYXsf_transfer_function, full_matrices=False)
264
+ singular_system = (
265
+ U.permute(3, 4, 0, 1, 2),
266
+ S.permute(3, 0, 1, 2),
267
+ Vh.permute(3, 4, 0, 1, 2),
268
+ )
269
+ return singular_system
270
+
271
+
272
+ def visualize_transfer_function(
273
+ viewer: "napari.Viewer",
274
+ sfZYX_transfer_function: torch.Tensor,
275
+ zyx_scale: tuple[float, float, float],
276
+ ) -> None:
277
+ add_transfer_function_to_viewer(
278
+ viewer,
279
+ sfZYX_transfer_function,
280
+ zyx_scale=zyx_scale,
281
+ layer_name="Transfer Function",
282
+ complex_rgb=True,
283
+ clim_factor=0.5,
284
+ )
285
+
286
+
287
+ def apply_transfer_function(
288
+ fzyx_object: torch.Tensor,
289
+ sfZYX_transfer_function: torch.Tensor,
290
+ intensity_to_stokes_matrix: torch.Tensor, # TODO use this to simulate intensities
291
+ ) -> torch.Tensor:
292
+ fZYX_object = torch.fft.fftn(fzyx_object, dim=(1, 2, 3))
293
+ sZYX_data = torch.einsum(
294
+ "fzyx,sfzyx->szyx", fZYX_object, sfZYX_transfer_function
295
+ )
296
+ szyx_data = torch.fft.ifftn(sZYX_data, dim=(1, 2, 3))
297
+
298
+ return 50 * szyx_data # + 0.1 * torch.randn(szyx_data.shape)
299
+
300
+
301
+ def apply_inverse_transfer_function(
302
+ szyx_data: Tensor,
303
+ singular_system: tuple[Tensor],
304
+ intensity_to_stokes_matrix: Tensor,
305
+ reconstruction_algorithm: Literal["Tikhonov", "TV"] = "Tikhonov",
306
+ regularization_strength: float = 1e-3,
307
+ TV_rho_strength: float = 1e-3,
308
+ TV_iterations: int = 10,
309
+ ):
310
+ # Key computation
311
+ print("Computing inverse filter")
312
+ U, S, Vh = singular_system
313
+ S_reg = S / (S**2 + regularization_strength)
314
+ sfzyx_inverse_filter = torch.einsum(
315
+ "sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
316
+ )
317
+
318
+ fzyx_recon = apply_filter_bank(sfzyx_inverse_filter, szyx_data)
319
+
320
+ return fzyx_recon