nabu 2025.1.0.dev13__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/cast_volume.py +12 -1
  3. nabu/app/cli_configs.py +81 -4
  4. nabu/app/estimate_motion.py +54 -0
  5. nabu/app/multicor.py +2 -4
  6. nabu/app/pcaflats.py +116 -0
  7. nabu/app/reconstruct.py +1 -7
  8. nabu/app/reduce_dark_flat.py +5 -2
  9. nabu/estimation/cor.py +1 -1
  10. nabu/estimation/motion.py +557 -0
  11. nabu/estimation/tests/test_motion_estimation.py +471 -0
  12. nabu/estimation/tilt.py +1 -1
  13. nabu/estimation/translation.py +47 -1
  14. nabu/io/cast_volume.py +94 -13
  15. nabu/io/reader.py +32 -1
  16. nabu/io/tests/test_remove_volume.py +152 -0
  17. nabu/pipeline/config_validators.py +42 -43
  18. nabu/pipeline/estimators.py +255 -0
  19. nabu/pipeline/fullfield/chunked.py +67 -43
  20. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  21. nabu/pipeline/fullfield/nabu_config.py +17 -11
  22. nabu/pipeline/fullfield/processconfig.py +8 -2
  23. nabu/pipeline/fullfield/reconstruction.py +3 -0
  24. nabu/pipeline/params.py +12 -0
  25. nabu/pipeline/tests/test_estimators.py +240 -3
  26. nabu/preproc/ccd.py +53 -3
  27. nabu/preproc/flatfield.py +306 -1
  28. nabu/preproc/shift.py +3 -1
  29. nabu/preproc/tests/test_pcaflats.py +154 -0
  30. nabu/processing/rotation_cuda.py +3 -1
  31. nabu/processing/tests/test_rotation.py +4 -2
  32. nabu/reconstruction/fbp.py +7 -0
  33. nabu/reconstruction/fbp_base.py +31 -7
  34. nabu/reconstruction/fbp_opencl.py +8 -0
  35. nabu/reconstruction/filtering_opencl.py +2 -0
  36. nabu/reconstruction/mlem.py +51 -14
  37. nabu/reconstruction/tests/test_filtering.py +13 -2
  38. nabu/reconstruction/tests/test_mlem.py +91 -62
  39. nabu/resources/dataset_analyzer.py +144 -20
  40. nabu/resources/nxflatfield.py +101 -35
  41. nabu/resources/tests/test_nxflatfield.py +1 -1
  42. nabu/resources/utils.py +16 -10
  43. nabu/stitching/alignment.py +7 -7
  44. nabu/stitching/config.py +22 -20
  45. nabu/stitching/definitions.py +2 -2
  46. nabu/stitching/overlap.py +4 -4
  47. nabu/stitching/sample_normalization.py +5 -5
  48. nabu/stitching/stitcher/post_processing.py +5 -3
  49. nabu/stitching/stitcher/pre_processing.py +24 -20
  50. nabu/stitching/tests/test_config.py +3 -3
  51. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  52. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  53. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  54. nabu/stitching/utils/utils.py +7 -7
  55. nabu/testutils.py +1 -4
  56. nabu/utils.py +13 -0
  57. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
  58. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
  59. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
  60. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
  61. nabu/app/correct_rot.py +0 -62
  62. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  63. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
nabu/preproc/ccd.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
  from ..utils import check_supported
3
+ from scipy.ndimage import binary_dilation
3
4
  from silx.math.medianfilter import medfilt2d
4
5
 
5
6
 
