reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,734 @@
1
+ import asyncio
2
+ import logging
3
+ from pathlib import Path
4
+ import time
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+ from typing import List, Tuple, Union
10
+ import ipywidgets as widgets
11
+ from IPython.display import display
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
15
+ from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
16
+ from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
17
+ from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
18
+ from reflectorch.runs.utils import (
19
+ get_trainer_by_name, train_from_config
20
+ )
21
+ from reflectorch.runs.config import load_config
22
+ from reflectorch.ml.trainers import PointEstimatorTrainer
23
+ from reflectorch.data_generation.likelihoods import LogLikelihood
24
+
25
+ from reflectorch.inference.preprocess_exp import StandardPreprocessing
26
+ from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
27
+ from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
28
+ from reflectorch.inference.record_time import print_time
29
+ from reflectorch.utils import to_t
30
+
31
+ class EasyInferenceModel(object):
32
+ """Facilitates the inference process using pretrained models
33
+
34
+ Args:
35
+ config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
36
+ model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
37
+ root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
38
+ repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
39
+ trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
40
+ device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
41
+ """
42
+ def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, repo_id: str = 'valentinsingularity/reflectivity',
43
+ trainer: PointEstimatorTrainer = None, device='cuda'):
44
+ self.config_name = config_name
45
+ self.model_name = model_name
46
+ self.root_dir = root_dir
47
+ self.repo_id = repo_id
48
+ self.trainer = trainer
49
+ self.device = device
50
+
51
+ if trainer is None and self.config_name is not None:
52
+ self.load_model(self.config_name, self.model_name, self.root_dir)
53
+
54
+ self.prediction_result = None
55
+
56
+ def load_model(self, config_name: str, model_name: str, root_dir: str) -> None:
57
+ """Loads a model for inference
58
+
59
+ Args:
60
+ config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
61
+ model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`.
62
+ root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
63
+ """
64
+ if self.config_name == config_name and self.trainer is not None:
65
+ return
66
+
67
+ if not config_name.endswith('.yaml'):
68
+ config_name_no_extension = config_name
69
+ self.config_name = config_name_no_extension + '.yaml'
70
+ else:
71
+ config_name_no_extension = config_name[:-5]
72
+ self.config_name = config_name
73
+
74
+ self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
75
+ self.model_name = model_name or 'model_' + config_name_no_extension + '.pt'
76
+ if not self.model_name.endswith('.pt'):
77
+ self.model_name += '.pt'
78
+ self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
79
+
80
+ config_path = Path(self.config_dir) / self.config_name
81
+ if config_path.exists():
82
+ print(f"Configuration file `{config_path}` found locally.")
83
+ else:
84
+ print(f"Configuration file `{config_path}` not found locally.")
85
+ if self.repo_id is None:
86
+ raise ValueError("repo_id must be provided to download files from Huggingface.")
87
+ print("Downloading from Huggingface...")
88
+ hf_hub_download(repo_id=self.repo_id, subfolder='configs', filename=self.config_name, local_dir=config_path.parents[1])
89
+
90
+ model_path = Path(self.model_dir) / self.model_name
91
+ if model_path.exists():
92
+ print(f"Weights file `{model_path}` found locally.")
93
+ else:
94
+ print(f"Weights file `{model_path}` not found locally.")
95
+ if self.repo_id is None:
96
+ raise ValueError("repo_id must be provided to download files from Huggingface.")
97
+ print("Downloading from Huggingface...")
98
+ hf_hub_download(repo_id=self.repo_id, subfolder='saved_models', filename=self.model_name, local_dir=model_path.parents[1])
99
+
100
+ self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
101
+ self.trainer.model.eval()
102
+
103
+ print(f'The model corresponds to a `{self.trainer.loader.prior_sampler.param_model.NAME}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
104
+ print("Parameter types and total ranges:")
105
+ for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
106
+ print(f"- {param}: {range_}")
107
+ print("Allowed widths of the prior bound intervals (max-min):")
108
+ for param, range_ in self.trainer.loader.prior_sampler.bound_width_ranges.items():
109
+ print(f"- {param}: {range_}")
110
+
111
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
112
+ q_min = self.trainer.loader.q_generator.q[0].item()
113
+ q_max = self.trainer.loader.q_generator.q[-1].item()
114
+ n_q = self.trainer.loader.q_generator.q.shape[0]
115
+ print(f'The model was trained on curves discretized at {n_q} uniform points between between q_min={q_min} and q_max={q_max}')
116
+ elif isinstance(self.trainer.loader.q_generator, VariableQ):
117
+ q_min_range = self.trainer.loader.q_generator.q_min_range
118
+ q_max_range = self.trainer.loader.q_generator.q_max_range
119
+ n_q_range = self.trainer.loader.q_generator.n_q_range
120
+ print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} of uniform points between between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
121
+
122
+ def predict(self, reflectivity_curve: Union[np.ndarray, Tensor],
123
+ q_values: Union[np.ndarray, Tensor] = None,
124
+ prior_bounds: Union[np.ndarray, List[Tuple]] = None,
125
+ clip_prediction: bool = False,
126
+ polish_prediction: bool = False,
127
+ fit_growth: bool = False,
128
+ max_d_change: float = 5.,
129
+ use_q_shift: bool = False,
130
+ calc_pred_curve: bool = True,
131
+ calc_pred_sld_profile: bool = False,
132
+ ):
133
+ """Predict the thin film parameters
134
+
135
+ Args:
136
+ reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
137
+ q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
138
+ prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
139
+ clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
140
+ polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
141
+ fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
142
+ max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
143
+ use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
144
+ calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
145
+ calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
146
+
147
+ Returns:
148
+ dict: dictionary containing the predictions
149
+ """
150
+
151
+ scaled_curve = self._scale_curve(reflectivity_curve)
152
+ prior_bounds = np.array(prior_bounds)
153
+ scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
154
+
155
+ if not self.trainer.train_with_q_input:
156
+ q_values = self.trainer.loader.q_generator.q
157
+ else:
158
+ q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
159
+
160
+ if use_q_shift and not self.trainer.train_with_q_input:
161
+ predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
162
+
163
+ else:
164
+ with torch.no_grad():
165
+ self.trainer.model.eval()
166
+ if self.trainer.train_with_q_input:
167
+ scaled_q = self.trainer.loader.q_generator.scale_q(q_values).float()
168
+ scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds, scaled_q)
169
+ else:
170
+ scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds)
171
+
172
+ predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
173
+
174
+ if clip_prediction:
175
+ predicted_params = self.trainer.loader.prior_sampler.clamp_params(predicted_params)
176
+
177
+ prediction_dict = {
178
+ "predicted_params_object": predicted_params,
179
+ "predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
180
+ "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
181
+ }
182
+
183
+ if calc_pred_curve:
184
+ predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
185
+ prediction_dict[ "predicted_curve"] = predicted_curve
186
+
187
+ if calc_pred_sld_profile:
188
+ predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
189
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
190
+ )
191
+ prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
192
+ prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
193
+ else:
194
+ predicted_sld_xaxis = None
195
+
196
+ if polish_prediction: #only for standard box-model parameterization
197
+ polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
198
+ curve = reflectivity_curve,
199
+ predicted_params = predicted_params,
200
+ priors = np.array(prior_bounds),
201
+ fit_growth = fit_growth,
202
+ max_d_change = max_d_change,
203
+ calc_polished_curve = calc_pred_curve,
204
+ calc_polished_sld_profile = False,
205
+ sld_x_axis = predicted_sld_xaxis,
206
+ )
207
+ prediction_dict.update(polished_dict)
208
+
209
+ if fit_growth and "polished_params_array" in prediction_dict:
210
+ prediction_dict["param_names"].append("max_d_change")
211
+
212
+ return prediction_dict
213
+
214
+ async def predict_using_widget(self, reflectivity_curve: Union[np.ndarray, torch.Tensor], **kwargs):
215
+ """Use an interactive Python widget for specifying the prior bounds before the prediction (works only in a Jupyter notebook).
216
+ The other arguments are the same as for the ``predict`` method.
217
+ """
218
+
219
+ NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
220
+ param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
221
+ min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
222
+ max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
223
+ max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
224
+
225
+ print(f'Parameter ranges: {self.trainer.loader.prior_sampler.param_ranges}')
226
+ print(f'Allowed widths of the prior bound intervals (max-min): {self.trainer.loader.prior_sampler.bound_width_ranges}')
227
+ print(f'Please fill in the values of the minimum and maximum prior bound for each parameter and press the button!')
228
+
229
+ def create_interval_widgets(n):
230
+ intervals = []
231
+ for i in range(n):
232
+ interval_label = widgets.Label(value=f'{param_labels[i]}')
233
+ initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
234
+ slider = widgets.FloatRangeSlider(
235
+ value=[min_bounds[i], initial_max],
236
+ min=min_bounds[i],
237
+ max=max_bounds[i],
238
+ step=0.01,
239
+ layout=widgets.Layout(width='400px'),
240
+ style={'description_width': '60px'}
241
+ )
242
+
243
+ def validate_range(change, slider=slider, max_width=max_deltas[i]):
244
+ min_val, max_val = change['new']
245
+ if max_val - min_val > max_width:
246
+ if change['name'] == 'value':
247
+ if change['old'][0] != min_val:
248
+ max_val = min_val + max_width
249
+ else:
250
+ min_val = max_val - max_width
251
+ slider.value = [min_val, max_val]
252
+
253
+ slider.observe(validate_range, names='value')
254
+
255
+ interval_row = widgets.HBox([interval_label, slider])
256
+ intervals.append((slider, interval_row))
257
+ return intervals
258
+
259
+ interval_widgets = create_interval_widgets(NUM_INTERVALS)
260
+ interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
261
+ display(interval_box)
262
+
263
+ button = widgets.Button(description="Make prediction")
264
+ display(button)
265
+
266
+ prediction_result = None
267
+
268
+ def store_values(b, future):
269
+ print("Debug: Button clicked")
270
+ values = []
271
+ for slider, _ in interval_widgets:
272
+ values.append((slider.value[0], slider.value[1]))
273
+ array_values = np.array(values)
274
+
275
+ nonlocal prediction_result
276
+ prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
277
+ print(prediction_result["predicted_params_array"])
278
+
279
+ print("Prediction completed. Closing widget.")
280
+
281
+ for child in interval_box.children:
282
+ child.close()
283
+ button.close()
284
+
285
+ future.set_result(prediction_result)
286
+
287
+ button.on_click(store_values)
288
+
289
+
290
+ future = asyncio.Future()
291
+
292
+ button.on_click(lambda b: store_values(b, future))
293
+
294
+ return await future
295
+
296
+
297
+ def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
298
+ assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
299
+ q = self.trainer.loader.q_generator.q.squeeze().float()
300
+ dq_max = (q[1] - q[0]) * dq_coef
301
+ q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
302
+
303
+ curve = to_t(curve).to(scaled_bounds)
304
+ shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
305
+
306
+ assert shifted_curves.shape == (num, q.shape[0])
307
+
308
+ scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
309
+ scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
310
+
311
+ with torch.no_grad():
312
+ self.trainer.model.eval()
313
+ scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
314
+ restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
315
+
316
+ best_param = get_best_mse_param(
317
+ restored_params,
318
+ self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
319
+ )
320
+ return best_param
321
+
322
+ def _polish_prediction(self,
323
+ q: np.ndarray,
324
+ curve: np.ndarray,
325
+ predicted_params: BasicParams,
326
+ priors: np.ndarray,
327
+ sld_x_axis,
328
+ fit_growth: bool = False,
329
+ max_d_change: float = 5.,
330
+ calc_polished_curve: bool = True,
331
+ calc_polished_sld_profile: bool = False,
332
+ ) -> dict:
333
+ params = torch.cat([
334
+ predicted_params.thicknesses.squeeze(),
335
+ predicted_params.roughnesses.squeeze(),
336
+ predicted_params.slds.squeeze()
337
+ ]).cpu().numpy()
338
+
339
+ polished_params_dict = {}
340
+
341
+ try:
342
+ if fit_growth:
343
+ polished_params_arr, curve_polished = get_fit_with_growth(
344
+ q = q,
345
+ curve = curve,
346
+ init_params = params,
347
+ bounds = priors.T,
348
+ max_d_change = max_d_change,
349
+ )
350
+ polished_params = BasicParams(
351
+ torch.from_numpy(polished_params_arr[:-1][None]),
352
+ torch.from_numpy(priors.T[0][None]),
353
+ torch.from_numpy(priors.T[1][None]),
354
+ #self.trainer.loader.prior_sampler.param_model
355
+ )
356
+ else:
357
+ polished_params_arr, curve_polished = standard_refl_fit(
358
+ q = q,
359
+ curve = curve,
360
+ init_params = params,
361
+ bounds=priors.T)
362
+ polished_params = BasicParams(
363
+ torch.from_numpy(polished_params_arr[None]),
364
+ torch.from_numpy(priors.T[0][None]),
365
+ torch.from_numpy(priors.T[1][None]),
366
+ #self.trainer.loader.prior_sampler.param_model
367
+ )
368
+ except Exception as err:
369
+ polished_params = predicted_params
370
+ polished_params_arr = get_prediction_array(polished_params)
371
+ curve_polished = np.zeros_like(q)
372
+
373
+ polished_params_dict['polished_params_array'] = polished_params_arr
374
+ if calc_polished_curve:
375
+ polished_params_dict['polished_curve'] = curve_polished
376
+
377
+ if calc_polished_sld_profile:
378
+ _, sld_profile_polished, _ = get_density_profiles(
379
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
380
+ )
381
+ polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
382
+
383
+ return polished_params_dict
384
+
385
+ def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
386
+ if not isinstance(curve, Tensor):
387
+ curve = torch.from_numpy(curve).float()
388
+ curve = torch.atleast_2d(curve).to(self.device)
389
+ scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
390
+ return scaled_curve
391
+
392
+ def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
393
+ prior_bounds = torch.tensor(prior_bounds)
394
+ prior_bounds = prior_bounds.to(self.device).T
395
+ min_bounds, max_bounds = prior_bounds[:, None]
396
+
397
+ scaled_bounds = torch.cat([
398
+ self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
399
+ self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
400
+ ], -1)
401
+
402
+ return scaled_bounds.float()
403
+
404
+ def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
405
+ return LogLikelihood(
406
+ q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
407
+ )
408
+
409
+ class InferenceModel(object):
410
+ def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,
411
+ num_sampling: int = 2 ** 13):
412
+ self.log = logging.getLogger(__name__)
413
+ self.model_name = name
414
+ self.trainer = trainer
415
+ self.q = None
416
+ self.preprocessing = StandardPreprocessing(**(preprocessing_parameters or {}))
417
+ self._sampling_num = num_sampling
418
+
419
+ if trainer is None and self.model_name is not None:
420
+ self.load_model(self.model_name)
421
+ elif trainer is not None:
422
+ self._set_trainer(trainer, preprocessing_parameters)
423
+
424
+ ### API methods ###
425
+
426
+ def load_model(self, name: str) -> None:
427
+ self.log.debug(f"loading model {name}")
428
+ if self.model_name == name and self.trainer is not None:
429
+ return
430
+ self.model_name = name
431
+ self._set_trainer(get_trainer_by_name(name))
432
+ self.log.info(f"Model {name} is loaded.")
433
+
434
+ def train_model(self, name: str):
435
+ self.model_name = name
436
+ self.trainer = train_from_config(load_config(name))
437
+
438
+ def set_preprocessing_parameters(self, **kwargs) -> None:
439
+ self.preprocessing.set_parameters(**kwargs)
440
+
441
+ def preprocess(self,
442
+ intensity: np.ndarray,
443
+ scattering_angle: np.ndarray,
444
+ attenuation: np.ndarray,
445
+ update_params: bool = False,
446
+ **kwargs) -> dict:
447
+ if update_params:
448
+ self.preprocessing.set_parameters(**kwargs)
449
+ preprocessed_dict = self.preprocessing(intensity, scattering_angle, attenuation, **kwargs)
450
+ return preprocessed_dict
451
+
452
+ def predict(self,
453
+ intensity: np.ndarray,
454
+ scattering_angle: np.ndarray,
455
+ attenuation: np.ndarray,
456
+ priors: np.ndarray,
457
+ preprocessing_parameters: dict = None,
458
+ polish: bool = True,
459
+ use_sampler: bool = False,
460
+ use_q_shift: bool = True,
461
+ max_d_change: float = 5.,
462
+ fit_growth: bool = True,
463
+ ) -> dict:
464
+
465
+ with print_time("everything"):
466
+ with print_time("preprocess"):
467
+ preprocessed_dict = self.preprocess(
468
+ intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
469
+ )
470
+
471
+ preprocessed_curve = preprocessed_dict["curve_interp"]
472
+ raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
473
+ q_ratio = preprocessed_dict["q_ratio"]
474
+
475
+ with print_time("predict_from_preprocessed_curve"):
476
+ preprocessed_dict.update(self.predict_from_preprocessed_curve(
477
+ preprocessed_curve, priors, raw_curve=raw_curve, raw_q=raw_q, polish=polish, q_ratio=q_ratio,
478
+ use_sampler=use_sampler, use_q_shift=use_q_shift, max_d_change=max_d_change,
479
+ fit_growth=fit_growth,
480
+ ))
481
+
482
+ return preprocessed_dict
483
+
484
+ def predict_from_preprocessed_curve(self,
485
+ curve: np.ndarray,
486
+ priors: np.ndarray, *,
487
+ polish: bool = True,
488
+ raw_curve: np.ndarray = None,
489
+ raw_q: np.ndarray = None,
490
+ clip_prediction: bool = True,
491
+ q_ratio: float = 1.,
492
+ use_sampler: bool = False,
493
+ use_q_shift: bool = True,
494
+ max_d_change: float = 5.,
495
+ fit_growth: bool = True,
496
+ ) -> dict:
497
+
498
+ scaled_curve = self._scale_curve(curve)
499
+ scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
500
+
501
+ if not use_q_shift:
502
+ predicted_params: UniformSubPriorParams = self._simple_prediction(scaled_curve, scaled_bounds)
503
+ else:
504
+ predicted_params: UniformSubPriorParams = self._qshift_prediction(curve, scaled_bounds)
505
+
506
+ if use_sampler:
507
+ predicted_params: UniformSubPriorParams = self._sampler_solution(
508
+ curve, predicted_params,
509
+ )
510
+
511
+ if clip_prediction:
512
+ predicted_params = self._prior_sampler.clamp_params(predicted_params)
513
+
514
+ if raw_curve is None:
515
+ raw_curve = curve
516
+ if raw_q is None:
517
+ raw_q = self.q.squeeze().cpu().numpy()
518
+ raw_q_t = self.q
519
+ else:
520
+ raw_q_t = torch.from_numpy(raw_q).to(self.q)
521
+
522
+ if q_ratio != 1.:
523
+ predicted_params.scale_with_q(q_ratio)
524
+ raw_q = raw_q * q_ratio
525
+ raw_q_t = raw_q_t * q_ratio
526
+
527
+ prediction_dict = {
528
+ "params": get_prediction_array(predicted_params),
529
+ "param_names": get_param_labels(
530
+ predicted_params.max_layer_num,
531
+ thickness_name='d',
532
+ roughness_name='sigma',
533
+ sld_name='rho',
534
+ ),
535
+ "curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
536
+ }
537
+
538
+ sld_x_axis, sld_profile, _ = get_density_profiles(
539
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
540
+ )
541
+
542
+ prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
543
+ prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
544
+
545
+ if polish:
546
+ prediction_dict.update(self._polish_prediction(
547
+ raw_q, raw_curve, predicted_params, priors, sld_x_axis,
548
+ max_d_change=max_d_change, fit_growth=fit_growth,
549
+ ))
550
+
551
+ if fit_growth and "params_polished" in prediction_dict:
552
+ prediction_dict["param_names"].append("max_d_change")
553
+
554
+ return prediction_dict
555
+
556
+ ### some shortcut methods for data processing ###
557
+
558
+ def _simple_prediction(self, scaled_curve, scaled_bounds) -> UniformSubPriorParams:
559
+ context = torch.cat([scaled_curve, scaled_bounds], -1)
560
+
561
+ with torch.no_grad():
562
+ self.trainer.model.eval()
563
+ scaled_params = self.trainer.model(context)
564
+
565
+ predicted_params: UniformSubPriorParams = self._restore_predicted_params(scaled_params, context)
566
+ return predicted_params
567
+
568
+ @print_time
569
+ def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> UniformSubPriorParams:
570
+ q = self.q.squeeze().float()
571
+ curve = to_t(curve).to(q)
572
+ dq_max = (q[1] - q[0]) * dq_coef
573
+ q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
574
+ shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
575
+
576
+ assert shifted_curves.shape == (num, q.shape[0])
577
+
578
+ scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
579
+ context = torch.cat([scaled_curves, torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)], -1)
580
+
581
+ with torch.no_grad():
582
+ self.trainer.model.eval()
583
+ scaled_params = self.trainer.model(context)
584
+ restored_params = self._restore_predicted_params(scaled_params, context)
585
+
586
+ best_param = get_best_mse_param(
587
+ restored_params,
588
+ self._get_likelihood(curve),
589
+ )
590
+ return best_param
591
+
592
+ @print_time
593
+ def _polish_prediction(self,
594
+ q: np.ndarray,
595
+ curve: np.ndarray,
596
+ predicted_params: Params,
597
+ priors: np.ndarray,
598
+ sld_x_axis,
599
+ fit_growth: bool = True,
600
+ max_d_change: float = 5.,
601
+ ) -> dict:
602
+ params = torch.cat([
603
+ predicted_params.thicknesses.squeeze(),
604
+ predicted_params.roughnesses.squeeze(),
605
+ predicted_params.slds.squeeze()
606
+ ]).cpu().numpy()
607
+
608
+ polished_params_dict = {}
609
+
610
+ try:
611
+ if fit_growth:
612
+ polished_params_arr, curve_polished = get_fit_with_growth(
613
+ q, curve, params, bounds=priors.T,
614
+ max_d_change=max_d_change,
615
+ )
616
+ polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[:-1][None]).to(self.q))
617
+ else:
618
+ polished_params_arr, curve_polished = standard_refl_fit(q, curve, params, bounds=priors.T)
619
+ polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[None]).to(self.q))
620
+ except Exception as err:
621
+ self.log.exception(err)
622
+ polished_params = predicted_params
623
+ polished_params_arr = get_prediction_array(polished_params)
624
+ curve_polished = np.zeros_like(q)
625
+
626
+ polished_params_dict['params_polished'] = polished_params_arr
627
+ polished_params_dict['curve_polished'] = curve_polished
628
+
629
+ sld_x_axis_polished, sld_profile_polished, _ = get_density_profiles(
630
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
631
+ )
632
+
633
+ polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
634
+
635
+ return polished_params_dict
636
+
637
+ def _restore_predicted_params(self, scaled_params: Tensor, context: Tensor) -> UniformSubPriorParams:
638
+ predicted_params: UniformSubPriorParams = self.trainer.loader.prior_sampler.restore_params(
639
+ self.trainer.loader.prior_sampler.PARAM_CLS.restore_params_from_context(scaled_params, context)
640
+ )
641
+ return predicted_params
642
+
643
+ def _input2context(self, curve: np.ndarray, priors: np.ndarray, q_ratio: float = 1.):
644
+ scaled_curve = self._scale_curve(curve)
645
+ scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
646
+ scaled_input = torch.cat([scaled_curve, scaled_bounds], -1)
647
+ return scaled_input, min_bounds, max_bounds
648
+
649
+ def _scale_curve(self, curve: np.ndarray or Tensor):
650
+ if not isinstance(curve, Tensor):
651
+ curve = torch.from_numpy(curve).float()
652
+ curve = torch.atleast_2d(curve).to(self.q)
653
+ scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
654
+ return scaled_curve.float()
655
+
656
+ def _scale_priors(self, priors: np.ndarray or Tensor, q_ratio: float = 1.):
657
+ if not isinstance(priors, Tensor):
658
+ priors = torch.from_numpy(priors)
659
+
660
+ priors = priors.float().clone()
661
+
662
+ priors = priors.to(self.q).T
663
+ priors = self._prior_sampler.scale_bounds_with_q(priors, 1 / q_ratio)
664
+ priors = self._prior_sampler.clamp_bounds(priors)
665
+
666
+ min_bounds, max_bounds = priors[:, None].to(self.q)
667
+ prior_sampler = self._prior_sampler
668
+ scaled_bounds = torch.cat([
669
+ prior_sampler.scale_bounds(min_bounds), prior_sampler.scale_bounds(max_bounds)
670
+ ], -1)
671
+ return scaled_bounds.float(), min_bounds, max_bounds
672
+
673
+ @property
674
+ def _prior_sampler(self) -> ExpUniformSubPriorSampler:
675
+ return self.trainer.loader.prior_sampler
676
+
677
+ def _set_trainer(self, trainer, preprocessing_parameters: dict = None):
678
+ self.trainer = trainer
679
+ self.trainer.model.eval()
680
+ self._update_preprocessing(preprocessing_parameters)
681
+
682
+ def _update_preprocessing(self, preprocessing_parameters: dict = None):
683
+ self.log.debug(f"setting preprocessing_parameters {preprocessing_parameters}.")
684
+ self.q = self.trainer.loader.q_generator.q
685
+ self.preprocessing = StandardPreprocessing(
686
+ self.q.cpu().squeeze().numpy(),
687
+ **(preprocessing_parameters or {})
688
+ )
689
+ self.log.info(f"preprocessing params are set: {preprocessing_parameters}.")
690
+
691
+ @print_time
692
+ def _sampler_solution(
693
+ self,
694
+ curve: Tensor or np.ndarray,
695
+ predicted_params: UniformSubPriorParams,
696
+ ) -> UniformSubPriorParams:
697
+
698
+ if not isinstance(curve, Tensor):
699
+ curve = torch.from_numpy(curve).float()
700
+ curve = curve.to(self.q)
701
+
702
+ refined_params = simple_sampler_solution(
703
+ self._get_likelihood(curve),
704
+ predicted_params,
705
+ self._prior_sampler.min_bounds,
706
+ self._prior_sampler.max_bounds,
707
+ num=self._sampling_num, coef=0.1,
708
+ )
709
+
710
+ return refined_params
711
+
712
+ def _get_likelihood(self, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
713
+ return LogLikelihood(
714
+ self.q, curve, self._prior_sampler, curve * rel_err + abs_err
715
+ )
716
+
717
+
718
+ def get_prediction_array(params: BasicParams) -> np.ndarray:
719
+ predict_arr = torch.cat([
720
+ params.thicknesses.squeeze(),
721
+ params.roughnesses.squeeze(),
722
+ params.slds.squeeze(),
723
+ ]).cpu().numpy()
724
+
725
+ return predict_arr
726
+
727
+
728
+ def _qshift_interp(q, r, q_shifts):
729
+ qs = q[None] + q_shifts[:, None]
730
+ eps = torch.finfo(r.dtype).eps
731
+ ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
732
+ ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
733
+ slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
734
+ return r[ind] + slopes[ind] * (qs - q[ind])