waveorder 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
waveorder/optics.py CHANGED
@@ -133,7 +133,7 @@ def generate_pupil(frr, NA, lamb_in):
133
133
  numerical aperture of the pupil function (normalized by the refractive index of the immersion media)
134
134
 
135
135
  lamb_in : float
136
- wavelength of the light (inside the immersion media)
136
+ wavelength of the light in free space
137
137
  in units of length (inverse of frr's units)
138
138
 
139
139
  Returns
@@ -225,6 +225,103 @@ def gen_sector_Pupil(fxx, fyy, NA, lamb_in, sector_angle, rotation_angle):
225
225
  return Pupil_sector
226
226
 
227
227
 
228
+ def rotation_matrix(nu_z, nu_y, nu_x, wavelength):
229
+ nu_perp_squared = nu_x**2 + nu_y**2
230
+ nu_zz = wavelength * nu_z - 1
231
+
232
+ R_xx = (wavelength * nu_x**2 * nu_z + nu_y**2) / nu_perp_squared
233
+ R_yy = (wavelength * nu_y**2 * nu_z + nu_x**2) / nu_perp_squared
234
+ R_xy = nu_x * nu_y * nu_zz / nu_perp_squared
235
+
236
+ row0 = torch.stack((-wavelength * nu_y, -wavelength * nu_x), dim=0)
237
+ row1 = torch.stack((R_yy, R_xy), dim=0)
238
+ row2 = torch.stack((R_xy, R_xx), dim=0)
239
+
240
+ out = torch.stack((row0, row1, row2), dim=0)
241
+
242
+ # KLUDGE: fix the DC term manually, avoiding nan
243
+ out[..., 0, 0] = torch.tensor([[0, 0], [1, 0], [0, 1]])[..., None]
244
+
245
+ return torch.nan_to_num(out, nan=0.0)
246
+
247
+
248
+ def generate_vector_source_defocus_pupil(
249
+ x_frequencies,
250
+ y_frequencies,
251
+ z_position_list,
252
+ defocus_pupil,
253
+ input_jones,
254
+ ill_pupil,
255
+ wavelength,
256
+ ):
257
+ ill_pupil_3d = torch.einsum(
258
+ "zyx,yx->zyx", torch.fft.fft(defocus_pupil, dim=0), ill_pupil
259
+ ).abs() # make this real
260
+
261
+ freq_shape = z_position_list.shape + x_frequencies.shape
262
+
263
+ y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape)
264
+ x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape)
265
+ z_broadcast = np.sqrt(wavelength ** (-2) - x_broadcast**2 - y_broadcast**2)
266
+
267
+ # Calculate rotation matrix
268
+ rotations = rotation_matrix(
269
+ z_broadcast, y_broadcast, x_broadcast, wavelength
270
+ ).type(torch.complex64)
271
+
272
+ # TEMPORARY SIMPLIFY ROTATIONS "TURN OFF ROTATIONS"
273
+ # 3x2 IDENTITY MATRIX
274
+ rotations = torch.zeros_like(rotations)
275
+ rotations[1, 0, ...] = 1
276
+ rotations[2, 1, ...] = 1
277
+
278
+ # Main calculation in the frequency domain
279
+ source_pupil = torch.einsum(
280
+ "ijzyx,j,zyx->izyx", rotations, input_jones, ill_pupil_3d
281
+ )
282
+
283
+ # Convert back to defocus pupil
284
+ source_defocus_pupil = torch.fft.ifft(source_pupil, dim=-3)
285
+
286
+ return source_defocus_pupil
287
+
288
+
289
+ def generate_vector_detection_defocus_pupil(
290
+ x_frequencies,
291
+ y_frequencies,
292
+ z_position_list,
293
+ det_defocus_pupil,
294
+ det_pupil,
295
+ wavelength,
296
+ ):
297
+ # TODO: refactor redundancy with illumination pupil
298
+ det_pupil_3d = torch.einsum(
299
+ "zyx,yx->zyx", torch.fft.ifft(det_defocus_pupil, dim=0), det_pupil
300
+ )
301
+
302
+ # Calculate zyx_frequency grid (inelegant)
303
+ z_frequencies = torch.fft.ifft(z_position_list)
304
+ freq_shape = z_frequencies.shape + x_frequencies.shape
305
+ z_broadcast = torch.broadcast_to(z_frequencies[:, None, None], freq_shape)
306
+ y_broadcast = torch.broadcast_to(y_frequencies[None, :, :], freq_shape)
307
+ x_broadcast = torch.broadcast_to(x_frequencies[None, :, :], freq_shape)
308
+
309
+ # Calculate rotation matrix
310
+ rotations = rotation_matrix(
311
+ z_broadcast, y_broadcast, x_broadcast, wavelength
312
+ ).type(torch.complex64)
313
+
314
+ # Main calculation in the frequency domain
315
+ vector_detection_pupil = torch.einsum(
316
+ "jizyx,zyx->ijzyx", rotations, det_pupil_3d
317
+ )
318
+
319
+ # Convert back to defocus pupil
320
+ detection_defocus_pupil = torch.fft.fft(vector_detection_pupil, dim=-3)
321
+
322
+ return detection_defocus_pupil
323
+
324
+
228
325
  def Source_subsample(Source_cont, NAx_coord, NAy_coord, subsampled_NA=0.1):
