reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
2
+
3
+ class Layer():
4
+ """Defines a single slab layer with prior bounds for thickness, roughness and SLD.
5
+
6
+ The bounds can be given for both real and imaginary parts of the SLD (the latter only if the model supports absorption).
7
+
8
+ Args:
9
+ thickness_bounds (Tuple[float, float]): Minimum and maximum thickness of the layer (in Å).
10
+ roughness_bounds (Tuple[float, float]): Minimum and maximum interfacial roughness at the top of this layer (in Å).
11
+ sld_bounds (Tuple[float, float]): Minimum and maximum real SLD of this layer (in 10⁻⁶ Å⁻²).
12
+ imag_sld_bounds (Tuple[float, float], optional): Minimum and maximum imaginary SLD (in 10⁻⁶ Å⁻²) of this layer. Defaults to None.
13
+ """
14
+ def __init__(self, thickness_bounds, roughness_bounds, sld_bounds, imag_sld_bounds=None):
15
+ self.thickness_bounds = thickness_bounds
16
+ self.roughness_bounds = roughness_bounds
17
+ self.sld_bounds = sld_bounds
18
+ self.imag_sld_bounds = imag_sld_bounds
19
+
20
+ class Backing():
21
+ """Defines the backing medium (substrate) for the multilayer structure.
22
+
23
+ The backing is assumed to be semi-infinite and has no thickness parameter.
24
+ This class ensures compatibility with the layer-based structure definition.
25
+
26
+ Args:
27
+ roughness_bounds (Tuple[float, float]): Minimum and maximum interfacial roughness at the top of the backing medium (in Å).
28
+ sld_bounds (Tuple[float, float]): Minimum and maximum real SLD of the backing medium (in 10⁻⁶ Å⁻²).
29
+ imag_sld_bounds (Tuple[float, float], optional): Minimum and maximum imaginary SLD (in 10⁻⁶ Å⁻²) of the backing. Defaults to None.
30
+ """
31
+ def __init__(self, roughness_bounds, sld_bounds, imag_sld_bounds=None):
32
+ self.thickness_bounds = None
33
+ self.roughness_bounds = roughness_bounds
34
+ self.sld_bounds = sld_bounds
35
+ self.imag_sld_bounds = imag_sld_bounds
36
+
37
+ class Structure():
38
+ """Defines a multilayer structure and its parameter bounds in a layer-wise manner.
39
+
40
+ This class allows the user to define the prior bounds for the full structure (film + backing) in a layer-wise format. It automatically constructs the
41
+ flattened list of parameter bounds compatible with the inference model’s expected input format.
42
+
43
+ Args:
44
+ layers (List[Union[Layer, Backing]]): Ordered list of layers defining the structure, from the ambient side to the backing. The last element
45
+ must be a :class:`Backing` instance. Note that the fronting medium (ambient) is not part of this list (since it is not a predicted parameter),
46
+ and is treated by default as being 0 (air). For different fronting media one can use the ``ambient_sld`` argument of the prediction method.
47
+ q_shift_bounds (Tuple[float, float], optional): Bounds for the global ``q_shift`` nuisance parameter. Defaults to None.
48
+ r_scale_bounds (Tuple[float, float], optional): Bounds for the global reflectivity scale factor ``r_scale``. Defaults to None.
49
+ log10_background_bounds (Tuple[float, float], optional): Bounds for the background term expressed as log10(background). Defaults to None.
50
+
51
+ Attributes:
52
+ thicknesses_bounds (List[Tuple[float, float]]): Bounds for all thicknesses (excluding backing).
53
+ roughnesses_bounds (List[Tuple[float, float]]): Bounds for all roughnesses (including backing).
54
+ slds_bounds (List[Tuple[float, float]]): Bounds for all real SLDs (including backing).
55
+ imag_slds_bounds (List[Tuple[float, float]]): Bounds for all imaginary SLDs (if provided).
56
+ prior_bounds (List[Tuple[float, float]]): Flattened list of all parameter bounds in the order expected by the model: thicknesses,
57
+ roughnesses, real SLDs, imaginary SLDs (if present), followed by nuisance parameters.
58
+
59
+ Example:
60
+ >>> layer1 = Layer(thickness_bounds=[1, 100], roughness_bounds=[0, 10], sld_bounds=[-2, 2])
61
+ >>> backing = Backing(roughness_bounds=[0, 15], sld_bounds=[0, 3])
62
+ >>> structure = Structure(layers=[layer1, backing], r_scale_bounds=[0.9, 1.1])
63
+ >>> structure.prior_bounds
64
+ """
65
+ def __init__(self, layers, q_shift_bounds=None, r_scale_bounds=None, log10_background_bounds=None):
66
+ self.layers=layers
67
+ self.q_shift_bounds=q_shift_bounds
68
+ self.r_scale_bounds = r_scale_bounds
69
+ self.log10_background_bounds = log10_background_bounds
70
+ self.thicknesses_bounds = []
71
+ self.roughnesses_bounds = []
72
+ self.slds_bounds = []
73
+ self.imag_slds_bounds = []
74
+
75
+ for layer in layers:
76
+ if layer.thickness_bounds is not None:
77
+ self.thicknesses_bounds.append(layer.thickness_bounds)
78
+ self.roughnesses_bounds.append(layer.roughness_bounds)
79
+ self.slds_bounds.append(layer.sld_bounds)
80
+ if layer.imag_sld_bounds is not None:
81
+ self.imag_slds_bounds.append(layer.imag_sld_bounds)
82
+
83
+ self.prior_bounds = self.thicknesses_bounds + self.roughnesses_bounds + self.slds_bounds + self.imag_slds_bounds
84
+
85
+ if q_shift_bounds is not None:
86
+ self.prior_bounds += [q_shift_bounds]
87
+ if r_scale_bounds is not None:
88
+ self.prior_bounds += [r_scale_bounds]
89
+ if log10_background_bounds is not None:
90
+ self.prior_bounds += [log10_background_bounds]
91
+
92
+ def validate_parameters_and_ranges(self, inference_model):
93
+ """Validate that all layer bounds and nuisance parameters match the model's configuration.
94
+
95
+ This method checks that:
96
+ * The number of layers matches the model’s expected number.
97
+ * Each layer’s thickness, roughness, and SLD bounds are within the
98
+ model’s training ranges.
99
+ * The SLD bound width does not exceed the maximum training width.
100
+ * Any nuisance parameters expected by the model (e.g. q_shift, r_scale,
101
+ log10_background) are provided and within training bounds.
102
+
103
+ Args:
104
+ inference_model (InferenceModel): A loaded instance of :class:`InferenceModel` used to access the model’s metadata.
105
+
106
+ Raises:
107
+ ValueError: If the number of layers, parameter ranges, or nuisance configurations are inconsistent with the model.
108
+ """
109
+ if len(self.layers) - 1 != inference_model.trainer.loader.prior_sampler.max_num_layers:
110
+ raise ValueError(f'Number of layers mismatch: this model expects {inference_model.trainer.loader.prior_sampler.max_num_layers} layers (backing not included) but you provided {len(self.layers) - 1}')
111
+
112
+
113
+ thickness_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['thicknesses']
114
+ roughness_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['roughnesses']
115
+ sld_ranges = inference_model.trainer.loader.prior_sampler.param_ranges['slds']
116
+
117
+ def layer_name(i):
118
+ if i == inference_model.trainer.loader.prior_sampler.max_num_layers:
119
+ return 'the backing medium'
120
+ else:
121
+ return f'layer {i+1}'
122
+
123
+ for i, layer in enumerate(self.layers):
124
+ if layer.thickness_bounds is not None:
125
+ if layer.thickness_bounds[0] < thickness_ranges[0] or layer.thickness_bounds[1] > thickness_ranges[1]:
126
+ raise ValueError(f"The provided prior bounds for the thickness of layer {i+1} are outside the training range of the network: {thickness_ranges}")
127
+ if layer.roughness_bounds[0] < roughness_ranges[0] or layer.roughness_bounds[1] > roughness_ranges[1]:
128
+ raise ValueError(f"The provided prior bounds for the roughness of {layer_name(i)} are outside the training range of the network: {roughness_ranges}")
129
+ if layer.sld_bounds[0] < sld_ranges[0] or layer.sld_bounds[1] > sld_ranges[1]:
130
+ raise ValueError(f"The provided prior bounds for the (real) SLD of {layer_name(i)} are outside the training range of the network: {sld_ranges}")
131
+
132
+ max_sld_bounds_width = inference_model.trainer.loader.prior_sampler.bound_width_ranges['slds'][1]
133
+ if layer.sld_bounds[1] - layer.sld_bounds[0] > max_sld_bounds_width:
134
+ raise ValueError(f"The provided prior bounds for the (real) SLD of {layer_name(i)} have a width (max - min) exceeding the maximum width used for training: {max_sld_bounds_width}")
135
+
136
+ param_model = inference_model.trainer.loader.prior_sampler.param_model
137
+ if isinstance(param_model, NuisanceParamsWrapper):
138
+ nuisance_params_config = inference_model.trainer.loader.prior_sampler.shift_param_config
139
+
140
+ if self.q_shift_bounds is not None:
141
+ if 'q_shift' not in nuisance_params_config:
142
+ raise ValueError(f'Prior bounds for the q_shift parameter were provided but this parameter is not supported by this model.')
143
+ q_shift_range = inference_model.trainer.loader.prior_sampler.param_ranges['q_shift']
144
+ if self.q_shift_bounds[0] < q_shift_range[0] or self.q_shift_bounds[1] > q_shift_range[1]:
145
+ raise ValueError(f"The provided prior bounds for the q_shift are outside the training range of the network: {q_shift_range}")
146
+
147
+ if self.r_scale_bounds is not None:
148
+ if 'r_scale' not in nuisance_params_config:
149
+ raise ValueError(f'Prior bounds for the r_scale parameter were provided but this parameter is not supported by this model.')
150
+ r_scale_range = inference_model.trainer.loader.prior_sampler.param_ranges['r_scale']
151
+ if self.r_scale_bounds[0] < r_scale_range[0] or self.r_scale_bounds[1] > r_scale_range[1]:
152
+ raise ValueError(f"The provided prior bounds for the r_scale are outside the training range of the network: {r_scale_range}")
153
+
154
+ if self.log10_background_bounds is not None:
155
+ if 'log10_background' not in nuisance_params_config:
156
+ raise ValueError(f'Prior bounds for the log10_background parameter were provided but this parameter is not supported by this model.')
157
+ log10_background_range = inference_model.trainer.loader.prior_sampler.param_ranges['log10_background']
158
+ if self.log10_background_bounds[0] < log10_background_range[0] or self.log10_background_bounds[1] > log10_background_range[1]:
159
+ raise ValueError(f"The provided prior bounds for the r_scale are outside the training range of the network: {log10_background_range}")
160
+
161
+ if isinstance(param_model, NuisanceParamsWrapper):
162
+ if 'q_shift' in nuisance_params_config and self.q_shift_bounds is None:
163
+ raise ValueError(f'Prior bounds for the q_shift parameter are expected by this model but were not provided.')
164
+
165
+ if 'r_scale' in nuisance_params_config and self.r_scale_bounds is None:
166
+ raise ValueError(f'Prior bounds for the r_scale parameter are expected by this model but were not provided.')
167
+
168
+ if 'log10_background' in nuisance_params_config and self.log10_background_bounds is None:
169
+ raise ValueError(f'Prior bounds for the log10_background parameter are expected by this model but were not provided.')
170
+
171
+ print("All checks passed.")
172
+
173
+ def get_huggingface_filtering_query(self):
174
+ """Constructs a metadata query for selecting compatible pretrained models from Huggingface. Currently it only supports the older (research style)
175
+ layout of Huggingface repositories (such as 'valentinsingularity/reflectivity'), but not the newer layout (such as `reflectorch-ILL`).
176
+
177
+ Returns:
178
+ dict: A dictionary describing the structure and parameter bounds, suitable for filtering available model configurations
179
+ in a Huggingface repository using :class:`HuggingfaceQueryMatcher`.
180
+
181
+ Example:
182
+ >>> structure = Structure([...])
183
+ >>> query = structure.get_huggingface_filtering_query()
184
+ >>> matcher = HuggingfaceQueryMatcher(repo_id='valentinsingularity/reflectivity')
185
+ >>> configs = matcher.get_matching_configs(query)
186
+ """
187
+ query = {'dset.prior_sampler.kwargs.max_num_layers': len(self.layers) - 1}
188
+
189
+ query['dset.prior_sampler.kwargs.param_ranges.thicknesses'] = [min(sl[0] for sl in self.thicknesses_bounds), max(sl[1] for sl in self.thicknesses_bounds)]
190
+ query['dset.prior_sampler.kwargs.param_ranges.roughnesses'] = [min(sl[0] for sl in self.roughnesses_bounds), max(sl[1] for sl in self.roughnesses_bounds)]
191
+ query['dset.prior_sampler.kwargs.param_ranges.slds'] = [min(sl[0] for sl in self.slds_bounds), max(sl[1] for sl in self.slds_bounds)]
192
+
193
+ if len(self.imag_slds_bounds) > 0:
194
+ query['dset.prior_sampler.kwargs.model_name'] = 'model_with_absorption'
195
+ query['dset.prior_sampler.kwargs.param_ranges.islds'] = [min(sl[0] for sl in self.imag_slds_bounds), max(sl[1] for sl in self.imag_slds_bounds)]
196
+ else:
197
+ query['dset.prior_sampler.kwargs.model_name'] = 'standard_model'
198
+
199
+ if self.q_shift_bounds is not None:
200
+ query['dset.prior_sampler.kwargs.shift_param_config.q_shift'] = True
201
+ query['dset.prior_sampler.kwargs.param_ranges.q_shift'] = self.q_shift_bounds
202
+
203
+ if self.r_scale_bounds is not None:
204
+ query['dset.prior_sampler.kwargs.shift_param_config.r_scale'] = True
205
+ query['dset.prior_sampler.kwargs.param_ranges.r_scale'] = self.r_scale_bounds
206
+
207
+ if self.log10_background_bounds is not None:
208
+ query['dset.prior_sampler.kwargs.shift_param_config.log10_background'] = True
209
+ query['dset.prior_sampler.kwargs.param_ranges.log10_background'] = self.log10_background_bounds
210
+
211
+ return query
212
+
213
+
214
+ if __name__ == '__main__':
215
+ from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
216
+ from reflectorch import EasyInferenceModel
217
+
218
+ layer1 = Layer(thickness_bounds=[1, 1000], roughness_bounds=[0, 60], sld_bounds=[-2, 2])
219
+ layer2 = Layer(thickness_bounds=[1, 50], roughness_bounds=[0, 10], sld_bounds=[1, 4])
220
+ backing = Backing(roughness_bounds=[0, 15], sld_bounds=[0, 3])
221
+
222
+ structure = Structure(
223
+ layers=[layer1, layer2, backing],
224
+ r_scale_bounds=[0.9, 1.1],
225
+ log10_background_bounds=[-8, -5],
226
+ )
227
+
228
+ print(structure.prior_bounds)
229
+
230
+ query_matcher = HuggingfaceQueryMatcher(repo_id='valentinsingularity/reflectivity')
231
+ filtering_query = structure.get_huggingface_filtering_query()
232
+ print(filtering_query)
233
+
234
+ matching_configs = query_matcher.get_matching_configs(filtering_query)
235
+ print(f'Matching configs: {matching_configs}')
236
+
237
+
238
+ inference_model = EasyInferenceModel(config_name=matching_configs[0])
239
+ structure.validate_parameters_and_ranges(inference_model)
@@ -0,0 +1,55 @@
1
+ from pathlib import Path
2
+ import numpy as np
3
+
4
+ def load_mft_data(filepath):
5
+ """
6
+ Load q, reflectivity, reflectivity error, and q-resolution from an .mft file.
7
+
8
+ Parameters:
9
+ filepath (str or Path): Path to the .mft file
10
+
11
+ Returns:
12
+ q, refl, refl_err, q_res : np.ndarray
13
+ """
14
+ filepath = Path(filepath)
15
+
16
+ with filepath.open('r', encoding='utf-8') as f:
17
+ lines = f.readlines()
18
+
19
+ start_idx = next(
20
+ i for i, line in enumerate(lines)
21
+ if line.strip().startswith('q') and 'q_res' in line
22
+ ) + 1
23
+
24
+ data = []
25
+ for line in lines[start_idx:]:
26
+ parts = line.strip().split()
27
+ if len(parts) == 4:
28
+ try:
29
+ data.append([float(p.replace('E', 'e')) for p in parts])
30
+ except ValueError:
31
+ continue
32
+
33
+ if not data:
34
+ raise ValueError(f"No valid data found in {filepath}")
35
+
36
+ data_array = np.array(data)
37
+ return data_array.T
38
+
39
+
40
+ def load_ort_data(filepath):
41
+ raise NotImplementedError("Loading ORT data is not implemented yet")
42
+
43
+ def load_csv_data(filepath):
44
+ raise NotImplementedError("Loading CSV data is not implemented yet")
45
+
46
+
47
+ def load_data(filepath):
48
+ if filepath.endswith('.mft'):
49
+ return load_mft_data(filepath)
50
+ elif filepath.endswith('.ort'):
51
+ return load_ort_data(filepath)
52
+ elif filepath.endswith('.csv'):
53
+ return load_csv_data(filepath)
54
+ else:
55
+ raise ValueError(f"Unsupported file format: {filepath}")
@@ -0,0 +1,171 @@
1
+ from tqdm import trange
2
+
3
+ from time import perf_counter
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+
8
+ from reflectorch.data_generation.reflectivity import kinematical_approximation
9
+ from reflectorch.data_generation.priors.multilayer_structures import SimpleMultilayerSampler
10
+
11
+
12
+ class MultilayerFit(object):
13
+ def __init__(self,
14
+ q: Tensor,
15
+ exp_curve: Tensor,
16
+ params: Tensor,
17
+ convert_func,
18
+ scale_curve_func,
19
+ min_bounds: Tensor,
20
+ max_bounds: Tensor,
21
+ optim_cls=None,
22
+ lr: float = 5e-2,
23
+ ):
24
+
25
+ self.q = q
26
+ self.scale_curve_func = scale_curve_func
27
+ self.scaled_exp_curve = self.scale_curve_func(torch.atleast_2d(exp_curve))
28
+ self.params_to_fit = nn.Parameter(params.clone())
29
+ self.convert_func = convert_func
30
+
31
+ self.min_bounds = min_bounds
32
+ self.max_bounds = max_bounds
33
+
34
+ optim_cls = optim_cls or torch.optim.Adam
35
+ self.optim = optim_cls([self.params_to_fit], lr)
36
+
37
+ self.best_loss = float('inf')
38
+ self.best_params, self.best_curve = None, None
39
+
40
+ self.get_best_solution()
41
+ self.losses = []
42
+
43
+ @classmethod
44
+ def from_prior_sampler(
45
+ cls,
46
+ q: Tensor,
47
+ exp_curve: Tensor,
48
+ prior_sampler: SimpleMultilayerSampler,
49
+ batch_size: int = 2 ** 13,
50
+ **kwargs
51
+ ):
52
+ _, scaled_params = prior_sampler.optimized_sample(batch_size)
53
+ params = prior_sampler.restore_params2parametrized(scaled_params)
54
+
55
+ return cls(
56
+ q.float(), exp_curve.float(), params.float(),
57
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
58
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
59
+ max_bounds=prior_sampler.max_bounds,
60
+ **kwargs
61
+ )
62
+
63
+ @classmethod
64
+ def from_prediction(
65
+ cls,
66
+ param: Tensor,
67
+ prior_sampler,
68
+ q: Tensor,
69
+ exp_curve: Tensor,
70
+ batch_size: int = 2 ** 13,
71
+ rel_bounds: float = 0.3,
72
+ **kwargs
73
+ ):
74
+
75
+ num_params = param.shape[-1]
76
+
77
+ deltas = prior_sampler.restore_params2parametrized(
78
+ (torch.rand(batch_size, num_params, device=param.device) * 2 - 1) * rel_bounds
79
+ ) - prior_sampler.min_bounds
80
+
81
+ deltas[0] = 0.
82
+
83
+ params = torch.atleast_2d(param) + deltas
84
+
85
+ return cls(
86
+ q.float(), exp_curve.float(), params.float(),
87
+ convert_func=get_convert_func(prior_sampler.multilayer_model),
88
+ scale_curve_func=_save_log, min_bounds=prior_sampler.min_bounds,
89
+ max_bounds=prior_sampler.max_bounds,
90
+ **kwargs
91
+ )
92
+
93
+ def get_clipped_params(self):
94
+ return torch.clamp(self.params_to_fit.detach().clone(), self.min_bounds, self.max_bounds)
95
+
96
+ def calc_loss(self, reduce=True, clip: bool = False):
97
+ if clip:
98
+ params = self.get_clipped_params()
99
+ else:
100
+ params = self.params_to_fit
101
+ curves = self.get_curves(params=params)
102
+ losses = ((self.scale_curve_func(curves) - self.scaled_exp_curve) ** 2)
103
+
104
+ if reduce:
105
+ return losses.sum()
106
+ else:
107
+ return losses.sum(-1)
108
+
109
+ def get_curves(self, params: Tensor = None):
110
+ if params is None:
111
+ params = self.params_to_fit
112
+ curves = kinematical_approximation(self.q, **self.convert_func(params))
113
+ return torch.atleast_2d(curves)
114
+
115
+ def run(self, num_iterations: int = 500, disable_tqdm: bool = False):
116
+ pbar = trange(num_iterations, disable=disable_tqdm)
117
+
118
+ for _ in pbar:
119
+ self.optim.zero_grad()
120
+ loss = self.calc_loss()
121
+ loss.backward()
122
+ self.optim.step()
123
+ self.losses.append(loss.item())
124
+ pbar.set_description(f'Loss = {loss.item():.2e}')
125
+
126
+ def clear(self):
127
+ self.losses.clear()
128
+
129
+ def run_fixed_time(self, time_limit: float = 2.):
130
+
131
+ start = perf_counter()
132
+
133
+ while True:
134
+ self.optim.zero_grad()
135
+ loss = self.calc_loss()
136
+ loss.backward()
137
+ self.optim.step()
138
+ self.losses.append(loss.item())
139
+ time_spent = perf_counter() - start
140
+
141
+ if time_spent > time_limit:
142
+ break
143
+
144
+ @torch.no_grad()
145
+ def get_best_solution(self, clip: bool = True):
146
+ losses = self.calc_loss(clip=clip, reduce=False)
147
+ idx = torch.argmin(losses)
148
+ best_loss = losses[idx].item()
149
+
150
+ if best_loss < self.best_loss:
151
+ best_curve = self.get_curves()[idx]
152
+ best_params = self.params_to_fit.detach()[idx].clone()
153
+ self.best_params, self.best_curve, self.best_loss = best_params, best_curve, best_loss
154
+
155
+ return self.best_params
156
+
157
+
158
+ def get_convert_func(multilayer_model):
159
+ def func(params):
160
+ params = multilayer_model.to_standard_params(params)
161
+ return {
162
+ 'thickness': params['thicknesses'], # very useful transformation indeed ...
163
+ 'roughness': params['roughnesses'],
164
+ 'sld': params['slds'],
165
+ }
166
+
167
+ return func
168
+
169
+
170
+ def _save_log(curves, eps: float = 1e-10):
171
+ return torch.log10(curves + eps)
@@ -0,0 +1,193 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ import numpy as np
6
+
7
+ from reflectorch.inference.inference_model import (
8
+ InferenceModel,
9
+ )
10
+ from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
11
+
12
+ from reflectorch.data_generation.priors import (
13
+ MultilayerStructureParams,
14
+ SimpleMultilayerSampler,
15
+ )
16
+ from reflectorch.inference.record_time import print_time
17
+ from reflectorch.inference.scipy_fitter import standard_refl_fit
18
+ from reflectorch.inference.multilayer_fitter import MultilayerFit
19
+
20
+
21
+ class MultilayerInferenceModel(InferenceModel):
22
+ def predict(self,
23
+ intensity: np.ndarray,
24
+ scattering_angle: np.ndarray,
25
+ attenuation: np.ndarray,
26
+ priors: np.ndarray = None,
27
+ preprocessing_parameters: dict = None,
28
+ polish: bool = True,
29
+ use_raw_q: bool = False,
30
+ **kwargs
31
+ ) -> dict:
32
+
33
+ with print_time("everything"):
34
+ with print_time("preprocess"):
35
+ preprocessed_dict = self.preprocess(
36
+ intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
37
+ )
38
+
39
+ preprocessed_curve = preprocessed_dict["curve_interp"]
40
+
41
+ raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
42
+
43
+ with print_time("predict_from_preprocessed_curve"):
44
+ preprocessed_dict.update(self.predict_from_preprocessed_curve(
45
+ preprocessed_curve, priors,
46
+ raw_curve=(raw_curve if use_raw_q else None),
47
+ raw_q=raw_q,
48
+ polish=polish,
49
+ use_raw_q=use_raw_q,
50
+ **kwargs
51
+ ))
52
+
53
+ return preprocessed_dict
54
+
55
+ def predict_from_preprocessed_curve(self,
56
+ curve: np.ndarray,
57
+ priors: np.ndarray = None, *, # ignore the priors so far
58
+ polish: bool = True,
59
+ raw_curve: np.ndarray = None,
60
+ raw_q: np.ndarray = None,
61
+ clip_prediction: bool = True,
62
+ use_raw_q: bool = False,
63
+ use_sampler: bool = False,
64
+ fitted_time_limit: float = 3.,
65
+ sampler_rel_bounds: float = 0.3,
66
+ polish_with_abeles: bool = False,
67
+ **kwargs
68
+ ) -> dict:
69
+
70
+ scaled_curve = self._scale_curve(curve)
71
+
72
+ predicted_params, parametrized = self._simple_prediction(scaled_curve)
73
+
74
+ if use_sampler:
75
+ parametrized: Tensor = self._sampler_solution(
76
+ curve, parametrized,
77
+ time_limit=fitted_time_limit,
78
+ rel_bounds=sampler_rel_bounds,
79
+ )
80
+
81
+ init_raw_q = raw_q
82
+
83
+ if raw_curve is None:
84
+ raw_curve = curve
85
+ raw_q = self.q.squeeze().cpu().numpy()
86
+ raw_q_t = self.q
87
+ else:
88
+ raw_q_t = torch.from_numpy(raw_q).to(self.q)
89
+
90
+ # if q_ratio != 1.:
91
+ # predicted_params.scale_with_q(q_ratio)
92
+ # raw_q = raw_q * q_ratio
93
+ # raw_q_t = raw_q_t * q_ratio
94
+
95
+ prediction_dict = {
96
+ "params": parametrized.squeeze().cpu().numpy(),
97
+ "param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
98
+ "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
99
+ }
100
+
101
+ # sld_x_axis, sld_profile, _ = get_density_profiles(
102
+ # predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
103
+ # )
104
+ #
105
+ # prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
106
+ # prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
107
+
108
+ if polish:
109
+ prediction_dict.update(self._polish_prediction(
110
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
111
+ ))
112
+ if polish_with_abeles:
113
+ prediction_dict.update(self._polish_prediction(
114
+ raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
115
+ ))
116
+
117
+ return prediction_dict
118
+
119
+ def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
120
+ with torch.no_grad():
121
+ self.trainer.model.eval()
122
+ scaled_params = self.trainer.model(scaled_curve)
123
+
124
+ predicted_params, parametrized = self._restore_predicted_params(scaled_params)
125
+ return predicted_params, parametrized
126
+
127
+ def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
128
+ parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
129
+ predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
130
+ return predicted_params, parametrized
131
+
132
+ @print_time
133
+ def _sampler_solution(
134
+ self,
135
+ curve: Tensor or np.ndarray,
136
+ predicted_params: Tensor,
137
+ batch_size: int = 2 ** 13,
138
+ time_limit: float = 3.,
139
+ rel_bounds: float = 0.3,
140
+ ) -> Tensor:
141
+
142
+ fit_obj = MultilayerFit.from_prediction(
143
+ predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
144
+ batch_size=batch_size, rel_bounds=rel_bounds,
145
+ )
146
+
147
+ fit_obj.run_fixed_time(time_limit)
148
+
149
+ best_params = fit_obj.get_best_solution()
150
+
151
+ return best_params
152
+
153
+ @property
154
+ def _prior_sampler(self) -> SimpleMultilayerSampler:
155
+ return self.trainer.loader.prior_sampler
156
+
157
+ @print_time
158
+ def _polish_prediction(self,
159
+ q: np.ndarray,
160
+ curve: np.ndarray,
161
+ predicted_params: Tensor,
162
+ q_values: np.ndarray,
163
+ use_kinematical: bool = True,
164
+ ) -> dict:
165
+
166
+ params = predicted_params.squeeze().cpu().numpy()
167
+ polished_params_dict = {}
168
+
169
+ if use_kinematical:
170
+ refl_generator = kinematical_approximation_np
171
+ else:
172
+ refl_generator = abeles_np
173
+
174
+ try:
175
+ polished_params_arr, curve_polished = standard_refl_fit(
176
+ q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
177
+ refl_generator=refl_generator,
178
+ bounds=self._prior_sampler.get_np_bounds(),
179
+ )
180
+ params = self._prior_sampler.restore_np_params(polished_params_arr)
181
+ if q_values is None:
182
+ q_values = q
183
+ curve_polished = abeles_np(q_values, **params)
184
+
185
+ except Exception as err:
186
+ self.log.exception(err)
187
+ polished_params_arr = params
188
+ curve_polished = np.zeros_like(q)
189
+
190
+ polished_params_dict['params_polished'] = polished_params_arr
191
+ polished_params_dict['curve_polished'] = curve_polished
192
+
193
+ return polished_params_dict