@@ -13,6 +14,7 @@ class CCDFilter:
13
14
  def __init__(
14
15
  self,
15
16
  radios_shape: tuple,
17
+ kernel_size: int = 3,
16
18
  correction_type: str = "median_clip",
17
19
  median_clip_thresh: float = 0.1,
18
20
  abs_diff=False,
@@ -26,6 +28,9 @@ class CCDFilter:
26
28
  radios_shape: tuple
27
29
  A tuple describing the shape of the radios stack, in the form
28
30
  `(n_radios, n_z, n_x)`.
31
+ kernel_size: int
32
+ Size of the kernel for the median filter.
33
+ Default is 3.
29
34
  correction_type: str
30
35
  Correction type for radios ("median_clip", "sigma_clip", ...)
31
36
  median_clip_thresh: float, optional
@@ -48,6 +53,7 @@ class CCDFilter:
48
53
  then this pixel value is set to the median value.
49
54
  """
50
55
  self._set_radios_shape(radios_shape)
56
+ self.kernel_size = kernel_size
51
57
  check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode")
52
58
  self.correction_type = correction_type
53
59
  self.median_clip_thresh = median_clip_thresh
@@ -67,11 +73,11 @@ class CCDFilter:
67
73
  self.shape = (n_z, n_x)
68
74
 
69
75
  @staticmethod
70
- def median_filter(img):
76
+ def median_filter(img, kernel_size=3):
71
77
  """
72
78
  Perform a median filtering on an image.
73
79
  """
74
- return medfilt2d(img, (3, 3), mode="reflect")
80
+ return medfilt2d(img, (kernel_size, kernel_size), mode="reflect")
75
81
 
76
82
  def median_clip_mask(self, img, return_medians=False):
77
83
  """
@@ -85,7 +91,7 @@ class CCDFilter:
85
91
  return_medians: bool, optional
86
92
  Whether to return the median values additionally to the mask.
87
93
  """
88
- median_values = self.median_filter(img)
94
+ median_values = self.median_filter(img, kernel_size=self.kernel_size)
89
95
  if not self.abs_diff:
90
96
  invalid_mask = img >= median_values + self.median_clip_thresh
91
97
  else:
@@ -124,6 +130,50 @@ class CCDFilter:
124
130
 
125
131
  return output
126
132
 
133
+ def dezinger_correction(self, radios, dark=None, nsigma=5, output=None):
134
+ """
135
+ Compute the median clip correction on a radios stack, and propagates the invalid pixels into vert and horiz directions.
136
+
137
+ Parameters
138
+ ----------
139
+ radios: numpy.ndarray
140
+ A radios stack.
141
+ dark: numpy.ndarray, optional
142
+ A dark image. Default is None. If not None, it is subtracted from the radios.
143
+ nsigma: float, optional
144
+ Number of standard deviations to use for the zinger detection.
145
+ Default is 5.
146
+ output: numpy.ndarray, optional
147
+ Output array
148
+ """
149
+ if radios.shape[1:] != self.radios_shape[1:]:
150
+ raise ValueError(f"Expected radios shape {self.radios_shape}, got {radios.shape}")
151
+
152
+ if output is None:
153
+ output = np.copy(radios)
154
+ else:
155
+ output[:] = radios[:]
156
+
157
+ n_radios = radios.shape[0]
158
+ for i in range(n_radios):
159
+ if dark is None:
160
+ dimg = radios[i]
161
+ elif dark.shape == radios.shape[1:]:
162
+ dimg = radios[i] - dark
163
+ else:
164
+ raise ValueError("Dark image shape does not match radios shape.")
165
+
166
+ dimg = radios[i] - dark
167
+ med = self.median_filter(dimg, self.kernel_size)
168
+ err = dimg - med
169
+ ds0 = err.std()
170
+ msk = err > ds0 * nsigma
171
+ gromsk = binary_dilation(msk)
172
+
173
+ output[i] = np.where(gromsk, med, radios[i])
174
+
175
+ return output
176
+
127
177
 
128
178
  class Log:
129
179
  """
nabu/preproc/flatfield.py CHANGED
@@ -1,8 +1,11 @@
1
+ import os
1
2
  from multiprocessing.pool import ThreadPool
2
3
  from bisect import bisect_left
3
4
  import numpy as np
5
+ from tomoscan.io import HDF5File
4
6
  from ..io.reader import load_images_from_dataurl_dict
5
7
  from ..utils import check_supported, deprecated_class, get_num_threads
8
+ from .ccd import CCDFilter
6
9
 
7
10
 
8
11
  class FlatFieldArrays:
@@ -102,7 +105,7 @@ class FlatFieldArrays:
102
105
  self._precompute_flats_indices_weights()
103
106
  self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent)