229
326
  """
230
327
 
@@ -270,7 +367,7 @@ def Source_subsample(Source_cont, NAx_coord, NAy_coord, subsampled_NA=0.1):
270
367
  illu_list.append(i)
271
368
  first_idx = False
272
369
  elif (
273
- np.product(
370
+ np.prod(
274
371
  (NAx_list[i] - NAx_list[illu_list]) ** 2
275
372
  + (NAy_list[i] - NAy_list[illu_list]) ** 2
276
373
  >= subsampled_NA**2
@@ -300,7 +397,7 @@ def generate_propagation_kernel(
300
397
  wavelength : float
301
398
  wavelength of the light in the immersion media
302
399
 
303
- z_position_list : torch.tensor or list
400
+ z_position_list : torch.tensor
304
401
  1D array of defocused z positions with the size of (Z)
305
402
 
306
403
  Returns
@@ -310,15 +407,16 @@ def generate_propagation_kernel(
310
407
 
311
408
  """
312
409
 
313
- oblique_factor = (
314
- (1 - wavelength**2 * radial_frequencies**2) * pupil_support
315
- ) ** (1 / 2) / wavelength
410
+ oblique_factor = ((1 - wavelength**2 * radial_frequencies**2)) ** (
411
+ 1 / 2
412
+ ) / wavelength
413
+ oblique_factor = torch.nan_to_num(oblique_factor, nan=0.0)
316
414
 
317
415
  propagation_kernel = pupil_support[None, :, :] * torch.exp(
318
416
  1j
319
417
  * 2
320
418
  * np.pi
321
- * torch.tensor(z_position_list)[:, None, None]
419
+ * z_position_list[:, None, None]
322
420
  * oblique_factor[None, :, :]
323
421
  )
324
422
 
