waveorder 2.2.0rc0__py3-none-any.whl → 2.2.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,335 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def plot_5d_ortho(
6
+ rcCzyx_data: np.ndarray,
7
+ filename: str,
8
+ voxel_size: tuple[float, float, float],
9
+ zyx_slice: tuple[int, int, int],
10
+ color_funcs: list[list[callable]],
11
+ row_labels: list[str] = None,
12
+ column_labels: list[str] = None,
13
+ rose_path: str = None,
14
+ inches_per_column: float = 1.5,
15
+ label_size: int = 1,
16
+ ortho_line_width: float = 0.5,
17
+ row_column_line_width: float = 0.5,
18
+ xyz_labels: bool = True,
19
+ background_color: str = "white",
20
+ **kwargs: dict,
21
+ ) -> None:
22
+ """
23
+ Plot 5D multi-channel data in a grid or ortho-slice views.
24
+
25
+ Input data is a 6D array with (row, column, channels, Z, Y, X) dimensions.
26
+
27
+ `color_funcs` permits different RGB color maps for each row and column.
28
+
29
+ Parameters
30
+ ----------
31
+ rcCzyx_data : numpy.ndarray
32
+ 5D array with shape (R, C, Ch, Z, Y, X) containing the data to plot.
33
+ [r]ows and [c]olumns form a grid
34
+ [C]hannels contain multiple color channels
35
+ [ZYX] contain 3D volumes.
36
+ filename : str
37
+ Path to save the output plot.
38
+ voxel_size : tuple[float, float, float]
39
+ Size of each voxel in (Z, Y, X) dimensions.
40
+ zyx_slice : tuple[int, int, int]
41
+ Indices of the ortho-slices to plot in (Z, Y, X) indices.
42
+ color_funcs : list[list[callable]]
43
+ A list of lists of callables, one for each element of the plot grid,
44
+ with len(color_funcs) == R and len(colors_funcs[0] == C).
45
+ Each callable accepts [C]hannel arguments and returns RGB color values,
46
+ enabling different RGB color maps for each member of the grid.
47
+ row_labels : list[str], optional
48
+ Labels for the rows, by default None.
49
+ column_labels : list[str], optional
50
+ Labels for the columns, by default None.
51
+ rose_path : str, optional
52
+ Path to an image to display in the top-left corner, by default None.
53
+ inches_per_column : float, optional
54
+ Width of each column in inches, by default 1.5.
55
+ label_size : int, optional
56
+ Size of the labels, by default 1.
57
+ ortho_line_width : float, optional
58
+ Width of the orthogonal lines, by default 0.5.
59
+ row_column_line_width : float, optional
60
+ Width of the lines between rows and columns, by default 0.5.
61
+ xyz_labels : bool, optional
62
+ Whether to display XYZ labels, by default True.
63
+ background_color : str, optional
64
+ Background color of the plot, by default "white".
65
+ **kwargs : dict
66
+ Additional keyword arguments passed to color_funcs.
67
+ """
68
+ R, C, Ch, Z, Y, X = rcCzyx_data.shape
69
+
70
+ # Extent
71
+ dZ, dY, dX = Z * voxel_size[0], Y * voxel_size[1], X * voxel_size[2]
72
+
73
+ assert R == len(row_labels)
74
+ assert C == len(column_labels)
75
+ assert zyx_slice[0] < Z and zyx_slice[1] < Y and zyx_slice[2] < X
76
+ assert zyx_slice[0] >= 0 and zyx_slice[1] >= 0 and zyx_slice[2] >= 0
77
+
78
+ assert R == len(color_funcs)
79
+ for color_func_row in color_funcs:
80
+ if isinstance(color_func_row, list):
81
+ assert len(color_func_row) == C
82
+ else:
83
+ color_func_row = [color_func_row] * C
84
+
85
+ n_rows = 1 + (2 * R)
86
+ n_cols = 1 + (2 * C)
87
+
88
+ width_ratios = [label_size] + C * [1, dZ / dX]
89
+ height_ratios = [label_size] + R * [dY / dX, dZ / dX]
90
+
91
+ fig_width = np.array(width_ratios).sum() * inches_per_column
92
+ fig_height = np.array(height_ratios).sum() * inches_per_column
93
+
94
+ fig, axes = plt.subplots(
95
+ n_rows,
96
+ n_cols,
97
+ figsize=(fig_width, fig_height),
98
+ gridspec_kw={
99
+ "wspace": 0.05,
100
+ "hspace": 0.05,
101
+ "width_ratios": width_ratios,
102
+ "height_ratios": height_ratios,
103
+ },
104
+ )
105
+ fig.patch.set_facecolor(background_color)
106
+ for ax in axes.flat:
107
+ ax.set_facecolor(background_color)
108
+
109
+ if rose_path is not None:
110
+ axes[0, 0].imshow(plt.imread(rose_path))
111
+
112
+ for i in range(n_rows):
113
+ for j in range(n_cols):
114
+ # Add labels
115
+ if (i == 0 and (j - 1) % 2 == 0) or (j == 0 and (i - 1) % 2 == 0):
116
+ axes[i, j].text(
117
+ 0.5,
118
+ 0.5,
119
+ index,
120
+ horizontalalignment="center",
121
+ verticalalignment="center",
122
+ fontsize=10 * label_size,
123
+ color="black",
124
+ )
125
+
126
+ # Add data
127
+ if i > 0 and j > 0:
128
+ color_func = color_funcs[int((i - 1) / 2)][int((j - 1) / 2)]
129
+
130
+ Cyx_data = rcCzyx_data[
131
+ int((i - 1) / 2), int((j - 1) / 2), :, zyx_slice[0]
132
+ ]
133
+ Cyz_data = rcCzyx_data[
134
+ int((i - 1) / 2), int((j - 1) / 2), :, :, :, zyx_slice[2]
135
+ ].transpose(0, 2, 1)
136
+ Czx_data = rcCzyx_data[
137
+ int((i - 1) / 2), int((j - 1) / 2), :, :, zyx_slice[1]
138
+ ]
139
+
140
+ # YX
141
+ if (i - 1) % 2 == 0 and (j - 1) % 2 == 0:
142
+ axes[i, j].imshow(
143
+ color_func(*Cyx_data, **kwargs),
144
+ aspect=voxel_size[1] / voxel_size[2],
145
+ )
146
+ # YZ
147
+ elif (i - 1) % 2 == 0 and (j - 1) % 2 == 1:
148
+ axes[i, j].imshow(
149
+ color_func(*Cyz_data, **kwargs),
150
+ aspect=voxel_size[1] / voxel_size[0],
151
+ )
152
+ # XZ
153
+ elif (i - 1) % 2 == 1 and (j - 1) % 2 == 0:
154
+ axes[i, j].imshow(
155
+ color_func(*Czx_data, **kwargs),
156
+ aspect=voxel_size[0] / voxel_size[2],
157
+ )
158
+
159
+ # Draw lines between rows and cols
160
+ top = axes[0, 0].get_position().y1
161
+ bottom = axes[-1, -1].get_position().y0
162
+ left = axes[0, 0].get_position().x0
163
+ right = axes[-1, -1].get_position().x1
164
+ if i == 0 and (j - 1) % 2 == 0:
165
+ left_edge = (
166
+ axes[0, j].get_position().x0
167
+ + axes[0, j - 1].get_position().x1
168
+ ) / 2
169
+ fig.add_artist(
170
+ plt.Line2D(
171
+ [left_edge, left_edge],
172
+ [bottom, top],
173
+ transform=fig.transFigure,
174
+ color="black",
175
+ lw=row_column_line_width,
176
+ )
177
+ )
178
+ if j == 0 and (i - 1) % 2 == 0:
179
+ top_edge = (
180
+ axes[i, 0].get_position().y1
181
+ + axes[i - 1, 0].get_position().y0
182
+ ) / 2
183
+ fig.add_artist(
184
+ plt.Line2D(
185
+ [left, right],
186
+ [top_edge, top_edge],
187
+ transform=fig.transFigure,
188
+ color="black",
189
+ lw=row_column_line_width,
190
+ )
191
+ )
192
+
193
+ # Remove ticks and spines
194
+ axes[i, j].tick_params(
195
+ left=False, bottom=False, labelleft=False, labelbottom=False
196
+ )
197
+ axes[i, j].spines["top"].set_visible(False)
198
+ axes[i, j].spines["right"].set_visible(False)
199
+ axes[i, j].spines["bottom"].set_visible(False)
200
+ axes[i, j].spines["left"].set_visible(False)
201
+
202
+ yx_slice_color = "green"
203
+ yz_slice_color = "red"
204
+ zx_slice_color = "blue"
205
+
206
+ # Label orthogonal slices
207
+ add_ortho_lines_to_axis(
208
+ axes[1, 1],
209
+ (zyx_slice[1], zyx_slice[2]),
210
+ ("y", "x") if xyz_labels else ("", ""),
211
+ yx_slice_color,
212
+ yz_slice_color,
213
+ zx_slice_color,
214
+ ortho_line_width,
215
+ ) # YX axis
216
+
217
+ add_ortho_lines_to_axis(
218
+ axes[2, 1],
219
+ (zyx_slice[0], zyx_slice[2]),
220
+ ("z", "x") if xyz_labels else ("", ""),
221
+ zx_slice_color,
222
+ yz_slice_color,
223
+ yx_slice_color,
224
+ ortho_line_width,
225
+ ) # ZX axis
226
+
227
+ add_ortho_lines_to_axis(
228
+ axes[1, 2],
229
+ (zyx_slice[1], zyx_slice[0]),
230
+ ("y", "z") if xyz_labels else ("", ""),
231
+ yz_slice_color,
232
+ yx_slice_color,
233
+ zx_slice_color,
234
+ ortho_line_width,
235
+ ) # YZ axis
236
+
237
+ print(f"Saving {filename}")
238
+ fig.savefig(filename, dpi=400, format="pdf", bbox_inches="tight")
239
+
240
+
241
+ def add_ortho_lines_to_axis(
242
+ axis: plt.Axes,
243
+ yx_slice: tuple[int, int],
244
+ axis_labels: tuple[str, str],
245
+ outer_color: str,
246
+ vertical_color: str,
247
+ horizontal_color: str,
248
+ line_width: float = 0,
249
+ text_color: str = "white",
250
+ ) -> None:
251
+ """
252
+ Add orthogonal lines and labels to a given axis.
253
+
254
+ Parameters
255
+ ----------
256
+ axis : matplotlib.axes.Axes
257
+ The axis to which the orthogonal lines and labels will be added.
258
+ yx_slice : tuple[int, int]
259
+ The (Y, X) slice indices for the orthogonal lines.
260
+ axis_labels : tuple[str, str]
261
+ The labels for the Y and X axes.
262
+ outer_color : str
263
+ The color of the outer rectangle.
264
+ vertical_color : str
265
+ The color of the vertical line.
266
+ horizontal_color : str
267
+ The color of the horizontal line.
268
+ line_width : float, optional
269
+ The width of the lines, by default 0.
270
+ text_color : str, optional
271
+ The color of the text labels, by default "white".
272
+ """
273
+ xmin, xmax = axis.get_xlim()
274
+ ymin, ymax = axis.get_ylim()
275
+
276
+ # Axis labels
277
+ horizontal_axis_label_pos = (0.1, 0.975)
278
+ vertical_axis_label_pos = (0.025, 0.9)
279
+ axis.text(
280
+ horizontal_axis_label_pos[0],
281
+ horizontal_axis_label_pos[1],
282
+ axis_labels[1],
283
+ horizontalalignment="left",
284
+ verticalalignment="top",
285
+ transform=axis.transAxes,
286
+ fontsize=5,
287
+ color=text_color,
288
+ )
289
+
290
+ axis.text(
291
+ vertical_axis_label_pos[0],
292
+ vertical_axis_label_pos[1],
293
+ axis_labels[0],
294
+ horizontalalignment="left",
295
+ verticalalignment="top",
296
+ transform=axis.transAxes,
297
+ fontsize=5,
298
+ color=text_color,
299
+ )
300
+
301
+ # Outer rectangle
302
+ axis.add_artist(
303
+ plt.Rectangle(
304
+ (xmin, ymin),
305
+ xmax - xmin,
306
+ ymax - ymin,
307
+ linewidth=line_width,
308
+ edgecolor=outer_color,
309
+ facecolor="none",
310
+ transform=axis.transData,
311
+ clip_on=False,
312
+ )
313
+ )
314
+
315
+ # Horizontal line
316
+ axis.add_artist(
317
+ plt.Line2D(
318
+ [xmin, xmax],
319
+ [yx_slice[0], yx_slice[0]],
320
+ transform=axis.transData,
321
+ color=horizontal_color,
322
+ lw=line_width,
323
+ )
324
+ )
325
+
326
+ # Vertical line
327
+ axis.add_artist(
328
+ plt.Line2D(
329
+ [yx_slice[1], yx_slice[1]],
330
+ [ymin, ymax],
331
+ transform=axis.transData,
332
+ color=vertical_color,
333
+ lw=line_width,
334
+ )
335
+ )
@@ -0,0 +1,76 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from waveorder.visuals.utils import complex_tensor_to_rgb
5
+
6
+
7
+ def add_transfer_function_to_viewer(
8
+ viewer: "napari.Viewer",
9
+ transfer_function: torch.Tensor,
10
+ zyx_scale: tuple[float, float, float],
11
+ layer_name: str = "Transfer Function",
12
+ clim_factor: float = 1.0,
13
+ complex_rgb: bool = False,
14
+ ):
15
+ zyx_shape = transfer_function.shape[-3:]
16
+ lim = torch.max(torch.abs(transfer_function)) * clim_factor
17
+ voxel_scale = np.array(
18
+ [
19
+ zyx_shape[0] * zyx_scale[0],
20
+ zyx_shape[1] * zyx_scale[1],
21
+ zyx_shape[2] * zyx_scale[2],
22
+ ]
23
+ )
24
+ shift_dims = (-3, -2, -1)
25
+
26
+ if complex_rgb:
27
+ rgb_transfer_function = complex_tensor_to_rgb(
28
+ np.array(torch.fft.ifftshift(transfer_function, dim=shift_dims)),
29
+ saturate_clim_fraction=clim_factor,
30
+ )
31
+ viewer.add_image(
32
+ rgb_transfer_function,
33
+ scale=1 / voxel_scale,
34
+ name=layer_name,
35
+ )
36
+ else:
37
+ viewer.add_image(
38
+ torch.fft.ifftshift(torch.real(transfer_function), dim=shift_dims)
39
+ .cpu()
40
+ .numpy(),
41
+ colormap="bwr",
42
+ contrast_limits=(-lim, lim),
43
+ scale=1 / voxel_scale,
44
+ name="Re(" + layer_name + ")",
45
+ )
46
+ if transfer_function.dtype == torch.complex64:
47
+ viewer.add_image(
48
+ torch.fft.ifftshift(
49
+ torch.imag(transfer_function), dim=shift_dims
50
+ )
51
+ .cpu()
52
+ .numpy(),
53
+ colormap="bwr",
54
+ contrast_limits=(-lim, lim),
55
+ scale=1 / voxel_scale,
56
+ name="Im(" + layer_name + ")",
57
+ )
58
+
59
+ viewer.dims.current_step = (0,) * (transfer_function.ndim - 3) + (
60
+ zyx_shape[0] // 2,
61
+ zyx_shape[1] // 2,
62
+ zyx_shape[2] // 2,
63
+ )
64
+
65
+ # Show XZ view by default, and only allow rolling between XY and XZ
66
+ viewer.dims.order = list(range(transfer_function.ndim - 3)) + [
67
+ transfer_function.ndim - 2,
68
+ transfer_function.ndim - 3,
69
+ transfer_function.ndim - 1,
70
+ ]
71
+ viewer.dims.rollable = (False,) * (transfer_function.ndim - 3) + (
72
+ True,
73
+ True,
74
+ False,
75
+ )
76
+ viewer.dims.axis_labels = ("DATA", "OBJECT", "Z", "Y", "X")
@@ -0,0 +1,30 @@
1
+ import matplotlib.colors as mcolors
2
+ import numpy as np
3
+
4
+
5
+ # Main function to convert a complex-valued torch tensor to RGB numpy array
6
+ # with red at +1, green at +i, blue at -1, and purple at -i
7
+ def complex_tensor_to_rgb(array, saturate_clim_fraction=1.0):
8
+ # Calculate magnitude and phase for the entire array
9
+ magnitude = np.abs(array)
10
+ phase = np.angle(array)
11
+
12
+ # Normalize phase to [0, 1]
13
+ hue = (phase + np.pi) / (2 * np.pi)
14
+ hue = np.mod(hue + 0.5, 1)
15
+
16
+ # Normalize magnitude to [0, 1] for saturation
17
+ if saturate_clim_fraction is not None:
18
+ max_abs_val = np.amax(magnitude) * saturate_clim_fraction
19
+ else:
20
+ max_abs_val = 1.0
21
+
22
+ sat = magnitude / max_abs_val if max_abs_val != 0 else magnitude
23
+
24
+ # Create HSV array: hue, saturation, value (value is set to 1)
25
+ hsv = np.stack((hue, sat, np.ones_like(sat)), axis=-1)
26
+
27
+ # Convert the entire HSV array to RGB using vectorized conversion
28
+ rgb_array = mcolors.hsv_to_rgb(hsv)
29
+
30
+ return rgb_array
@@ -1,14 +1,15 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
1
  import itertools