104
107
  self.distortion_correction = distortion_correction
105
- self.n_threads = min(1, get_num_threads(n_threads) // 2)
108
+ self.n_threads = max(1, get_num_threads(n_threads) // 2)
106
109
 
107
110
  def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
108
111
  self._set_radios_shape(radios_shape)
@@ -450,3 +453,305 @@ class FlatFieldDataUrls(FlatField):
450
453
  radios_srcurrent=radios_srcurrent,
451
454
  flats_srcurrent=flats_srcurrent,
452
455
  )
456
+
457
+
458
+ class PCAFlatsNormalizer:
459
+ """This class implement a flatfield normalization based on a PCA of a series of acauired flatfields.
460
+ The PCA decomposition is handled by a PCAFlatsDecomposer object.
461
+
462
+ This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
463
+
464
+ Code initially written by ID11 @ ESRF staff.
465
+ Jonathan Wright - Implementation based on research paper
466
+ Pedro D. Resende - Added saving and loading from file capabilities
467
+
468
+ Jerome Lesaint - Integrated the solution in Nabu.
469
+
470
+ """
471
+
472
+ def __init__(self, components, dark, mean):
473
+ """Initializes all variables needed to perform the flatfield normalization.
474
+
475
+ Parameters
476
+ -----------
477
+ components: ndarray
478
+ The components of the PCA decomposition.
479
+ dark: ndarray
480
+ The dark image. Should be one single 2D image.
481
+ mean: ndarray
482
+ The mean image of the series of flats.
483
+ """
484
+ ones = np.ones_like(components[0], dtype=np.float32)[np.newaxis] # This comp will account for I0
485
+ self.components = np.concatenate([ones, components], axis=0)
486
+ self.dark = dark
487
+ self.mean = mean
488
+ self.n_threads = max(1, get_num_threads() // 2)
489
+ self._setmask()
490
+ self.ccdfilter = CCDFilter(mean.shape)
491
+
492
+ def _form_lsq_matrix(self):
493
+ """This function form the Least Square matrix, based on the flats components and the mask."""
494
+ self.Amat = np.stack([gg for gg in self.g], axis=1) # JL: this is the matrix for the fit
495
+
496
+ def _setmask(self, prop=0.125):
497
+ """Sets the mask to select where the model is going to be fitted.
498
+
499
+ Parameters
500
+ ----------
501
+ prop: float, default: 0.125
502
+ The proportion of the image width to take on each side of the image as a mask.
503
+
504
+ By default it sets the strips on each side of the frame in the form:
505
+ mask[:, lim:] = True
506
+ mask[:, -lim:] = True
507
+ Where lim = prop * flat.shape[1]
508
+
509
+ If you need a custom mask, see update_mask() method.
510
+ """
511
+
512
+ lim = int(prop * self.mean.shape[1])
513
+ self.mask = np.zeros(self.mean.shape, dtype=bool)
514
+ self.mask[:, :lim] = True
515
+ self.mask[:, -lim:] = True
516
+ self.g = []
517
+ for component in self.components:
518
+ self.g.append(component[self.mask])
519
+
520
+ self._form_lsq_matrix()
521
+
522
+ def update_mask(self, mask: np.ndarray):
523
+ """Method to update the mask with a custom mask in the form of a boolean 2+D array.
524
+
525
+ Paramters
526
+ =========
527
+ mask: np.ndarray of Boolean;
528
+ The array of boolean allows the selection of the region of the image that will be used to fit against the components.
529
+
530
+ It will set the mask, replacing the standard mask created with setmask().
531
+ """
532
+ if mask.dtype == bool:
533
+ self.mask = mask
534
+ self.g = []
535
+ for component in self.components:
536
+ self.g.append(component[self.mask])
537
+
538
+ self._form_lsq_matrix()
539
+ else:
540
+ raise TypeError("Not a boolean array. Will keep the default mask")
541
+
542
+ def normalize_radios(self, projections, mask=None, prop=0.125):
543
+ """This is to keep the flatfield API in the pipeline."""
544
+ self.correct_stack(projections, mask=mask, prop=prop)
545
+
546
+ def correct_stack(
547
+ self,
548
+ projections: np.ndarray,
549
+ mask: np.ndarray = None,
550
+ prop: float = 0.125,
551
+ ):
552
+ """This functions normalizes the stack of projections.
553
+
554
+ Performs correction on a stack of projections based on the calculated decomposition.
555
+ The normalizations is done in-place. The previous projections before normalization are lost.
556
+
557
+ Parameters
558
+ ----------
559
+ projections: ndarray
560
+ Stack of projections to normalize.
561
+
562
+ prop: float (default: {0.125})
563
+ Fraction to mask on the horizontal field of view, assuming vertical rotation axis
564
+
565
+ mask: np.ndarray (default: None)
566
+ Custom mask if your data requires it.
567
+
568
+ Returns
569
+ -------
570
+ corrected projections: np.ndarray
571
+ Flat field corrected images. Note that the returned projections are exp-transformed to fit the pipeline (e.g. to allow for phase retrieval).
572
+ """
573
+ self.projections = self.ccdfilter.dezinger_correction(projections, self.dark)
574
+
575
+ if mask is not None:
576
+ self.update_mask(mask=mask)
577
+ else:
578
+ self._setmask(prop=prop)
579
+
580
+ with ThreadPool(self.n_threads) as tp:
581
+ for i, cor, sol in tp.map(self._readcorrect1, range(len(projections))):
582
+ # solution[i] = sol
583
+ projections[i] = np.exp(-cor)
584
+
585
+ def _readcorrect1(self, ii):
586
+ """Method to allow parallelization of the normalization."""
587
+ corr, s = self.correctproj(self.projections[ii])
588
+ return ii, corr, s
589
+
590
+ def correctproj(self, projection):
591
+ """Performs the correction on one projection of the stack.
592
+
593
+ Parameters
594
+ ----------
595
+ projection: np.ndarray, float
596
+ Radiograph from the acquisition stack.
597
+
598
+ Returns
599
+ -------
600
+ The fitted projection.
601
+ """
602
+ logp = np.log(projection.astype(np.float32) - self.dark)
603
+ corr = self.mean - logp
604
+ # model to be fitted !!
605
+ return self.fit(corr)
606
+
607
+ def fit(self, corr):
608
+ """Fit the (masked) projection to the (masked) components of the PCA decomposition.
609
+
610
+ This is for each projection, so worth optimising ...
611
+ """
612
+ y = corr[self.mask]
613
+ solution = np.linalg.lstsq(self.Amat, y, rcond=None)[0]
614
+ correction = np.einsum("ijk,i->jk", self.components, solution)
615
+ return corr - correction, solution
616
+
617
+
618
+ class PCAFlatsDecomposer:
619
+ """This class implements a PCA decomposition of a serie of acquired flatfields.
620
+ The PCA decomposition is used to normalize the projections through a PCAFLatNormalizer object.
621
+
622
+ This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
623
+
624
+ Code initially written by ID11 @ ESRF staff.
625
+ Jonathan Wright - Implementation based on research paper
626
+ Pedro D. Resende - Added saving and loading from file capabilities
627
+
628
+ Jerome Lesaint - Integrated the solution in Nabu.
629
+ """
630
+
631
+ def __init__(self, flats: np.ndarray, darks: np.ndarray, nsigma=3):
632
+ """
633
+
634
+ Parameters
635
+ -----------
636
+ flats: np.ndarray
637
+ A stack of darks corrected flat field images
638
+ darks: np.ndarray
639
+ An image or stack of images of the dark current images of the camera.
640
+
641
+ Does the log scaling.
642
+ Subtracts mean and does eigenvector decomposition.
643
+ """
644
+ self.n_threads = max(1, get_num_threads() // 2)
645
+ self.flats = np.empty(flats.shape, dtype=np.float32)
646
+
647
+ darks = darks.astype(np.float32)
648
+ if darks.ndim == 3:
649
+ self.dark = np.median(darks, axis=0)
650
+ else:
651
+ self.dark = darks.copy()
652
+ del darks
653
+
654
+ self.nsigma = nsigma
655
+ self.ccdfilter = CCDFilter(self.dark.shape)
656
+ self._ccdfilter_and_log(flats) # Log is taken here (after dezinger)
657
+
658
+ self.mean = np.mean(self.flats, axis=0) # average
659
+
660
+ self.flats = self.flats - self.mean # makes a copy
661
+ self.cov = self.compute_correlation_matrix()
662
+ self.compute_pca()
663
+ self.generate_pca_flats(nsigma=nsigma) # Default nsigma=3
664
+
665
+ def __str__(self):
666
+
667
+ return f"PCA decomposition from flat images. \nThere are {self.components} components created at {self.sigma} level."
668
+
669
+ def _ccdfilter_and_log(self, flats: np.ndarray):
670
+ """Dezinger (substract dark, apply median filter) and takes log of the flat stack"""
671
+
672
+ self.ccdfilter.dezinger_correction(flats, self.dark, nsigma=self.nsigma)
673
+
674
+ with ThreadPool(self.n_threads) as tp:
675
+ for i, frame in enumerate(tp.map(np.log, flats)):
676
+ self.flats[i] = frame
677
+
678
+ @staticmethod
679
+ def ccij(args):
680
+ """Compute the covariance (img[i]*img[j]).sum() / npixels
681
+ It is a wrapper for threading.
682
+ args == i, j, npixels, imgs
683
+ """
684
+
685
+ i, j, NY, imgs = args
686
+ return i, j, np.einsum("ij,ij", imgs[i], imgs[j]) / NY
687
+
688
+ def compute_correlation_matrix(self):
689
+ """Computes an (nflats x nflats) correlation matrix"""
690
+
691
+ N = len(self.flats)
692
+ CC = np.zeros((N, N), float)
693
+ args = [(i, j, N, self.flats) for i in range(N) for j in range(i + 1)]
694
+ with ThreadPool(self.n_threads) as tp:
695
+ for i, j, result in tp.map(self.ccij, args):
696
+ CC[i, j] = CC[j, i] = result
697
+ return CC
698
+
699
+ def compute_pca(self):
700
+ """Gets eigenvectors and eigenvalues and sorts them into order"""
701
+
702
+ self.eigenvalues, self.eigenvectors = np.linalg.eigh(
703
+ self.cov
704
+ ) # Not sure why the eigh is needed. eig should be enough.
705
+ order = np.argsort(abs(self.eigenvalues))[::-1] # high to low
706
+ self.eigenvalues = self.eigenvalues[order]
707
+ self.eigenvectors = self.eigenvectors[:, order]
708
+
709
+ def generate_pca_flats(self, nsigma=3):
710
+ """Projects the eigenvectors back into image space.
711
+
712
+ Parameters
713
+ ----------
714
+ nsigma: int (default: 3)
715
+ """
716
+
717
+ self.sigma = nsigma
718
+ av = abs(self.eigenvalues)
719
+ N = (av > (av[-2] * nsigma)).sum() # Go for 3 sigma
720
+ self.components = [
721
+ None,
722
+ ] * N
723
+
724
+ def calculate(ii):
725
+ calc = np.einsum("i,ijk->jk", self.eigenvectors[:, ii], self.flats)
726
+ norm = (calc**2).sum()
727
+ return ii, calc / np.sqrt(norm)
728
+
729
+ with ThreadPool(self.n_threads) as tp:
730
+ for ii, result in tp.map(calculate, range(N)):
731
+ self.components[ii] = result
732
+
733
+ # simple gradients
734
+ r, c = self.components[0].shape
735
+ self.components.append(np.outer(np.ones(r), np.linspace(-1 / c, 1 / c, c)))
736
+ self.components.append(np.outer(np.linspace(-1 / r, 1 / r, r), np.ones(c)))
737
+
738
+ def save_decomposition(self, path="PCA_flats.h5", overwrite=True, entry="entry0000"):
739
+ """Saves the basic information of a PCA decomposition in view of the normalization of projections.
740
+
741
+ Parameters
742
+ ----------
743
+ path: str (default: "PCA_flats.h5")
744
+ Full path to the h5 file you want to save your results. It will overwrite!! Be careful.
745
+ """
746
+
747
+ file_exists = os.path.exists(path)
748
+ if overwrite or not file_exists:
749
+ with HDF5File(path, "w") as hout:
750
+ group = hout.create_group(entry)
751
+ group["eigenvalues"] = self.eigenvalues
752
+ group["dark"] = self.dark.astype(np.float32)
753
+ group["p_mean"] = self.mean.astype(np.float32)
754
+ group["p_components"] = np.array(self.components)
755
+ hout.flush()
756
+ else:
757
+ raise OSError(f"The file {path} already exists and you chose to NOT overwrite.")
nabu/preproc/shift.py CHANGED
@@ -51,7 +51,7 @@ class VerticalShift:
51
51
  assert np.max(iangles) < len(self.interp_infos)
52
52
  assert len(iangles) == radios.shape[0]
53
53
 
54
- def apply_vertical_shifts(self, radios, iangles, output=None):
54
+ def apply_vertical_shifts(self, radios, iangles=None, output=None):
55
55
  """
56
56
  Parameters
57
57
  ----------
@@ -65,6 +65,8 @@ class VerticalShift:
65
65
  If given, it will be modified to contain the shifted radios.
66
66
  Must be of the same shape of `radios`.
67
67
  """
68
+ if iangles is None:
69
+ iangles = np.arange(radios.shape[0])
68
70
  self._check(radios, iangles)
69
71
 
70
72
  newradio = np.zeros_like(radios[0])
@@ -0,0 +1,154 @@
1
+ import os
2
+ import numpy as np
3
+ import pytest
4
+
5
+ import h5py
6
+ from nabu.testutils import utilstest
7
+ from nabu.preproc.flatfield import (
8
+ PCAFlatsDecomposer,
9
+ PCAFlatsNormalizer,
10
+ )
11
+
12
+
13
+ @pytest.fixture(scope="class")
14
+ def bootstrap_pcaflats(request):
15
+ cls = request.cls
16
+ # TODO: these tolerances for having the tests passed should be tighter.
17
+ # Discrepancies between id11 code and nabu code are still mysterious.
18
+ cls.mean_abs_tol = 1e-1
19
+ cls.comps_abs_tol = 1e-2
20
+ cls.projs, cls.flats, cls.darks = get_pcaflats_data("test_pcaflats.npz")
21
+ cls.raw_projs = cls.projs.copy() # Needed because flat correction is done inplace.
22
+ ref_data = get_pcaflats_refdata("ref_pcaflats.npz")
23
+ cls.mean = ref_data["mean"]
24
+ cls.components_3 = ref_data["components_3"]
25
+ cls.components_15 = ref_data["components_15"]
26
+ cls.dark = ref_data["dark"]
27
+ cls.normalized_projs_3 = ref_data["normalized_projs_3"]
28
+ cls.normalized_projs_15 = ref_data["normalized_projs_15"]
29
+ cls.normalized_projs_custom_mask = ref_data["normalized_projs_custom_mask"]
30
+ cls.test_normalize_projs_custom_prop = ref_data["normalized_projs_custom_prop"]
31
+
32
+ cls.h5_filename_3 = get_h5_pcaflats("pcaflat_3.h5")
33
+ cls.h5_filename_15 = get_h5_pcaflats("pcaflat_15.h5")
34
+
35
+
36
+ def get_pcaflats_data(*dataset_path):
37
+ """
38
+ Get a dataset file from silx.org/pub/nabu/data
39
+ dataset_args is a list describing a nested folder structures, ex.
40
+ ["path", "to", "my", "dataset.h5"]
41
+ """
42
+ dataset_relpath = os.path.join(*dataset_path)
43
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
44
+ data = np.load(dataset_downloaded_path)
45
+ projs = data["projs"].astype(np.float32)
46
+ flats = data["flats"].astype(np.float32)
47
+ darks = data["darks"].astype(np.float32)
48
+
49
+ return projs, flats, darks
50
+
51
+
52
+ def get_h5_pcaflats(*dataset_path):
53
+ """
54
+ Get a dataset file from silx.org/pub/nabu/data
55
+ dataset_args is a list describing a nested folder structures, ex.
56
+ ["path", "to", "my", "dataset.h5"]
57
+ """
58
+ dataset_relpath = os.path.join(*dataset_path)
59
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
60
+
61
+ return dataset_downloaded_path
62
+
63
+
64
+ def get_pcaflats_refdata(*dataset_path):
65
+ """
66
+ Get a dataset file from silx.org/pub/nabu/data
67
+ dataset_args is a list describing a nested folder structures, ex.
68
+ ["path", "to", "my", "dataset.h5"]
69
+ """
70
+ dataset_relpath = os.path.join(*dataset_path)
71
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
72
+ data = np.load(dataset_downloaded_path)
73
+
74
+ return data
75
+
76
+
77
+ def get_decomposition(filename):
78
+ with h5py.File(filename, "r") as f:
79
+ # Load the dataset
80
+ p_comps = f["entry0000/p_components"][()]
81
+ p_mean = f["entry0000/p_mean"][()]
82
+ dark = f["entry0000/dark"][()]
83
+ return p_comps, p_mean, dark
84
+
85
+
86
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
87
+ class TestPCAFlatsDecomposer:
88
+ def test_decompose_flats(self):
89
+ # Build 3-sigma basis
90
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
91
+ message = f"Found a discrepency between computed mean flat and reference."
92
+ assert np.allclose(self.mean, pca.mean, atol=self.mean_abs_tol), message
93
+ message = f"Found a discrepency between computed components and reference ones if nsigma=3."
94
+ assert np.allclose(self.components_3, np.array(pca.components), atol=self.comps_abs_tol), message
95
+
96
+ # Build 1.5-sigma basis
97
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=1.5)
98
+ message = f"Found a discrepency between computed components and reference ones, if nsigma=1.5."
99
+ assert np.allclose(self.components_15, np.array(pca.components), atol=self.comps_abs_tol), message
100
+
101
+ def test_save_load_decomposition(self):
102
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
103
+ tmp_path = os.path.join(os.path.dirname(self.h5_filename_3), "PCA_Flats.h5")
104
+ pca.save_decomposition(path=tmp_path)
105
+ p_comps, p_mean, dark = get_decomposition(tmp_path)
106
+ message = f"Found a discrepency between saved and loaded mean flat."
107
+ assert np.allclose(self.mean, p_mean, atol=self.mean_abs_tol), message
108
+ message = f"Found a discrepency between saved and loaded components if nsigma=3."
109
+ assert np.allclose(self.components_3, p_comps, atol=self.comps_abs_tol), message
110
+ message = f"Found a discrepency between saved and loaded dark."
111
+ assert np.allclose(self.dark, dark, atol=self.comps_abs_tol), message
112
+ # Clean up
113
+ if os.path.exists(tmp_path):
114
+ os.remove(tmp_path)
115
+
116
+
117
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
118
+ class TestPCAFlatsNormalizer:
119
+ def test_load_pcaflats(self):
120
+ """Tests that the structure of the output PCAFlat h5 file is correct."""
121
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
122
+ # Check the shape of the loaded data
123
+ assert p_comps.shape[1:] == p_mean.shape
124
+ assert p_comps.shape[1:] == dark.shape
125
+
126
+ def test_normalize_projs(self):
127
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
128
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
129
+ projs = self.raw_projs.copy()
130
+ pca.normalize_radios(projs)
131
+ assert np.allclose(projs, self.normalized_projs_3, atol=1e-2)
132
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_15)
133
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
134
+ projs = self.raw_projs.copy()
135
+ pca.normalize_radios(projs)
136
+ assert np.allclose(projs, self.normalized_projs_15, atol=1e-2)
137
+
138
+ def test_use_custom_mask(self):
139
+ mask = np.zeros(self.mean.shape, dtype=bool)
140
+ mask[:, :10] = True
141
+ mask[:, -10:] = True
142
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
143
+
144
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
145
+ projs = self.raw_projs.copy()
146
+ pca.normalize_radios(projs, mask=mask)
147
+ assert np.allclose(projs, self.normalized_projs_custom_mask, atol=1e-2)
148
+
149
+ def test_change_mask_prop(self):
150
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
151
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
152
+ projs = self.raw_projs.copy()
153
+ pca.normalize_radios(projs, prop=0.05)
154
+ assert np.allclose(projs, self.test_normalize_projs_custom_prop, atol=1e-2)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
  from .rotation import Rotation
3
3
  from ..utils import get_cuda_srcfile, updiv
4
- from ..cuda.utils import __has_pycuda__, copy_array
4
+ from ..cuda.utils import __has_pycuda__, copy_array, check_textures_availability
5
5
  from ..cuda.processing import CudaProcessing
6
6
 
7
7
  if __has_pycuda__:
@@ -11,6 +11,8 @@ if __has_pycuda__:
11
11
 
12
12
  class CudaRotation(Rotation):
13
13
  def __init__(self, shape, angle, center=None, mode="edge", reshape=False, cuda_options=None, **sk_kwargs):
14
+ if not (check_textures_availability()):
15
+ raise RuntimeError("Need cuda textures for this class")
14
16
  if center is None:
15
17
  center = ((shape[1] - 1) / 2.0, (shape[0] - 1) / 2.0)
16
18
  super().__init__(shape, angle, center=center, mode=mode, reshape=reshape, **sk_kwargs)
@@ -3,7 +3,7 @@ import pytest
3
3
  from nabu.testutils import generate_tests_scenarios
4
4
  from nabu.processing.rotation_cuda import Rotation
5
5
  from nabu.processing.rotation import __have__skimage__
6
- from nabu.cuda.utils import __has_pycuda__, get_cuda_context
6
+ from nabu.cuda.utils import __has_pycuda__, get_cuda_context, check_textures_availability
7
7
 
8
8
  if __have__skimage__:
9
9
  from skimage.transform import rotate
@@ -68,7 +68,9 @@ class TestRotation:
68
68
  res = R(self.image)
69
69
  self._check_result(res, config, 1e-6)
70
70
 
71
- @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda rotation")
71
+ @pytest.mark.skipif(
72
+ not (__has_pycuda__) or not (check_textures_availability()), reason="Need cuda rotation (and textures)"
73
+ )
72
74
  @pytest.mark.parametrize("config", scenarios)
73
75
  def test_cuda_rotation(self, config):
74
76
  R = CudaRotation(
@@ -86,6 +86,13 @@ class CudaBackprojector(BackprojectorBase):
86
86
  self.sino_mult = CudaSinoMult(self.sino_shape, self.rot_center, ctx=self._processing.ctx)
87
87
  self._prepare_textures() # has to be done after compilation for Cuda (to bind texture to built kernel)
88
88
 
89
+ def _get_filter_init_extra_options(self):
90
+ return {
91
+ "cuda_options": {
92
+ "ctx": self._processing.ctx,
93
+ },
94
+ }
95
+
89
96
  def _transfer_to_texture(self, sino, do_checks=True):
90
97
  if do_checks and not (sino.flags.c_contiguous):
91
98
  raise ValueError("Expected C-Contiguous array")