@@ -326,7 +424,11 @@ def generate_propagation_kernel(
326
424
 
327
425
 
328
426
  def generate_greens_function_z(
329
- radial_frequencies, pupil_support, wavelength_illumination, z_position_list
427
+ radial_frequencies,
428
+ pupil_support,
429
+ wavelength_illumination,
430
+ z_position_list,
431
+ axially_even=True,
330
432
  ):
331
433
  """
332
434
 
@@ -343,9 +445,14 @@ def generate_greens_function_z(
343
445
  wavelength_illumination : float
344
446
  wavelength of the light in the immersion media
345
447
 
346
- z_position_list : torch.tensor or list
448
+ z_position_list : torch.tensor
347
449
  1D array of defocused z position with the size of (Z,)
348
450
 
451
+ axially_even : bool
452
+ For backwards compatibility with legacy phase reconstruction.
453
+ Ideally the legacy phase reconstruction should be unified with
454
+ the new reconstructions, and this parameter should be removed.
455
+
349
456
  Returns
350
457
  -------
351
458
  greens_function_z : torch.tensor
@@ -358,47 +465,97 @@ def generate_greens_function_z(
358
465
  * pupil_support
359
466
  ) ** (1 / 2) / wavelength_illumination
360
467
 
468
+ if axially_even:
469
+ z_positions = torch.abs(z_position_list[:, None, None])
470
+ else:
471
+ z_positions = z_position_list[:, None, None]
472
+
361
473
  greens_function_z = (
362
474
  -1j
363
475
  / 4
364
476
  / np.pi
365
477
  * pupil_support[None, :, :]
366
- * torch.exp(
367
- 1j
368
- * 2
369
- * np.pi
370
- * torch.tensor(z_position_list)[:, None, None]
371
- * oblique_factor[None, :, :]
372
- )
478
+ * torch.exp(1j * 2 * np.pi * z_positions * oblique_factor[None, :, :])
373
479
  / (oblique_factor[None, :, :] + 1e-15)
374
480
  )
375
481
 
376
482
  return greens_function_z
377
483
 
378
484
 
379
- def gen_dyadic_Greens_tensor_z(fxx, fyy, G_fun_z, Pupil_support, lambda_in):
485
+ def generate_defocus_greens_tensor(
486
+ fxx, fyy, G_fun_z, Pupil_support, lambda_in
487
+ ):
380
488
  """
381
489
 
382
490
  generate forward dyadic Green's function in u_x, u_y, z space
383
491
 
384
492
  Parameters
385
493
  ----------
386
- fxx : numpy.ndarray
494
+ fxx : tensor.Tensor
387
495
  x component of 2D spatial frequency array with the size of (Ny, Nx)
388
496
 
389
- fyy : numpy.ndarray
497
+ fyy : tensor.Tensor
390
498
  y component of 2D spatial frequency array with the size of (Ny, Nx)
391
499
 
392
- G_fun_z : numpy.ndarray
393
- forward Green's function in u_x, u_y, z space with size of (Ny, Nx, Nz)
500
+ G_fun_z : tensor.Tensor
501
+ forward Green's function in u_x, u_y, z space with size of (Nz, Ny, Nx)
394
502
 
395
- Pupil_support : numpy.ndarray
503
+ Pupil_support : tensor.Tensor
396
504
  the array that defines the support of the pupil function with the size of (Ny, Nx)
397
505
 
398
506
  lambda_in : float
399
507
  wavelength of the light in the immersion media
400
508
 
401
509
  Returns
510
+ -------
511
+ G_tensor_z : tensor.Tensor
512
+ forward dyadic Green's function in u_x, u_y, z space with the size of (3, 3, Nz, Ny, Nx)
513
+ """
514
+
515
+ fr = (fxx**2 + fyy**2) ** (1 / 2)
516
+ oblique_factor = ((1 - lambda_in**2 * fr**2) * Pupil_support) ** (
517
+ 1 / 2
518
+ ) / lambda_in
519
+
520
+ diff_filter = torch.zeros((3,) + G_fun_z.shape, dtype=torch.complex64)
521
+ diff_filter[0] = (1j * 2 * np.pi * oblique_factor)[None, ...]
522
+ diff_filter[1] = (1j * 2 * np.pi * fyy * Pupil_support)[None, ...]
523
+ diff_filter[2] = (1j * 2 * np.pi * fxx * Pupil_support)[None, ...]
524
+
525
+ G_tensor_z = torch.zeros((3, 3) + G_fun_z.shape, dtype=torch.complex64)
526
+
527
+ for i in range(3):
528
+ for j in range(3):
529
+ G_tensor_z[i, j] = (
530
+ G_fun_z
531
+ * diff_filter[i]
532
+ * diff_filter[j]
533
+ / (2 * np.pi / lambda_in) ** 2
534
+ )
535
+ if i == j:
536
+ G_tensor_z[i, i] += G_fun_z
537
+
538
+ return G_tensor_z
539
+
540
+
541
+ def gen_dyadic_Greens_tensor_z(fxx, fyy, G_fun_z, Pupil_support, lambda_in):
542
+ """
543
+ keeping for backwards compatibility
544
+
545
+ generate forward dyadic Green's function in u_x, u_y, z space
546
+ Parameters
547
+ ----------
548
+ fxx : numpy.ndarray
549
+ x component of 2D spatial frequency array with the size of (Ny, Nx)
550
+ fyy : numpy.ndarray
551
+ y component of 2D spatial frequency array with the size of (Ny, Nx)
552
+ G_fun_z : numpy.ndarray
553
+ forward Green's function in u_x, u_y, z space with size of (Ny, Nx, Nz)
554
+ Pupil_support : numpy.ndarray
555
+ the array that defines the support of the pupil function with the size of (Ny, Nx)
556
+ lambda_in : float
557
+ wavelength of the light in the immersion media
558
+ Returns
402
559
  -------
403
560
  G_tensor_z : numpy.ndarray
404
561
  forward dyadic Green's function in u_x, u_y, z space with the size of (3, 3, Ny, Nx, Nz)
@@ -427,7 +584,6 @@ def gen_dyadic_Greens_tensor_z(fxx, fyy, G_fun_z, Pupil_support, lambda_in):
427
584
  )