4
2
  import time
5
- import os
6
- from numpy.fft import fft, ifft, fft2, ifft2, fftn, ifftn, fftshift, ifftshift
3
+ import warnings
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
7
  from IPython import display
8
- from scipy.ndimage import uniform_filter
9
- from .util import *
10
- from .optics import *
8
+ from numpy.fft import fft2, fftn, fftshift, ifft, ifft2, ifftn, ifftshift
9
+
11
10
  from .background_estimator import *
11
+ from .optics import *
12
+ from .util import *
12
13
 
13
14
 
14
15
  def intensity_mapping(img_stack):
@@ -160,8 +161,8 @@ def instrument_matrix_calibration(I_cali_norm, I_meas):
160
161
 
161
162
 
162
163
  class waveorder_microscopy:
163
-
164
164
  """
165
+ DEPRECATED: Please see `waveorder.models` for maintained alternatives.
165
166
 
166
167
  waveorder_microscopy contains reconstruction algorithms for label-free
167
168
  microscopy with various types of dataset:
@@ -367,6 +368,10 @@ class waveorder_microscopy:
367
368
  initialize the system parameters for phase and orders microscopy
368
369
 
369
370
  """
371
+ warnings.warn(
372
+ "Please see `waveorder.models` for maintained alternatives.",
373
+ category=DeprecationWarning,
374
+ )
370
375
 
