reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/priors/parametric_models.py +1 -1
- reflectorch/data_generation/q_generator.py +70 -36
- reflectorch/data_generation/utils.py +1 -0
- reflectorch/inference/inference_model.py +711 -188
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +505 -86
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +19 -5
- reflectorch/ml/trainers.py +9 -0
- reflectorch/models/__init__.py +1 -0
- reflectorch/models/encoders/__init__.py +2 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/mlp_networks.py +10 -4
- reflectorch/runs/utils.py +5 -2
- reflectorch/utils.py +30 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/METADATA +3 -2
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/RECORD +21 -19
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/licenses/LICENSE.txt +0 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -13,9 +13,8 @@ from huggingface_hub import hf_hub_download
|
|
|
13
13
|
|
|
14
14
|
from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
|
|
15
15
|
from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
|
|
16
|
-
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
|
|
16
|
+
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
|
|
17
17
|
from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
|
|
18
|
-
from reflectorch.inference.plotting import plot_prediction_results
|
|
19
18
|
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
20
19
|
from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
|
|
21
20
|
from reflectorch.runs.utils import (
|
|
@@ -29,7 +28,8 @@ from reflectorch.inference.preprocess_exp import StandardPreprocessing
|
|
|
29
28
|
from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
|
|
30
29
|
from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
|
|
31
30
|
from reflectorch.inference.record_time import print_time
|
|
32
|
-
from reflectorch.
|
|
31
|
+
from reflectorch.inference.plotting import plot_reflectivity, plot_prediction_results, print_prediction_results
|
|
32
|
+
from reflectorch.utils import get_filtering_mask, to_t
|
|
33
33
|
|
|
34
34
|
class EasyInferenceModel(object):
|
|
35
35
|
"""Facilitates the inference process using pretrained models
|
|
@@ -150,21 +150,158 @@ class EasyInferenceModel(object):
|
|
|
150
150
|
inputs_str = ", ".join(additional_inputs)
|
|
151
151
|
print(f"The following quantities are additional inputs to the network: {inputs_str}.")
|
|
152
152
|
|
|
153
|
+
def preprocess_and_predict(self,
|
|
154
|
+
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
155
|
+
q_values: Union[np.ndarray, Tensor] = None,
|
|
156
|
+
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
157
|
+
sigmas: Union[np.ndarray, Tensor] = None,
|
|
158
|
+
q_resolution: float = None,
|
|
159
|
+
ambient_sld: float = None,
|
|
160
|
+
clip_prediction: bool = True,
|
|
161
|
+
polish_prediction: bool = False,
|
|
162
|
+
polishing_method: str = 'trf',
|
|
163
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
164
|
+
use_sigmas_for_polishing: bool = False,
|
|
165
|
+
polishing_max_nfev: int = None,
|
|
166
|
+
fit_growth: bool = False,
|
|
167
|
+
max_d_change: float = 5.,
|
|
168
|
+
calc_pred_curve: bool = True,
|
|
169
|
+
calc_pred_sld_profile: bool = False,
|
|
170
|
+
calc_polished_sld_profile: bool = False,
|
|
171
|
+
sld_profile_padding_left: float = 0.2,
|
|
172
|
+
sld_profile_padding_right: float = 1.1,
|
|
173
|
+
kwargs_param_labels: dict = {},
|
|
174
|
+
|
|
175
|
+
truncate_index_left: int = None,
|
|
176
|
+
truncate_index_right: int = None,
|
|
177
|
+
enable_error_bars_filtering: bool = True,
|
|
178
|
+
filter_threshold=0.3,
|
|
179
|
+
filter_remove_singles=True,
|
|
180
|
+
filter_remove_consecutives=True,
|
|
181
|
+
filter_consecutive=3,
|
|
182
|
+
filter_q_start_trunc=0.1,
|
|
183
|
+
):
|
|
184
|
+
|
|
185
|
+
## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
|
|
186
|
+
(q_values, reflectivity_curve, sigmas, q_resolution,
|
|
187
|
+
q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
|
|
188
|
+
reflectivity_curve=reflectivity_curve,
|
|
189
|
+
q_values=q_values,
|
|
190
|
+
sigmas=sigmas,
|
|
191
|
+
q_resolution=q_resolution,
|
|
192
|
+
truncate_index_left=truncate_index_left,
|
|
193
|
+
truncate_index_right=truncate_index_right,
|
|
194
|
+
enable_error_bars_filtering=enable_error_bars_filtering,
|
|
195
|
+
filter_threshold=filter_threshold,
|
|
196
|
+
filter_remove_singles=filter_remove_singles,
|
|
197
|
+
filter_remove_consecutives=filter_remove_consecutives,
|
|
198
|
+
filter_consecutive=filter_consecutive,
|
|
199
|
+
filter_q_start_trunc=filter_q_start_trunc,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
### Interpolate the experimental data if needed by the embedding network
|
|
203
|
+
interp_data = self.interpolate_data_to_model_q(
|
|
204
|
+
q_exp=q_values,
|
|
205
|
+
refl_exp=reflectivity_curve,
|
|
206
|
+
sigmas_exp=sigmas,
|
|
207
|
+
q_res_exp=q_resolution,
|
|
208
|
+
as_dict=True
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
q_model = interp_data["q_model"]
|
|
212
|
+
reflectivity_curve_interp = interp_data["reflectivity"]
|
|
213
|
+
sigmas_interp = interp_data.get("sigmas")
|
|
214
|
+
q_resolution_interp = interp_data.get("q_resolution")
|
|
215
|
+
key_padding_mask = interp_data.get("key_padding_mask")
|
|
216
|
+
|
|
217
|
+
### Make the prediction
|
|
218
|
+
prediction_dict = self.predict(
|
|
219
|
+
reflectivity_curve=reflectivity_curve_interp,
|
|
220
|
+
q_values=q_model,
|
|
221
|
+
sigmas=sigmas_interp,
|
|
222
|
+
q_resolution=q_resolution_interp,
|
|
223
|
+
key_padding_mask=key_padding_mask,
|
|
224
|
+
prior_bounds=prior_bounds,
|
|
225
|
+
ambient_sld=ambient_sld,
|
|
226
|
+
clip_prediction=clip_prediction,
|
|
227
|
+
polish_prediction=False, ###do the polishing outside the predict method on the full data
|
|
228
|
+
supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
|
|
229
|
+
calc_pred_curve=calc_pred_curve,
|
|
230
|
+
calc_pred_sld_profile=calc_pred_sld_profile,
|
|
231
|
+
sld_profile_padding_left=sld_profile_padding_left,
|
|
232
|
+
sld_profile_padding_right=sld_profile_padding_right,
|
|
233
|
+
kwargs_param_labels=kwargs_param_labels,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
### Save interpolated data
|
|
237
|
+
prediction_dict['q_model'] = q_model
|
|
238
|
+
prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
|
|
239
|
+
if q_resolution_interp is not None:
|
|
240
|
+
prediction_dict['q_resolution_interp'] = q_resolution_interp
|
|
241
|
+
if sigmas_interp is not None:
|
|
242
|
+
prediction_dict['sigmas_interp'] = sigmas_interp
|
|
243
|
+
if key_padding_mask is not None:
|
|
244
|
+
prediction_dict['key_padding_mask'] = key_padding_mask
|
|
245
|
+
|
|
246
|
+
### Perform polishing on the original data
|
|
247
|
+
if polish_prediction:
|
|
248
|
+
polishing_kwargs = polishing_kwargs_reflectivity or {}
|
|
249
|
+
polishing_kwargs.setdefault('dq', q_resolution_original)
|
|
250
|
+
|
|
251
|
+
prior_bounds = np.array(prior_bounds)
|
|
252
|
+
if ambient_sld:
|
|
253
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
254
|
+
|
|
255
|
+
polished_dict = self._polish_prediction(
|
|
256
|
+
q=q_values_original,
|
|
257
|
+
curve=reflectivity_curve_original,
|
|
258
|
+
predicted_params=prediction_dict['predicted_params_object'],
|
|
259
|
+
priors=prior_bounds,
|
|
260
|
+
ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)).to(self.device) if ambient_sld is not None else None,
|
|
261
|
+
calc_polished_sld_profile=calc_polished_sld_profile,
|
|
262
|
+
sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
|
|
263
|
+
polishing_kwargs_reflectivity = polishing_kwargs,
|
|
264
|
+
error_bars=sigmas_original if use_sigmas_for_polishing else None,
|
|
265
|
+
polishing_method=polishing_method,
|
|
266
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
267
|
+
fit_growth=fit_growth,
|
|
268
|
+
max_d_change=max_d_change,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
prediction_dict.update(polished_dict)
|
|
272
|
+
if fit_growth and "polished_params_array" in prediction_dict:
|
|
273
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
274
|
+
|
|
275
|
+
### Shift back the slds for nonzero ambient
|
|
276
|
+
if ambient_sld:
|
|
277
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
278
|
+
|
|
279
|
+
return prediction_dict
|
|
280
|
+
|
|
281
|
+
|
|
153
282
|
def predict(self,
|
|
154
283
|
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
155
284
|
q_values: Union[np.ndarray, Tensor] = None,
|
|
156
285
|
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
157
|
-
|
|
286
|
+
sigmas: Union[np.ndarray, Tensor] = None,
|
|
287
|
+
key_padding_mask: Union[np.ndarray, Tensor] = None,
|
|
288
|
+
q_resolution: float = None,
|
|
158
289
|
ambient_sld: float = None,
|
|
159
|
-
clip_prediction: bool =
|
|
160
|
-
polish_prediction: bool = False,
|
|
290
|
+
clip_prediction: bool = True,
|
|
291
|
+
polish_prediction: bool = False,
|
|
292
|
+
polishing_method: str = 'trf',
|
|
161
293
|
polishing_kwargs_reflectivity: dict = None,
|
|
294
|
+
polishing_max_nfev: int = None,
|
|
162
295
|
fit_growth: bool = False,
|
|
163
296
|
max_d_change: float = 5.,
|
|
164
297
|
use_q_shift: bool = False,
|
|
165
298
|
calc_pred_curve: bool = True,
|
|
166
299
|
calc_pred_sld_profile: bool = False,
|
|
167
300
|
calc_polished_sld_profile: bool = False,
|
|
301
|
+
sld_profile_padding_left: float = 0.2,
|
|
302
|
+
sld_profile_padding_right: float = 1.1,
|
|
303
|
+
supress_sld_amb_back_shift: bool = False,
|
|
304
|
+
kwargs_param_labels: dict = {},
|
|
168
305
|
):
|
|
169
306
|
"""Predict the thin film parameters
|
|
170
307
|
|
|
@@ -172,17 +309,13 @@ class EasyInferenceModel(object):
|
|
|
172
309
|
reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
|
|
173
310
|
q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
|
|
174
311
|
prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
|
|
175
|
-
q_resolution (Union[float, np.ndarray], optional): the instrumental resolution. Either as a float with meaning dq/q for linear smearing or as a numpy array with meaning dq for pointwise smearing.
|
|
176
|
-
ambient_sld (float, optional): the SLD of the ambient medium (fronting), if different from air.
|
|
177
312
|
clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
|
|
178
313
|
polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
|
|
179
|
-
polishing_kwargs_reflectivity (dict): extra arguments for the reflectivity function used during polishing.
|
|
180
314
|
fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
|
|
181
315
|
max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
|
|
182
316
|
use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
|
|
183
317
|
calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
|
|
184
318
|
calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
|
|
185
|
-
calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
|
|
186
319
|
|
|
187
320
|
Returns:
|
|
188
321
|
dict: dictionary containing the predictions
|
|
@@ -192,54 +325,48 @@ class EasyInferenceModel(object):
|
|
|
192
325
|
prior_bounds = np.array(prior_bounds)
|
|
193
326
|
|
|
194
327
|
if ambient_sld:
|
|
195
|
-
|
|
196
|
-
sld_indices = slice(2*n_layers+1, 3*n_layers+2)
|
|
197
|
-
prior_bounds[sld_indices, ...] -= ambient_sld
|
|
198
|
-
training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
|
|
199
|
-
training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
|
|
200
|
-
lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
|
|
201
|
-
upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
|
|
202
|
-
assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
|
|
328
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
203
329
|
|
|
204
|
-
|
|
205
|
-
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
206
|
-
except ValueError as e:
|
|
207
|
-
print(str(e))
|
|
208
|
-
return None
|
|
330
|
+
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
209
331
|
|
|
210
|
-
if
|
|
332
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
211
333
|
q_values = self.trainer.loader.q_generator.q
|
|
212
334
|
else:
|
|
213
335
|
q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
|
|
214
336
|
|
|
337
|
+
scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
|
|
338
|
+
|
|
339
|
+
if q_resolution is not None:
|
|
340
|
+
q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
|
|
341
|
+
if isinstance(q_resolution, float):
|
|
342
|
+
unscaled_q_resolutions = q_resolution_tensor
|
|
343
|
+
else:
|
|
344
|
+
unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
|
|
345
|
+
scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
|
|
346
|
+
scaled_conditioning_params = scaled_q_resolutions
|
|
347
|
+
if polishing_kwargs_reflectivity is None:
|
|
348
|
+
polishing_kwargs_reflectivity = {'dq': q_resolution}
|
|
349
|
+
else:
|
|
350
|
+
q_resolution_tensor = None
|
|
351
|
+
scaled_conditioning_params = None
|
|
352
|
+
|
|
353
|
+
if key_padding_mask is not None:
|
|
354
|
+
key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
|
|
355
|
+
key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
|
|
356
|
+
|
|
215
357
|
if use_q_shift and not self.trainer.train_with_q_input:
|
|
216
358
|
predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
|
|
217
|
-
|
|
218
359
|
else:
|
|
219
360
|
with torch.no_grad():
|
|
220
361
|
self.trainer.model.eval()
|
|
221
|
-
|
|
222
|
-
scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
|
|
223
|
-
|
|
224
|
-
if q_resolution is not None:
|
|
225
|
-
q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
|
|
226
|
-
if isinstance(q_resolution, float):
|
|
227
|
-
unscaled_q_resolutions = q_resolution_tensor
|
|
228
|
-
else:
|
|
229
|
-
unscaled_q_resolutions = (q_resolution_tensor / q_values).mean(dim=-1, keepdim=True)
|
|
230
|
-
scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
|
|
231
|
-
scaled_conditioning_params = scaled_q_resolutions
|
|
232
|
-
if polishing_kwargs_reflectivity is None:
|
|
233
|
-
polishing_kwargs_reflectivity = {'dq': q_resolution}
|
|
234
|
-
else:
|
|
235
|
-
q_resolution_tensor = None
|
|
236
|
-
scaled_conditioning_params = None
|
|
237
|
-
|
|
362
|
+
|
|
238
363
|
scaled_predicted_params = self.trainer.model(
|
|
239
364
|
curves=scaled_curve,
|
|
240
365
|
bounds=scaled_prior_bounds,
|
|
241
366
|
q_values=scaled_q_values,
|
|
242
367
|
conditioning_params = scaled_conditioning_params,
|
|
368
|
+
key_padding_mask = key_padding_mask,
|
|
369
|
+
unscaled_q_values = q_values,
|
|
243
370
|
)
|
|
244
371
|
|
|
245
372
|
predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
@@ -250,155 +377,60 @@ class EasyInferenceModel(object):
|
|
|
250
377
|
prediction_dict = {
|
|
251
378
|
"predicted_params_object": predicted_params,
|
|
252
379
|
"predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
|
|
253
|
-
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
380
|
+
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
|
|
254
381
|
}
|
|
382
|
+
|
|
383
|
+
key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
|
|
255
384
|
|
|
256
385
|
if calc_pred_curve:
|
|
257
386
|
predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
|
|
258
|
-
prediction_dict[ "predicted_curve"] = predicted_curve
|
|
387
|
+
prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
|
|
259
388
|
|
|
260
|
-
ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld
|
|
389
|
+
ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
|
|
261
390
|
if calc_pred_sld_profile:
|
|
262
391
|
predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
|
|
263
|
-
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds
|
|
392
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
393
|
+
num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
|
|
264
394
|
)
|
|
265
395
|
prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
|
|
266
396
|
prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
|
|
267
397
|
else:
|
|
268
398
|
predicted_sld_xaxis = None
|
|
399
|
+
|
|
400
|
+
refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
|
|
401
|
+
q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
|
|
402
|
+
prediction_dict['q_plot_pred'] = q_polish
|
|
269
403
|
|
|
270
404
|
if polish_prediction:
|
|
271
405
|
if ambient_sld_tensor:
|
|
272
406
|
ambient_sld_tensor = ambient_sld_tensor.cpu()
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
407
|
+
|
|
408
|
+
polished_dict = self._polish_prediction(
|
|
409
|
+
q = q_polish,
|
|
410
|
+
curve = refl_curve_polish,
|
|
411
|
+
predicted_params = predicted_params,
|
|
412
|
+
priors = np.array(prior_bounds),
|
|
413
|
+
error_bars = sigmas,
|
|
414
|
+
fit_growth = fit_growth,
|
|
415
|
+
max_d_change = max_d_change,
|
|
416
|
+
calc_polished_curve = calc_pred_curve,
|
|
417
|
+
calc_polished_sld_profile = calc_polished_sld_profile,
|
|
418
|
+
ambient_sld_tensor=ambient_sld_tensor,
|
|
419
|
+
sld_x_axis = predicted_sld_xaxis,
|
|
420
|
+
polishing_method=polishing_method,
|
|
421
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
422
|
+
polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
|
|
423
|
+
)
|
|
285
424
|
prediction_dict.update(polished_dict)
|
|
286
425
|
|
|
287
426
|
if fit_growth and "polished_params_array" in prediction_dict:
|
|
288
427
|
prediction_dict["param_names"].append("max_d_change")
|
|
289
428
|
|
|
290
|
-
if ambient_sld: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object
|
|
291
|
-
prediction_dict
|
|
292
|
-
if "polished_params_array" in prediction_dict:
|
|
293
|
-
prediction_dict["polished_params_array"][sld_indices] += ambient_sld
|
|
429
|
+
if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
|
|
430
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
294
431
|
|
|
295
432
|
return prediction_dict
|
|
296
|
-
|
|
297
|
-
def predict_using_widget(self, reflectivity_curve, **kwargs):
|
|
298
|
-
"""
|
|
299
|
-
"""
|
|
300
|
-
|
|
301
|
-
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
302
|
-
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
303
|
-
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
304
|
-
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
305
|
-
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
306
|
-
|
|
307
|
-
print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
|
|
308
|
-
|
|
309
|
-
interval_widgets = []
|
|
310
|
-
for i in range(NUM_INTERVALS):
|
|
311
|
-
label = widgets.Label(value=f'{param_labels[i]}')
|
|
312
|
-
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
313
|
-
slider = widgets.FloatRangeSlider(
|
|
314
|
-
value=[min_bounds[i], initial_max],
|
|
315
|
-
min=min_bounds[i],
|
|
316
|
-
max=max_bounds[i],
|
|
317
|
-
step=0.01,
|
|
318
|
-
layout=widgets.Layout(width='400px'),
|
|
319
|
-
style={'description_width': '60px'}
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
323
|
-
min_val, max_val = change['new']
|
|
324
|
-
if max_val - min_val > max_width:
|
|
325
|
-
old_min_val, old_max_val = change['old']
|
|
326
|
-
if abs(old_min_val - min_val) > abs(old_max_val - max_val):
|
|
327
|
-
max_val = min_val + max_width
|
|
328
|
-
else:
|
|
329
|
-
min_val = max_val - max_width
|
|
330
|
-
slider.value = [min_val, max_val]
|
|
331
|
-
|
|
332
|
-
slider.observe(validate_range, names='value')
|
|
333
|
-
interval_widgets.append((slider, widgets.HBox([label, slider])))
|
|
334
|
-
|
|
335
|
-
sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
|
|
336
|
-
|
|
337
|
-
output = widgets.Output()
|
|
338
|
-
predict_button = widgets.Button(description="Predict")
|
|
339
|
-
close_button = widgets.Button(description="Close Widget")
|
|
340
|
-
|
|
341
|
-
container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
|
|
342
|
-
display(container)
|
|
343
|
-
|
|
344
|
-
@output.capture(clear_output=True)
|
|
345
|
-
def on_predict_click(_):
|
|
346
|
-
if 'prior_bounds' in kwargs:
|
|
347
|
-
array_values = kwargs.pop('prior_bounds')
|
|
348
|
-
for i, (s, _) in enumerate(interval_widgets):
|
|
349
|
-
s.value = tuple(array_values[i])
|
|
350
|
-
else:
|
|
351
|
-
values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
|
|
352
|
-
array_values = np.array(values)
|
|
353
|
-
|
|
354
|
-
prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
|
|
355
|
-
prior_bounds=array_values,
|
|
356
|
-
**kwargs)
|
|
357
|
-
param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
358
|
-
for param_name, pred_param_val in zip(param_names, prediction_result["predicted_params_array"]):
|
|
359
|
-
print(f'{param_name.ljust(14)} : {pred_param_val:.2f}')
|
|
360
|
-
|
|
361
|
-
plot_prediction_results(
|
|
362
|
-
prediction_result,
|
|
363
|
-
q_exp=kwargs['q_values'],
|
|
364
|
-
curve_exp=reflectivity_curve,
|
|
365
|
-
q_model=kwargs['q_values'],
|
|
366
|
-
)
|
|
367
|
-
self.prediction_result = prediction_result
|
|
368
|
-
|
|
369
|
-
def on_close_click(_):
|
|
370
|
-
container.close()
|
|
371
|
-
print("Widget closed.")
|
|
372
|
-
|
|
373
|
-
predict_button.on_click(on_predict_click)
|
|
374
|
-
close_button.on_click(on_close_click)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
378
|
-
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
379
|
-
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
380
|
-
dq_max = (q[1] - q[0]) * dq_coef
|
|
381
|
-
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
382
|
-
|
|
383
|
-
curve = to_t(curve).to(scaled_bounds)
|
|
384
|
-
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
385
|
-
|
|
386
|
-
assert shifted_curves.shape == (num, q.shape[0])
|
|
387
|
-
|
|
388
|
-
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
389
|
-
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
390
|
-
|
|
391
|
-
with torch.no_grad():
|
|
392
|
-
self.trainer.model.eval()
|
|
393
|
-
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
394
|
-
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
395
|
-
|
|
396
|
-
best_param = get_best_mse_param(
|
|
397
|
-
restored_params,
|
|
398
|
-
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
399
|
-
)
|
|
400
|
-
return best_param
|
|
401
|
-
|
|
433
|
+
|
|
402
434
|
def _polish_prediction(self,
|
|
403
435
|
q: np.ndarray,
|
|
404
436
|
curve: np.ndarray,
|
|
@@ -410,6 +442,9 @@ class EasyInferenceModel(object):
|
|
|
410
442
|
max_d_change: float = 5.,
|
|
411
443
|
calc_polished_curve: bool = True,
|
|
412
444
|
calc_polished_sld_profile: bool = False,
|
|
445
|
+
error_bars: np.ndarray = None,
|
|
446
|
+
polishing_method: str = 'trf',
|
|
447
|
+
polishing_max_nfev: int = None,
|
|
413
448
|
polishing_kwargs_reflectivity: dict = None,
|
|
414
449
|
) -> dict:
|
|
415
450
|
params = predicted_params.parameters.squeeze().cpu().numpy()
|
|
@@ -440,6 +475,9 @@ class EasyInferenceModel(object):
|
|
|
440
475
|
init_params = params,
|
|
441
476
|
bounds=priors.T,
|
|
442
477
|
prior_sampler=self.trainer.loader.prior_sampler,
|
|
478
|
+
error_bars=error_bars,
|
|
479
|
+
method=polishing_method,
|
|
480
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
443
481
|
reflectivity_kwargs=polishing_kwargs_reflectivity,
|
|
444
482
|
)
|
|
445
483
|
polished_params = BasicParams(
|
|
@@ -458,18 +496,23 @@ class EasyInferenceModel(object):
|
|
|
458
496
|
if calc_polished_curve:
|
|
459
497
|
polished_params_dict['polished_curve'] = curve_polished
|
|
460
498
|
|
|
499
|
+
if ambient_sld_tensor is not None:
|
|
500
|
+
ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
|
|
501
|
+
|
|
502
|
+
|
|
461
503
|
if calc_polished_sld_profile:
|
|
462
504
|
_, sld_profile_polished, _ = get_density_profiles(
|
|
463
|
-
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, ambient_sld_tensor,
|
|
505
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
506
|
+
z_axis=sld_x_axis.to(polished_params.slds.device),
|
|
464
507
|
)
|
|
465
|
-
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().numpy()
|
|
508
|
+
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
|
|
466
509
|
|
|
467
510
|
return polished_params_dict
|
|
468
511
|
|
|
469
512
|
def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
|
|
470
513
|
if not isinstance(curve, Tensor):
|
|
471
514
|
curve = torch.from_numpy(curve).float()
|
|
472
|
-
curve =
|
|
515
|
+
curve = curve.unsqueeze(0).to(self.device)
|
|
473
516
|
scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
|
|
474
517
|
return scaled_curve
|
|
475
518
|
|
|
@@ -500,29 +543,509 @@ class EasyInferenceModel(object):
|
|
|
500
543
|
f"- The number of layers or parameterization type differs from the one used during training.\n\n"
|
|
501
544
|
f" Check the configuration or the summary of expected parameters."
|
|
502
545
|
)
|
|
503
|
-
raise ValueError(msg) from e
|
|
546
|
+
raise ValueError(msg) from e
|
|
504
547
|
|
|
505
|
-
def
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
exp_curve_interp = interp_reflectivity(q_model, q_exp, curve_exp)
|
|
519
|
-
|
|
520
|
-
return q_model, exp_curve_interp
|
|
548
|
+
def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
|
|
549
|
+
n_layers = self.trainer.loader.prior_sampler.max_num_layers
|
|
550
|
+
sld_indices = slice(2*n_layers+1, 3*n_layers+2)
|
|
551
|
+
prior_bounds[sld_indices, ...] -= ambient_sld
|
|
552
|
+
|
|
553
|
+
training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
|
|
554
|
+
training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
|
|
555
|
+
lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
|
|
556
|
+
upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
|
|
557
|
+
assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
|
|
558
|
+
|
|
559
|
+
return sld_indices
|
|
521
560
|
|
|
561
|
+
def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
|
|
562
|
+
prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
|
|
563
|
+
if "polished_params_array" in prediction_dict:
|
|
564
|
+
prediction_dict["polished_params_array"][sld_indices] += ambient_sld
|
|
565
|
+
|
|
522
566
|
def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
523
567
|
return LogLikelihood(
|
|
524
568
|
q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
|
|
525
569
|
)
|
|
570
|
+
|
|
571
|
+
def get_param_labels(self, **kwargs):
|
|
572
|
+
return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
|
|
573
|
+
|
|
574
|
+
@staticmethod
|
|
575
|
+
def _preprocess_input_data(
|
|
576
|
+
reflectivity_curve,
|
|
577
|
+
q_values,
|
|
578
|
+
sigmas=None,
|
|
579
|
+
q_resolution=None,
|
|
580
|
+
truncate_index_left=None,
|
|
581
|
+
truncate_index_right=None,
|
|
582
|
+
enable_error_bars_filtering=True,
|
|
583
|
+
filter_threshold=0.3,
|
|
584
|
+
filter_remove_singles=True,
|
|
585
|
+
filter_remove_consecutives=True,
|
|
586
|
+
filter_consecutive=3,
|
|
587
|
+
filter_q_start_trunc=0.1):
|
|
588
|
+
|
|
589
|
+
# Save originals for polishing
|
|
590
|
+
reflectivity_curve_original = reflectivity_curve.copy()
|
|
591
|
+
q_values_original = q_values.copy() if q_values is not None else None
|
|
592
|
+
q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
|
|
593
|
+
sigmas_original = sigmas.copy() if sigmas is not None else None
|
|
594
|
+
|
|
595
|
+
# Remove points with non-positive intensities
|
|
596
|
+
nonnegative_mask = reflectivity_curve > 0.0
|
|
597
|
+
reflectivity_curve = reflectivity_curve[nonnegative_mask]
|
|
598
|
+
q_values = q_values[nonnegative_mask]
|
|
599
|
+
if sigmas is not None:
|
|
600
|
+
sigmas = sigmas[nonnegative_mask]
|
|
601
|
+
if isinstance(q_resolution, np.ndarray):
|
|
602
|
+
q_resolution = q_resolution[nonnegative_mask]
|
|
603
|
+
|
|
604
|
+
# Truncate arrays
|
|
605
|
+
if truncate_index_left is not None or truncate_index_right is not None:
|
|
606
|
+
slice_obj = slice(truncate_index_left, truncate_index_right)
|
|
607
|
+
reflectivity_curve = reflectivity_curve[slice_obj]
|
|
608
|
+
q_values = q_values[slice_obj]
|
|
609
|
+
if sigmas is not None:
|
|
610
|
+
sigmas = sigmas[slice_obj]
|
|
611
|
+
if isinstance(q_resolution, np.ndarray):
|
|
612
|
+
q_resolution = q_resolution[slice_obj]
|
|
613
|
+
|
|
614
|
+
# Filter high-error points
|
|
615
|
+
if enable_error_bars_filtering and sigmas is not None:
|
|
616
|
+
valid_mask = get_filtering_mask(
|
|
617
|
+
q_values,
|
|
618
|
+
reflectivity_curve,
|
|
619
|
+
sigmas,
|
|
620
|
+
threshold=filter_threshold,
|
|
621
|
+
consecutive=filter_consecutive,
|
|
622
|
+
remove_singles=filter_remove_singles,
|
|
623
|
+
remove_consecutives=filter_remove_consecutives,
|
|
624
|
+
q_start_trunc=filter_q_start_trunc
|
|
625
|
+
)
|
|
626
|
+
reflectivity_curve = reflectivity_curve[valid_mask]
|
|
627
|
+
q_values = q_values[valid_mask]
|
|
628
|
+
sigmas = sigmas[valid_mask]
|
|
629
|
+
if isinstance(q_resolution, np.ndarray):
|
|
630
|
+
q_resolution = q_resolution[valid_mask]
|
|
631
|
+
|
|
632
|
+
return (q_values, reflectivity_curve, sigmas, q_resolution,
|
|
633
|
+
q_values_original, reflectivity_curve_original,
|
|
634
|
+
sigmas_original, q_resolution_original)
|
|
635
|
+
|
|
636
|
+
def interpolate_data_to_model_q(
|
|
637
|
+
self,
|
|
638
|
+
q_exp,
|
|
639
|
+
refl_exp,
|
|
640
|
+
sigmas_exp=None,
|
|
641
|
+
q_res_exp=None,
|
|
642
|
+
as_dict=False
|
|
643
|
+
):
|
|
644
|
+
q_generator = self.trainer.loader.q_generator
|
|
645
|
+
|
|
646
|
+
def _pad(arr, pad_to, value=0.0):
|
|
647
|
+
if arr is None:
|
|
648
|
+
return None
|
|
649
|
+
return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
|
|
650
|
+
|
|
651
|
+
def _interp_or_keep(q_model, q_exp, arr):
|
|
652
|
+
"""Interpolate arrays, keep floats or None unchanged."""
|
|
653
|
+
if arr is None:
|
|
654
|
+
return None
|
|
655
|
+
return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
|
|
656
|
+
|
|
657
|
+
def _pad_or_keep(arr, max_n):
|
|
658
|
+
"""Pad arrays, keep floats or None unchanged."""
|
|
659
|
+
if arr is None:
|
|
660
|
+
return None
|
|
661
|
+
return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
|
|
662
|
+
|
|
663
|
+
def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
|
|
664
|
+
if as_dict:
|
|
665
|
+
result = {"q_model": q, "reflectivity": refl}
|
|
666
|
+
if sigmas is not None: result["sigmas"] = sigmas
|
|
667
|
+
if q_res is not None: result["q_resolution"] = q_res
|
|
668
|
+
if mask is not None: result["key_padding_mask"] = mask
|
|
669
|
+
return result
|
|
670
|
+
result = [q, refl]
|
|
671
|
+
if sigmas is not None: result.append(sigmas)
|
|
672
|
+
if q_res is not None: result.append(q_res)
|
|
673
|
+
if mask is not None: result.append(mask)
|
|
674
|
+
return tuple(result)
|
|
675
|
+
|
|
676
|
+
# ConstantQ
|
|
677
|
+
if isinstance(q_generator, ConstantQ):
|
|
678
|
+
q_model = q_generator.q.cpu().numpy()
|
|
679
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
|
|
680
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
681
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
682
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
683
|
+
|
|
684
|
+
# VariableQ
|
|
685
|
+
elif isinstance(q_generator, VariableQ):
|
|
686
|
+
if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
|
|
687
|
+
n_q_model = q_generator.n_q_range[0]
|
|
688
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
689
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
690
|
+
if self.trainer.loader.q_generator.mode == 'logspace':
|
|
691
|
+
q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
|
|
692
|
+
end=torch.log10(torch.tensor(q_max, device=self.device)),
|
|
693
|
+
steps=n_q_model, device=self.device).to('cpu')
|
|
694
|
+
logspace = True
|
|
695
|
+
else:
|
|
696
|
+
q_model = np.linspace(q_min, q_max, n_q_model)
|
|
697
|
+
logspace = False
|
|
698
|
+
else:
|
|
699
|
+
return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
|
|
700
|
+
|
|
701
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
|
|
702
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
703
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
704
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
705
|
+
|
|
706
|
+
# MaskedVariableQ
|
|
707
|
+
elif isinstance(q_generator, MaskedVariableQ):
|
|
708
|
+
min_n, max_n = q_generator.n_q_range
|
|
709
|
+
n_exp = len(q_exp)
|
|
710
|
+
|
|
711
|
+
if min_n <= n_exp <= max_n:
|
|
712
|
+
# Pad only
|
|
713
|
+
q_model = _pad(q_exp, max_n, 0.0)
|
|
714
|
+
refl_out = _pad(refl_exp, max_n, 0.0)
|
|
715
|
+
sigmas_out = _pad_or_keep(sigmas_exp, max_n)
|
|
716
|
+
q_res_out = _pad_or_keep(q_res_exp, max_n)
|
|
717
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
718
|
+
key_padding_mask[:n_exp] = True
|
|
719
|
+
|
|
720
|
+
else:
|
|
721
|
+
# Interpolate + pad
|
|
722
|
+
n_interp = min(max(n_exp, min_n), max_n)
|
|
723
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
724
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
725
|
+
q_interp = np.linspace(q_min, q_max, n_interp)
|
|
726
|
+
|
|
727
|
+
refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
|
|
728
|
+
sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
|
|
729
|
+
q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
|
|
730
|
+
|
|
731
|
+
q_model = _pad(q_interp, max_n, 0.0)
|
|
732
|
+
refl_out = _pad(refl_interp, max_n, 0.0)
|
|
733
|
+
sigmas_out = _pad_or_keep(sigmas_interp, max_n)
|
|
734
|
+
q_res_out = _pad_or_keep(q_res_interp, max_n)
|
|
735
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
736
|
+
key_padding_mask[:n_interp] = True
|
|
737
|
+
|
|
738
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
|
|
739
|
+
|
|
740
|
+
else:
|
|
741
|
+
raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
|
|
742
|
+
|
|
743
|
+
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
744
|
+
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
745
|
+
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
746
|
+
dq_max = (q[1] - q[0]) * dq_coef
|
|
747
|
+
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
748
|
+
|
|
749
|
+
curve = to_t(curve).to(scaled_bounds)
|
|
750
|
+
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
751
|
+
|
|
752
|
+
assert shifted_curves.shape == (num, q.shape[0])
|
|
753
|
+
|
|
754
|
+
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
755
|
+
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
756
|
+
|
|
757
|
+
with torch.no_grad():
|
|
758
|
+
self.trainer.model.eval()
|
|
759
|
+
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
760
|
+
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
761
|
+
|
|
762
|
+
best_param = get_best_mse_param(
|
|
763
|
+
restored_params,
|
|
764
|
+
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
765
|
+
)
|
|
766
|
+
return best_param
|
|
767
|
+
|
|
768
|
+
def predict_using_widget(self, reflectivity_curve, **kwargs):
|
|
769
|
+
"""
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
773
|
+
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
774
|
+
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
775
|
+
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
776
|
+
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
777
|
+
|
|
778
|
+
print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
|
|
779
|
+
|
|
780
|
+
interval_widgets = []
|
|
781
|
+
for i in range(NUM_INTERVALS):
|
|
782
|
+
label = widgets.Label(value=f'{param_labels[i]}')
|
|
783
|
+
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
784
|
+
slider = widgets.FloatRangeSlider(
|
|
785
|
+
value=[min_bounds[i], initial_max],
|
|
786
|
+
min=min_bounds[i],
|
|
787
|
+
max=max_bounds[i],
|
|
788
|
+
step=0.01,
|
|
789
|
+
layout=widgets.Layout(width='400px'),
|
|
790
|
+
style={'description_width': '60px'}
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
794
|
+
min_val, max_val = change['new']
|
|
795
|
+
if max_val - min_val > max_width:
|
|
796
|
+
old_min_val, old_max_val = change['old']
|
|
797
|
+
if abs(old_min_val - min_val) > abs(old_max_val - max_val):
|
|
798
|
+
max_val = min_val + max_width
|
|
799
|
+
else:
|
|
800
|
+
min_val = max_val - max_width
|
|
801
|
+
slider.value = [min_val, max_val]
|
|
802
|
+
|
|
803
|
+
slider.observe(validate_range, names='value')
|
|
804
|
+
interval_widgets.append((slider, widgets.HBox([label, slider])))
|
|
805
|
+
|
|
806
|
+
sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
|
|
807
|
+
|
|
808
|
+
output = widgets.Output()
|
|
809
|
+
predict_button = widgets.Button(description="Predict")
|
|
810
|
+
close_button = widgets.Button(description="Close Widget")
|
|
811
|
+
|
|
812
|
+
container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
|
|
813
|
+
display(container)
|
|
814
|
+
|
|
815
|
+
@output.capture(clear_output=True)
|
|
816
|
+
def on_predict_click(_):
|
|
817
|
+
if 'prior_bounds' in kwargs:
|
|
818
|
+
array_values = kwargs.pop('prior_bounds')
|
|
819
|
+
for i, (s, _) in enumerate(interval_widgets):
|
|
820
|
+
s.value = tuple(array_values[i])
|
|
821
|
+
else:
|
|
822
|
+
values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
|
|
823
|
+
array_values = np.array(values)
|
|
824
|
+
|
|
825
|
+
prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
|
|
826
|
+
prior_bounds=array_values,
|
|
827
|
+
**kwargs)
|
|
828
|
+
param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
829
|
+
print_prediction_results(prediction_result)
|
|
830
|
+
|
|
831
|
+
plot_prediction_results(
|
|
832
|
+
prediction_result,
|
|
833
|
+
q_exp=kwargs['q_values'],
|
|
834
|
+
curve_exp=reflectivity_curve,
|
|
835
|
+
)
|
|
836
|
+
self.widget_prediction_result = prediction_result
|
|
837
|
+
|
|
838
|
+
def on_close_click(_):
|
|
839
|
+
container.close()
|
|
840
|
+
print("Widget closed.")
|
|
841
|
+
|
|
842
|
+
predict_button.on_click(on_predict_click)
|
|
843
|
+
close_button.on_click(on_close_click)
|
|
844
|
+
|
|
845
|
+
def preprocess_and_predict_using_widget(self,
|
|
846
|
+
reflectivity_curve,
|
|
847
|
+
q_values=None,
|
|
848
|
+
sigmas=None,
|
|
849
|
+
q_resolution=None,
|
|
850
|
+
prior_bounds=None,
|
|
851
|
+
ambient_sld=None,
|
|
852
|
+
):
|
|
853
|
+
"""
|
|
854
|
+
Interactive widget around `preprocess_and_predict`
|
|
855
|
+
Results are stored in `self.widget_prediction_result`
|
|
856
|
+
"""
|
|
857
|
+
|
|
858
|
+
if q_values is None:
|
|
859
|
+
raise ValueError("q_values must be provided for this widget.")
|
|
860
|
+
|
|
861
|
+
N = len(reflectivity_curve)
|
|
862
|
+
|
|
863
|
+
# ---------- Priors sliders ----------
|
|
864
|
+
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
865
|
+
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
866
|
+
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
867
|
+
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
868
|
+
NUM = len(param_labels)
|
|
869
|
+
|
|
870
|
+
sliders, rows = [], []
|
|
871
|
+
init_pb = np.array(prior_bounds) if prior_bounds is not None else None
|
|
872
|
+
for i in range(NUM):
|
|
873
|
+
init_min = float(init_pb[i, 0]) if init_pb is not None else float(min_bounds[i])
|
|
874
|
+
init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(min_bounds[i] + max_deltas[i], max_bounds[i]))
|
|
875
|
+
lab = widgets.Label(value=param_labels[i])
|
|
876
|
+
s = widgets.FloatRangeSlider(
|
|
877
|
+
value=[init_min, init_max],
|
|
878
|
+
min=float(min_bounds[i]),
|
|
879
|
+
max=float(max_bounds[i]),
|
|
880
|
+
step=0.01,
|
|
881
|
+
layout=widgets.Layout(width='420px'),
|
|
882
|
+
readout_format='.3f'
|
|
883
|
+
)
|
|
884
|
+
# Constrain slider widths
|
|
885
|
+
def _mk_validator(slider, max_width=float(max_deltas[i])):
|
|
886
|
+
def _validate(change):
|
|
887
|
+
a, b = change['new']
|
|
888
|
+
if b - a > max_width:
|
|
889
|
+
oa, ob = change['old']
|
|
890
|
+
if abs(oa - a) > abs(ob - b):
|
|
891
|
+
b = a + max_width
|
|
892
|
+
else:
|
|
893
|
+
a = b - max_width
|
|
894
|
+
slider.value = (a, b)
|
|
895
|
+
return _validate
|
|
896
|
+
s.observe(_mk_validator(s), names='value')
|
|
897
|
+
sliders.append(s)
|
|
898
|
+
rows.append(widgets.HBox([lab, s]))
|
|
899
|
+
priors_box = widgets.VBox([widgets.HTML("<b>Priors</b>")] + rows)
|
|
900
|
+
|
|
901
|
+
# ---------- Preprocess & Predict controls ----------
|
|
902
|
+
# Preprocessing
|
|
903
|
+
trunc_L = widgets.IntSlider(description='truncate left', min=0, max=max(0, N-1), step=1, value=0)
|
|
904
|
+
trunc_R = widgets.IntSlider(description='truncate right', min=1, max=N, step=1, value=N)
|
|
905
|
+
enable_filt = widgets.Checkbox(description='filter error bars', value=True)
|
|
906
|
+
thr = widgets.FloatSlider(description='filtering threshold', min=0.0, max=1.0, step=0.01, value=0.3)
|
|
907
|
+
rem_single = widgets.Checkbox(description='remove singles', value=True)
|
|
908
|
+
rem_cons = widgets.Checkbox(description='remove consecutives', value=True)
|
|
909
|
+
consec = widgets.IntSlider(description='num. consecutive', min=1, max=10, step=1, value=3)
|
|
910
|
+
qstart = widgets.FloatSlider(description='q_start_trunc', min=0.0, max=1.0, step=0.01, value=0.1)
|
|
911
|
+
|
|
912
|
+
# Polishing
|
|
913
|
+
polish = widgets.Checkbox(description='polish prediction', value=True)
|
|
914
|
+
use_sigmas_polish = widgets.Checkbox(description='use sigmas during polishing', value=True)
|
|
915
|
+
|
|
916
|
+
# Plotting
|
|
917
|
+
pred_show_yerr = widgets.Checkbox(description='show error bars', value=True)
|
|
918
|
+
pred_show_xerr = widgets.Checkbox(description='show q-resolution', value=False)
|
|
919
|
+
pred_logx = widgets.Checkbox(description='log x-axis', value=False)
|
|
920
|
+
plot_sld = widgets.Checkbox(description='plot SLD profile', value=True)
|
|
921
|
+
sld_pad_left = widgets.FloatText(description='SLD pad left', value=0.2, step=0.1)
|
|
922
|
+
sld_pad_right = widgets.FloatText(description='SLD pad right', value=1.1, step=0.1)
|
|
923
|
+
|
|
924
|
+
# Color pickers
|
|
925
|
+
exp_color_picker = widgets.ColorPicker(description='exp color', value='#0000FF') # blue
|
|
926
|
+
exp_errcolor_picker = widgets.ColorPicker(description='errbar color', value='#800080') # purple
|
|
927
|
+
pred_color_picker = widgets.ColorPicker(description='pred color', value='#FF0000') # red
|
|
928
|
+
pol_color_picker = widgets.ColorPicker(description='polished color', value='#FFA500') # orange
|
|
929
|
+
sld_pred_color_picker = widgets.ColorPicker(description='SLD pred color', value='#FF0000') # red
|
|
930
|
+
sld_pol_color_picker = widgets.ColorPicker(description='SLD pol color', value='#FFA500') # orange
|
|
931
|
+
|
|
932
|
+
# Compute toggles
|
|
933
|
+
calc_curve = widgets.Checkbox(description='calc curve', value=True)
|
|
934
|
+
calc_sld_pred = widgets.Checkbox(description='calc predicted SLD', value=True)
|
|
935
|
+
calc_sld_pol = widgets.Checkbox(description='calc polished SLD', value=True)
|
|
936
|
+
|
|
937
|
+
btn_predict = widgets.Button(description='Predict')
|
|
938
|
+
btn_close = widgets.Button(description='Close')
|
|
939
|
+
|
|
940
|
+
controls_box = widgets.VBox([
|
|
941
|
+
widgets.HTML("<b>Preprocess & Predict</b>"),
|
|
942
|
+
widgets.HTML("<i>Preprocessing</i>"),
|
|
943
|
+
widgets.HBox([trunc_L, trunc_R]),
|
|
944
|
+
widgets.HBox([enable_filt, rem_single, rem_cons]),
|
|
945
|
+
widgets.HBox([thr, consec, qstart]),
|
|
946
|
+
widgets.HTML("<i>Polishing</i>"),
|
|
947
|
+
widgets.HBox([polish, use_sigmas_polish]),
|
|
948
|
+
widgets.HTML("<i>Plotting</i>"),
|
|
949
|
+
widgets.HBox([pred_show_yerr, pred_show_xerr, pred_logx, plot_sld]),
|
|
950
|
+
widgets.HBox([sld_pad_left, sld_pad_right]),
|
|
951
|
+
widgets.HTML("<i>Colors</i>"),
|
|
952
|
+
widgets.HBox([exp_color_picker, exp_errcolor_picker]),
|
|
953
|
+
widgets.HBox([pred_color_picker, pol_color_picker]),
|
|
954
|
+
widgets.HBox([sld_pred_color_picker, sld_pol_color_picker]),
|
|
955
|
+
widgets.HTML("<i>Compute</i>"),
|
|
956
|
+
widgets.HBox([calc_curve, calc_sld_pred, calc_sld_pol]),
|
|
957
|
+
widgets.HBox([btn_predict, btn_close]),
|
|
958
|
+
])
|
|
959
|
+
|
|
960
|
+
out_predict = widgets.Output()
|
|
961
|
+
container = widgets.VBox([priors_box, controls_box, out_predict])
|
|
962
|
+
display(container)
|
|
963
|
+
|
|
964
|
+
def _sync_trunc(_):
|
|
965
|
+
if trunc_L.value >= trunc_R.value:
|
|
966
|
+
trunc_L.value = max(0, trunc_R.value - 1)
|
|
967
|
+
trunc_L.observe(_sync_trunc, names='value')
|
|
968
|
+
trunc_R.observe(_sync_trunc, names='value')
|
|
969
|
+
|
|
970
|
+
def _current_priors():
|
|
971
|
+
return np.array([s.value for s in sliders], dtype=np.float32) # (param_dim, 2)
|
|
972
|
+
|
|
973
|
+
@out_predict.capture(clear_output=True)
|
|
974
|
+
def _on_predict(_):
|
|
975
|
+
out_predict.clear_output(wait=True)
|
|
976
|
+
|
|
977
|
+
res = self.preprocess_and_predict(
|
|
978
|
+
reflectivity_curve=reflectivity_curve,
|
|
979
|
+
q_values=q_values,
|
|
980
|
+
prior_bounds=_current_priors(),
|
|
981
|
+
sigmas=sigmas,
|
|
982
|
+
q_resolution=q_resolution,
|
|
983
|
+
ambient_sld=ambient_sld,
|
|
984
|
+
clip_prediction=True,
|
|
985
|
+
polish_prediction=polish.value,
|
|
986
|
+
use_sigmas_for_polishing=use_sigmas_polish.value,
|
|
987
|
+
calc_pred_curve=calc_curve.value,
|
|
988
|
+
calc_pred_sld_profile=(calc_sld_pred.value or plot_sld.value),
|
|
989
|
+
calc_polished_sld_profile=(calc_sld_pol.value or plot_sld.value),
|
|
990
|
+
sld_profile_padding_left=float(sld_pad_left.value),
|
|
991
|
+
sld_profile_padding_right=float(sld_pad_right.value),
|
|
992
|
+
|
|
993
|
+
truncate_index_left=trunc_L.value,
|
|
994
|
+
truncate_index_right=trunc_R.value,
|
|
995
|
+
enable_error_bars_filtering=enable_filt.value,
|
|
996
|
+
filter_threshold=thr.value,
|
|
997
|
+
filter_remove_singles=rem_single.value,
|
|
998
|
+
filter_remove_consecutives=rem_cons.value,
|
|
999
|
+
filter_consecutive=consec.value,
|
|
1000
|
+
filter_q_start_trunc=qstart.value,
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# Full experimental data as scatter
|
|
1004
|
+
q_exp_plot = q_values
|
|
1005
|
+
r_exp_plot = reflectivity_curve
|
|
1006
|
+
yerr_plot = (sigmas if pred_show_yerr.value else None)
|
|
1007
|
+
xerr_plot = (q_resolution if pred_show_xerr.value else None)
|
|
1008
|
+
|
|
1009
|
+
# Predicted curve only on the model region
|
|
1010
|
+
q_pred = res.get('q_plot_pred', None)
|
|
1011
|
+
r_pred = res.get('predicted_curve', None)
|
|
1012
|
+
|
|
1013
|
+
# Polished curve on the full experimental grid
|
|
1014
|
+
q_pol = q_values if ('polished_curve' in res) else None
|
|
1015
|
+
r_pol = res.get('polished_curve', None)
|
|
1016
|
+
|
|
1017
|
+
# SLD profiles
|
|
1018
|
+
z_sld = res.get('predicted_sld_xaxis', None)
|
|
1019
|
+
sld_pred = res.get('predicted_sld_profile', None)
|
|
1020
|
+
sld_pol = res.get('sld_profile_polished', None)
|
|
1021
|
+
|
|
1022
|
+
print_prediction_results(res)
|
|
1023
|
+
|
|
1024
|
+
plot_reflectivity(
|
|
1025
|
+
q_exp=q_exp_plot, r_exp=r_exp_plot,
|
|
1026
|
+
yerr=yerr_plot, xerr=xerr_plot,
|
|
1027
|
+
exp_style=('errorbar' if pred_show_yerr.value or pred_show_xerr.value else 'scatter'),
|
|
1028
|
+
exp_color=exp_color_picker.value,
|
|
1029
|
+
exp_errcolor=exp_errcolor_picker.value,
|
|
1030
|
+
q_pred=q_pred, r_pred=r_pred, pred_color=pred_color_picker.value,
|
|
1031
|
+
q_pol=q_pol, r_pol=r_pol, pol_color=pol_color_picker.value,
|
|
1032
|
+
z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
|
|
1033
|
+
sld_pred_color=sld_pred_color_picker.value,
|
|
1034
|
+
sld_pol_color=sld_pol_color_picker.value,
|
|
1035
|
+
plot_sld_profile=plot_sld.value,
|
|
1036
|
+
logx=pred_logx.value, logy=True,
|
|
1037
|
+
figsize=(12,6),
|
|
1038
|
+
legend=True
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
self.widget_prediction_result = res
|
|
1042
|
+
|
|
1043
|
+
def _on_close(_):
|
|
1044
|
+
container.close()
|
|
1045
|
+
print("Widget closed.")
|
|
1046
|
+
|
|
1047
|
+
btn_predict.on_click(_on_predict)
|
|
1048
|
+
btn_close.on_click(_on_close)
|
|
526
1049
|
|
|
527
1050
|
class InferenceModel(object):
|
|
528
1051
|
def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,
|