428
585
  if i == j:
429
586
  G_tensor_z[i, i] += G_fun_z
430
-
431
587
  return G_tensor_z
432
588
 
433
589
 
@@ -561,6 +717,60 @@ def gen_dyadic_Greens_tensor(G_real, ps, psz, lambda_in, space="real"):
561
717
  )
562
718
 
563
719
 
720
+ def generate_greens_tensor_spectrum(
721
+ zyx_shape,
722
+ zyx_pixel_size,
723
+ wavelength,
724
+ ):
725
+ """
726
+ Parameters
727
+ ----------
728
+ zyx_shape : tuple
729
+ zyx_pixel_size : tuple
730
+ wavelength : float
731
+ wavelength in medium
732
+
733
+ Returns
734
+ -------
735
+ torch.tensor
736
+ Green's tensor spectrum
737
+ """
738
+ Z, Y, X = zyx_shape
739
+ dZ, dY, dX = zyx_pixel_size
740
+
741
+ z_step = torch.fft.ifftshift(
742
+ (torch.arange(Z) - Z // 2) * dZ
743
+ )
744
+ y_step = torch.fft.ifftshift((torch.arange(Y) - Y // 2) * dY)
745
+ x_step = torch.fft.ifftshift((torch.arange(X) - X // 2) * dX)
746
+
747
+ zz = torch.broadcast_to(z_step[:, None, None], (Z, Y, X))
748
+ yy = torch.broadcast_to(y_step[None, :, None], (Z, Y, X))
749
+ xx = torch.broadcast_to(x_step[None, None, :], (Z, Y, X))
750
+
751
+ rr = torch.sqrt(xx**2 + yy**2 + zz**2)
752
+ rhat = torch.stack([zz, yy, xx], dim=0) / rr
753
+
754
+ scalar_g = torch.exp(1j * 2 * torch.pi * rr / wavelength) / (
755
+ 4 * torch.pi * rr
756
+ )
757
+
758
+ eye = torch.zeros((3, 3, Z, Y, X))
759
+ eye[0, 0] = 1
760
+ eye[1, 1] = 1
761
+ eye[2, 2] = 1
762
+
763
+ Q = eye - torch.einsum("izyx,jzyx->ijzyx", rhat, rhat)
764
+ g_3d = Q * scalar_g
765
+ g_3d = torch.nan_to_num(g_3d)
766
+
767
+ G_3D = torch.fft.fftn(g_3d, dim=(-3, -2, -1))
768
+ G_3D = torch.imag(G_3D) * 1j
769
+ G_3D /= torch.amax(torch.abs(G_3D))
770
+
771
+ return G_3D
772
+
773
+
564
774
  def compute_weak_object_transfer_function_2d(
565
775
  illumination_pupil, detection_pupil
566
776
  ):
@@ -739,19 +949,23 @@ def compute_weak_object_transfer_function_3D(
739
949
 
740
950
  H1 = torch.fft.ifft2(torch.conj(SPHz_hat) * PG_hat, dim=(1, 2))
741
951
  H1 = H1 * window[:, None, None]
742
- H1 = torch.fft.fft(H1, dim=0) * z_pixel_size
952
+ H1 = torch.fft.fft(H1, dim=0)
743
953
 
744
954
  H2 = torch.fft.ifft2(SPHz_hat * torch.conj(PG_hat), dim=(1, 2))
745
955
  H2 = H2 * window[:, None, None]
746
- H2 = torch.fft.fft(H2, dim=0) * z_pixel_size
956
+ H2 = torch.fft.fft(H2, dim=0)
747
957
 
748
- I_norm = torch.sum(
958
+ direct_intensity = torch.sum(
749
959
  illumination_pupil_support
750
960
  * detection_pupil
751
961
  * torch.conj(detection_pupil)
752
962
  )
753
- real_potential_transfer_function = (H1 + H2) / I_norm
754
- imag_potential_transfer_function = 1j * (H1 - H2) / I_norm
963
+ real_potential_transfer_function = (H1 + H2) / direct_intensity
964
+ imag_potential_transfer_function = 1j * (H1 - H2) / direct_intensity
965
+
966
+ # Discretization factor for unitless input and output
967
+ real_potential_transfer_function *= z_pixel_size
968
+ imag_potential_transfer_function *= z_pixel_size
755
969
 
756
970
  return real_potential_transfer_function, imag_potential_transfer_function
757
971
 
waveorder/sampling.py ADDED
@@ -0,0 +1,94 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def transverse_nyquist(
6
+ wavelength_emission,
7
+ numerical_aperture_illumination,
8
+ numerical_aperture_detection,
9
+ ):
10
+ """Transverse Nyquist sample spacing in `wavelength_emission` units.
11
+
12
+ For widefield label-free imaging, the transverse Nyquist sample spacing is
13
+ lambda / (2 * (NA_ill + NA_det)).
14
+
15
+ Perhaps surprisingly, the transverse Nyquist sample spacing for widefield
16
+ fluorescence is lambda / (4 * NA), which is equivalent to the above formula
17
+ when NA_ill = NA_det.
18
+
19
+ Parameters
20
+ ----------
21
+ wavelength_emission : float
22
+ Output units match these units
23
+ numerical_aperture_illumination : float
24
+ For widefield fluorescence, set to numerical_aperture_detection
25
+ numerical_aperture_detection : float
26
+
27
+ Returns
28
+ -------
29
+ float
30
+ Transverse Nyquist sample spacing
31
+
32
+ """
33
+ return wavelength_emission / (
34
+ 2 * (numerical_aperture_detection + numerical_aperture_illumination)
35
+ )
36
+
37
+
38
+ def axial_nyquist(
39
+ wavelength_emission,
40
+ numerical_aperture_detection,
41
+ index_of_refraction_media,
42
+ ):
43
+ """Axial Nyquist sample spacing in `wavelength_emission` units.
44
+
45
+ For widefield microscopes, the axial Nyquist cutoff frequency is:
46
+
47
+ (n/lambda) - sqrt( (n/lambda)^2 - (NA_det/lambda)^2 ),
48
+
49
+ and the axial Nyquist sample spacing is 1 / (2 * cutoff_frequency).
50
+
51
+ Perhaps surprisingly, the axial Nyquist sample spacing is independent of
52
+ the illumination numerical aperture.
53
+
54
+ Parameters
55
+ ----------
56
+ wavelength_emission : float
57
+ Output units match these units
58
+ numerical_aperture_detection : float
59
+ index_of_refraction_media: float
60
+
61
+ Returns
62
+ -------
63
+ float
64
+ Axial Nyquist sample spacing
65
+
66
+ """
67
+ n_on_lambda = index_of_refraction_media / wavelength_emission
68
+ cutoff_frequency = n_on_lambda - np.sqrt(
69
+ n_on_lambda**2
70
+ - (numerical_aperture_detection / wavelength_emission) ** 2
71
+ )
72
+ return 1 / (2 * cutoff_frequency)
73
+
74
+
75
+ def nd_fourier_central_cuboid(source, target_shape):
76
+ """Central cuboid of an N-D Fourier transform.
77
+
78
+ Parameters
79
+ ----------
80
+ source : torch.Tensor
81
+ Source tensor
82
+ target_shape : tuple of int
83
+
84
+ Returns
85
+ -------
86
+ torch.Tensor
87
+ Center cuboid in Fourier space
88
+
89
+ """
90
+ center_slices = tuple(
91
+ slice((s - o) // 2, (s - o) // 2 + o)
92
+ for s, o in zip(source.shape, target_shape)
93
+ )
94
+ return torch.fft.ifftshift(torch.fft.fftshift(source)[center_slices])
waveorder/util.py CHANGED
@@ -331,12 +331,15 @@ def gen_coordinate(img_dim, ps):
331
331
  return (xx, yy, fxx, fyy)
332
332
 
333
333
 
334
- def generate_radial_frequencies(img_dim, ps):
334
+ def generate_frequencies(img_dim, ps):
335
335
  fy = torch.fft.fftfreq(img_dim[0], ps)
336
336
  fx = torch.fft.fftfreq(img_dim[1], ps)
337
-
338
337
  fyy, fxx = torch.meshgrid(fy, fx, indexing="ij")
338
+ return fyy, fxx
339
339
 
340
+
341
+ def generate_radial_frequencies(img_dim, ps):
342
+ fyy, fxx = generate_frequencies(img_dim, ps)
340
343
  return torch.sqrt(fyy**2 + fxx**2)
341
344
 
342
345
 
@@ -2239,3 +2242,52 @@ def orientation_3D_continuity_map(
2239
2242
  retardance_pr_avg /= np.max(retardance_pr_avg)
2240
2243
 
2241
2244
  return retardance_pr_avg
2245
+
2246
+
2247
+ def pauli():
2248
+ # yx order
2249
+ # trace-orthogonal normalization
2250
+ # torch.einsum("kij,lji->kl", pauli(), pauli()) == torch.eye(4)
2251
+
2252
+ # intensity, x-y, +45-(-45), LCP-RCP
2253
+ # yx
2254
+ # yx
2255
+ a = 2**-0.5
2256
+ sigma = torch.tensor(
2257
+ [
2258
+ [[a, 0], [0, a]],
2259
+ [[-a, 0], [0, a]],
2260
+ [[0, a], [a, 0]],
2261
+ [[0, 1j * a], [-1j * a, 0]],
2262
+ ]
2263
+ )
2264
+ return sigma
2265
+
2266
+
2267
+ def gellmann():
2268
+ # zyx order
2269
+ # trace-orthogonal normalization
2270
+ # torch.einsum("kij,lji->kl", gellmann(), gellmann()) == torch.eye(9)
2271
+ #
2272
+ # lexicographical order of the Gell-Mann matrices
2273
+ # 00, 1-1, 10, 11, 2-2, 2-1, 20, 21, 22
2274
+ #
2275
+ # zyx
2276
+ # zyx
2277
+ a = 3**-0.5
2278
+ c = 2**-0.5
2279
+ d = -(6**-0.5)
2280
+ e = 2 * (6**-0.5)
2281
+ return torch.tensor(
2282
+ [
2283
+ [[a, 0, 0], [0, a, 0], [0, 0, a]],
2284
+ [[0, 0, -c], [0, 0, 0], [c, 0, 0]],
2285
+ [[0, 0, 0], [0, 0, -c], [0, c, 0]],
2286
+ [[0, -c, 0], [c, 0, 0], [0, 0, 0]],
2287
+ [[0, 0, 0], [0, 0, c], [0, c, 0]], #
2288
+ [[0, c, 0], [c, 0, 0], [0, 0, 0]],
2289
+ [[e, 0, 0], [0, d, 0], [0, 0, d]],
2290
+ [[0, 0, c], [0, 0, 0], [c, 0, 0]],
2291
+ [[0, 0, 0], [0, -c, 0], [0, 0, c]], #
2292
+ ], dtype=torch.complex64
2293
+ )
@@ -8,11 +8,7 @@ from ipywidgets import (
8
8
  Image,
9
9
  Layout,
10
10
  interact,
11
- interactive,
12
- fixed,
13
- interact_manual,
14
11
  HBox,
15
- VBox,
16
12
  )
17
13
  from matplotlib.colors import hsv_to_rgb
18
14
  from matplotlib.colors import Normalize
@@ -176,7 +172,7 @@ def image_stack_viewer_fast(
176
172
  else:
177
173
  raise ValueError('origin can only be either "upper" or "lower"')
178
174
 
179
- im_wgt = Image(
175
+ im_wgt = Image(
180
176
  value=im_dict[0],
181
177
  layout=Layout(height=str(size[0]) + "px", width=str(size[1]) + "px"),
182
178
  )
@@ -1928,4 +1924,4 @@ def orientation_3D_hist(
1928
1924
  if colorbar:
1929
1925
  fig.colorbar(img, ax=ax[row_idx, col_idx])
1930
1926
 
1931
- return fig, ax
1927
+ return fig, ax