reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

@@ -13,9 +13,8 @@ from huggingface_hub import hf_hub_download
13
13
 
14
14
  from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
15
15
  from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
16
- from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
16
+ from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
17
17
  from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
18
- from reflectorch.inference.plotting import plot_prediction_results
19
18
  from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
20
19
  from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
21
20
  from reflectorch.runs.utils import (
@@ -29,7 +28,8 @@ from reflectorch.inference.preprocess_exp import StandardPreprocessing
29
28
  from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
30
29
  from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
31
30
  from reflectorch.inference.record_time import print_time
32
- from reflectorch.utils import to_t
31
+ from reflectorch.inference.plotting import plot_reflectivity, plot_prediction_results, print_prediction_results
32
+ from reflectorch.utils import get_filtering_mask, to_t
33
33
 
34
34
  class EasyInferenceModel(object):
35
35
  """Facilitates the inference process using pretrained models
@@ -150,21 +150,158 @@ class EasyInferenceModel(object):
150
150
  inputs_str = ", ".join(additional_inputs)
151
151
  print(f"The following quantities are additional inputs to the network: {inputs_str}.")
152
152
 
153
+ def preprocess_and_predict(self,
154
+ reflectivity_curve: Union[np.ndarray, Tensor],
155
+ q_values: Union[np.ndarray, Tensor] = None,
156
+ prior_bounds: Union[np.ndarray, List[Tuple]] = None,
157
+ sigmas: Union[np.ndarray, Tensor] = None,
158
+ q_resolution: float = None,
159
+ ambient_sld: float = None,
160
+ clip_prediction: bool = True,
161
+ polish_prediction: bool = False,
162
+ polishing_method: str = 'trf',
163
+ polishing_kwargs_reflectivity: dict = None,
164
+ use_sigmas_for_polishing: bool = False,
165
+ polishing_max_nfev: int = None,
166
+ fit_growth: bool = False,
167
+ max_d_change: float = 5.,
168
+ calc_pred_curve: bool = True,
169
+ calc_pred_sld_profile: bool = False,
170
+ calc_polished_sld_profile: bool = False,
171
+ sld_profile_padding_left: float = 0.2,
172
+ sld_profile_padding_right: float = 1.1,
173
+ kwargs_param_labels: dict = {},
174
+
175
+ truncate_index_left: int = None,
176
+ truncate_index_right: int = None,
177
+ enable_error_bars_filtering: bool = True,
178
+ filter_threshold=0.3,
179
+ filter_remove_singles=True,
180
+ filter_remove_consecutives=True,
181
+ filter_consecutive=3,
182
+ filter_q_start_trunc=0.1,
183
+ ):
184
+
185
+ ## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
186
+ (q_values, reflectivity_curve, sigmas, q_resolution,
187
+ q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
188
+ reflectivity_curve=reflectivity_curve,
189
+ q_values=q_values,
190
+ sigmas=sigmas,
191
+ q_resolution=q_resolution,
192
+ truncate_index_left=truncate_index_left,
193
+ truncate_index_right=truncate_index_right,
194
+ enable_error_bars_filtering=enable_error_bars_filtering,
195
+ filter_threshold=filter_threshold,
196
+ filter_remove_singles=filter_remove_singles,
197
+ filter_remove_consecutives=filter_remove_consecutives,
198
+ filter_consecutive=filter_consecutive,
199
+ filter_q_start_trunc=filter_q_start_trunc,
200
+ )
201
+
202
+ ### Interpolate the experimental data if needed by the embedding network
203
+ interp_data = self.interpolate_data_to_model_q(
204
+ q_exp=q_values,
205
+ refl_exp=reflectivity_curve,
206
+ sigmas_exp=sigmas,
207
+ q_res_exp=q_resolution,
208
+ as_dict=True
209
+ )
210
+
211
+ q_model = interp_data["q_model"]
212
+ reflectivity_curve_interp = interp_data["reflectivity"]
213
+ sigmas_interp = interp_data.get("sigmas")
214
+ q_resolution_interp = interp_data.get("q_resolution")
215
+ key_padding_mask = interp_data.get("key_padding_mask")
216
+
217
+ ### Make the prediction
218
+ prediction_dict = self.predict(
219
+ reflectivity_curve=reflectivity_curve_interp,
220
+ q_values=q_model,
221
+ sigmas=sigmas_interp,
222
+ q_resolution=q_resolution_interp,
223
+ key_padding_mask=key_padding_mask,
224
+ prior_bounds=prior_bounds,
225
+ ambient_sld=ambient_sld,
226
+ clip_prediction=clip_prediction,
227
+ polish_prediction=False, ###do the polishing outside the predict method on the full data
228
+ supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
229
+ calc_pred_curve=calc_pred_curve,
230
+ calc_pred_sld_profile=calc_pred_sld_profile,
231
+ sld_profile_padding_left=sld_profile_padding_left,
232
+ sld_profile_padding_right=sld_profile_padding_right,
233
+ kwargs_param_labels=kwargs_param_labels,
234
+ )
235
+
236
+ ### Save interpolated data
237
+ prediction_dict['q_model'] = q_model
238
+ prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
239
+ if q_resolution_interp is not None:
240
+ prediction_dict['q_resolution_interp'] = q_resolution_interp
241
+ if sigmas_interp is not None:
242
+ prediction_dict['sigmas_interp'] = sigmas_interp
243
+ if key_padding_mask is not None:
244
+ prediction_dict['key_padding_mask'] = key_padding_mask
245
+
246
+ ### Perform polishing on the original data
247
+ if polish_prediction:
248
+ polishing_kwargs = polishing_kwargs_reflectivity or {}
249
+ polishing_kwargs.setdefault('dq', q_resolution_original)
250
+
251
+ prior_bounds = np.array(prior_bounds)
252
+ if ambient_sld:
253
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
254
+
255
+ polished_dict = self._polish_prediction(
256
+ q=q_values_original,
257
+ curve=reflectivity_curve_original,
258
+ predicted_params=prediction_dict['predicted_params_object'],
259
+ priors=prior_bounds,
260
+ ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)).to(self.device) if ambient_sld is not None else None,
261
+ calc_polished_sld_profile=calc_polished_sld_profile,
262
+ sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
263
+ polishing_kwargs_reflectivity = polishing_kwargs,
264
+ error_bars=sigmas_original if use_sigmas_for_polishing else None,
265
+ polishing_method=polishing_method,
266
+ polishing_max_nfev=polishing_max_nfev,
267
+ fit_growth=fit_growth,
268
+ max_d_change=max_d_change,
269
+ )
270
+
271
+ prediction_dict.update(polished_dict)
272
+ if fit_growth and "polished_params_array" in prediction_dict:
273
+ prediction_dict["param_names"].append("max_d_change")
274
+
275
+ ### Shift back the slds for nonzero ambient
276
+ if ambient_sld:
277
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
278
+
279
+ return prediction_dict
280
+
281
+
153
282
  def predict(self,
154
283
  reflectivity_curve: Union[np.ndarray, Tensor],
155
284
  q_values: Union[np.ndarray, Tensor] = None,
156
285
  prior_bounds: Union[np.ndarray, List[Tuple]] = None,
157
- q_resolution: Union[float, np.ndarray] = None,
286
+ sigmas: Union[np.ndarray, Tensor] = None,
287
+ key_padding_mask: Union[np.ndarray, Tensor] = None,
288
+ q_resolution: float = None,
158
289
  ambient_sld: float = None,
159
- clip_prediction: bool = False,
160
- polish_prediction: bool = False,
290
+ clip_prediction: bool = True,
291
+ polish_prediction: bool = False,
292
+ polishing_method: str = 'trf',
161
293
  polishing_kwargs_reflectivity: dict = None,
294
+ polishing_max_nfev: int = None,
162
295
  fit_growth: bool = False,
163
296
  max_d_change: float = 5.,
164
297
  use_q_shift: bool = False,
165
298
  calc_pred_curve: bool = True,
166
299
  calc_pred_sld_profile: bool = False,
167
300
  calc_polished_sld_profile: bool = False,
301
+ sld_profile_padding_left: float = 0.2,
302
+ sld_profile_padding_right: float = 1.1,
303
+ supress_sld_amb_back_shift: bool = False,
304
+ kwargs_param_labels: dict = {},
168
305
  ):
169
306
  """Predict the thin film parameters
170
307
 
@@ -172,17 +309,13 @@ class EasyInferenceModel(object):
172
309
  reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
173
310
  q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
174
311
  prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
175
- q_resolution (Union[float, np.ndarray], optional): the instrumental resolution. Either as a float with meaning dq/q for linear smearing or as a numpy array with meaning dq for pointwise smearing.
176
- ambient_sld (float, optional): the SLD of the ambient medium (fronting), if different from air.
177
312
  clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
178
313
  polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
179
- polishing_kwargs_reflectivity (dict): extra arguments for the reflectivity function used during polishing.
180
314
  fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
181
315
  max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
182
316
  use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
183
317
  calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
184
318
  calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
185
- calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
186
319
 
187
320
  Returns:
188
321
  dict: dictionary containing the predictions
@@ -192,54 +325,48 @@ class EasyInferenceModel(object):
192
325
  prior_bounds = np.array(prior_bounds)
193
326
 
194
327
  if ambient_sld:
195
- n_layers = self.trainer.loader.prior_sampler.max_num_layers
196
- sld_indices = slice(2*n_layers+1, 3*n_layers+2)
197
- prior_bounds[sld_indices, ...] -= ambient_sld
198
- training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
199
- training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
200
- lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
201
- upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
202
- assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
328
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
203
329
 
204
- try:
205
- scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
206
- except ValueError as e:
207
- print(str(e))
208
- return None
330
+ scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
209
331
 
210
- if not self.trainer.train_with_q_input:
332
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
211
333
  q_values = self.trainer.loader.q_generator.q
212
334
  else:
213
335
  q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
214
336
 
337
+ scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
338
+
339
+ if q_resolution is not None:
340
+ q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
341
+ if isinstance(q_resolution, float):
342
+ unscaled_q_resolutions = q_resolution_tensor
343
+ else:
344
+ unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
345
+ scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
346
+ scaled_conditioning_params = scaled_q_resolutions
347
+ if polishing_kwargs_reflectivity is None:
348
+ polishing_kwargs_reflectivity = {'dq': q_resolution}
349
+ else:
350
+ q_resolution_tensor = None
351
+ scaled_conditioning_params = None
352
+
353
+ if key_padding_mask is not None:
354
+ key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
355
+ key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
356
+
215
357
  if use_q_shift and not self.trainer.train_with_q_input:
216
358
  predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
217
-
218
359
  else:
219
360
  with torch.no_grad():
220
361
  self.trainer.model.eval()
221
-
222
- scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
223
-
224
- if q_resolution is not None:
225
- q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
226
- if isinstance(q_resolution, float):
227
- unscaled_q_resolutions = q_resolution_tensor
228
- else:
229
- unscaled_q_resolutions = (q_resolution_tensor / q_values).mean(dim=-1, keepdim=True)
230
- scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
231
- scaled_conditioning_params = scaled_q_resolutions
232
- if polishing_kwargs_reflectivity is None:
233
- polishing_kwargs_reflectivity = {'dq': q_resolution}
234
- else:
235
- q_resolution_tensor = None
236
- scaled_conditioning_params = None
237
-
362
+
238
363
  scaled_predicted_params = self.trainer.model(
239
364
  curves=scaled_curve,
240
365
  bounds=scaled_prior_bounds,
241
366
  q_values=scaled_q_values,
242
367
  conditioning_params = scaled_conditioning_params,
368
+ key_padding_mask = key_padding_mask,
369
+ unscaled_q_values = q_values,
243
370
  )
244
371
 
245
372
  predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
@@ -250,155 +377,60 @@ class EasyInferenceModel(object):
250
377
  prediction_dict = {
251
378
  "predicted_params_object": predicted_params,
252
379
  "predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
253
- "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
380
+ "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
254
381
  }
382
+
383
+ key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
255
384
 
256
385
  if calc_pred_curve:
257
386
  predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
258
- prediction_dict[ "predicted_curve"] = predicted_curve
387
+ prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
259
388
 
260
- ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld)).to(predicted_params.thicknesses.device) if ambient_sld is not None else None
389
+ ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
261
390
  if calc_pred_sld_profile:
262
391
  predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
263
- predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, ambient_sld_tensor, num=1024,
392
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
393
+ num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
264
394
  )
265
395
  prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
266
396
  prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
267
397
  else:
268
398
  predicted_sld_xaxis = None
399
+
400
+ refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
401
+ q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
402
+ prediction_dict['q_plot_pred'] = q_polish
269
403
 
270
404
  if polish_prediction:
271
405
  if ambient_sld_tensor:
272
406
  ambient_sld_tensor = ambient_sld_tensor.cpu()
273
- polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
274
- curve = reflectivity_curve,
275
- predicted_params = predicted_params,
276
- priors = np.array(prior_bounds),
277
- fit_growth = fit_growth,
278
- max_d_change = max_d_change,
279
- calc_polished_curve = calc_pred_curve,
280
- calc_polished_sld_profile = calc_polished_sld_profile,
281
- ambient_sld_tensor=ambient_sld_tensor,
282
- sld_x_axis = predicted_sld_xaxis,
283
- polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
284
- )
407
+
408
+ polished_dict = self._polish_prediction(
409
+ q = q_polish,
410
+ curve = refl_curve_polish,
411
+ predicted_params = predicted_params,
412
+ priors = np.array(prior_bounds),
413
+ error_bars = sigmas,
414
+ fit_growth = fit_growth,
415
+ max_d_change = max_d_change,
416
+ calc_polished_curve = calc_pred_curve,
417
+ calc_polished_sld_profile = calc_polished_sld_profile,
418
+ ambient_sld_tensor=ambient_sld_tensor,
419
+ sld_x_axis = predicted_sld_xaxis,
420
+ polishing_method=polishing_method,
421
+ polishing_max_nfev=polishing_max_nfev,
422
+ polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
423
+ )
285
424
  prediction_dict.update(polished_dict)
