reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
- from reflectorch.inference.preprocess_exp.preprocess import (
2
- standard_preprocessing,
3
- StandardPreprocessing,
4
- )
5
- from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
- from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
1
+ from reflectorch.inference.preprocess_exp.preprocess import (
2
+ standard_preprocessing,
3
+ StandardPreprocessing,
4
+ )
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
7
7
  from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction
@@ -1,36 +1,36 @@
1
- import numpy as np
2
-
3
-
4
- def apply_attenuation_correction(
5
- intensity: np.ndarray,
6
- attenuation: np.ndarray,
7
- scattering_angle: np.ndarray = None,
8
- correct_discontinuities: bool = True
9
- ) -> np.ndarray:
10
- """Applies attenuation correction to experimental reflectivity curves
11
-
12
- Args:
13
- intensity (np.ndarray): intensities of an experimental reflectivity curve
14
- attenuation (np.ndarray): attenuation factors for each measured point
15
- scattering_angle (np.ndarray, optional): scattering angles of the measured points. Defaults to None.
16
- correct_discontinuities (bool, optional): whether to correct discontinuities in the measured curves. Defaults to True.
17
-
18
- Returns:
19
- np.ndarray: the corrected reflectivity curve
20
- """
21
- intensity = intensity / attenuation
22
- if correct_discontinuities:
23
- if scattering_angle is None:
24
- raise ValueError("correct_discontinuities options requires scattering_angle, but scattering_angle is None.")
25
- intensity = apply_discontinuities_correction(intensity, scattering_angle)
26
- return intensity
27
-
28
-
29
- def apply_discontinuities_correction(intensity: np.ndarray, scattering_angle: np.ndarray) -> np.ndarray:
30
- intensity = intensity.copy()
31
- diff_angle = np.diff(scattering_angle)
32
- for i in range(len(diff_angle)):
33
- if diff_angle[i] == 0:
34
- factor = intensity[i] / intensity[i + 1]
35
- intensity[(i + 1):] *= factor
36
- return intensity
1
+ import numpy as np
2
+
3
+
4
+ def apply_attenuation_correction(
5
+ intensity: np.ndarray,
6
+ attenuation: np.ndarray,
7
+ scattering_angle: np.ndarray = None,
8
+ correct_discontinuities: bool = True
9
+ ) -> np.ndarray:
10
+ """Applies attenuation correction to experimental reflectivity curves
11
+
12
+ Args:
13
+ intensity (np.ndarray): intensities of an experimental reflectivity curve
14
+ attenuation (np.ndarray): attenuation factors for each measured point
15
+ scattering_angle (np.ndarray, optional): scattering angles of the measured points. Defaults to None.
16
+ correct_discontinuities (bool, optional): whether to correct discontinuities in the measured curves. Defaults to True.
17
+
18
+ Returns:
19
+ np.ndarray: the corrected reflectivity curve
20
+ """
21
+ intensity = intensity / attenuation
22
+ if correct_discontinuities:
23
+ if scattering_angle is None:
24
+ raise ValueError("correct_discontinuities options requires scattering_angle, but scattering_angle is None.")
25
+ intensity = apply_discontinuities_correction(intensity, scattering_angle)
26
+ return intensity
27
+
28
+
29
+ def apply_discontinuities_correction(intensity: np.ndarray, scattering_angle: np.ndarray) -> np.ndarray:
30
+ intensity = intensity.copy()
31
+ diff_angle = np.diff(scattering_angle)
32
+ for i in range(len(diff_angle)):
33
+ if diff_angle[i] == 0:
34
+ factor = intensity[i] / intensity[i + 1]
35
+ intensity[(i + 1):] *= factor
36
+ return intensity
@@ -1,31 +1,31 @@
1
- import numpy as np
2
-
3
- from reflectorch.utils import angle_to_q
4
-
5
-
6
- def cut_curve(q: np.ndarray, curve: np.ndarray, max_q: float, max_angle: float, wavelength: float):
7
- """Cuts an experimental reflectivity curve at a maximum q position
8
-
9
- Args:
10
- q (np.ndarray): the array of q points
11
- curve (np.ndarray): the experimental reflectivity curve
12
- max_q (float): the maximum q value at which the curve is cut
13
- max_angle (float): the maximum scattering angle at which the curve is cut; only used if max_q is not provided
14
- wavelength (float): the wavelength of the beam
15
-
16
- Returns:
17
- tuple: the q array after cutting, the reflectivity curve after cutting, and the ratio between the maximum q after cutting and before cutting
18
- """
19
- if max_angle is None and max_q is None:
20
- q_ratio = 1.
21
- else:
22
- if max_q is None:
23
- max_q = angle_to_q(max_angle, wavelength)
24
-
25
- q_ratio = max_q / q.max()
26
-
27
- if q_ratio < 1.:
28
- idx = np.argmax(q > max_q)
29
- q = q[:idx] / q_ratio
30
- curve = curve[:idx]
31
- return q, curve, q_ratio
1
+ import numpy as np
2
+
3
+ from reflectorch.utils import angle_to_q
4
+
5
+
6
+ def cut_curve(q: np.ndarray, curve: np.ndarray, max_q: float, max_angle: float, wavelength: float):
7
+ """Cuts an experimental reflectivity curve at a maximum q position
8
+
9
+ Args:
10
+ q (np.ndarray): the array of q points
11
+ curve (np.ndarray): the experimental reflectivity curve
12
+ max_q (float): the maximum q value at which the curve is cut
13
+ max_angle (float): the maximum scattering angle at which the curve is cut; only used if max_q is not provided
14
+ wavelength (float): the wavelength of the beam
15
+
16
+ Returns:
17
+ tuple: the q array after cutting, the reflectivity curve after cutting, and the ratio between the maximum q after cutting and before cutting
18
+ """
19
+ if max_angle is None and max_q is None:
20
+ q_ratio = 1.
21
+ else:
22
+ if max_q is None:
23
+ max_q = angle_to_q(max_angle, wavelength)
24
+
25
+ q_ratio = max_q / q.max()
26
+
27
+ if q_ratio < 1.:
28
+ idx = np.argmax(q > max_q)
29
+ q = q[:idx] / q_ratio
30
+ curve = curve[:idx]
31
+ return q, curve, q_ratio
@@ -1,81 +1,81 @@
1
- try:
2
- from typing import Literal
3
- except ImportError:
4
- from typing_extensions import Literal
5
-
6
- import numpy as np
7
- from scipy.special import erf
8
-
9
- __all__ = [
10
- "apply_footprint_correction",
11
- "remove_footprint_correction",
12
- "BEAM_SHAPE",
13
- ]
14
-
15
-
16
- BEAM_SHAPE = Literal["gauss", "box"]
17
-
18
-
19
- def apply_footprint_correction(
20
- intensity: np.ndarray,
21
- scattering_angle: np.ndarray,
22
- beam_width: float,
23
- sample_length: float,
24
- beam_shape: BEAM_SHAPE = "gauss",
25
- ) -> np.ndarray:
26
- """Applies footprint correction to an experimental reflectivity curve
27
-
28
- Args:
29
- intensity (np.ndarray): reflectivity curve
30
- scattering_angle (np.ndarray): array of scattering angles
31
- beam_width (float): the beam width
32
- sample_length (float): the sample length
33
- beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
34
-
35
- Returns:
36
- np.ndarray: the footprint corrected reflectivity curve
37
- """
38
- factors = _get_factors_by_beam_shape(
39
- scattering_angle, beam_width, sample_length, beam_shape
40
- )
41
- return intensity.copy() * factors
42
-
43
-
44
- def remove_footprint_correction(
45
- intensity: np.ndarray,
46
- scattering_angle: np.ndarray,
47
- beam_width: float,
48
- sample_length: float,
49
- beam_shape: BEAM_SHAPE = "gauss",
50
- ):
51
- factors = _get_factors_by_beam_shape(
52
- scattering_angle, beam_width, sample_length, beam_shape
53
- )
54
- return intensity.copy() / factors
55
-
56
-
57
- def _get_factors_by_beam_shape(
58
- scattering_angle: np.ndarray, beam_width: float, sample_length: float, beam_shape: BEAM_SHAPE
59
- ):
60
- if beam_shape == "gauss":
61
- return gaussian_factors(scattering_angle, beam_width, sample_length)
62
- elif beam_shape == "box":
63
- return box_factors(scattering_angle, beam_width, sample_length)
64
- else:
65
- raise ValueError("invalid beam shape")
66
-
67
-
68
- def box_factors(scattering_angle, beam_width, sample_length):
69
- max_angle = 2 * np.arcsin(beam_width / sample_length) / np.pi * 180
70
- ratios = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
71
- ones = np.ones_like(scattering_angle)
72
- return np.where(scattering_angle < max_angle, ones * ratios, ones)
73
-
74
-
75
- def gaussian_factors(scattering_angle, beam_width, sample_length):
76
- ratio = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
77
- return 1 / erf(np.sqrt(np.log(2)) / ratio)
78
-
79
-
80
- def beam_footprint_ratio(scattering_angle, beam_width, sample_length):
81
- return beam_width / sample_length / np.sin(scattering_angle / 2 * np.pi / 180)
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from scipy.special import erf
8
+
9
+ __all__ = [
10
+ "apply_footprint_correction",
11
+ "remove_footprint_correction",
12
+ "BEAM_SHAPE",
13
+ ]
14
+
15
+
16
+ BEAM_SHAPE = Literal["gauss", "box"]
17
+
18
+
19
+ def apply_footprint_correction(
20
+ intensity: np.ndarray,
21
+ scattering_angle: np.ndarray,
22
+ beam_width: float,
23
+ sample_length: float,
24
+ beam_shape: BEAM_SHAPE = "gauss",
25
+ ) -> np.ndarray:
26
+ """Applies footprint correction to an experimental reflectivity curve
27
+
28
+ Args:
29
+ intensity (np.ndarray): reflectivity curve
30
+ scattering_angle (np.ndarray): array of scattering angles
31
+ beam_width (float): the beam width
32
+ sample_length (float): the sample length
33
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
34
+
35
+ Returns:
36
+ np.ndarray: the footprint corrected reflectivity curve
37
+ """
38
+ factors = _get_factors_by_beam_shape(
39
+ scattering_angle, beam_width, sample_length, beam_shape
40
+ )
41
+ return intensity.copy() * factors
42
+
43
+
44
+ def remove_footprint_correction(
45
+ intensity: np.ndarray,
46
+ scattering_angle: np.ndarray,
47
+ beam_width: float,
48
+ sample_length: float,
49
+ beam_shape: BEAM_SHAPE = "gauss",
50
+ ):
51
+ factors = _get_factors_by_beam_shape(
52
+ scattering_angle, beam_width, sample_length, beam_shape
53
+ )
54
+ return intensity.copy() / factors
55
+
56
+
57
+ def _get_factors_by_beam_shape(
58
+ scattering_angle: np.ndarray, beam_width: float, sample_length: float, beam_shape: BEAM_SHAPE
59
+ ):
60
+ if beam_shape == "gauss":
61
+ return gaussian_factors(scattering_angle, beam_width, sample_length)
62
+ elif beam_shape == "box":
63
+ return box_factors(scattering_angle, beam_width, sample_length)
64
+ else:
65
+ raise ValueError("invalid beam shape")
66
+
67
+
68
+ def box_factors(scattering_angle, beam_width, sample_length):
69
+ max_angle = 2 * np.arcsin(beam_width / sample_length) / np.pi * 180
70
+ ratios = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
71
+ ones = np.ones_like(scattering_angle)
72
+ return np.where(scattering_angle < max_angle, ones * ratios, ones)
73
+
74
+
75
+ def gaussian_factors(scattering_angle, beam_width, sample_length):
76
+ ratio = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
77
+ return 1 / erf(np.sqrt(np.log(2)) / ratio)
78
+
79
+
80
+ def beam_footprint_ratio(scattering_angle, beam_width, sample_length):
81
+ return beam_width / sample_length / np.sin(scattering_angle / 2 * np.pi / 180)
@@ -1,19 +1,19 @@
1
- import numpy as np
2
-
3
-
4
- def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10, logspace = False):
5
- """Interpolate data on a base 10 logarithmic scale
6
-
7
- Args:
8
- q_interp (array-like): reciprocal space points used for the interpolation
9
- q (array-like): reciprocal space points of the measured reflectivity curve
10
- reflectivity (array-like): reflectivity curve measured at the points ``q``
11
- min_value (float, optional): minimum intensity of the reflectivity curve. Defaults to 1e-10.
12
-
13
- Returns:
14
- array-like: interpolated reflectivity curve
15
- """
16
- if not(logspace):
17
- return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
18
- else:
19
- return 10 ** np.interp(np.log10(q_interp), np.log10(q), np.log10(np.clip(reflectivity, min_value, None)))
1
+ import numpy as np
2
+
3
+
4
+ def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10, logspace = False):
5
+ """Interpolate data on a base 10 logarithmic scale
6
+
7
+ Args:
8
+ q_interp (array-like): reciprocal space points used for the interpolation
9
+ q (array-like): reciprocal space points of the measured reflectivity curve
10
+ reflectivity (array-like): reflectivity curve measured at the points ``q``
11
+ min_value (float, optional): minimum intensity of the reflectivity curve. Defaults to 1e-10.
12
+
13
+ Returns:
14
+ array-like: interpolated reflectivity curve
15
+ """
16
+ if not(logspace):
17
+ return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
18
+ else:
19
+ return 10 ** np.interp(np.log10(q_interp), np.log10(q), np.log10(np.clip(reflectivity, min_value, None)))
@@ -1,21 +1,21 @@
1
- try:
2
- from typing import Literal
3
- except ImportError:
4
- from typing_extensions import Literal
5
-
6
- import numpy as np
7
- from numpy import ndarray
8
-
9
- NORMALIZE_MODE = Literal["first", "max", "incoming_intensity"]
10
-
11
-
12
- def intensity2reflectivity(intensity: ndarray, mode: NORMALIZE_MODE, incoming_intensity=None) -> np.ndarray:
13
- if mode == "first":
14
- return intensity / intensity[0]
15
- if mode == "max":
16
- return intensity / intensity.max()
17
- if mode == "incoming_intensity":
18
- if incoming_intensity is None:
19
- raise ValueError("incoming_intensity is None")
20
- return intensity / incoming_intensity
21
- raise ValueError(f"Unknown mode {mode}")
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from numpy import ndarray
8
+
9
+ NORMALIZE_MODE = Literal["first", "max", "incoming_intensity"]
10
+
11
+
12
+ def intensity2reflectivity(intensity: ndarray, mode: NORMALIZE_MODE, incoming_intensity=None) -> np.ndarray:
13
+ if mode == "first":
14
+ return intensity / intensity[0]
15
+ if mode == "max":
16
+ return intensity / intensity.max()
17
+ if mode == "incoming_intensity":
18
+ if incoming_intensity is None:
19
+ raise ValueError("incoming_intensity is None")
20
+ return intensity / incoming_intensity
21
+ raise ValueError(f"Unknown mode {mode}")
@@ -1,121 +1,121 @@
1
- from dataclasses import dataclass
2
-
3
- import numpy as np
4
-
5
- from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
- from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
7
- from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
8
- from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
9
- from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
10
- from reflectorch.utils import angle_to_q
11
-
12
-
13
- def standard_preprocessing(
14
- intensity: np.ndarray,
15
- scattering_angle: np.ndarray,
16
- attenuation: np.ndarray,
17
- q_interp: np.ndarray,
18
- wavelength: float,
19
- beam_width: float,
20
- sample_length: float,
21
- min_intensity: float = 1e-10,
22
- beam_shape: BEAM_SHAPE = "gauss",
23
- normalize_mode: NORMALIZE_MODE = "max",
24
- incoming_intensity: float = None,
25
- max_q: float = None, # if provided, max_angle is ignored
26
- max_angle: float = None,
27
- ) -> dict:
28
- """Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation
29
-
30
- Args:
31
- intensity (np.ndarray): array of intensities of the reflectivity curve
32
- scattering_angle (np.ndarray): array of scattering angles
33
- attenuation (np.ndarray): attenuation factors for each measured point
34
- q_interp (np.ndarray): reciprocal space points used for the interpolation
35
- wavelength (float): the wavelength of the beam
36
- beam_width (float): the beam width
37
- sample_length (float): the sample length
38
- min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10.
39
- beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
40
- normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max".
41
- incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None.
42
- max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None.
43
- max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None.
44
-
45
- Returns:
46
- dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting
47
- """
48
- intensity = apply_attenuation_correction(
49
- intensity,
50
- attenuation,
51
- scattering_angle,
52
- )
53
-
54
- intensity = apply_footprint_correction(
55
- intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape
56
- )
57
-
58
- curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity)
59
-
60
- curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity)
61
-
62
- q = angle_to_q(scattering_angle, wavelength)
63
-
64
- q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength)
65
-
66
- curve_interp = interp_reflectivity(q_interp, q, curve)
67
-
68
- assert np.all(np.isfinite(curve_interp))
69
- assert np.all(np.isfinite(curve))
70
- assert np.all(np.isfinite(q))
71
- assert np.all(np.isfinite(q_interp))
72
- assert np.all(curve > 0.)
73
- assert np.all(curve_interp > 0.)
74
-
75
- return {
76
- "curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio,
77
- }
78
-
79
-
80
- def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7):
81
- indices = (curve > thresh) & np.isfinite(curve)
82
- return curve[indices], scattering_angle[indices]
83
-
84
-
85
- @dataclass
86
- class StandardPreprocessing:
87
- q_interp: np.ndarray = None
88
- wavelength: float = 1.
89
- beam_width: float = None
90
- sample_length: float = None
91
- beam_shape: BEAM_SHAPE = "gauss"
92
- normalize_mode: NORMALIZE_MODE = "max"
93
- incoming_intensity: float = None
94
-
95
- def preprocess(self,
96
- intensity: np.ndarray,
97
- scattering_angle: np.ndarray,
98
- attenuation: np.ndarray,
99
- **kwargs
100
- ) -> dict:
101
- attrs = self._get_updated_attrs(**kwargs)
102
- return standard_preprocessing(
103
- intensity,
104
- scattering_angle,
105
- attenuation,
106
- **attrs
107
- )
108
-
109
- __call__ = preprocess
110
-
111
- def set_parameters(self, **kwargs) -> None:
112
- for k, v in kwargs.items():
113
- if k in self.__annotations__:
114
- setattr(self, k, v)
115
- else:
116
- raise KeyError(f'Unknown parameter {k}.')
117
-
118
- def _get_updated_attrs(self, **kwargs):
119
- current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()}
120
- current_attrs.update(kwargs)
121
- return current_attrs
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
7
+ from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
8
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
9
+ from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
10
+ from reflectorch.utils import angle_to_q
11
+
12
+
13
+ def standard_preprocessing(
14
+ intensity: np.ndarray,
15
+ scattering_angle: np.ndarray,
16
+ attenuation: np.ndarray,
17
+ q_interp: np.ndarray,
18
+ wavelength: float,
19
+ beam_width: float,
20
+ sample_length: float,
21
+ min_intensity: float = 1e-10,
22
+ beam_shape: BEAM_SHAPE = "gauss",
23
+ normalize_mode: NORMALIZE_MODE = "max",
24
+ incoming_intensity: float = None,
25
+ max_q: float = None, # if provided, max_angle is ignored
26
+ max_angle: float = None,
27
+ ) -> dict:
28
+ """Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation
29
+
30
+ Args:
31
+ intensity (np.ndarray): array of intensities of the reflectivity curve
32
+ scattering_angle (np.ndarray): array of scattering angles
33
+ attenuation (np.ndarray): attenuation factors for each measured point
34
+ q_interp (np.ndarray): reciprocal space points used for the interpolation
35
+ wavelength (float): the wavelength of the beam
36
+ beam_width (float): the beam width
37
+ sample_length (float): the sample length
38
+ min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10.
39
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
40
+ normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max".
41
+ incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None.
42
+ max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None.
43
+ max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None.
44
+
45
+ Returns:
46
+ dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting
47
+ """
48
+ intensity = apply_attenuation_correction(
49
+ intensity,
50
+ attenuation,
51
+ scattering_angle,
52
+ )
53
+
54
+ intensity = apply_footprint_correction(
55
+ intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape
56
+ )
57
+
58
+ curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity)
59
+
60
+ curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity)
61
+
62
+ q = angle_to_q(scattering_angle, wavelength)
63
+
64
+ q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength)
65
+
66
+ curve_interp = interp_reflectivity(q_interp, q, curve)
67
+
68
+ assert np.all(np.isfinite(curve_interp))
69
+ assert np.all(np.isfinite(curve))
70
+ assert np.all(np.isfinite(q))
71
+ assert np.all(np.isfinite(q_interp))
72
+ assert np.all(curve > 0.)
73
+ assert np.all(curve_interp > 0.)
74
+
75
+ return {
76
+ "curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio,
77
+ }
78
+
79
+
80
+ def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7):
81
+ indices = (curve > thresh) & np.isfinite(curve)
82
+ return curve[indices], scattering_angle[indices]
83
+
84
+
85
+ @dataclass
86
+ class StandardPreprocessing:
87
+ q_interp: np.ndarray = None
88
+ wavelength: float = 1.
89
+ beam_width: float = None
90
+ sample_length: float = None
91
+ beam_shape: BEAM_SHAPE = "gauss"
92
+ normalize_mode: NORMALIZE_MODE = "max"
93
+ incoming_intensity: float = None
94
+
95
+ def preprocess(self,
96
+ intensity: np.ndarray,
97
+ scattering_angle: np.ndarray,
98
+ attenuation: np.ndarray,
99
+ **kwargs
100
+ ) -> dict:
101
+ attrs = self._get_updated_attrs(**kwargs)
102
+ return standard_preprocessing(
103
+ intensity,
104
+ scattering_angle,
105
+ attenuation,
106
+ **attrs
107
+ )
108
+
109
+ __call__ = preprocess
110
+
111
+ def set_parameters(self, **kwargs) -> None:
112
+ for k, v in kwargs.items():
113
+ if k in self.__annotations__:
114
+ setattr(self, k, v)
115
+ else:
116
+ raise KeyError(f'Unknown parameter {k}.')
117
+
118
+ def _get_updated_attrs(self, **kwargs):
119
+ current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()}
120
+ current_attrs.update(kwargs)
121
+ return current_attrs