reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,852 +1,848 @@
1
- import asyncio
2
- import logging
3
- from pathlib import Path
4
- import time
5
-
6
- import numpy as np
7
- import torch
8
- from torch import Tensor
9
- from typing import List, Tuple, Union
10
- import ipywidgets as widgets
11
- from IPython.display import display
12
- from huggingface_hub import hf_hub_download
13
-
14
- from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
15
- from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
16
- from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
17
- from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
18
- from reflectorch.inference.plotting import plot_prediction_results
19
- from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
20
- from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
21
- from reflectorch.runs.utils import (
22
- get_trainer_by_name, train_from_config
23
- )
24
- from reflectorch.runs.config import load_config
25
- from reflectorch.ml.trainers import PointEstimatorTrainer
26
- from reflectorch.data_generation.likelihoods import LogLikelihood
27
-
28
- from reflectorch.inference.preprocess_exp import StandardPreprocessing
29
- from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
30
- from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
31
- from reflectorch.inference.record_time import print_time
32
- from reflectorch.utils import to_t
33
-
34
- class EasyInferenceModel(object):
35
- """Facilitates the inference process using pretrained models
36
-
37
- Args:
38
- config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
39
- model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
40
- root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
41
- weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
42
- repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
43
- trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
44
- device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
45
- """
46
- def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
47
- repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
48
- self.config_name = config_name
49
- self.model_name = model_name
50
- self.root_dir = root_dir
51
- self.weights_format = weights_format
52
- self.repo_id = repo_id
53
- self.trainer = trainer
54
- self.device = device
55
-
56
- if trainer is None and self.config_name is not None:
57
- self.load_model(self.config_name, self.model_name, self.root_dir)
58
-
59
- self.prediction_result = None
60
-
61
- def load_model(self, config_name: str, model_name: str, root_dir: str) -> None:
62
- """Loads a model for inference
63
-
64
- Args:
65
- config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
66
- model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
67
- root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
68
- """
69
- if self.config_name == config_name and self.trainer is not None:
70
- return
71
-
72
- if not config_name.endswith('.yaml'):
73
- config_name_no_extension = config_name
74
- self.config_name = config_name_no_extension + '.yaml'
75
- else:
76
- config_name_no_extension = config_name[:-5]
77
- self.config_name = config_name
78
-
79
- self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
80
- weights_extension = '.' + self.weights_format
81
- self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
82
- if not self.model_name.endswith(weights_extension):
83
- self.model_name += weights_extension
84
- self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
85
-
86
- config_path = Path(self.config_dir) / self.config_name
87
- if config_path.exists():
88
- print(f"Configuration file `{config_path}` found locally.")
89
- else:
90
- print(f"Configuration file `{config_path}` not found locally.")
91
- if self.repo_id is None:
92
- raise ValueError("repo_id must be provided to download files from Huggingface.")
93
- print("Downloading from Huggingface...")
94
- hf_hub_download(repo_id=self.repo_id, subfolder='configs', filename=self.config_name, local_dir=config_path.parents[1])
95
-
96
- model_path = Path(self.model_dir) / self.model_name
97
- if model_path.exists():
98
- print(f"Weights file `{model_path}` found locally.")
99
- else:
100
- print(f"Weights file `{model_path}` not found locally.")
101
- if self.repo_id is None:
102
- raise ValueError("repo_id must be provided to download files from Huggingface.")
103
- print("Downloading from Huggingface...")
104
- hf_hub_download(repo_id=self.repo_id, subfolder='saved_models', filename=self.model_name, local_dir=model_path.parents[1])
105
-
106
- self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
107
- self.trainer.model.eval()
108
-
109
- param_model = self.trainer.loader.prior_sampler.param_model
110
- param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
111
- print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
112
- print("Parameter types and total ranges:")
113
- for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
114
- print(f"- {param}: {range_}")
115
- print("Allowed widths of the prior bound intervals (max-min):")
116
- for param, range_ in self.trainer.loader.prior_sampler.bound_width_ranges.items():
117
- print(f"- {param}: {range_}")
118
-
119
- if isinstance(self.trainer.loader.q_generator, ConstantQ):
120
- q_min = self.trainer.loader.q_generator.q[0].item()
121
- q_max = self.trainer.loader.q_generator.q[-1].item()
122
- n_q = self.trainer.loader.q_generator.q.shape[0]
123
- print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
124
- elif isinstance(self.trainer.loader.q_generator, VariableQ):
125
- q_min_range = self.trainer.loader.q_generator.q_min_range
126
- q_max_range = self.trainer.loader.q_generator.q_max_range
127
- n_q_range = self.trainer.loader.q_generator.n_q_range
128
- if n_q_range[0] == n_q_range[1]:
129
- n_q_fixed = n_q_range[0]
130
- print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
131
- f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
132
- else:
133
- print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
134
- f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
135
-
136
- if self.trainer.loader.smearing is not None:
137
- q_res_min = self.trainer.loader.smearing.sigma_min
138
- q_res_max = self.trainer.loader.smearing.sigma_max
139
- if self.trainer.loader.smearing.constant_dq == False:
140
- print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
141
- elif self.trainer.loader.smearing.constant_dq == True:
142
- print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
143
-
144
- additional_inputs = ["prior bounds"]
145
- if self.trainer.train_with_q_input:
146
- additional_inputs.append("q values")
147
- if self.trainer.condition_on_q_resolutions:
148
- additional_inputs.append("the resolution dq/q")
149
- if additional_inputs:
150
- inputs_str = ", ".join(additional_inputs)
151
- print(f"The following quantities are additional inputs to the network: {inputs_str}.")
152
-
153
- def predict(self,
154
- reflectivity_curve: Union[np.ndarray, Tensor],
155
- q_values: Union[np.ndarray, Tensor] = None,
156
- prior_bounds: Union[np.ndarray, List[Tuple]] = None,
157
- q_resolution: Union[float, np.ndarray] = None,
158
- ambient_sld: float = None,
159
- clip_prediction: bool = False,
160
- polish_prediction: bool = False,
161
- polishing_kwargs_reflectivity: dict = None,
162
- fit_growth: bool = False,
163
- max_d_change: float = 5.,
164
- use_q_shift: bool = False,
165
- calc_pred_curve: bool = True,
166
- calc_pred_sld_profile: bool = False,
167
- calc_polished_sld_profile: bool = False,
168
- ):
169
- """Predict the thin film parameters
170
-
171
- Args:
172
- reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
173
- q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
174
- prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
175
- q_resolution (Union[float, np.ndarray], optional): the instrumental resolution. Either as a float with meaning dq/q for linear smearing or as a numpy array with meaning dq for pointwise smearing.
176
- ambient_sld (float, optional): the SLD of the ambient medium (fronting), if different from air.
177
- clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
178
- polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
179
- polishing_kwargs_reflectivity (dict): extra arguments for the reflectivity function used during polishing.
180
- fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
181
- max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
182
- use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
183
- calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
184
- calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
185
- calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
186
-
187
- Returns:
188
- dict: dictionary containing the predictions
189
- """
190
-
191
- scaled_curve = self._scale_curve(reflectivity_curve)
192
- prior_bounds = np.array(prior_bounds)
193
-
194
- if ambient_sld:
195
- n_layers = self.trainer.loader.prior_sampler.max_num_layers
196
- sld_indices = slice(2*n_layers+1, 3*n_layers+2)
197
- prior_bounds[sld_indices, ...] -= ambient_sld
198
- training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
199
- training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
200
- lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
201
- upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
202
- assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
203
-
204
- try:
205
- scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
206
- except ValueError as e:
207
- print(str(e))
208
- return None
209
-
210
- if not self.trainer.train_with_q_input:
211
- q_values = self.trainer.loader.q_generator.q
212
- else:
213
- q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
214
-
215
- if use_q_shift and not self.trainer.train_with_q_input:
216
- predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
217
-
218
- else:
219
- with torch.no_grad():
220
- self.trainer.model.eval()
221
-
222
- scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
223
-
224
- if q_resolution is not None:
225
- q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
226
- if isinstance(q_resolution, float):
227
- unscaled_q_resolutions = q_resolution_tensor
228
- else:
229
- unscaled_q_resolutions = (q_resolution_tensor / q_values).mean(dim=-1, keepdim=True)
230
- scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
231
- scaled_conditioning_params = scaled_q_resolutions
232
- if polishing_kwargs_reflectivity is None:
233
- polishing_kwargs_reflectivity = {'dq': q_resolution}
234
- else:
235
- q_resolution_tensor = None
236
- scaled_conditioning_params = None
237
-
238
- scaled_predicted_params = self.trainer.model(
239
- curves=scaled_curve,
240
- bounds=scaled_prior_bounds,
241
- q_values=scaled_q_values,
242
- conditioning_params = scaled_conditioning_params,
243
- )
244
-
245
- predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
246
-
247
- if clip_prediction:
248
- predicted_params = self.trainer.loader.prior_sampler.clamp_params(predicted_params)
249
-
250
- prediction_dict = {
251
- "predicted_params_object": predicted_params,
252
- "predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
253
- "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
254
- }
255
-
256
- if calc_pred_curve:
257
- predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
258
- prediction_dict[ "predicted_curve"] = predicted_curve
259
-
260
- ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld)).to(predicted_params.thicknesses.device) if ambient_sld is not None else None
261
- if calc_pred_sld_profile:
262
- predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
263
- predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, ambient_sld_tensor, num=1024,
264
- )
265
- prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
266
- prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
267
- else:
268
- predicted_sld_xaxis = None
269
-
270
- if polish_prediction:
271
- if ambient_sld_tensor:
272
- ambient_sld_tensor = ambient_sld_tensor.cpu()
273
- polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
274
- curve = reflectivity_curve,
275
- predicted_params = predicted_params,
276
- priors = np.array(prior_bounds),
277
- fit_growth = fit_growth,
278
- max_d_change = max_d_change,
279
- calc_polished_curve = calc_pred_curve,
280
- calc_polished_sld_profile = calc_polished_sld_profile,
281
- ambient_sld_tensor=ambient_sld_tensor,
282
- sld_x_axis = predicted_sld_xaxis,
283
- polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
284
- )
285
- prediction_dict.update(polished_dict)
286
-
287
- if fit_growth and "polished_params_array" in prediction_dict:
288
- prediction_dict["param_names"].append("max_d_change")
289
-
290
- if ambient_sld: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object
291
- prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
292
- if "polished_params_array" in prediction_dict:
293
- prediction_dict["polished_params_array"][sld_indices] += ambient_sld
294
-
295
- return prediction_dict
296
-
297
- def predict_using_widget(self, reflectivity_curve, **kwargs):
298
- """
299
- """
300
-
301
- NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
302
- param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
303
- min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
304
- max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
305
- max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
306
-
307
- print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
308
-
309
- interval_widgets = []
310
- for i in range(NUM_INTERVALS):
311
- label = widgets.Label(value=f'{param_labels[i]}')
312
- initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
313
- slider = widgets.FloatRangeSlider(
314
- value=[min_bounds[i], initial_max],
315
- min=min_bounds[i],
316
- max=max_bounds[i],
317
- step=0.01,
318
- layout=widgets.Layout(width='400px'),
319
- style={'description_width': '60px'}
320
- )
321
-
322
- def validate_range(change, slider=slider, max_width=max_deltas[i]):
323
- min_val, max_val = change['new']
324
- if max_val - min_val > max_width:
325
- old_min_val, old_max_val = change['old']
326
- if abs(old_min_val - min_val) > abs(old_max_val - max_val):
327
- max_val = min_val + max_width
328
- else:
329
- min_val = max_val - max_width
330
- slider.value = [min_val, max_val]
331
-
332
- slider.observe(validate_range, names='value')
333
- interval_widgets.append((slider, widgets.HBox([label, slider])))
334
-
335
- sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
336
-
337
- output = widgets.Output()
338
- predict_button = widgets.Button(description="Predict")
339
- close_button = widgets.Button(description="Close Widget")
340
-
341
- container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
342
- display(container)
343
-
344
- @output.capture(clear_output=True)
345
- def on_predict_click(_):
346
- if 'prior_bounds' in kwargs:
347
- array_values = kwargs.pop('prior_bounds')
348
- for i, (s, _) in enumerate(interval_widgets):
349
- s.value = tuple(array_values[i])
350
- else:
351
- values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
352
- array_values = np.array(values)
353
-
354
- prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
355
- prior_bounds=array_values,
356
- **kwargs)
357
- param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
358
- for param_name, pred_param_val in zip(param_names, prediction_result["predicted_params_array"]):
359
- print(f'{param_name.ljust(14)} : {pred_param_val:.2f}')
360
-
361
- plot_prediction_results(
362
- prediction_result,
363
- q_exp=kwargs['q_values'],
364
- curve_exp=reflectivity_curve,
365
- q_model=kwargs['q_values'],
366
- )
367
- self.prediction_result = prediction_result
368
-
369
- def on_close_click(_):
370
- container.close()
371
- print("Widget closed.")
372
-
373
- predict_button.on_click(on_predict_click)
374
- close_button.on_click(on_close_click)
375
-
376
-
377
- def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
378
- assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
379
- q = self.trainer.loader.q_generator.q.squeeze().float()
380
- dq_max = (q[1] - q[0]) * dq_coef
381
- q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
382
-
383
- curve = to_t(curve).to(scaled_bounds)
384
- shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
385
-
386
- assert shifted_curves.shape == (num, q.shape[0])
387
-
388
- scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
389
- scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
390
-
391
- with torch.no_grad():
392
- self.trainer.model.eval()
393
- scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
394
- restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
395
-
396
- best_param = get_best_mse_param(
397
- restored_params,
398
- self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
399
- )
400
- return best_param
401
-
402
- def _polish_prediction(self,
403
- q: np.ndarray,
404
- curve: np.ndarray,
405
- predicted_params: BasicParams,
406
- priors: np.ndarray,
407
- sld_x_axis,
408
- ambient_sld_tensor: Tensor = None,
409
- fit_growth: bool = False,
410
- max_d_change: float = 5.,
411
- calc_polished_curve: bool = True,
412
- calc_polished_sld_profile: bool = False,
413
- polishing_kwargs_reflectivity: dict = None,
414
- ) -> dict:
415
- params = predicted_params.parameters.squeeze().cpu().numpy()
416
-
417
- polished_params_dict = {}
418
- polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
419
-
420
- try:
421
- if fit_growth:
422
- polished_params_arr, curve_polished = get_fit_with_growth(
423
- q = q,
424
- curve = curve,
425
- init_params = params,
426
- bounds = priors.T,
427
- max_d_change = max_d_change,
428
- )
429
- polished_params = BasicParams(
430
- torch.from_numpy(polished_params_arr[:-1][None]),
431
- torch.from_numpy(priors.T[0][None]),
432
- torch.from_numpy(priors.T[1][None]),
433
- self.trainer.loader.prior_sampler.max_num_layers,
434
- self.trainer.loader.prior_sampler.param_model
435
- )
436
- else:
437
- polished_params_arr, curve_polished = refl_fit(
438
- q = q,
439
- curve = curve,
440
- init_params = params,
441
- bounds=priors.T,
442
- prior_sampler=self.trainer.loader.prior_sampler,
443
- reflectivity_kwargs=polishing_kwargs_reflectivity,
444
- )
445
- polished_params = BasicParams(
446
- torch.from_numpy(polished_params_arr[None]),
447
- torch.from_numpy(priors.T[0][None]),
448
- torch.from_numpy(priors.T[1][None]),
449
- self.trainer.loader.prior_sampler.max_num_layers,
450
- self.trainer.loader.prior_sampler.param_model
451
- )
452
- except Exception as err:
453
- polished_params = predicted_params
454
- polished_params_arr = get_prediction_array(polished_params)
455
- curve_polished = np.zeros_like(q)
456
-
457
- polished_params_dict['polished_params_array'] = polished_params_arr
458
- if calc_polished_curve:
459
- polished_params_dict['polished_curve'] = curve_polished
460
-
461
- if calc_polished_sld_profile:
462
- _, sld_profile_polished, _ = get_density_profiles(
463
- polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, ambient_sld_tensor, z_axis=sld_x_axis.cpu(),
464
- )
465
- polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().numpy()
466
-
467
- return polished_params_dict
468
-
469
- def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
470
- if not isinstance(curve, Tensor):
471
- curve = torch.from_numpy(curve).float()
472
- curve = torch.atleast_2d(curve).to(self.device)
473
- scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
474
- return scaled_curve
475
-
476
- def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
477
- try:
478
- prior_bounds = torch.tensor(prior_bounds)
479
- prior_bounds = prior_bounds.to(self.device).T
480
- min_bounds, max_bounds = prior_bounds[:, None]
481
-
482
- scaled_bounds = torch.cat([
483
- self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
484
- self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
485
- ], -1)
486
-
487
- return scaled_bounds.float()
488
-
489
- except RuntimeError as e:
490
- expected_param_dim = self.trainer.loader.prior_sampler.param_dim
491
- actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
492
-
493
- msg = (
494
- f"\n **Parameter dimension mismatch during inference!**\n"
495
- f"- Model expects **{expected_param_dim}** parameters.\n"
496
- f"- You provided **{actual_param_dim}** prior bounds.\n\n"
497
- f"💡This often occurs when:\n"
498
- f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
499
- f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
500
- f"- The number of layers or parameterization type differs from the one used during training.\n\n"
501
- f" Check the configuration or the summary of expected parameters."
502
- )
503
- raise ValueError(msg) from e
504
-
505
- def interpolate_data_to_model_q(self, q_exp, curve_exp):
506
- if isinstance(self.trainer.loader.q_generator, ConstantQ):
507
- q_model = self.trainer.loader.q_generator.q.cpu().numpy()
508
- elif isinstance(self.trainer.loader.q_generator, VariableQ):
509
- if self.trainer.loader.q_generator.n_q_range[0] == self.trainer.loader.q_generator.n_q_range[1]:
510
- n_q_model = self.trainer.loader.q_generator.n_q_range[0]
511
- q_model_min = max(q_exp.min(), self.trainer.loader.q_generator.q_min_range[0])
512
- q_model_max = min(q_exp.max(), self.trainer.loader.q_generator.q_max_range[1])
513
- q_model = np.linspace(q_model_min, q_model_max, n_q_model)
514
- else:
515
- q_model = q_exp
516
- exp_curve_interp = curve_exp
517
-
518
- exp_curve_interp = interp_reflectivity(q_model, q_exp, curve_exp)
519
-
520
- return q_model, exp_curve_interp
521
-
522
- def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
523
- return LogLikelihood(
524
- q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
525
- )
526
-
527
- class InferenceModel(object):
528
- def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,
529
- num_sampling: int = 2 ** 13):
530
- self.log = logging.getLogger(__name__)
531
- self.model_name = name
532
- self.trainer = trainer
533
- self.q = None
534
- self.preprocessing = StandardPreprocessing(**(preprocessing_parameters or {}))
535
- self._sampling_num = num_sampling
536
-
537
- if trainer is None and self.model_name is not None:
538
- self.load_model(self.model_name)
539
- elif trainer is not None:
540
- self._set_trainer(trainer, preprocessing_parameters)
541
-
542
- ### API methods ###
543
-
544
- def load_model(self, name: str) -> None:
545
- self.log.debug(f"loading model {name}")
546
- if self.model_name == name and self.trainer is not None:
547
- return
548
- self.model_name = name
549
- self._set_trainer(get_trainer_by_name(name))
550
- self.log.info(f"Model {name} is loaded.")
551
-
552
- def train_model(self, name: str):
553
- self.model_name = name
554
- self.trainer = train_from_config(load_config(name))
555
-
556
- def set_preprocessing_parameters(self, **kwargs) -> None:
557
- self.preprocessing.set_parameters(**kwargs)
558
-
559
- def preprocess(self,
560
- intensity: np.ndarray,
561
- scattering_angle: np.ndarray,
562
- attenuation: np.ndarray,
563
- update_params: bool = False,
564
- **kwargs) -> dict:
565
- if update_params:
566
- self.preprocessing.set_parameters(**kwargs)
567
- preprocessed_dict = self.preprocessing(intensity, scattering_angle, attenuation, **kwargs)
568
- return preprocessed_dict
569
-
570
- def predict(self,
571
- intensity: np.ndarray,
572
- scattering_angle: np.ndarray,
573
- attenuation: np.ndarray,
574
- priors: np.ndarray,
575
- preprocessing_parameters: dict = None,
576
- polish: bool = True,
577
- use_sampler: bool = False,
578
- use_q_shift: bool = True,
579
- max_d_change: float = 5.,
580
- fit_growth: bool = True,
581
- ) -> dict:
582
-
583
- with print_time("everything"):
584
- with print_time("preprocess"):
585
- preprocessed_dict = self.preprocess(
586
- intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
587
- )
588
-
589
- preprocessed_curve = preprocessed_dict["curve_interp"]
590
- raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
591
- q_ratio = preprocessed_dict["q_ratio"]
592
-
593
- with print_time("predict_from_preprocessed_curve"):
594
- preprocessed_dict.update(self.predict_from_preprocessed_curve(
595
- preprocessed_curve, priors, raw_curve=raw_curve, raw_q=raw_q, polish=polish, q_ratio=q_ratio,
596
- use_sampler=use_sampler, use_q_shift=use_q_shift, max_d_change=max_d_change,
597
- fit_growth=fit_growth,
598
- ))
599
-
600
- return preprocessed_dict
601
-
602
- def predict_from_preprocessed_curve(self,
603
- curve: np.ndarray,
604
- priors: np.ndarray, *,
605
- polish: bool = True,
606
- raw_curve: np.ndarray = None,
607
- raw_q: np.ndarray = None,
608
- clip_prediction: bool = True,
609
- q_ratio: float = 1.,
610
- use_sampler: bool = False,
611
- use_q_shift: bool = True,
612
- max_d_change: float = 5.,
613
- fit_growth: bool = True,
614
- ) -> dict:
615
-
616
- scaled_curve = self._scale_curve(curve)
617
- scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
618
-
619
- if not use_q_shift:
620
- predicted_params: UniformSubPriorParams = self._simple_prediction(scaled_curve, scaled_bounds)
621
- else:
622
- predicted_params: UniformSubPriorParams = self._qshift_prediction(curve, scaled_bounds)
623
-
624
- if use_sampler:
625
- predicted_params: UniformSubPriorParams = self._sampler_solution(
626
- curve, predicted_params,
627
- )
628
-
629
- if clip_prediction:
630
- predicted_params = self._prior_sampler.clamp_params(predicted_params)
631
-
632
- if raw_curve is None:
633
- raw_curve = curve
634
- if raw_q is None:
635
- raw_q = self.q.squeeze().cpu().numpy()
636
- raw_q_t = self.q
637
- else:
638
- raw_q_t = torch.from_numpy(raw_q).to(self.q)
639
-
640
- if q_ratio != 1.:
641
- predicted_params.scale_with_q(q_ratio)
642
- raw_q = raw_q * q_ratio
643
- raw_q_t = raw_q_t * q_ratio
644
-
645
- prediction_dict = {
646
- "params": get_prediction_array(predicted_params),
647
- "param_names": get_param_labels(
648
- predicted_params.max_layer_num,
649
- thickness_name='d',
650
- roughness_name='sigma',
651
- sld_name='rho',
652
- ),
653
- "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
654
- }
655
-
656
- sld_x_axis, sld_profile, _ = get_density_profiles(
657
- predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
658
- )
659
-
660
- prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
661
- prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
662
-
663
- if polish:
664
- prediction_dict.update(self._polish_prediction(
665
- raw_q, raw_curve, predicted_params, priors, sld_x_axis,
666
- max_d_change=max_d_change, fit_growth=fit_growth,
667
- ))
668
-
669
- if fit_growth and "params_polished" in prediction_dict:
670
- prediction_dict["param_names"].append("max_d_change")
671
-
672
- return prediction_dict
673
-
674
- ### some shortcut methods for data processing ###
675
-
676
- def _simple_prediction(self, scaled_curve, scaled_bounds) -> UniformSubPriorParams:
677
- context = torch.cat([scaled_curve, scaled_bounds], -1)
678
-
679
- with torch.no_grad():
680
- self.trainer.model.eval()
681
- scaled_params = self.trainer.model(context)
682
-
683
- predicted_params: UniformSubPriorParams = self._restore_predicted_params(scaled_params, context)
684
- return predicted_params
685
-
686
- @print_time
687
- def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> UniformSubPriorParams:
688
- q = self.q.squeeze().float()
689
- curve = to_t(curve).to(q)
690
- dq_max = (q[1] - q[0]) * dq_coef
691
- q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
692
- shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
693
-
694
- assert shifted_curves.shape == (num, q.shape[0])
695
-
696
- scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
697
- context = torch.cat([scaled_curves, torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)], -1)
698
-
699
- with torch.no_grad():
700
- self.trainer.model.eval()
701
- scaled_params = self.trainer.model(context)
702
- restored_params = self._restore_predicted_params(scaled_params, context)
703
-
704
- best_param = get_best_mse_param(
705
- restored_params,
706
- self._get_likelihood(curve),
707
- )
708
- return best_param
709
-
710
- @print_time
711
- def _polish_prediction(self,
712
- q: np.ndarray,
713
- curve: np.ndarray,
714
- predicted_params: Params,
715
- priors: np.ndarray,
716
- sld_x_axis,
717
- fit_growth: bool = True,
718
- max_d_change: float = 5.,
719
- ) -> dict:
720
- params = torch.cat([
721
- predicted_params.thicknesses.squeeze(),
722
- predicted_params.roughnesses.squeeze(),
723
- predicted_params.slds.squeeze()
724
- ]).cpu().numpy()
725
-
726
- polished_params_dict = {}
727
-
728
- try:
729
- if fit_growth:
730
- polished_params_arr, curve_polished = get_fit_with_growth(
731
- q, curve, params, bounds=priors.T,
732
- max_d_change=max_d_change,
733
- )
734
- polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[:-1][None]).to(self.q))
735
- else:
736
- polished_params_arr, curve_polished = standard_refl_fit(q, curve, params, bounds=priors.T)
737
- polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[None]).to(self.q))
738
- except Exception as err:
739
- self.log.exception(err)
740
- polished_params = predicted_params
741
- polished_params_arr = get_prediction_array(polished_params)
742
- curve_polished = np.zeros_like(q)
743
-
744
- polished_params_dict['params_polished'] = polished_params_arr
745
- polished_params_dict['curve_polished'] = curve_polished
746
-
747
- sld_x_axis_polished, sld_profile_polished, _ = get_density_profiles(
748
- polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
749
- )
750
-
751
- polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
752
-
753
- return polished_params_dict
754
-
755
- def _restore_predicted_params(self, scaled_params: Tensor, context: Tensor) -> UniformSubPriorParams:
756
- predicted_params: UniformSubPriorParams = self.trainer.loader.prior_sampler.restore_params(
757
- self.trainer.loader.prior_sampler.PARAM_CLS.restore_params_from_context(scaled_params, context)
758
- )
759
- return predicted_params
760
-
761
- def _input2context(self, curve: np.ndarray, priors: np.ndarray, q_ratio: float = 1.):
762
- scaled_curve = self._scale_curve(curve)
763
- scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
764
- scaled_input = torch.cat([scaled_curve, scaled_bounds], -1)
765
- return scaled_input, min_bounds, max_bounds
766
-
767
- def _scale_curve(self, curve: np.ndarray or Tensor):
768
- if not isinstance(curve, Tensor):
769
- curve = torch.from_numpy(curve).float()
770
- curve = torch.atleast_2d(curve).to(self.q)
771
- scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
772
- return scaled_curve.float()
773
-
774
- def _scale_priors(self, priors: np.ndarray or Tensor, q_ratio: float = 1.):
775
- if not isinstance(priors, Tensor):
776
- priors = torch.from_numpy(priors)
777
-
778
- priors = priors.float().clone()
779
-
780
- priors = priors.to(self.q).T
781
- priors = self._prior_sampler.scale_bounds_with_q(priors, 1 / q_ratio)
782
- priors = self._prior_sampler.clamp_bounds(priors)
783
-
784
- min_bounds, max_bounds = priors[:, None].to(self.q)
785
- prior_sampler = self._prior_sampler
786
- scaled_bounds = torch.cat([
787
- prior_sampler.scale_bounds(min_bounds), prior_sampler.scale_bounds(max_bounds)
788
- ], -1)
789
- return scaled_bounds.float(), min_bounds, max_bounds
790
-
791
- @property
792
- def _prior_sampler(self) -> ExpUniformSubPriorSampler:
793
- return self.trainer.loader.prior_sampler
794
-
795
- def _set_trainer(self, trainer, preprocessing_parameters: dict = None):
796
- self.trainer = trainer
797
- self.trainer.model.eval()
798
- self._update_preprocessing(preprocessing_parameters)
799
-
800
- def _update_preprocessing(self, preprocessing_parameters: dict = None):
801
- self.log.debug(f"setting preprocessing_parameters {preprocessing_parameters}.")
802
- self.q = self.trainer.loader.q_generator.q
803
- self.preprocessing = StandardPreprocessing(
804
- self.q.cpu().squeeze().numpy(),
805
- **(preprocessing_parameters or {})
806
- )
807
- self.log.info(f"preprocessing params are set: {preprocessing_parameters}.")
808
-
809
- @print_time
810
- def _sampler_solution(
811
- self,
812
- curve: Tensor or np.ndarray,
813
- predicted_params: UniformSubPriorParams,
814
- ) -> UniformSubPriorParams:
815
-
816
- if not isinstance(curve, Tensor):
817
- curve = torch.from_numpy(curve).float()
818
- curve = curve.to(self.q)
819
-
820
- refined_params = simple_sampler_solution(
821
- self._get_likelihood(curve),
822
- predicted_params,
823
- self._prior_sampler.min_bounds,
824
- self._prior_sampler.max_bounds,
825
- num=self._sampling_num, coef=0.1,
826
- )
827
-
828
- return refined_params
829
-
830
- def _get_likelihood(self, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
831
- return LogLikelihood(
832
- self.q, curve, self._prior_sampler, curve * rel_err + abs_err
833
- )
834
-
835
-
836
- def get_prediction_array(params: BasicParams) -> np.ndarray:
837
- predict_arr = torch.cat([
838
- params.thicknesses.squeeze(),
839
- params.roughnesses.squeeze(),
840
- params.slds.squeeze(),
841
- ]).cpu().numpy()
842
-
843
- return predict_arr
844
-
845
-
846
- def _qshift_interp(q, r, q_shifts):
847
- qs = q[None] + q_shifts[:, None]
848
- eps = torch.finfo(r.dtype).eps
849
- ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
850
- ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
851
- slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+ from typing import List, Tuple, Union
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from reflectorch.data_generation.priors import BasicParams
10
+ from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
11
+ from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
12
+ from reflectorch.data_generation.utils import get_density_profiles
13
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
14
+ from reflectorch.paths import CONFIG_DIR, SAVED_MODELS_DIR
15
+ from reflectorch.runs.utils import (
16
+ get_trainer_by_name
17
+ )
18
+ from reflectorch.ml.trainers import PointEstimatorTrainer
19
+ from reflectorch.data_generation.likelihoods import LogLikelihood
20
+
21
+ from reflectorch.inference.scipy_fitter import refl_fit, get_fit_with_growth
22
+ from reflectorch.inference.sampler_solution import get_best_mse_param
23
+ from reflectorch.utils import get_filtering_mask, to_t
24
+
25
+ from huggingface_hub.utils import disable_progress_bars
26
+
27
+ # that causes some Rust related errors when downloading models from Huggingface
28
+ disable_progress_bars()
29
+
30
+
31
+ class InferenceModel(object):
32
+ """Facilitates the inference process using pretrained models
33
+
34
+ Args:
35
+ config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
36
+ model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
37
+ root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
38
+ weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
39
+ repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
40
+ trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
41
+ device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
42
+ """
43
+ def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
44
+ repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
45
+ self.config_name = config_name
46
+ self.model_name = model_name
47
+ self.root_dir = root_dir
48
+ self.weights_format = weights_format
49
+ self.repo_id = repo_id
50
+ self.trainer = trainer
51
+ self.device = device
52
+
53
+ if trainer is None and self.config_name is not None:
54
+ self.load_model(self.config_name, self.model_name, self.root_dir)
55
+
56
+ self.prediction_result = None
57
+
58
+ def load_model(self, config_name: str, model_name: str, root_dir: str) -> None:
59
+ """Loads a model for inference
60
+
61
+ Args:
62
+ config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
63
+ model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
64
+ root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
65
+ """
66
+ if self.config_name == config_name and self.trainer is not None:
67
+ return
68
+
69
+ if not config_name.endswith('.yaml'):
70
+ config_name_no_extension = config_name
71
+ self.config_name = config_name_no_extension + '.yaml'
72
+ else:
73
+ config_name_no_extension = config_name[:-5]
74
+ self.config_name = config_name
75
+
76
+ self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
77
+ weights_extension = '.' + self.weights_format
78
+ self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
79
+ if not self.model_name.endswith(weights_extension):
80
+ self.model_name += weights_extension
81
+ self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
82
+
83
+ def _download_with_fallback(filename: str, local_target_dir: Path, legacy_subfolder: str):
84
+ """Try to download from repo root (new layout). If not found, retry with legacy `subfolder=legacy_subfolder`. Place result under local_target_dir using `local_dir`.
85
+ """
86
+ try: # new layout: files at repo root (same level as README.md)
87
+ hf_hub_download(repo_id=self.repo_id + '/' + config_name, filename=filename, local_dir=str(local_target_dir))
88
+ except Exception : # legacy layout fallback: e.g. subfolder='configs' or 'saved_models'
89
+ hf_hub_download(repo_id=self.repo_id, filename=filename, subfolder=legacy_subfolder, local_dir=str(local_target_dir.parent))
90
+
91
+ config_path = Path(self.config_dir) / self.config_name
92
+ if config_path.exists():
93
+ print(f"Configuration file `{config_path}` found locally.")
94
+ else:
95
+ print(f"Configuration file `{config_path}` not found locally.")
96
+ if self.repo_id is None:
97
+ raise ValueError("repo_id must be provided to download files from Huggingface.")
98
+ print("Downloading from Huggingface...")
99
+ _download_with_fallback(self.config_name, self.config_dir, legacy_subfolder='configs')
100
+
101
+ model_path = Path(self.model_dir) / self.model_name
102
+ if model_path.exists():
103
+ print(f"Weights file `{model_path}` found locally.")
104
+ else:
105
+ print(f"Weights file `{model_path}` not found locally.")
106
+ if self.repo_id is None:
107
+ raise ValueError("repo_id must be provided to download files from Huggingface.")
108
+ print("Downloading from Huggingface...")
109
+ _download_with_fallback(self.model_name, self.model_dir, legacy_subfolder='saved_models')
110
+
111
+ self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
112
+ self.trainer.model.eval()
113
+
114
+ param_model = self.trainer.loader.prior_sampler.param_model
115
+ param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
116
+ print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
117
+ print("Parameter types and total ranges:")
118
+ for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
119
+ print(f"- {param}: {range_}")
120
+ print("Allowed widths of the prior bound intervals (max-min):")
121
+ for param, range_ in self.trainer.loader.prior_sampler.bound_width_ranges.items():
122
+ print(f"- {param}: {range_}")
123
+
124
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
125
+ q_min = self.trainer.loader.q_generator.q[0].item()
126
+ q_max = self.trainer.loader.q_generator.q[-1].item()
127
+ n_q = self.trainer.loader.q_generator.q.shape[0]
128
+ print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
129
+ elif isinstance(self.trainer.loader.q_generator, VariableQ):
130
+ q_min_range = self.trainer.loader.q_generator.q_min_range
131
+ q_max_range = self.trainer.loader.q_generator.q_max_range
132
+ n_q_range = self.trainer.loader.q_generator.n_q_range
133
+ if n_q_range[0] == n_q_range[1]:
134
+ n_q_fixed = n_q_range[0]
135
+ print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
136
+ f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
137
+ else:
138
+ print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
139
+ f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
140
+
141
+ if self.trainer.loader.smearing is not None:
142
+ q_res_min = self.trainer.loader.smearing.sigma_min
143
+ q_res_max = self.trainer.loader.smearing.sigma_max
144
+ if self.trainer.loader.smearing.constant_dq == False:
145
+ print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
146
+ elif self.trainer.loader.smearing.constant_dq == True:
147
+ print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
148
+
149
+ additional_inputs = ["prior bounds"]
150
+ if self.trainer.train_with_q_input:
151
+ additional_inputs.append("q values")
152
+ if self.trainer.condition_on_q_resolutions:
153
+ additional_inputs.append("the resolution dq/q")
154
+ if additional_inputs:
155
+ inputs_str = ", ".join(additional_inputs)
156
+ print(f"The following quantities are additional inputs to the network: {inputs_str}.")
157
+
158
+ def preprocess_and_predict(self,
159
+ reflectivity_curve: np.ndarray,
160
+ q_values: np.ndarray = None,
161
+ prior_bounds: Union[np.ndarray, List[Tuple]] = None,
162
+ sigmas: np.ndarray = None,
163
+ q_resolution: Union[float, np.ndarray] = None,
164
+ ambient_sld: float = None,
165
+ clip_prediction: bool = True,
166
+ polish_prediction: bool = False,
167
+ polishing_method: str = 'trf',
168
+ polishing_kwargs_reflectivity: dict = None,
169
+ use_sigmas_for_polishing: bool = False,
170
+ polishing_max_steps: int = None,
171
+ fit_growth: bool = False,
172
+ max_d_change: float = 5.,
173
+ calc_pred_curve: bool = True,
174
+ calc_pred_sld_profile: bool = False,
175
+ calc_polished_sld_profile: bool = False,
176
+ sld_profile_padding_left: float = 0.2,
177
+ sld_profile_padding_right: float = 1.1,
178
+ kwargs_param_labels: dict = {},
179
+
180
+ truncate_index_left: int = None,
181
+ truncate_index_right: int = None,
182
+ enable_error_bars_filtering: bool = True,
183
+ filter_threshold=0.3,
184
+ filter_remove_singles=True,
185
+ filter_remove_consecutives=True,
186
+ filter_consecutive=3,
187
+ filter_q_start_trunc=0.1,
188
+ ):
189
+ """Preprocess experimental data (clean, truncate, filter, interpolate) and run prediction. This wrapper prepares inputs according to the model's Q generator calls `predict(...)` on the interpolated/padded data, and (optionally) performs a polishing step on the original data (pre-interpolation)
190
+
191
+ Args:
192
+ reflectivity_curve (Union[np.ndarray, Tensor]): 1D array of experimental reflectivity values.
193
+ q_values (Union[np.ndarray, Tensor]): 1D array of momentum transfer values for the reflectivity curve (in units of inverse angstroms).
194
+ prior_bounds (Union[np.ndarray, List[Tuple]]): Prior bounds for all parameters, shape ``(num_params, 2)`` as ``[(min, max), …]``.
195
+ sigmas (Union[np.ndarray, Tensor], optional): 1D array of experimental uncertainties (same length as `reflectivity_curve`). Used for error-bar filtering (if enabled) and for polishing (if requested).
196
+ q_resolution (Union[float, np.ndarray], optional): The q resolution for neutron reflectometry models. Can be either a float (dq/q) for linear resolution smearing (e.g. 0.05 meaning 5% reolution smearing) or an array of dq values for pointwise resolution smearing.
197
+ ambient_sld (float, optional): The SLD of the fronting (i.e. ambient) medium for structure with fronting medium different than air.
198
+ clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to True.
199
+ polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Defaults to False.
200
+ polishing_method (str): {'trf', 'dogbox', 'lm'} SciPy least-squares method used for polishing.
201
+ use_sigmas_for_polishing (bool): If ``True``, weigh residuals by `sigmas` during polishing.
202
+ polishing_max_steps (int, optional): Maximum number of function evaluations for the SciPy optimizer.
203
+ fit_growth (bool, optional): (Deprecated) If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
204
+ max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
205
+ calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
206
+ calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
207
+ calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
208
+ sld_profile_padding_left (float, optional): Controls the amount of padding applied to the left side of the computed SLD profiles.
209
+ sld_profile_padding_right (float, optional): Controls the amount of padding applied to the right side of the computed SLD profiles.
210
+ truncate_index_left (int, optional): The data provided as input to the neural network will be truncated between the indices [truncate_index_left, truncate_index_right].
211
+ truncate_index_right (int, optional): The data provided as input to the neural network will be truncated between the indices [truncate_index_left, truncate_index_right].
212
+ enable_error_bars_filtering (bool, optional). If ``True``, the data points with high error bars (above a threshold) will be removed before constructing the input to the neural network (they are still used in the polishing step). Default to True.
213
+ filter_threshold (float, optional). The relative threshold (dR/R) for error bar filtering. Defaults to 0.3.
214
+ filter_remove_singles (float, optional). If ``True``, all isolated points exceeding the filtering threshold will be eliminated. Default to True.
215
+ filter_remove_consecutives (float, optional). If ``True``, in the situation when a number of ``filter_consecutive`` consecutive points exceeding the filtering threshold are detected at a position higher than ``filter_q_start_trunc``, all the subsequent points in the curve are eliminated.
216
+
217
+ Returns:
218
+ dict: dictionary containing the predictions
219
+ """
220
+
221
+ ## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
222
+ (q_values, reflectivity_curve, sigmas, q_resolution,
223
+ q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
224
+ reflectivity_curve=reflectivity_curve,
225
+ q_values=q_values,
226
+ sigmas=sigmas,
227
+ q_resolution=q_resolution,
228
+ truncate_index_left=truncate_index_left,
229
+ truncate_index_right=truncate_index_right,
230
+ enable_error_bars_filtering=enable_error_bars_filtering,
231
+ filter_threshold=filter_threshold,
232
+ filter_remove_singles=filter_remove_singles,
233
+ filter_remove_consecutives=filter_remove_consecutives,
234
+ filter_consecutive=filter_consecutive,
235
+ filter_q_start_trunc=filter_q_start_trunc,
236
+ )
237
+
238
+ ### Interpolate the experimental data if needed by the embedding network
239
+ interp_data = self.interpolate_data_to_model_q(
240
+ q_exp=q_values,
241
+ refl_exp=reflectivity_curve,
242
+ sigmas_exp=sigmas,
243
+ q_res_exp=q_resolution,
244
+ as_dict=True
245
+ )
246
+
247
+ q_model = interp_data["q_model"]
248
+ reflectivity_curve_interp = interp_data["reflectivity"]
249
+ sigmas_interp = interp_data.get("sigmas")
250
+ q_resolution_interp = interp_data.get("q_resolution")
251
+ key_padding_mask = interp_data.get("key_padding_mask")
252
+
253
+ ### Make the prediction
254
+ prediction_dict = self.predict(
255
+ reflectivity_curve=reflectivity_curve_interp,
256
+ q_values=q_model,
257
+ sigmas=sigmas_interp,
258
+ q_resolution=q_resolution_interp,
259
+ key_padding_mask=key_padding_mask,
260
+ prior_bounds=prior_bounds,
261
+ ambient_sld=ambient_sld,
262
+ clip_prediction=clip_prediction,
263
+ polish_prediction=False, ###do the polishing outside the predict method on the full data
264
+ supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
265
+ calc_pred_curve=calc_pred_curve,
266
+ calc_pred_sld_profile=calc_pred_sld_profile,
267
+ sld_profile_padding_left=sld_profile_padding_left,
268
+ sld_profile_padding_right=sld_profile_padding_right,
269
+ kwargs_param_labels=kwargs_param_labels,
270
+ )
271
+
272
+ ### Save interpolated data
273
+ prediction_dict['q_model'] = q_model
274
+ prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
275
+ if q_resolution_interp is not None:
276
+ prediction_dict['q_resolution_interp'] = q_resolution_interp
277
+ if sigmas_interp is not None:
278
+ prediction_dict['sigmas_interp'] = sigmas_interp
279
+ if key_padding_mask is not None:
280
+ prediction_dict['key_padding_mask'] = key_padding_mask
281
+
282
+ ### Shift the slds for nonzero ambient
283
+ prior_bounds = np.array(prior_bounds)
284
+ if ambient_sld:
285
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
286
+
287
+ ### Perform polishing on the original data
288
+ if polish_prediction:
289
+ polishing_kwargs = polishing_kwargs_reflectivity or {}
290
+ polishing_kwargs.setdefault('dq', q_resolution_original)
291
+
292
+ polished_dict = self._polish_prediction(
293
+ q=q_values_original,
294
+ curve=reflectivity_curve_original,
295
+ predicted_params=prediction_dict['predicted_params_object'],
296
+ priors=prior_bounds,
297
+ ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)) if ambient_sld is not None else None,
298
+ calc_polished_sld_profile=calc_polished_sld_profile,
299
+ sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
300
+ polishing_kwargs_reflectivity = polishing_kwargs,
301
+ error_bars=sigmas_original if use_sigmas_for_polishing else None,
302
+ polishing_method=polishing_method,
303
+ polishing_max_steps=polishing_max_steps,
304
+ fit_growth=fit_growth,
305
+ max_d_change=max_d_change,
306
+ )
307
+
308
+ prediction_dict.update(polished_dict)
309
+ if fit_growth and "polished_params_array" in prediction_dict:
310
+ prediction_dict["param_names"].append("max_d_change")
311
+
312
+ ### Shift back the slds for nonzero ambient
313
+ if ambient_sld:
314
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
315
+
316
+ return prediction_dict
317
+
318
+
319
+ def predict(self,
320
+ reflectivity_curve: Union[np.ndarray, Tensor],
321
+ q_values: Union[np.ndarray, Tensor] = None,
322
+ prior_bounds: Union[np.ndarray, List[Tuple]] = None,
323
+ sigmas: Union[np.ndarray, Tensor] = None,
324
+ key_padding_mask: Union[np.ndarray, Tensor] = None,
325
+ q_resolution: Union[float, np.ndarray] = None,
326
+ ambient_sld: float = None,
327
+ clip_prediction: bool = True,
328
+ polish_prediction: bool = False,
329
+ polishing_method: str = 'trf',
330
+ polishing_kwargs_reflectivity: dict = None,
331
+ polishing_max_steps: int = None,
332
+ fit_growth: bool = False,
333
+ max_d_change: float = 5.,
334
+ use_q_shift: bool = False,
335
+ calc_pred_curve: bool = True,
336
+ calc_pred_sld_profile: bool = False,
337
+ calc_polished_sld_profile: bool = False,
338
+ sld_profile_padding_left: float = 0.2,
339
+ sld_profile_padding_right: float = 1.1,
340
+ supress_sld_amb_back_shift: bool = False,
341
+ kwargs_param_labels: dict = {},
342
+ ):
343
+ """Predict the thin film parameters
344
+
345
+ Args:
346
+ reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
347
+ q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
348
+ prior_bounds (Union[np.ndarray, List[Tuple]]): The prior bounds for the predicted parameters.
349
+ sigmas (Union[np.ndarray, Tensor], optional): The error bars of the reflectivity curve, if available. They are used for filtering out points with high error bars if ``enable_error_bars_filtering`` is ``True``, as well as for the polishing step if ``use_sigmas_for_polishing`` is ``True``.
350
+ key_padding_mask (Union[np.ndarray, Tensor], optional): The key padding mask required for some embedding networks.
351
+ q_resolution (Union[float, np.ndarray], optional): The q resolution for neutron reflectometry models. Can be either a float dq/q for linear resolution smearing (e.g. 0.05 meaning 5% reolution smearing) or an array of dq values for pointwise resolution smearing.
352
+ ambient_sld (float, optional): The SLD of the fronting (i.e. ambient) medium for structure with fronting medium different than air.
353
+ clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to True.
354
+ polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Defaults to False.
355
+ polishing_method (str): Type of scipy method used for polishing.
356
+ polishing_max_steps (int, optional): Sets the maximum number of steps for the polishing algorithm.
357
+ fit_growth (bool, optional): (Deprecated) If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
358
+ max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
359
+ use_q_shift: (Deprecated) If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
360
+ calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
361
+ calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
362
+ calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
363
+ sld_profile_padding_left (float, optional): Controls the amount of padding applied to the left side of the computed SLD profiles.
364
+ sld_profile_padding_right (float, optional): Controls the amount of padding applied to the right side of the computed SLD profiles.
365
+
366
+ Returns:
367
+ dict: dictionary containing the predictions
368
+ """
369
+
370
+ scaled_curve = self._scale_curve(reflectivity_curve)
371
+ if prior_bounds is None:
372
+ raise ValueError(f'Prior bounds were not provided')
373
+ prior_bounds = np.array(prior_bounds)
374
+
375
+ if ambient_sld:
376
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
377
+
378
+ scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
379
+
380
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
381
+ q_values = self.trainer.loader.q_generator.q
382
+ else:
383
+ if q_values is None:
384
+ raise ValueError(f'The q values were not provided')
385
+ q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
386
+
387
+ scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
388
+
389
+ if q_resolution is None and self.trainer.loader.smearing is not None:
390
+ raise ValueError(f'The q resolution must be provided for NR models')
391
+
392
+ if q_resolution is not None:
393
+ q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
394
+ if isinstance(q_resolution, float):
395
+ unscaled_q_resolutions = q_resolution_tensor
396
+ else:
397
+ unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
398
+ scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
399
+ scaled_conditioning_params = scaled_q_resolutions
400
+ if polishing_kwargs_reflectivity is None:
401
+ polishing_kwargs_reflectivity = {'dq': q_resolution}
402
+ else:
403
+ q_resolution_tensor = None
404
+ scaled_conditioning_params = None
405
+
406
+ if key_padding_mask is not None:
407
+ key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
408
+ key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
409
+
410
+ if use_q_shift and not self.trainer.train_with_q_input:
411
+ predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
412
+ else:
413
+ with torch.no_grad():
414
+ self.trainer.model.eval()
415
+
416
+ scaled_predicted_params = self.trainer.model(
417
+ curves=scaled_curve,
418
+ bounds=scaled_prior_bounds,
419
+ q_values=scaled_q_values,
420
+ conditioning_params = scaled_conditioning_params,
421
+ key_padding_mask = key_padding_mask,
422
+ unscaled_q_values = q_values,
423
+ )
424
+
425
+ predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
426
+
427
+ if clip_prediction:
428
+ predicted_params = self.trainer.loader.prior_sampler.clamp_params(predicted_params)
429
+
430
+ prediction_dict = {
431
+ "predicted_params_object": predicted_params,
432
+ "predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
433
+ "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
434
+ }
435
+
436
+ key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
437
+
438
+ if calc_pred_curve:
439
+ predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
440
+ prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
441
+
442
+ ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
443
+ if calc_pred_sld_profile:
444
+ predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
445
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
446
+ num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
447
+ )
448
+ prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
449
+ prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
450
+ else:
451
+ predicted_sld_xaxis = None
452
+
453
+ refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
454
+ q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
455
+ prediction_dict['q_plot_pred'] = q_polish
456
+
457
+ if polish_prediction:
458
+ if ambient_sld_tensor:
459
+ ambient_sld_tensor = ambient_sld_tensor.cpu()
460
+
461
+ polished_dict = self._polish_prediction(
462
+ q = q_polish,
463
+ curve = refl_curve_polish,
464
+ predicted_params = predicted_params,
465
+ priors = np.array(prior_bounds),
466
+ error_bars = sigmas,
467
+ fit_growth = fit_growth,
468
+ max_d_change = max_d_change,
469
+ calc_polished_curve = calc_pred_curve,
470
+ calc_polished_sld_profile = calc_polished_sld_profile,
471
+ ambient_sld_tensor=ambient_sld_tensor,
472
+ sld_x_axis = predicted_sld_xaxis,
473
+ polishing_method=polishing_method,
474
+ polishing_max_steps=polishing_max_steps,
475
+ polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
476
+ )
477
+ prediction_dict.update(polished_dict)
478
+
479
+ if fit_growth and "polished_params_array" in prediction_dict:
480
+ prediction_dict["param_names"].append("max_d_change")
481
+
482
+ if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
483
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
484
+
485
+ return prediction_dict
486
+
487
+ def _polish_prediction(self,
488
+ q: np.ndarray,
489
+ curve: np.ndarray,
490
+ predicted_params: BasicParams,
491
+ priors: np.ndarray,
492
+ sld_x_axis,
493
+ ambient_sld_tensor: Tensor = None,
494
+ fit_growth: bool = False,
495
+ max_d_change: float = 5.,
496
+ calc_polished_curve: bool = True,
497
+ calc_polished_sld_profile: bool = False,
498
+ error_bars: np.ndarray = None,
499
+ polishing_method: str = 'trf',
500
+ polishing_max_steps: int = None,
501
+ polishing_kwargs_reflectivity: dict = None,
502
+ ) -> dict:
503
+ params = predicted_params.parameters.squeeze().cpu().numpy()
504
+
505
+ polished_params_dict = {}
506
+ polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
507
+
508
+ try:
509
+ if fit_growth:
510
+ polished_params_arr, curve_polished = get_fit_with_growth(
511
+ q = q,
512
+ curve = curve,
513
+ init_params = params,
514
+ bounds = priors.T,
515
+ max_d_change = max_d_change,
516
+ )
517
+ polished_params = BasicParams(
518
+ torch.from_numpy(polished_params_arr[:-1][None]),
519
+ torch.from_numpy(priors.T[0][None]),
520
+ torch.from_numpy(priors.T[1][None]),
521
+ self.trainer.loader.prior_sampler.max_num_layers,
522
+ self.trainer.loader.prior_sampler.param_model
523
+ )
524
+ else:
525
+ polished_params_arr, polished_params_err, curve_polished = refl_fit(
526
+ q = q,
527
+ curve = curve,
528
+ init_params = params,
529
+ bounds=priors.T,
530
+ prior_sampler=self.trainer.loader.prior_sampler,
531
+ error_bars=error_bars,
532
+ method=polishing_method,
533
+ polishing_max_steps=polishing_max_steps,
534
+ reflectivity_kwargs=polishing_kwargs_reflectivity,
535
+ )
536
+ polished_params = BasicParams(
537
+ torch.from_numpy(polished_params_arr[None]),
538
+ torch.from_numpy(priors.T[0][None]),
539
+ torch.from_numpy(priors.T[1][None]),
540
+ self.trainer.loader.prior_sampler.max_num_layers,
541
+ self.trainer.loader.prior_sampler.param_model
542
+ )
543
+ except Exception as err:
544
+ polished_params = predicted_params
545
+ polished_params_arr = get_prediction_array(polished_params)
546
+ curve_polished = np.zeros_like(q)
547
+ polished_params_err = None
548
+
549
+ polished_params_dict['polished_params_array'] = polished_params_arr
550
+
551
+ polished_params_dict['polished_params_error_array'] = (
552
+ np.array(polished_params_err)
553
+ if polished_params_err is not None
554
+ else np.full_like(polished_params, np.nan, dtype=np.float64)
555
+ )
556
+ if calc_polished_curve:
557
+ polished_params_dict['polished_curve'] = curve_polished
558
+
559
+ if ambient_sld_tensor is not None:
560
+ ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
561
+
562
+
563
+ if calc_polished_sld_profile:
564
+ _, sld_profile_polished, _ = get_density_profiles(
565
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
566
+ z_axis=sld_x_axis.to(polished_params.slds.device),
567
+ )
568
+ polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
569
+
570
+ return polished_params_dict
571
+
572
+ def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
573
+ if not isinstance(curve, Tensor):
574
+ curve = torch.from_numpy(curve).float()
575
+ curve = curve.unsqueeze(0).to(self.device)
576
+ scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
577
+ return scaled_curve
578
+
579
+ def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
580
+ try:
581
+ prior_bounds = torch.tensor(prior_bounds)
582
+ prior_bounds = prior_bounds.to(self.device).T
583
+ min_bounds, max_bounds = prior_bounds[:, None]
584
+
585
+ scaled_bounds = torch.cat([
586
+ self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
587
+ self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
588
+ ], -1)
589
+
590
+ return scaled_bounds.float()
591
+
592
+ except RuntimeError as e:
593
+ expected_param_dim = self.trainer.loader.prior_sampler.param_dim
594
+ actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
595
+
596
+ msg = (
597
+ f"\n **Parameter dimension mismatch during inference!**\n"
598
+ f"- Model expects **{expected_param_dim}** parameters.\n"
599
+ f"- You provided **{actual_param_dim}** prior bounds.\n\n"
600
+ f"💡This often occurs when:\n"
601
+ f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
602
+ f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
603
+ f"- The number of layers or parameterization type differs from the one used during training.\n\n"
604
+ f" Check the configuration or the summary of expected parameters."
605
+ )
606
+ raise ValueError(msg) from e
607
+
608
+ def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
609
+ n_layers = self.trainer.loader.prior_sampler.max_num_layers
610
+ sld_indices = slice(2*n_layers+1, 3*n_layers+2)
611
+ prior_bounds[sld_indices, ...] -= ambient_sld
612
+
613
+ training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
614
+ training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
615
+ lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
616
+ upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
617
+ assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
618
+
619
+ return sld_indices
620
+
621
+ def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
622
+ prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
623
+ if "polished_params_array" in prediction_dict:
624
+ prediction_dict["polished_params_array"][sld_indices] += ambient_sld
625
+
626
+ def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
627
+ return LogLikelihood(
628
+ q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
629
+ )
630
+
631
+ def get_param_labels(self, **kwargs):
632
+ return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
633
+
634
+ @staticmethod
635
+ def _preprocess_input_data(
636
+ reflectivity_curve,
637
+ q_values,
638
+ sigmas=None,
639
+ q_resolution=None,
640
+ truncate_index_left=None,
641
+ truncate_index_right=None,
642
+ enable_error_bars_filtering=True,
643
+ filter_threshold=0.3,
644
+ filter_remove_singles=True,
645
+ filter_remove_consecutives=True,
646
+ filter_consecutive=3,
647
+ filter_q_start_trunc=0.1):
648
+
649
+ # Save originals for polishing
650
+ reflectivity_curve_original = reflectivity_curve.copy()
651
+ q_values_original = q_values.copy() if q_values is not None else None
652
+ q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
653
+ sigmas_original = sigmas.copy() if sigmas is not None else None
654
+
655
+ # Remove points with non-positive intensities
656
+ nonnegative_mask = reflectivity_curve > 0.0
657
+ reflectivity_curve = reflectivity_curve[nonnegative_mask]
658
+ q_values = q_values[nonnegative_mask]
659
+ if sigmas is not None:
660
+ sigmas = sigmas[nonnegative_mask]
661
+ if isinstance(q_resolution, np.ndarray):
662
+ q_resolution = q_resolution[nonnegative_mask]
663
+
664
+ # Truncate arrays
665
+ if truncate_index_left is not None or truncate_index_right is not None:
666
+ slice_obj = slice(truncate_index_left, truncate_index_right)
667
+ reflectivity_curve = reflectivity_curve[slice_obj]
668
+ q_values = q_values[slice_obj]
669
+ if sigmas is not None:
670
+ sigmas = sigmas[slice_obj]
671
+ if isinstance(q_resolution, np.ndarray):
672
+ q_resolution = q_resolution[slice_obj]
673
+
674
+ # Filter high-error points
675
+ if enable_error_bars_filtering and sigmas is not None:
676
+ valid_mask = get_filtering_mask(
677
+ q_values,
678
+ reflectivity_curve,
679
+ sigmas,
680
+ threshold=filter_threshold,
681
+ consecutive=filter_consecutive,
682
+ remove_singles=filter_remove_singles,
683
+ remove_consecutives=filter_remove_consecutives,
684
+ q_start_trunc=filter_q_start_trunc
685
+ )
686
+ reflectivity_curve = reflectivity_curve[valid_mask]
687
+ q_values = q_values[valid_mask]
688
+ sigmas = sigmas[valid_mask]
689
+ if isinstance(q_resolution, np.ndarray):
690
+ q_resolution = q_resolution[valid_mask]
691
+
692
+ return (q_values, reflectivity_curve, sigmas, q_resolution,
693
+ q_values_original, reflectivity_curve_original,
694
+ sigmas_original, q_resolution_original)
695
+
696
+ def interpolate_data_to_model_q(
697
+ self,
698
+ q_exp,
699
+ refl_exp,
700
+ sigmas_exp=None,
701
+ q_res_exp=None,
702
+ as_dict=False
703
+ ):
704
+ q_generator = self.trainer.loader.q_generator
705
+
706
+ def _pad(arr, pad_to, value=0.0):
707
+ if arr is None:
708
+ return None
709
+ return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
710
+
711
+ def _interp_or_keep(q_model, q_exp, arr):
712
+ """Interpolate arrays, keep floats or None unchanged."""
713
+ if arr is None:
714
+ return None
715
+ return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
716
+
717
+ def _pad_or_keep(arr, max_n):
718
+ """Pad arrays, keep floats or None unchanged."""
719
+ if arr is None:
720
+ return None
721
+ return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
722
+
723
+ def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
724
+ if as_dict:
725
+ result = {"q_model": q, "reflectivity": refl}
726
+ if sigmas is not None: result["sigmas"] = sigmas
727
+ if q_res is not None: result["q_resolution"] = q_res
728
+ if mask is not None: result["key_padding_mask"] = mask
729
+ return result
730
+ result = [q, refl]
731
+ if sigmas is not None: result.append(sigmas)
732
+ if q_res is not None: result.append(q_res)
733
+ if mask is not None: result.append(mask)
734
+ return tuple(result)
735
+
736
+ # ConstantQ
737
+ if isinstance(q_generator, ConstantQ):
738
+ q_model = q_generator.q.cpu().numpy()
739
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
740
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
741
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
742
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
743
+
744
+ # VariableQ
745
+ elif isinstance(q_generator, VariableQ):
746
+ if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
747
+ n_q_model = q_generator.n_q_range[0]
748
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
749
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
750
+ if self.trainer.loader.q_generator.mode == 'logspace':
751
+ q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
752
+ end=torch.log10(torch.tensor(q_max, device=self.device)),
753
+ steps=n_q_model, device=self.device).to('cpu')
754
+ logspace = True
755
+ else:
756
+ q_model = np.linspace(q_min, q_max, n_q_model)
757
+ logspace = False
758
+ else:
759
+ return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
760
+
761
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
762
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
763
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
764
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
765
+
766
+ # MaskedVariableQ
767
+ elif isinstance(q_generator, MaskedVariableQ):
768
+ min_n, max_n = q_generator.n_q_range
769
+ n_exp = len(q_exp)
770
+
771
+ if min_n <= n_exp <= max_n:
772
+ # Pad only
773
+ q_model = _pad(q_exp, max_n, 0.0)
774
+ refl_out = _pad(refl_exp, max_n, 0.0)
775
+ sigmas_out = _pad_or_keep(sigmas_exp, max_n)
776
+ q_res_out = _pad_or_keep(q_res_exp, max_n)
777
+ key_padding_mask = np.zeros(max_n, dtype=bool)
778
+ key_padding_mask[:n_exp] = True
779
+
780
+ else:
781
+ # Interpolate + pad
782
+ n_interp = min(max(n_exp, min_n), max_n)
783
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
784
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
785
+ q_interp = np.linspace(q_min, q_max, n_interp)
786
+
787
+ refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
788
+ sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
789
+ q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
790
+
791
+ q_model = _pad(q_interp, max_n, 0.0)
792
+ refl_out = _pad(refl_interp, max_n, 0.0)
793
+ sigmas_out = _pad_or_keep(sigmas_interp, max_n)
794
+ q_res_out = _pad_or_keep(q_res_interp, max_n)
795
+ key_padding_mask = np.zeros(max_n, dtype=bool)
796
+ key_padding_mask[:n_interp] = True
797
+
798
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
799
+
800
+ else:
801
+ raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
802
+
803
+ def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
804
+ assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
805
+ q = self.trainer.loader.q_generator.q.squeeze().float()
806
+ dq_max = (q[1] - q[0]) * dq_coef
807
+ q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
808
+
809
+ curve = to_t(curve).to(scaled_bounds)
810
+ shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
811
+
812
+ assert shifted_curves.shape == (num, q.shape[0])
813
+
814
+ scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
815
+ scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
816
+
817
+ with torch.no_grad():
818
+ self.trainer.model.eval()
819
+ scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
820
+ restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
821
+
822
+ best_param = get_best_mse_param(
823
+ restored_params,
824
+ self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
825
+ )
826
+ return best_param
827
+
828
+
829
+
830
+ EasyInferenceModel = InferenceModel
831
+
832
+ def get_prediction_array(params: BasicParams) -> np.ndarray:
833
+ predict_arr = torch.cat([
834
+ params.thicknesses.squeeze(),
835
+ params.roughnesses.squeeze(),
836
+ params.slds.squeeze(),
837
+ ]).cpu().numpy()
838
+
839
+ return predict_arr
840
+
841
+
842
+ def _qshift_interp(q, r, q_shifts):
843
+ qs = q[None] + q_shifts[:, None]
844
+ eps = torch.finfo(r.dtype).eps
845
+ ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
846
+ ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
847
+ slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
852
848
  return r[ind] + slopes[ind] * (qs - q[ind])