371
376
  t0 = time.time()
372
377
 
@@ -732,9 +737,7 @@ class waveorder_microscopy:
732
737
  wave_vec_norm_x = self.lambda_illu * self.fxx
733
738
  wave_vec_norm_y = self.lambda_illu * self.fyy
734
739
  wave_vec_norm_z = (
735
- np.maximum(
736
- 0, 1 - wave_vec_norm_x**2 - wave_vec_norm_y**2
737
- )
740
+ np.maximum(0, 1 - wave_vec_norm_x**2 - wave_vec_norm_y**2)
738
741
  ) ** (0.5)
739
742
 
740
743
  incident_theta = np.arctan2(
@@ -1471,8 +1474,9 @@ class waveorder_microscopy:
1471
1474
  torch.tensor(z.astype("complex64").transpose((2, 1, 0))),
1472
1475
  torch.tensor(self.psz),
1473
1476
  )
1474
- return H_re.numpy().transpose((1, 2, 0)), H_im.numpy().transpose(
1475
- (1, 2, 0)
1477
+ return (
1478
+ H_re.numpy().transpose((1, 2, 0)),
1479
+ H_im.numpy().transpose((1, 2, 0)),
1476
1480
  )
1477
1481
 
1478
1482
  for i in range(self.N_pattern):
@@ -4017,9 +4021,7 @@ class fluorescence_microscopy:
4017
4021
  S1_stack = cp.array(S1_stack)
4018
4022
  S2_stack = cp.array(S2_stack)
4019
4023
 
4020
- anisotropy = cp.asnumpy(
4021
- 0.5 * cp.sqrt(S1_stack**2 + S2_stack**2)
4022
- )
4024
+ anisotropy = cp.asnumpy(0.5 * cp.sqrt(S1_stack**2 + S2_stack**2))
4023
4025
  orientation = cp.asnumpy(
4024
4026
  (0.5 * cp.arctan2(S2_stack, S1_stack)) % np.pi
4025
4027
  )
@@ -1,13 +1,13 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
1
  import itertools
4
2
  import time
5
- import os
6
- import torch
7
- from numpy.fft import fft, ifft, fft2, ifft2, fftn, ifftn, fftshift, ifftshift
8
3
  from concurrent.futures import ProcessPoolExecutor
9
- from .util import *
4
+
5
+ import numpy as np
6
+ import torch
7
+ from numpy.fft import fft2, fftn, fftshift, ifft2, ifftn, ifftshift
8
+
10
9
  from .optics import *
10
+ from .util import *
11
11
 
12
12
 
13
13
  def Jones_PC_forward_model(