286
425
 
287
426
  if fit_growth and "polished_params_array" in prediction_dict:
288
427
  prediction_dict["param_names"].append("max_d_change")
289
428
 
290
- if ambient_sld: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object
291
- prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
292
- if "polished_params_array" in prediction_dict:
293
- prediction_dict["polished_params_array"][sld_indices] += ambient_sld
429
+ if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
430
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
294
431
 
295
432
  return prediction_dict
296
-
297
- def predict_using_widget(self, reflectivity_curve, **kwargs):
298
- """
299
- """
300
-
301
- NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
302
- param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
303
- min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
304
- max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
305
- max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
306
-
307
- print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
308
-
309
- interval_widgets = []
310
- for i in range(NUM_INTERVALS):
311
- label = widgets.Label(value=f'{param_labels[i]}')
312
- initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
313
- slider = widgets.FloatRangeSlider(
314
- value=[min_bounds[i], initial_max],
315
- min=min_bounds[i],
316
- max=max_bounds[i],
317
- step=0.01,
318
- layout=widgets.Layout(width='400px'),
319
- style={'description_width': '60px'}
320
- )
321
-
322
- def validate_range(change, slider=slider, max_width=max_deltas[i]):
323
- min_val, max_val = change['new']
324
- if max_val - min_val > max_width:
325
- old_min_val, old_max_val = change['old']
326
- if abs(old_min_val - min_val) > abs(old_max_val - max_val):
327
- max_val = min_val + max_width
328
- else:
329
- min_val = max_val - max_width
330
- slider.value = [min_val, max_val]
331
-
332
- slider.observe(validate_range, names='value')
333
- interval_widgets.append((slider, widgets.HBox([label, slider])))
334
-
335
- sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
336
-
337
- output = widgets.Output()
338
- predict_button = widgets.Button(description="Predict")
339
- close_button = widgets.Button(description="Close Widget")
340
-
341
- container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
342
- display(container)
343
-
344
- @output.capture(clear_output=True)
345
- def on_predict_click(_):
346
- if 'prior_bounds' in kwargs:
347
- array_values = kwargs.pop('prior_bounds')
348
- for i, (s, _) in enumerate(interval_widgets):
349
- s.value = tuple(array_values[i])
350
- else:
351
- values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
352
- array_values = np.array(values)
353
-
354
- prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
355
- prior_bounds=array_values,
356
- **kwargs)
357
- param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
358
- for param_name, pred_param_val in zip(param_names, prediction_result["predicted_params_array"]):
359
- print(f'{param_name.ljust(14)} : {pred_param_val:.2f}')
360
-
361
- plot_prediction_results(
362
- prediction_result,
363
- q_exp=kwargs['q_values'],
364
- curve_exp=reflectivity_curve,
365
- q_model=kwargs['q_values'],
366
- )
367
- self.prediction_result = prediction_result
368
-
369
- def on_close_click(_):
370
- container.close()
371
- print("Widget closed.")
372
-
373
- predict_button.on_click(on_predict_click)
374
- close_button.on_click(on_close_click)
375
-
376
-
377
- def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
378
- assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
379
- q = self.trainer.loader.q_generator.q.squeeze().float()
380
- dq_max = (q[1] - q[0]) * dq_coef
381
- q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
382
-
383
- curve = to_t(curve).to(scaled_bounds)
384
- shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
385
-
386
- assert shifted_curves.shape == (num, q.shape[0])
387
-
388
- scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
389
- scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
390
-
391
- with torch.no_grad():
392
- self.trainer.model.eval()
393
- scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
394
- restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
395
-
396
- best_param = get_best_mse_param(
397
- restored_params,
398
- self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
399
- )
400
- return best_param
401
-
433
+
402
434
  def _polish_prediction(self,
403
435
  q: np.ndarray,
404
436
  curve: np.ndarray,
@@ -410,6 +442,9 @@ class EasyInferenceModel(object):
410
442
  max_d_change: float = 5.,
411
443
  calc_polished_curve: bool = True,
412
444
  calc_polished_sld_profile: bool = False,
445
+ error_bars: np.ndarray = None,
446
+ polishing_method: str = 'trf',
447
+ polishing_max_nfev: int = None,
413
448
  polishing_kwargs_reflectivity: dict = None,
414
449
  ) -> dict:
415
450
  params = predicted_params.parameters.squeeze().cpu().numpy()
@@ -440,6 +475,9 @@ class EasyInferenceModel(object):
440
475
  init_params = params,
441
476
  bounds=priors.T,
442
477
  prior_sampler=self.trainer.loader.prior_sampler,
478
+ error_bars=error_bars,
479
+ method=polishing_method,
480
+ polishing_max_nfev=polishing_max_nfev,
443
481
  reflectivity_kwargs=polishing_kwargs_reflectivity,
444
482
  )
445
483
  polished_params = BasicParams(
@@ -458,18 +496,23 @@ class EasyInferenceModel(object):
458
496
  if calc_polished_curve:
459
497
  polished_params_dict['polished_curve'] = curve_polished
460
498
 
499
+ if ambient_sld_tensor is not None:
500
+ ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
501
+
502
+
461
503
  if calc_polished_sld_profile:
462
504
  _, sld_profile_polished, _ = get_density_profiles(
463
- polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, ambient_sld_tensor, z_axis=sld_x_axis.cpu(),
505
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
506
+ z_axis=sld_x_axis.to(polished_params.slds.device),
464
507
  )
465
- polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().numpy()
508
+ polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
466
509
 
467
510
  return polished_params_dict
468
511
 
469
512
  def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
470
513
  if not isinstance(curve, Tensor):
471
514
  curve = torch.from_numpy(curve).float()
472
- curve = torch.atleast_2d(curve).to(self.device)
515
+ curve = curve.unsqueeze(0).to(self.device)
473
516
  scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
474
517
  return scaled_curve
475
518
 
@@ -500,29 +543,509 @@ class EasyInferenceModel(object):
500
543
  f"- The number of layers or parameterization type differs from the one used during training.\n\n"
501
544
  f" Check the configuration or the summary of expected parameters."
502
545
  )
503
- raise ValueError(msg) from e
546
+ raise ValueError(msg) from e
504
547
 
505
- def interpolate_data_to_model_q(self, q_exp, curve_exp):
506
- if isinstance(self.trainer.loader.q_generator, ConstantQ):
507
- q_model = self.trainer.loader.q_generator.q.cpu().numpy()
508
- elif isinstance(self.trainer.loader.q_generator, VariableQ):
509
- if self.trainer.loader.q_generator.n_q_range[0] == self.trainer.loader.q_generator.n_q_range[1]:
510
- n_q_model = self.trainer.loader.q_generator.n_q_range[0]
511
- q_model_min = max(q_exp.min(), self.trainer.loader.q_generator.q_min_range[0])
512
- q_model_max = min(q_exp.max(), self.trainer.loader.q_generator.q_max_range[1])
513
- q_model = np.linspace(q_model_min, q_model_max, n_q_model)
514
- else:
515
- q_model = q_exp
516
- exp_curve_interp = curve_exp
517
-
518
- exp_curve_interp = interp_reflectivity(q_model, q_exp, curve_exp)
519
-
520
- return q_model, exp_curve_interp
548
+ def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
549
+ n_layers = self.trainer.loader.prior_sampler.max_num_layers
550
+ sld_indices = slice(2*n_layers+1, 3*n_layers+2)
551
+ prior_bounds[sld_indices, ...] -= ambient_sld
552
+
553
+ training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
554
+ training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
555
+ lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
556
+ upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
557
+ assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
558
+
559
+ return sld_indices
521
560
 
