reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,171 @@
1
+ from tqdm import trange
2
+
3
+ from time import perf_counter
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+
8
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
9
+ from reflectorch.data_generation.priors.multilayer_structures import SimpleMultilayerSampler
10
+
11
+
12
+ class MultilayerFit(object):
13
+ def __init__(self,
14
+ q: Tensor,
15
+ exp_curve: Tensor,
16
+ params: Tensor,
17
+ convert_func,
18
+ scale_curve_func,
19
+ min_bounds: Tensor,
20
+ max_bounds: Tensor,
21
+ optim_cls=None,
22
+ lr: float = 5e-2,
23
+ ):
24
+
25
+ self.q = q
26
+ self.scale_curve_func = scale_curve_func
27
+ self.scaled_exp_curve = self.scale_curve_func(torch.atleast_2d(exp_curve))
28
+ self.params_to_fit = nn.Parameter(params.clone())
29
+ self.convert_func = convert_func
30
+
31
+ self.min_bounds = min_bounds
32
+ self.max_bounds = max_bounds
33
+
34
+ optim_cls = optim_cls or torch.optim.Adam
35
+ self.optim = optim_cls([self.params_to_fit], lr)
36
+
37
+ self.best_loss = float('inf')
38
+ self.best_params, self.best_curve = None, None
39
+
40
+ self.get_best_solution()
41
+ self.losses = []
42
+
43
+ @classmethod
44
+ def from_prior_sampler(
45
+ cls,
46
+ q: Tensor,
47
+ exp_curve: Tensor,
48
+ prior_sampler: SimpleMultilayerSampler,
49
+ batch_size: int = 2 ** 13,
50
+ **kwargs
51
+ ):
52
+ _, scaled_params = prior_sampler.optimized_sample(batch_size)
53
+ params = prior_sampler.restore_params2parametrized(scaled_params)
54
+
55
+ return cls(
56
+ q.float(), exp_curve.float(), params.float(),
57
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
58
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
59
+ max_bounds=prior_sampler.max_bounds,
60
+ **kwargs
61
+ )
62
+
63
+ @classmethod
64
+ def from_prediction(
65
+ cls,
66
+ param: Tensor,
67
+ prior_sampler,
68
+ q: Tensor,
69
+ exp_curve: Tensor,
70
+ batch_size: int = 2 ** 13,
71
+ rel_bounds: float = 0.3,
72
+ **kwargs
73
+ ):
74
+
75
+ num_params = param.shape[-1]
76
+
77
+ deltas = prior_sampler.restore_params2parametrized(
78
+ (torch.rand(batch_size, num_params, device=param.device) * 2 - 1) * rel_bounds
79
+ ) - prior_sampler.min_bounds
80
+
81
+ deltas[0] = 0.
82
+
83
+ params = torch.atleast_2d(param) + deltas
84
+
85
+ return cls(
86
+ q.float(), exp_curve.float(), params.float(),
87
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
88
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
89
+ max_bounds=prior_sampler.max_bounds,
90
+ **kwargs
91
+ )
92
+
93
+ def get_clipped_params(self):
94
+ return torch.clamp(self.params_to_fit.detach().clone(), self.min_bounds, self.max_bounds)
95
+
96
+ def calc_loss(self, reduce=True, clip: bool = False):
97
+ if clip:
98
+ params = self.get_clipped_params()
99
+ else:
100
+ params = self.params_to_fit
101
+ curves = self.get_curves(params=params)
102
+ losses = ((self.scale_curve_func(curves) - self.scaled_exp_curve) ** 2)
103
+
104
+ if reduce:
105
+ return losses.sum()
106
+ else:
107
+ return losses.sum(-1)
108
+
109
+ def get_curves(self, params: Tensor = None):
110
+ if params is None:
111
+ params = self.params_to_fit
112
+ curves = kinematical_approximation(self.q, **self.convert_func(params))
113
+ return torch.atleast_2d(curves)
114
+
115
+ def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
116
+ pbar = trange(num_iterations, disable=disable_tqdm)
117
+
118
+ for _ in pbar:
119
+ self.optim.zero_grad()
120
+ loss = self.calc_loss()
121
+ loss.backward()
122
+ self.optim.step()
123
+ self.losses.append(loss.item())
124
+ pbar.set_description(f'Loss = {loss.item():.2e}')
125
+
126
+ def clear(self):
127
+ self.losses.clear()
128
+
129
+ def run_fixed_time(self, time_limit: float = 2.):
130
+
131
+ start = perf_counter()
132
+
133
+ while True:
134
+ self.optim.zero_grad()
135
+ loss = self.calc_loss()
136
+ loss.backward()
137
+ self.optim.step()
138
+ self.losses.append(loss.item())
139
+ time_spent = perf_counter() - start
140
+
141
+ if time_spent > time_limit:
142
+ break
143
+
144
+ @torch.no_grad()
145
+ def get_best_solution(self, clip: bool = True):
146
+ losses = self.calc_loss(clip=clip, reduce=False)
147
+ idx = torch.argmin(losses)
148
+ best_loss = losses[idx].item()
149
+
150
+ if best_loss < self.best_loss:
151
+ best_curve = self.get_curves()[idx]
152
+ best_params = self.params_to_fit.detach()[idx].clone()
153
+ self.best_params, self.best_curve, self.best_loss = best_params, best_curve, best_loss
154
+
155
+ return self.best_params
156
+
157
+
158
+ def get_convert_func(multilayer_model):
159
+ def func(params):
160
+ params = multilayer_model.to_standard_params(params)
161
+ return {
162
+ 'thickness': params['thicknesses'], # very useful transformation indeed ...
163
+ 'roughness': params['roughnesses'],
164
+ 'sld': params['slds'],
165
+ }
166
+
167
+ return func
168
+
169
+
170
+ def _save_log(curves, eps: float = 1e-10):
171
+ return torch.log10(curves + eps)
@@ -0,0 +1,193 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import numpy as np
6
+
7
+ from reflectorch.inference.inference_model import (
8
+ InferenceModel,
9
+ )
10
+ from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
11
+
12
+ from reflectorch.data_generation.priors import (
13
+ MultilayerStructureParams,
14
+ SimpleMultilayerSampler,
15
+ )
16
+ from reflectorch.inference.record_time import print_time
17
+ from reflectorch.inference.scipy_fitter import standard_refl_fit
18
+ from reflectorch.inference.multilayer_fitter import MultilayerFit
19
+
20
+
21
+ class MultilayerInferenceModel(InferenceModel):
22
+ def predict(self,
23
+ intensity: np.ndarray,
24
+ scattering_angle: np.ndarray,
25
+ attenuation: np.ndarray,
26
+ priors: np.ndarray = None,
27
+ preprocessing_parameters: dict = None,
28
+ polish: bool = True,
29
+ use_raw_q: bool = False,
30
+ **kwargs
31
+ ) -> dict:
32
+
33
+ with print_time("everything"):
34
+ with print_time("preprocess"):
35
+ preprocessed_dict = self.preprocess(
36
+ intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
37
+ )
38
+
39
+ preprocessed_curve = preprocessed_dict["curve_interp"]
40
+
41
+ raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
42
+
43
+ with print_time("predict_from_preprocessed_curve"):
44
+ preprocessed_dict.update(self.predict_from_preprocessed_curve(
45
+ preprocessed_curve, priors,
46
+ raw_curve=(raw_curve if use_raw_q else None),
47
+ raw_q=raw_q,
48
+ polish=polish,
49
+ use_raw_q=use_raw_q,
50
+ **kwargs
51
+ ))
52
+
53
+ return preprocessed_dict
54
+
55
+ def predict_from_preprocessed_curve(self,
56
+ curve: np.ndarray,
57
+ priors: np.ndarray = None, *, # ignore the priors so far
58
+ polish: bool = True,
59
+ raw_curve: np.ndarray = None,
60
+ raw_q: np.ndarray = None,
61
+ clip_prediction: bool = True,
62
+ use_raw_q: bool = False,
63
+ use_sampler: bool = False,
64
+ fitted_time_limit: float = 3.,
65
+ sampler_rel_bounds: float = 0.3,
66
+ polish_with_abeles: bool = False,
67
+ **kwargs
68
+ ) -> dict:
69
+
70
+ scaled_curve = self._scale_curve(curve)
71
+
72
+ predicted_params, parametrized = self._simple_prediction(scaled_curve)
73
+
74
+ if use_sampler:
75
+ parametrized: Tensor = self._sampler_solution(
76
+ curve, parametrized,
77
+ time_limit=fitted_time_limit,
78
+ rel_bounds=sampler_rel_bounds,
79
+ )
80
+
81
+ init_raw_q = raw_q
82
+
83
+ if raw_curve is None:
84
+ raw_curve = curve
85
+ raw_q = self.q.squeeze().cpu().numpy()
86
+ raw_q_t = self.q
87
+ else:
88
+ raw_q_t = torch.from_numpy(raw_q).to(self.q)
89
+
90
+ # if q_ratio != 1.:
91
+ # predicted_params.scale_with_q(q_ratio)
92
+ # raw_q = raw_q * q_ratio
93
+ # raw_q_t = raw_q_t * q_ratio
94
+
95
+ prediction_dict = {
96
+ "params": parametrized.squeeze().cpu().numpy(),
97
+ "param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
98
+ "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
99
+ }
100
+
101
+ # sld_x_axis, sld_profile, _ = get_density_profiles(
102
+ # predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
103
+ # )
104
+ #
105
+ # prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
106
+ # prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
107
+
108
+ if polish:
109
+ prediction_dict.update(self._polish_prediction(
110
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
111
+ ))
112
+ if polish_with_abeles:
113
+ prediction_dict.update(self._polish_prediction(
114
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
115
+ ))
116
+
117
+ return prediction_dict
118
+
119
+ def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
120
+ with torch.no_grad():
121
+ self.trainer.model.eval()
122
+ scaled_params = self.trainer.model(scaled_curve)
123
+
124
+ predicted_params, parametrized = self._restore_predicted_params(scaled_params)
125
+ return predicted_params, parametrized
126
+
127
+ def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
128
+ parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
129
+ predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
130
+ return predicted_params, parametrized
131
+
132
+ @print_time
133
+ def _sampler_solution(
134
+ self,
135
+ curve: Tensor or np.ndarray,
136
+ predicted_params: Tensor,
137
+ batch_size: int = 2 ** 13,
138
+ time_limit: float = 3.,
139
+ rel_bounds: float = 0.3,
140
+ ) -> Tensor:
141
+
142
+ fit_obj = MultilayerFit.from_prediction(
143
+ predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
144
+ batch_size=batch_size, rel_bounds=rel_bounds,
145
+ )
146
+
147
+ fit_obj.run_fixed_time(time_limit)
148
+
149
+ best_params = fit_obj.get_best_solution()
150
+
151
+ return best_params
152
+
153
+ @property
154
+ def _prior_sampler(self) -> SimpleMultilayerSampler:
155
+ return self.trainer.loader.prior_sampler
156
+
157
+ @print_time
158
+ def _polish_prediction(self,
159
+ q: np.ndarray,
160
+ curve: np.ndarray,
161
+ predicted_params: Tensor,
162
+ q_values: np.ndarray,
163
+ use_kinematical: bool = True,
164
+ ) -> dict:
165
+
166
+ params = predicted_params.squeeze().cpu().numpy()
167
+ polished_params_dict = {}
168
+
169
+ if use_kinematical:
170
+ refl_generator = kinematical_approximation_np
171
+ else:
172
+ refl_generator = abeles_np
173
+
174
+ try:
175
+ polished_params_arr, curve_polished = standard_refl_fit(
176
+ q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
177
+ refl_generator=refl_generator,
178
+ bounds=self._prior_sampler.get_np_bounds(),
179
+ )
180
+ params = self._prior_sampler.restore_np_params(polished_params_arr)
181
+ if q_values is None:
182
+ q_values = q
183
+ curve_polished = abeles_np(q_values, **params)
184
+
185
+ except Exception as err:
186
+ self.log.exception(err)
187
+ polished_params_arr = params
188
+ curve_polished = np.zeros_like(q)
189
+
190
+ polished_params_dict['params_polished'] = polished_params_arr
191
+ polished_params_dict['curve_polished'] = curve_polished
192
+
193
+ return polished_params_dict
@@ -0,0 +1,7 @@
1
+ from reflectorch.inference.preprocess_exp.preprocess import (
2
+ standard_preprocessing,
3
+ StandardPreprocessing,
4
+ )
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
7
+ from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction
@@ -0,0 +1,36 @@
1
+ import numpy as np
2
+
3
+
4
+ def apply_attenuation_correction(
5
+ intensity: np.ndarray,
6
+ attenuation: np.ndarray,
7
+ scattering_angle: np.ndarray = None,
8
+ correct_discontinuities: bool = True
9
+ ) -> np.ndarray:
10
+ """Applies attenuation correction to experimental reflectivity curves
11
+
12
+ Args:
13
+ intensity (np.ndarray): intensities of an experimental reflectivity curve
14
+ attenuation (np.ndarray): attenuation factors for each measured point
15
+ scattering_angle (np.ndarray, optional): scattering angles of the measured points. Defaults to None.
16
+ correct_discontinuities (bool, optional): whether to correct discontinuities in the measured curves. Defaults to True.
17
+
18
+ Returns:
19
+ np.ndarray: the corrected reflectivity curve
20
+ """
21
+ intensity = intensity / attenuation
22
+ if correct_discontinuities:
23
+ if scattering_angle is None:
24
+ raise ValueError("correct_discontinuities options requires scattering_angle, but scattering_angle is None.")
25
+ intensity = apply_discontinuities_correction(intensity, scattering_angle)
26
+ return intensity
27
+
28
+
29
+ def apply_discontinuities_correction(intensity: np.ndarray, scattering_angle: np.ndarray) -> np.ndarray:
30
+ intensity = intensity.copy()
31
+ diff_angle = np.diff(scattering_angle)
32
+ for i in range(len(diff_angle)):
33
+ if diff_angle[i] == 0:
34
+ factor = intensity[i] / intensity[i + 1]
35
+ intensity[(i + 1):] *= factor
36
+ return intensity
@@ -0,0 +1,31 @@
1
+ import numpy as np
2
+
3
+ from reflectorch.utils import angle_to_q
4
+
5
+
6
+ def cut_curve(q: np.ndarray, curve: np.ndarray, max_q: float, max_angle: float, wavelength: float):
7
+ """Cuts an experimental reflectivity curve at a maximum q position
8
+
9
+ Args:
10
+ q (np.ndarray): the array of q points
11
+ curve (np.ndarray): the experimental reflectivity curve
12
+ max_q (float): the maximum q value at which the curve is cut
13
+ max_angle (float): the maximum scattering angle at which the curve is cut; only used if max_q is not provided
14
+ wavelength (float): the wavelength of the beam
15
+
16
+ Returns:
17
+ tuple: the q array after cutting, the reflectivity curve after cutting, and the ratio between the maximum q after cutting and before cutting
18
+ """
19
+ if max_angle is None and max_q is None:
20
+ q_ratio = 1.
21
+ else:
22
+ if max_q is None:
23
+ max_q = angle_to_q(max_angle, wavelength)
24
+
25
+ q_ratio = max_q / q.max()
26
+
27
+ if q_ratio < 1.:
28
+ idx = np.argmax(q > max_q)
29
+ q = q[:idx] / q_ratio
30
+ curve = curve[:idx]
31
+ return q, curve, q_ratio
@@ -0,0 +1,81 @@
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from scipy.special import erf
8
+
9
+ __all__ = [
10
+ "apply_footprint_correction",
11
+ "remove_footprint_correction",
12
+ "BEAM_SHAPE",
13
+ ]
14
+
15
+
16
+ BEAM_SHAPE = Literal["gauss", "box"]
17
+
18
+
19
+ def apply_footprint_correction(
20
+ intensity: np.ndarray,
21
+ scattering_angle: np.ndarray,
22
+ beam_width: float,
23
+ sample_length: float,
24
+ beam_shape: BEAM_SHAPE = "gauss",
25
+ ) -> np.ndarray:
26
+ """Applies footprint correction to an experimental reflectivity curve
27
+
28
+ Args:
29
+ intensity (np.ndarray): reflectivity curve
30
+ scattering_angle (np.ndarray): array of scattering angles
31
+ beam_width (float): the beam width
32
+ sample_length (float): the sample length
33
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
34
+
35
+ Returns:
36
+ np.ndarray: the footprint corrected reflectivity curve
37
+ """
38
+ factors = _get_factors_by_beam_shape(
39
+ scattering_angle, beam_width, sample_length, beam_shape
40
+ )
41
+ return intensity.copy() * factors
42
+
43
+
44
+ def remove_footprint_correction(
45
+ intensity: np.ndarray,
46
+ scattering_angle: np.ndarray,
47
+ beam_width: float,
48
+ sample_length: float,
49
+ beam_shape: BEAM_SHAPE = "gauss",
50
+ ):
51
+ factors = _get_factors_by_beam_shape(
52
+ scattering_angle, beam_width, sample_length, beam_shape
53
+ )
54
+ return intensity.copy() / factors
55
+
56
+
57
+ def _get_factors_by_beam_shape(
58
+ scattering_angle: np.ndarray, beam_width: float, sample_length: float, beam_shape: BEAM_SHAPE
59
+ ):
60
+ if beam_shape == "gauss":
61
+ return gaussian_factors(scattering_angle, beam_width, sample_length)
62
+ elif beam_shape == "box":
63
+ return box_factors(scattering_angle, beam_width, sample_length)
64
+ else:
65
+ raise ValueError("invalid beam shape")
66
+
67
+
68
+ def box_factors(scattering_angle, beam_width, sample_length):
69
+ max_angle = 2 * np.arcsin(beam_width / sample_length) / np.pi * 180
70
+ ratios = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
71
+ ones = np.ones_like(scattering_angle)
72
+ return np.where(scattering_angle < max_angle, ones * ratios, ones)
73
+
74
+
75
+ def gaussian_factors(scattering_angle, beam_width, sample_length):
76
+ ratio = beam_footprint_ratio(scattering_angle, beam_width, sample_length)
77
+ return 1 / erf(np.sqrt(np.log(2)) / ratio)
78
+
79
+
80
+ def beam_footprint_ratio(scattering_angle, beam_width, sample_length):
81
+ return beam_width / sample_length / np.sin(scattering_angle / 2 * np.pi / 180)
@@ -0,0 +1,16 @@
1
+ import numpy as np
2
+
3
+
4
+ def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10):
5
+ """Interpolate data on a base 10 logarithmic scale
6
+
7
+ Args:
8
+ q_interp (array-like): reciprocal space points used for the interpolation
9
+ q (array-like): reciprocal space points of the measured reflectivity curve
10
+ reflectivity (array-like): reflectivity curve measured at the points ``q``
11
+ min_value (float, optional): minimum intensity of the reflectivity curve. Defaults to 1e-10.
12
+
13
+ Returns:
14
+ array-like: interpolated reflectivity curve
15
+ """
16
+ return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
@@ -0,0 +1,21 @@
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from numpy import ndarray
8
+
9
+ NORMALIZE_MODE = Literal["first", "max", "incoming_intensity"]
10
+
11
+
12
+ def intensity2reflectivity(intensity: ndarray, mode: NORMALIZE_MODE, incoming_intensity=None) -> np.ndarray:
13
+ if mode == "first":
14
+ return intensity / intensity[0]
15
+ if mode == "max":
16
+ return intensity / intensity.max()
17
+ if mode == "incoming_intensity":
18
+ if incoming_intensity is None:
19
+ raise ValueError("incoming_intensity is None")
20
+ return intensity / incoming_intensity
21
+ raise ValueError(f"Unknown mode {mode}")
@@ -0,0 +1,121 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
7
+ from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
8
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
9
+ from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
10
+ from reflectorch.utils import angle_to_q
11
+
12
+
13
+ def standard_preprocessing(
14
+ intensity: np.ndarray,
15
+ scattering_angle: np.ndarray,
16
+ attenuation: np.ndarray,
17
+ q_interp: np.ndarray,
18
+ wavelength: float,
19
+ beam_width: float,
20
+ sample_length: float,
21
+ min_intensity: float = 1e-10,
22
+ beam_shape: BEAM_SHAPE = "gauss",
23
+ normalize_mode: NORMALIZE_MODE = "max",
24
+ incoming_intensity: float = None,
25
+ max_q: float = None, # if provided, max_angle is ignored
26
+ max_angle: float = None,
27
+ ) -> dict:
28
+ """Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation
29
+
30
+ Args:
31
+ intensity (np.ndarray): array of intensities of the reflectivity curve
32
+ scattering_angle (np.ndarray): array of scattering angles
33
+ attenuation (np.ndarray): attenuation factors for each measured point
34
+ q_interp (np.ndarray): reciprocal space points used for the interpolation
35
+ wavelength (float): the wavelength of the beam
36
+ beam_width (float): the beam width
37
+ sample_length (float): the sample length
38
+ min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10.
39
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
40
+ normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max".
41
+ incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None.
42
+ max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None.
43
+ max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None.
44
+
45
+ Returns:
46
+ dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting
47
+ """
48
+ intensity = apply_attenuation_correction(
49
+ intensity,
50
+ attenuation,
51
+ scattering_angle,
52
+ )
53
+
54
+ intensity = apply_footprint_correction(
55
+ intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape
56
+ )
57
+
58
+ curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity)
59
+
60
+ curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity)
61
+
62
+ q = angle_to_q(scattering_angle, wavelength)
63
+
64
+ q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength)
65
+
66
+ curve_interp = interp_reflectivity(q_interp, q, curve)
67
+
68
+ assert np.all(np.isfinite(curve_interp))
69
+ assert np.all(np.isfinite(curve))
70
+ assert np.all(np.isfinite(q))
71
+ assert np.all(np.isfinite(q_interp))
72
+ assert np.all(curve > 0.)
73
+ assert np.all(curve_interp > 0.)
74
+
75
+ return {
76
+ "curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio,
77
+ }
78
+
79
+
80
+ def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7):
81
+ indices = (curve > thresh) & np.isfinite(curve)
82
+ return curve[indices], scattering_angle[indices]
83
+
84
+
85
+ @dataclass
86
+ class StandardPreprocessing:
87
+ q_interp: np.ndarray = None
88
+ wavelength: float = 1.
89
+ beam_width: float = None
90
+ sample_length: float = None
91
+ beam_shape: BEAM_SHAPE = "gauss"
92
+ normalize_mode: NORMALIZE_MODE = "max"
93
+ incoming_intensity: float = None
94
+
95
+ def preprocess(self,
96
+ intensity: np.ndarray,
97
+ scattering_angle: np.ndarray,
98
+ attenuation: np.ndarray,
99
+ **kwargs
100
+ ) -> dict:
101
+ attrs = self._get_updated_attrs(**kwargs)
102
+ return standard_preprocessing(
103
+ intensity,
104
+ scattering_angle,
105
+ attenuation,
106
+ **attrs
107
+ )
108
+
109
+ __call__ = preprocess
110
+
111
+ def set_parameters(self, **kwargs) -> None:
112
+ for k, v in kwargs.items():
113
+ if k in self.__annotations__:
114
+ setattr(self, k, v)
115
+ else:
116
+ raise KeyError(f'Unknown parameter {k}.')
117
+
118
+ def _get_updated_attrs(self, **kwargs):
119
+ current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()}
120
+ current_attrs.update(kwargs)
121
+ return current_attrs