561
+ def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
562
+ prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
563
+ if "polished_params_array" in prediction_dict:
564
+ prediction_dict["polished_params_array"][sld_indices] += ambient_sld
565
+
522
566
  def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
523
567
  return LogLikelihood(
524
568
  q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
525
569
  )
570
+
571
+ def get_param_labels(self, **kwargs):
572
+ return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
573
+
574
+ @staticmethod
575
+ def _preprocess_input_data(
576
+ reflectivity_curve,
577
+ q_values,
578
+ sigmas=None,
579
+ q_resolution=None,
580
+ truncate_index_left=None,
581
+ truncate_index_right=None,
582
+ enable_error_bars_filtering=True,
583
+ filter_threshold=0.3,
584
+ filter_remove_singles=True,
585
+ filter_remove_consecutives=True,
586
+ filter_consecutive=3,
587
+ filter_q_start_trunc=0.1):
588
+
589
+ # Save originals for polishing
590
+ reflectivity_curve_original = reflectivity_curve.copy()
591
+ q_values_original = q_values.copy() if q_values is not None else None
592
+ q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
593
+ sigmas_original = sigmas.copy() if sigmas is not None else None
594
+
595
+ # Remove points with non-positive intensities
596
+ nonnegative_mask = reflectivity_curve > 0.0
597
+ reflectivity_curve = reflectivity_curve[nonnegative_mask]
598
+ q_values = q_values[nonnegative_mask]
599
+ if sigmas is not None:
600
+ sigmas = sigmas[nonnegative_mask]
601
+ if isinstance(q_resolution, np.ndarray):
602
+ q_resolution = q_resolution[nonnegative_mask]
603
+
604
+ # Truncate arrays
605
+ if truncate_index_left is not None or truncate_index_right is not None:
606
+ slice_obj = slice(truncate_index_left, truncate_index_right)
607
+ reflectivity_curve = reflectivity_curve[slice_obj]
608
+ q_values = q_values[slice_obj]
609
+ if sigmas is not None:
610
+ sigmas = sigmas[slice_obj]
611
+ if isinstance(q_resolution, np.ndarray):
612
+ q_resolution = q_resolution[slice_obj]
613
+
614
+ # Filter high-error points
615
+ if enable_error_bars_filtering and sigmas is not None:
616
+ valid_mask = get_filtering_mask(
617
+ q_values,
618
+ reflectivity_curve,
619
+ sigmas,
620
+ threshold=filter_threshold,
621
+ consecutive=filter_consecutive,
622
+ remove_singles=filter_remove_singles,
623
+ remove_consecutives=filter_remove_consecutives,
624
+ q_start_trunc=filter_q_start_trunc
625
+ )
626
+ reflectivity_curve = reflectivity_curve[valid_mask]
627
+ q_values = q_values[valid_mask]
628
+ sigmas = sigmas[valid_mask]
629
+ if isinstance(q_resolution, np.ndarray):
630
+ q_resolution = q_resolution[valid_mask]
631
+
632
+ return (q_values, reflectivity_curve, sigmas, q_resolution,
633
+ q_values_original, reflectivity_curve_original,
634
+ sigmas_original, q_resolution_original)
635
+
636
+ def interpolate_data_to_model_q(
637
+ self,
638
+ q_exp,
639
+ refl_exp,
640
+ sigmas_exp=None,
641
+ q_res_exp=None,
642
+ as_dict=False
643
+ ):
644
+ q_generator = self.trainer.loader.q_generator
645
+
646
+ def _pad(arr, pad_to, value=0.0):
647
+ if arr is None:
648
+ return None
649
+ return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
650
+
651
+ def _interp_or_keep(q_model, q_exp, arr):
652
+ """Interpolate arrays, keep floats or None unchanged."""
653
+ if arr is None:
654
+ return None
655
+ return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
656
+
657
+ def _pad_or_keep(arr, max_n):
658
+ """Pad arrays, keep floats or None unchanged."""
659
+ if arr is None:
660
+ return None
661
+ return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
662
+
663
+ def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
664
+ if as_dict:
665
+ result = {"q_model": q, "reflectivity": refl}
666
+ if sigmas is not None: result["sigmas"] = sigmas
667
+ if q_res is not None: result["q_resolution"] = q_res
668
+ if mask is not None: result["key_padding_mask"] = mask
669
+ return result
670
+ result = [q, refl]
671
+ if sigmas is not None: result.append(sigmas)
672
+ if q_res is not None: result.append(q_res)
673
+ if mask is not None: result.append(mask)
674
+ return tuple(result)
675
+
676
+ # ConstantQ
677
+ if isinstance(q_generator, ConstantQ):
678
+ q_model = q_generator.q.cpu().numpy()
679
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
680
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
681
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
682
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
683
+
684
+ # VariableQ
685
+ elif isinstance(q_generator, VariableQ):
686
+ if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
687
+ n_q_model = q_generator.n_q_range[0]
688
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
689
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
690
+ if self.trainer.loader.q_generator.mode == 'logspace':
691
+ q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
692
+ end=torch.log10(torch.tensor(q_max, device=self.device)),
693
+ steps=n_q_model, device=self.device).to('cpu')
694
+ logspace = True
695
+ else:
696
+ q_model = np.linspace(q_min, q_max, n_q_model)
697
+ logspace = False
698
+ else:
699
+ return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
700
+
701
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
702
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
703
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
704
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
705
+
706
+ # MaskedVariableQ
707
+ elif isinstance(q_generator, MaskedVariableQ):
708
+ min_n, max_n = q_generator.n_q_range
709
+ n_exp = len(q_exp)
710
+
711
+ if min_n <= n_exp <= max_n:
712
+ # Pad only
713
+ q_model = _pad(q_exp, max_n, 0.0)
714
+ refl_out = _pad(refl_exp, max_n, 0.0)
715
+ sigmas_out = _pad_or_keep(sigmas_exp, max_n)
716
+ q_res_out = _pad_or_keep(q_res_exp, max_n)
717
+ key_padding_mask = np.zeros(max_n, dtype=bool)
718
+ key_padding_mask[:n_exp] = True
719
+
720
+ else:
721
+ # Interpolate + pad
722
+ n_interp = min(max(n_exp, min_n), max_n)
723
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
724
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
725
+ q_interp = np.linspace(q_min, q_max, n_interp)
726
+
727
+ refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
728
+ sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
729
+ q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
730
+
731
+ q_model = _pad(q_interp, max_n, 0.0)
732
+ refl_out = _pad(refl_interp, max_n, 0.0)
733
+ sigmas_out = _pad_or_keep(sigmas_interp, max_n)
734
+ q_res_out = _pad_or_keep(q_res_interp, max_n)
735
+ key_padding_mask = np.zeros(max_n, dtype=bool)
736
+ key_padding_mask[:n_interp] = True
737
+
738
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
739
+
740
+ else:
741
+ raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
742
+
743
+ def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
744
+ assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
745
+ q = self.trainer.loader.q_generator.q.squeeze().float()
746
+ dq_max = (q[1] - q[0]) * dq_coef
747
+ q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
748
+
749
+ curve = to_t(curve).to(scaled_bounds)
750
+ shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
751
+
752
+ assert shifted_curves.shape == (num, q.shape[0])
753
+
754
+ scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
755
+ scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
756
+
757
+ with torch.no_grad():
758
+ self.trainer.model.eval()
759
+ scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
760
+ restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
761
+
762
+ best_param = get_best_mse_param(
763
+ restored_params,
764
+ self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
765
+ )
766
+ return best_param
767
+
768
+ def predict_using_widget(self, reflectivity_curve, **kwargs):
769
+ """
770
+ """
771
+
772
+ NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
773
+ param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
774
+ min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
775
+ max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
776
+ max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
777
+
778
+ print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
779
+
780
+ interval_widgets = []
781
+ for i in range(NUM_INTERVALS):
782
+ label = widgets.Label(value=f'{param_labels[i]}')
783
+ initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
784
+ slider = widgets.FloatRangeSlider(
785
+ value=[min_bounds[i], initial_max],
786
+ min=min_bounds[i],
787
+ max=max_bounds[i],
788
+ step=0.01,
789
+ layout=widgets.Layout(width='400px'),
790
+ style={'description_width': '60px'}
791
+ )
792
+
793
+ def validate_range(change, slider=slider, max_width=max_deltas[i]):
794
+ min_val, max_val = change['new']
795
+ if max_val - min_val > max_width:
796
+ old_min_val, old_max_val = change['old']
797
+ if abs(old_min_val - min_val) > abs(old_max_val - max_val):
798
+ max_val = min_val + max_width
799
+ else:
800
+ min_val = max_val - max_width
801
+ slider.value = [min_val, max_val]
802
+
803
+ slider.observe(validate_range, names='value')
804
+ interval_widgets.append((slider, widgets.HBox([label, slider])))
805
+
806
+ sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
807
+
808
+ output = widgets.Output()
809
+ predict_button = widgets.Button(description="Predict")
810
+ close_button = widgets.Button(description="Close Widget")
811
+
812
+ container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
813
+ display(container)
814
+
815
+ @output.capture(clear_output=True)
816
+ def on_predict_click(_):
817
+ if 'prior_bounds' in kwargs:
818
+ array_values = kwargs.pop('prior_bounds')
819
+ for i, (s, _) in enumerate(interval_widgets):
820
+ s.value = tuple(array_values[i])
821
+ else:
822
+ values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
823
+ array_values = np.array(values)
824
+
825
+ prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
826
+ prior_bounds=array_values,
827
+ **kwargs)
828
+ param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
829
+ print_prediction_results(prediction_result)
830
+
831
+ plot_prediction_results(
832
+ prediction_result,
833
+ q_exp=kwargs['q_values'],
834
+ curve_exp=reflectivity_curve,
835
+ )
836
+ self.widget_prediction_result = prediction_result
837
+
838
+ def on_close_click(_):
839
+ container.close()
840
+ print("Widget closed.")
841
+
842
+ predict_button.on_click(on_predict_click)
843
+ close_button.on_click(on_close_click)
844
+
845
+ def preprocess_and_predict_using_widget(self,
846
+ reflectivity_curve,
847
+ q_values=None,
848
+ sigmas=None,
849
+ q_resolution=None,
850
+ prior_bounds=None,
851
+ ambient_sld=None,
852
+ ):
853
+ """
854
+ Interactive widget around `preprocess_and_predict`
855
+ Results are stored in `self.widget_prediction_result`
856
+ """
857
+
858
+ if q_values is None:
859
+ raise ValueError("q_values must be provided for this widget.")
860
+
861
+ N = len(reflectivity_curve)
862
+
863
+ # ---------- Priors sliders ----------
864
+ param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
865
+ min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
866
+ max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
867
+ max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
868
+ NUM = len(param_labels)
869
+
870
+ sliders, rows = [], []
871
+ init_pb = np.array(prior_bounds) if prior_bounds is not None else None
872
+ for i in range(NUM):
873
+ init_min = float(init_pb[i, 0]) if init_pb is not None else float(min_bounds[i])
874
+ init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(min_bounds[i] + max_deltas[i], max_bounds[i]))
875
+ lab = widgets.Label(value=param_labels[i])
876
+ s = widgets.FloatRangeSlider(
877
+ value=[init_min, init_max],
878
+ min=float(min_bounds[i]),
879
+ max=float(max_bounds[i]),
880
+ step=0.01,
881
+ layout=widgets.Layout(width='420px'),
882
+ readout_format='.3f'
883
+ )
884
+ # Constrain slider widths
885
+ def _mk_validator(slider, max_width=float(max_deltas[i])):
886
+ def _validate(change):
887
+ a, b = change['new']
888
+ if b - a > max_width:
889
+ oa, ob = change['old']
890
+ if abs(oa - a) > abs(ob - b):
891
+ b = a + max_width
892
+ else:
893
+ a = b - max_width
894
+ slider.value = (a, b)
895
+ return _validate
896
+ s.observe(_mk_validator(s), names='value')
897
+ sliders.append(s)
898
+ rows.append(widgets.HBox([lab, s]))
899
+ priors_box = widgets.VBox([widgets.HTML("<b>Priors</b>")] + rows)
900
+
901
+ # ---------- Preprocess & Predict controls ----------
902
+ # Preprocessing
903
+ trunc_L = widgets.IntSlider(description='truncate left', min=0, max=max(0, N-1), step=1, value=0)
904
+ trunc_R = widgets.IntSlider(description='truncate right', min=1, max=N, step=1, value=N)
905
+ enable_filt = widgets.Checkbox(description='filter error bars', value=True)
906
+ thr = widgets.FloatSlider(description='filtering threshold', min=0.0, max=1.0, step=0.01, value=0.3)
907
+ rem_single = widgets.Checkbox(description='remove singles', value=True)
908
+ rem_cons = widgets.Checkbox(description='remove consecutives', value=True)
909
+ consec = widgets.IntSlider(description='num. consecutive', min=1, max=10, step=1, value=3)
910
+ qstart = widgets.FloatSlider(description='q_start_trunc', min=0.0, max=1.0, step=0.01, value=0.1)
911
+
912
+ # Polishing
913
+ polish = widgets.Checkbox(description='polish prediction', value=True)
914
+ use_sigmas_polish = widgets.Checkbox(description='use sigmas during polishing', value=True)
915
+
916
+ # Plotting
917
+ pred_show_yerr = widgets.Checkbox(description='show error bars', value=True)
918
+ pred_show_xerr = widgets.Checkbox(description='show q-resolution', value=False)
919
+ pred_logx = widgets.Checkbox(description='log x-axis', value=False)
920
+ plot_sld = widgets.Checkbox(description='plot SLD profile', value=True)
921
+ sld_pad_left = widgets.FloatText(description='SLD pad left', value=0.2, step=0.1)
922
+ sld_pad_right = widgets.FloatText(description='SLD pad right', value=1.1, step=0.1)
923
+
924
+ # Color pickers
925
+ exp_color_picker = widgets.ColorPicker(description='exp color', value='#0000FF') # blue
926
+ exp_errcolor_picker = widgets.ColorPicker(description='errbar color', value='#800080') # purple
927
+ pred_color_picker = widgets.ColorPicker(description='pred color', value='#FF0000') # red
928
+ pol_color_picker = widgets.ColorPicker(description='polished color', value='#FFA500') # orange
929
+ sld_pred_color_picker = widgets.ColorPicker(description='SLD pred color', value='#FF0000') # red
930
+ sld_pol_color_picker = widgets.ColorPicker(description='SLD pol color', value='#FFA500') # orange
931
+
932
+ # Compute toggles
933
+ calc_curve = widgets.Checkbox(description='calc curve', value=True)
934
+ calc_sld_pred = widgets.Checkbox(description='calc predicted SLD', value=True)
935
+ calc_sld_pol = widgets.Checkbox(description='calc polished SLD', value=True)
936
+
937
+ btn_predict = widgets.Button(description='Predict')
938
+ btn_close = widgets.Button(description='Close')
939
+
940
+ controls_box = widgets.VBox([
941
+ widgets.HTML("<b>Preprocess & Predict</b>"),
942
+ widgets.HTML("<i>Preprocessing</i>"),
943
+ widgets.HBox([trunc_L, trunc_R]),
944
+ widgets.HBox([enable_filt, rem_single, rem_cons]),
945
+ widgets.HBox([thr, consec, qstart]),
946
+ widgets.HTML("<i>Polishing</i>"),
947
+ widgets.HBox([polish, use_sigmas_polish]),
948
+ widgets.HTML("<i>Plotting</i>"),
949
+ widgets.HBox([pred_show_yerr, pred_show_xerr, pred_logx, plot_sld]),
950
+ widgets.HBox([sld_pad_left, sld_pad_right]),
951
+ widgets.HTML("<i>Colors</i>"),
952
+ widgets.HBox([exp_color_picker, exp_errcolor_picker]),
953
+ widgets.HBox([pred_color_picker, pol_color_picker]),
954
+ widgets.HBox([sld_pred_color_picker, sld_pol_color_picker]),
955
+ widgets.HTML("<i>Compute</i>"),
956
+ widgets.HBox([calc_curve, calc_sld_pred, calc_sld_pol]),
957
+ widgets.HBox([btn_predict, btn_close]),
958
+ ])
959
+
960
+ out_predict = widgets.Output()
961
+ container = widgets.VBox([priors_box, controls_box, out_predict])
962
+ display(container)
963
+
964
+ def _sync_trunc(_):
965
+ if trunc_L.value >= trunc_R.value:
966
+ trunc_L.value = max(0, trunc_R.value - 1)
967
+ trunc_L.observe(_sync_trunc, names='value')
968
+ trunc_R.observe(_sync_trunc, names='value')
969
+
970
+ def _current_priors():
971
+ return np.array([s.value for s in sliders], dtype=np.float32) # (param_dim, 2)
972
+
973
+ @out_predict.capture(clear_output=True)
974
+ def _on_predict(_):
975
+ out_predict.clear_output(wait=True)
976
+
977
+ res = self.preprocess_and_predict(
978
+ reflectivity_curve=reflectivity_curve,
979
+ q_values=q_values,
980
+ prior_bounds=_current_priors(),
981
+ sigmas=sigmas,
982
+ q_resolution=q_resolution,
983
+ ambient_sld=ambient_sld,
984
+ clip_prediction=True,
985
+ polish_prediction=polish.value,
986
+ use_sigmas_for_polishing=use_sigmas_polish.value,
987
+ calc_pred_curve=calc_curve.value,
988
+ calc_pred_sld_profile=(calc_sld_pred.value or plot_sld.value),
989
+ calc_polished_sld_profile=(calc_sld_pol.value or plot_sld.value),
990
+ sld_profile_padding_left=float(sld_pad_left.value),
991
+ sld_profile_padding_right=float(sld_pad_right.value),
992
+
993
+ truncate_index_left=trunc_L.value,
994
+ truncate_index_right=trunc_R.value,
995
+ enable_error_bars_filtering=enable_filt.value,
996
+ filter_threshold=thr.value,
997
+ filter_remove_singles=rem_single.value,
998
+ filter_remove_consecutives=rem_cons.value,
999
+ filter_consecutive=consec.value,
1000
+ filter_q_start_trunc=qstart.value,
1001
+ )
1002
+
1003
+ # Full experimental data as scatter
1004
+ q_exp_plot = q_values
1005
+ r_exp_plot = reflectivity_curve
1006
+ yerr_plot = (sigmas if pred_show_yerr.value else None)
1007
+ xerr_plot = (q_resolution if pred_show_xerr.value else None)
1008
+
1009
+ # Predicted curve only on the model region
1010
+ q_pred = res.get('q_plot_pred', None)
1011
+ r_pred = res.get('predicted_curve', None)
1012
+
1013
+ # Polished curve on the full experimental grid
1014
+ q_pol = q_values if ('polished_curve' in res) else None
1015
+ r_pol = res.get('polished_curve', None)
1016
+
1017
+ # SLD profiles
1018
+ z_sld = res.get('predicted_sld_xaxis', None)
1019
+ sld_pred = res.get('predicted_sld_profile', None)
1020
+ sld_pol = res.get('sld_profile_polished', None)
1021
+
1022
+ print_prediction_results(res)
1023
+
1024
+ plot_reflectivity(
1025
+ q_exp=q_exp_plot, r_exp=r_exp_plot,
1026
+ yerr=yerr_plot, xerr=xerr_plot,
1027
+ exp_style=('errorbar' if pred_show_yerr.value or pred_show_xerr.value else 'scatter'),
1028
+ exp_color=exp_color_picker.value,
1029
+ exp_errcolor=exp_errcolor_picker.value,
1030
+ q_pred=q_pred, r_pred=r_pred, pred_color=pred_color_picker.value,
1031
+ q_pol=q_pol, r_pol=r_pol, pol_color=pol_color_picker.value,
1032
+ z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
1033
+ sld_pred_color=sld_pred_color_picker.value,
1034
+ sld_pol_color=sld_pol_color_picker.value,
1035
+ plot_sld_profile=plot_sld.value,
1036
+ logx=pred_logx.value, logy=True,
1037
+ figsize=(12,6),
1038
+ legend=True
1039
+ )
1040
+
1041
+ self.widget_prediction_result = res
1042
+
1043
+ def _on_close(_):
1044
+ container.close()
1045
+ print("Widget closed.")
1046
+
1047
+ btn_predict.on_click(_on_predict)
1048
+ btn_close.on_click(_on_close)
526
1049
 
527
1050
  class InferenceModel(object):
528
1051
  def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,