reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,10 @@ from IPython.display import display
12
12
  from huggingface_hub import hf_hub_download
13
13
 
14
14
  from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
15
- from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
15
+ from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
16
+ from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
16
17
  from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
18
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
17
19
  from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
18
20
  from reflectorch.runs.utils import (
19
21
  get_trainer_by_name, train_from_config
@@ -23,10 +25,11 @@ from reflectorch.ml.trainers import PointEstimatorTrainer
23
25
  from reflectorch.data_generation.likelihoods import LogLikelihood
24
26
 
25
27
  from reflectorch.inference.preprocess_exp import StandardPreprocessing
26
- from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
28
+ from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
27
29
  from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
28
30
  from reflectorch.inference.record_time import print_time
29
- from reflectorch.utils import to_t
31
+ from reflectorch.inference.plotting import plot_reflectivity, plot_prediction_results, print_prediction_results
32
+ from reflectorch.utils import get_filtering_mask, to_t
30
33
 
31
34
  class EasyInferenceModel(object):
32
35
  """Facilitates the inference process using pretrained models
@@ -103,7 +106,9 @@ class EasyInferenceModel(object):
103
106
  self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
104
107
  self.trainer.model.eval()
105
108
 
106
- print(f'The model corresponds to a `{self.trainer.loader.prior_sampler.param_model.NAME}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
109
+ param_model = self.trainer.loader.prior_sampler.param_model
110
+ param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
111
+ print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
107
112
  print("Parameter types and total ranges:")
108
113
  for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
109
114
  print(f"- {param}: {range_}")
@@ -115,23 +120,188 @@ class EasyInferenceModel(object):
115
120
  q_min = self.trainer.loader.q_generator.q[0].item()
116
121
  q_max = self.trainer.loader.q_generator.q[-1].item()
117
122
  n_q = self.trainer.loader.q_generator.q.shape[0]
118
- print(f'The model was trained on curves discretized at {n_q} uniform points between between q_min={q_min} and q_max={q_max}')
123
+ print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
119
124
  elif isinstance(self.trainer.loader.q_generator, VariableQ):
120
125
  q_min_range = self.trainer.loader.q_generator.q_min_range
121
126
  q_max_range = self.trainer.loader.q_generator.q_max_range
122
127
  n_q_range = self.trainer.loader.q_generator.n_q_range
123
- print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} of uniform points between between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
128
+ if n_q_range[0] == n_q_range[1]:
129
+ n_q_fixed = n_q_range[0]
130
+ print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
131
+ f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
132
+ else:
133
+ print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
134
+ f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
135
+
136
+ if self.trainer.loader.smearing is not None:
137
+ q_res_min = self.trainer.loader.smearing.sigma_min
138
+ q_res_max = self.trainer.loader.smearing.sigma_max
139
+ if self.trainer.loader.smearing.constant_dq == False:
140
+ print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
141
+ elif self.trainer.loader.smearing.constant_dq == True:
142
+ print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
143
+
144
+ additional_inputs = ["prior bounds"]
145
+ if self.trainer.train_with_q_input:
146
+ additional_inputs.append("q values")
147
+ if self.trainer.condition_on_q_resolutions:
148
+ additional_inputs.append("the resolution dq/q")
149
+ if additional_inputs:
150
+ inputs_str = ", ".join(additional_inputs)
151
+ print(f"The following quantities are additional inputs to the network: {inputs_str}.")
152
+
153
+ def preprocess_and_predict(self,
154
+ reflectivity_curve: Union[np.ndarray, Tensor],
155
+ q_values: Union[np.ndarray, Tensor] = None,
156
+ prior_bounds: Union[np.ndarray, List[Tuple]] = None,
157
+ sigmas: Union[np.ndarray, Tensor] = None,
158
+ q_resolution: float = None,
159
+ ambient_sld: float = None,
160
+ clip_prediction: bool = True,
161
+ polish_prediction: bool = False,
162
+ polishing_method: str = 'trf',
163
+ polishing_kwargs_reflectivity: dict = None,
164
+ use_sigmas_for_polishing: bool = False,
165
+ polishing_max_nfev: int = None,
166
+ fit_growth: bool = False,
167
+ max_d_change: float = 5.,
168
+ calc_pred_curve: bool = True,
169
+ calc_pred_sld_profile: bool = False,
170
+ calc_polished_sld_profile: bool = False,
171
+ sld_profile_padding_left: float = 0.2,
172
+ sld_profile_padding_right: float = 1.1,
173
+ kwargs_param_labels: dict = {},
174
+
175
+ truncate_index_left: int = None,
176
+ truncate_index_right: int = None,
177
+ enable_error_bars_filtering: bool = True,
178
+ filter_threshold=0.3,
179
+ filter_remove_singles=True,
180
+ filter_remove_consecutives=True,
181
+ filter_consecutive=3,
182
+ filter_q_start_trunc=0.1,
183
+ ):
184
+
185
+ ## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
186
+ (q_values, reflectivity_curve, sigmas, q_resolution,
187
+ q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
188
+ reflectivity_curve=reflectivity_curve,
189
+ q_values=q_values,
190
+ sigmas=sigmas,
191
+ q_resolution=q_resolution,
192
+ truncate_index_left=truncate_index_left,
193
+ truncate_index_right=truncate_index_right,
194
+ enable_error_bars_filtering=enable_error_bars_filtering,
195
+ filter_threshold=filter_threshold,
196
+ filter_remove_singles=filter_remove_singles,
197
+ filter_remove_consecutives=filter_remove_consecutives,
198
+ filter_consecutive=filter_consecutive,
199
+ filter_q_start_trunc=filter_q_start_trunc,
200
+ )
201
+
202
+ ### Interpolate the experimental data if needed by the embedding network
203
+ interp_data = self.interpolate_data_to_model_q(
204
+ q_exp=q_values,
205
+ refl_exp=reflectivity_curve,
206
+ sigmas_exp=sigmas,
207
+ q_res_exp=q_resolution,
208
+ as_dict=True
209
+ )
210
+
211
+ q_model = interp_data["q_model"]
212
+ reflectivity_curve_interp = interp_data["reflectivity"]
213
+ sigmas_interp = interp_data.get("sigmas")
214
+ q_resolution_interp = interp_data.get("q_resolution")
215
+ key_padding_mask = interp_data.get("key_padding_mask")
216
+
217
+ ### Make the prediction
218
+ prediction_dict = self.predict(
219
+ reflectivity_curve=reflectivity_curve_interp,
220
+ q_values=q_model,
221
+ sigmas=sigmas_interp,
222
+ q_resolution=q_resolution_interp,
223
+ key_padding_mask=key_padding_mask,
224
+ prior_bounds=prior_bounds,
225
+ ambient_sld=ambient_sld,
226
+ clip_prediction=clip_prediction,
227
+ polish_prediction=False, ###do the polishing outside the predict method on the full data
228
+ supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
229
+ calc_pred_curve=calc_pred_curve,
230
+ calc_pred_sld_profile=calc_pred_sld_profile,
231
+ sld_profile_padding_left=sld_profile_padding_left,
232
+ sld_profile_padding_right=sld_profile_padding_right,
233
+ kwargs_param_labels=kwargs_param_labels,
234
+ )
235
+
236
+ ### Save interpolated data
237
+ prediction_dict['q_model'] = q_model
238
+ prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
239
+ if q_resolution_interp is not None:
240
+ prediction_dict['q_resolution_interp'] = q_resolution_interp
241
+ if sigmas_interp is not None:
242
+ prediction_dict['sigmas_interp'] = sigmas_interp
243
+ if key_padding_mask is not None:
244
+ prediction_dict['key_padding_mask'] = key_padding_mask
245
+
246
+ ### Perform polishing on the original data
247
+ if polish_prediction:
248
+ polishing_kwargs = polishing_kwargs_reflectivity or {}
249
+ polishing_kwargs.setdefault('dq', q_resolution_original)
250
+
251
+ prior_bounds = np.array(prior_bounds)
252
+ if ambient_sld:
253
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
254
+
255
+ polished_dict = self._polish_prediction(
256
+ q=q_values_original,
257
+ curve=reflectivity_curve_original,
258
+ predicted_params=prediction_dict['predicted_params_object'],
259
+ priors=prior_bounds,
260
+ ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)).to(self.device) if ambient_sld is not None else None,
261
+ calc_polished_sld_profile=calc_polished_sld_profile,
262
+ sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
263
+ polishing_kwargs_reflectivity = polishing_kwargs,
264
+ error_bars=sigmas_original if use_sigmas_for_polishing else None,
265
+ polishing_method=polishing_method,
266
+ polishing_max_nfev=polishing_max_nfev,
267
+ fit_growth=fit_growth,
268
+ max_d_change=max_d_change,
269
+ )
270
+
271
+ prediction_dict.update(polished_dict)
272
+ if fit_growth and "polished_params_array" in prediction_dict:
273
+ prediction_dict["param_names"].append("max_d_change")
274
+
275
+ ### Shift back the slds for nonzero ambient
276
+ if ambient_sld:
277
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
124
278
 
125
- def predict(self, reflectivity_curve: Union[np.ndarray, Tensor],
279
+ return prediction_dict
280
+
281
+
282
+ def predict(self,
283
+ reflectivity_curve: Union[np.ndarray, Tensor],
126
284
  q_values: Union[np.ndarray, Tensor] = None,
127
285
  prior_bounds: Union[np.ndarray, List[Tuple]] = None,
128
- clip_prediction: bool = False,
129
- polish_prediction: bool = False,
286
+ sigmas: Union[np.ndarray, Tensor] = None,
287
+ key_padding_mask: Union[np.ndarray, Tensor] = None,
288
+ q_resolution: float = None,
289
+ ambient_sld: float = None,
290
+ clip_prediction: bool = True,
291
+ polish_prediction: bool = False,
292
+ polishing_method: str = 'trf',
293
+ polishing_kwargs_reflectivity: dict = None,
294
+ polishing_max_nfev: int = None,
130
295
  fit_growth: bool = False,
131
296
  max_d_change: float = 5.,
132
297
  use_q_shift: bool = False,
133
298
  calc_pred_curve: bool = True,
134
299
  calc_pred_sld_profile: bool = False,
300
+ calc_polished_sld_profile: bool = False,
301
+ sld_profile_padding_left: float = 0.2,
302
+ sld_profile_padding_right: float = 1.1,
303
+ supress_sld_amb_back_shift: bool = False,
304
+ kwargs_param_labels: dict = {},
135
305
  ):
136
306
  """Predict the thin film parameters
137
307
 
@@ -153,24 +323,51 @@ class EasyInferenceModel(object):
153
323
 
154
324
  scaled_curve = self._scale_curve(reflectivity_curve)
155
325
  prior_bounds = np.array(prior_bounds)
326
+
327
+ if ambient_sld:
328
+ sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
329
+
156
330
  scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
157
331
 
158
- if not self.trainer.train_with_q_input:
332
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
159
333
  q_values = self.trainer.loader.q_generator.q
160
334
  else:
161
335
  q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
162
336
 
337
+ scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
338
+
339
+ if q_resolution is not None:
340
+ q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
341
+ if isinstance(q_resolution, float):
342
+ unscaled_q_resolutions = q_resolution_tensor
343
+ else:
344
+ unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
345
+ scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
346
+ scaled_conditioning_params = scaled_q_resolutions
347
+ if polishing_kwargs_reflectivity is None:
348
+ polishing_kwargs_reflectivity = {'dq': q_resolution}
349
+ else:
350
+ q_resolution_tensor = None
351
+ scaled_conditioning_params = None
352
+
353
+ if key_padding_mask is not None:
354
+ key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
355
+ key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
356
+
163
357
  if use_q_shift and not self.trainer.train_with_q_input:
164
358
  predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
165
-
166
359
  else:
167
360
  with torch.no_grad():
168
361
  self.trainer.model.eval()
169
- if self.trainer.train_with_q_input:
170
- scaled_q = self.trainer.loader.q_generator.scale_q(q_values).float()
171
- scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds, scaled_q)
172
- else:
173
- scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds)
362
+
363
+ scaled_predicted_params = self.trainer.model(
364
+ curves=scaled_curve,
365
+ bounds=scaled_prior_bounds,
366
+ q_values=scaled_q_values,
367
+ conditioning_params = scaled_conditioning_params,
368
+ key_padding_mask = key_padding_mask,
369
+ unscaled_q_values = q_values,
370
+ )
174
371
 
175
372
  predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
176
373
 
@@ -180,166 +377,80 @@ class EasyInferenceModel(object):
180
377
  prediction_dict = {
181
378
  "predicted_params_object": predicted_params,
182
379
  "predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
183
- "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
380
+ "param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
184
381
  }
382
+
383
+ key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
185
384
 
186
385
  if calc_pred_curve:
187
- predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
188
- prediction_dict[ "predicted_curve"] = predicted_curve
386
+ predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
387
+ prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
189
388
 
389
+ ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
190
390
  if calc_pred_sld_profile:
191
391
  predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
192
- predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
392
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
393
+ num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
193
394
  )
194
395
  prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
195
396
  prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
196
397
  else:
197
398
  predicted_sld_xaxis = None
399
+
400
+ refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
401
+ q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
402
+ prediction_dict['q_plot_pred'] = q_polish
198
403
 
199
- if polish_prediction: #only for standard box-model parameterization
200
- polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
201
- curve = reflectivity_curve,
202
- predicted_params = predicted_params,
203
- priors = np.array(prior_bounds),
204
- fit_growth = fit_growth,
205
- max_d_change = max_d_change,
206
- calc_polished_curve = calc_pred_curve,
207
- calc_polished_sld_profile = False,
208
- sld_x_axis = predicted_sld_xaxis,
209
- )
404
+ if polish_prediction:
405
+ if ambient_sld_tensor:
406
+ ambient_sld_tensor = ambient_sld_tensor.cpu()
407
+
408
+ polished_dict = self._polish_prediction(
409
+ q = q_polish,
410
+ curve = refl_curve_polish,
411
+ predicted_params = predicted_params,
412
+ priors = np.array(prior_bounds),
413
+ error_bars = sigmas,
414
+ fit_growth = fit_growth,
415
+ max_d_change = max_d_change,
416
+ calc_polished_curve = calc_pred_curve,
417
+ calc_polished_sld_profile = calc_polished_sld_profile,
418
+ ambient_sld_tensor=ambient_sld_tensor,
419
+ sld_x_axis = predicted_sld_xaxis,
420
+ polishing_method=polishing_method,
421
+ polishing_max_nfev=polishing_max_nfev,
422
+ polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
423
+ )
210
424
  prediction_dict.update(polished_dict)
211
425
 
212
426
  if fit_growth and "polished_params_array" in prediction_dict:
213
427
  prediction_dict["param_names"].append("max_d_change")
214
428
 
215
- return prediction_dict
216
-
217
- async def predict_using_widget(self, reflectivity_curve: Union[np.ndarray, torch.Tensor], **kwargs):
218
- """Use an interactive Python widget for specifying the prior bounds before the prediction (works only in a Jupyter notebook).
219
- The other arguments are the same as for the ``predict`` method.
220
- """
221
-
222
- NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
223
- param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
224
- min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
225
- max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
226
- max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
227
-
228
- print(f'Parameter ranges: {self.trainer.loader.prior_sampler.param_ranges}')
229
- print(f'Allowed widths of the prior bound intervals (max-min): {self.trainer.loader.prior_sampler.bound_width_ranges}')
230
- print(f'Please fill in the values of the minimum and maximum prior bound for each parameter and press the button!')
231
-
232
- def create_interval_widgets(n):
233
- intervals = []
234
- for i in range(n):
235
- interval_label = widgets.Label(value=f'{param_labels[i]}')
236
- initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
237
- slider = widgets.FloatRangeSlider(
238
- value=[min_bounds[i], initial_max],
239
- min=min_bounds[i],
240
- max=max_bounds[i],
241
- step=0.01,
242
- layout=widgets.Layout(width='400px'),
243
- style={'description_width': '60px'}
244
- )
245
-
246
- def validate_range(change, slider=slider, max_width=max_deltas[i]):
247
- min_val, max_val = change['new']
248
- if max_val - min_val > max_width:
249
- if change['name'] == 'value':
250
- if change['old'][0] != min_val:
251
- max_val = min_val + max_width
252
- else:
253
- min_val = max_val - max_width
254
- slider.value = [min_val, max_val]
255
-
256
- slider.observe(validate_range, names='value')
257
-
258
- interval_row = widgets.HBox([interval_label, slider])
259
- intervals.append((slider, interval_row))
260
- return intervals
261
-
262
- interval_widgets = create_interval_widgets(NUM_INTERVALS)
263
- interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
264
- display(interval_box)
265
-
266
- button = widgets.Button(description="Make prediction")
267
- display(button)
268
-
269
- prediction_result = None
270
-
271
- def store_values(b, future):
272
- print("Debug: Button clicked")
273
- values = []
274
- for slider, _ in interval_widgets:
275
- values.append((slider.value[0], slider.value[1]))
276
- array_values = np.array(values)
277
-
278
- nonlocal prediction_result
279
- prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
280
- print(prediction_result["predicted_params_array"])
281
-
282
- print("Prediction completed. Closing widget.")
283
-
284
- for child in interval_box.children:
285
- child.close()
286
- button.close()
287
-
288
- future.set_result(prediction_result)
429
+ if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
430
+ self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
289
431
 
290
- button.on_click(store_values)
432
+ return prediction_dict
291
433
 
292
-
293
- future = asyncio.Future()
294
-
295
- button.on_click(lambda b: store_values(b, future))
296
-
297
- return await future
298
-
299
-
300
- def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
301
- assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
302
- q = self.trainer.loader.q_generator.q.squeeze().float()
303
- dq_max = (q[1] - q[0]) * dq_coef
304
- q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
305
-
306
- curve = to_t(curve).to(scaled_bounds)
307
- shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
308
-
309
- assert shifted_curves.shape == (num, q.shape[0])
310
-
311
- scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
312
- scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
313
-
314
- with torch.no_grad():
315
- self.trainer.model.eval()
316
- scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
317
- restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
318
-
319
- best_param = get_best_mse_param(
320
- restored_params,
321
- self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
322
- )
323
- return best_param
324
-
325
434
  def _polish_prediction(self,
326
435
  q: np.ndarray,
327
436
  curve: np.ndarray,
328
437
  predicted_params: BasicParams,
329
438
  priors: np.ndarray,
330
439
  sld_x_axis,
440
+ ambient_sld_tensor: Tensor = None,
331
441
  fit_growth: bool = False,
332
442
  max_d_change: float = 5.,
333
443
  calc_polished_curve: bool = True,
334
444
  calc_polished_sld_profile: bool = False,
445
+ error_bars: np.ndarray = None,
446
+ polishing_method: str = 'trf',
447
+ polishing_max_nfev: int = None,
448
+ polishing_kwargs_reflectivity: dict = None,
335
449
  ) -> dict:
336
- params = torch.cat([
337
- predicted_params.thicknesses.squeeze(),
338
- predicted_params.roughnesses.squeeze(),
339
- predicted_params.slds.squeeze()
340
- ]).cpu().numpy()
450
+ params = predicted_params.parameters.squeeze().cpu().numpy()
341
451
 
342
452
  polished_params_dict = {}
453
+ polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
343
454
 
344
455
  try:
345
456
  if fit_growth:
@@ -358,18 +469,24 @@ class EasyInferenceModel(object):
358
469
  self.trainer.loader.prior_sampler.param_model
359
470
  )
360
471
  else:
361
- polished_params_arr, curve_polished = standard_refl_fit(
472
+ polished_params_arr, curve_polished = refl_fit(
362
473
  q = q,
363
474
  curve = curve,
364
475
  init_params = params,
365
- bounds=priors.T)
476
+ bounds=priors.T,
477
+ prior_sampler=self.trainer.loader.prior_sampler,
478
+ error_bars=error_bars,
479
+ method=polishing_method,
480
+ polishing_max_nfev=polishing_max_nfev,
481
+ reflectivity_kwargs=polishing_kwargs_reflectivity,
482
+ )
366
483
  polished_params = BasicParams(
367
484
  torch.from_numpy(polished_params_arr[None]),
368
485
  torch.from_numpy(priors.T[0][None]),
369
486
  torch.from_numpy(priors.T[1][None]),
370
487
  self.trainer.loader.prior_sampler.max_num_layers,
371
488
  self.trainer.loader.prior_sampler.param_model
372
- )
489
+ )
373
490
  except Exception as err:
374
491
  polished_params = predicted_params
375
492
  polished_params_arr = get_prediction_array(polished_params)
@@ -379,9 +496,14 @@ class EasyInferenceModel(object):
379
496
  if calc_polished_curve:
380
497
  polished_params_dict['polished_curve'] = curve_polished
381
498
 
499
+ if ambient_sld_tensor is not None:
500
+ ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
501
+
502
+
382
503
  if calc_polished_sld_profile:
383
504
  _, sld_profile_polished, _ = get_density_profiles(
384
- polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
505
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
506
+ z_axis=sld_x_axis.to(polished_params.slds.device),
385
507
  )
386
508
  polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
387
509
 
@@ -390,26 +512,540 @@ class EasyInferenceModel(object):
390
512
  def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
391
513
  if not isinstance(curve, Tensor):
392
514
  curve = torch.from_numpy(curve).float()
393
- curve = torch.atleast_2d(curve).to(self.device)
515
+ curve = curve.unsqueeze(0).to(self.device)
394
516
  scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
395
517
  return scaled_curve
396
518
 
397
519
  def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
398
- prior_bounds = torch.tensor(prior_bounds)
399
- prior_bounds = prior_bounds.to(self.device).T
400
- min_bounds, max_bounds = prior_bounds[:, None]
401
-
402
- scaled_bounds = torch.cat([
403
- self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
404
- self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
405
- ], -1)
406
-
407
- return scaled_bounds.float()
520
+ try:
521
+ prior_bounds = torch.tensor(prior_bounds)
522
+ prior_bounds = prior_bounds.to(self.device).T
523
+ min_bounds, max_bounds = prior_bounds[:, None]
524
+
525
+ scaled_bounds = torch.cat([
526
+ self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
527
+ self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
528
+ ], -1)
529
+
530
+ return scaled_bounds.float()
531
+
532
+ except RuntimeError as e:
533
+ expected_param_dim = self.trainer.loader.prior_sampler.param_dim
534
+ actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
535
+
536
+ msg = (
537
+ f"\n **Parameter dimension mismatch during inference!**\n"
538
+ f"- Model expects **{expected_param_dim}** parameters.\n"
539
+ f"- You provided **{actual_param_dim}** prior bounds.\n\n"
540
+ f"💡This often occurs when:\n"
541
+ f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
542
+ f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
543
+ f"- The number of layers or parameterization type differs from the one used during training.\n\n"
544
+ f" Check the configuration or the summary of expected parameters."
545
+ )
546
+ raise ValueError(msg) from e
408
547
 
548
+ def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
549
+ n_layers = self.trainer.loader.prior_sampler.max_num_layers
550
+ sld_indices = slice(2*n_layers+1, 3*n_layers+2)
551
+ prior_bounds[sld_indices, ...] -= ambient_sld
552
+
553
+ training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
554
+ training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
555
+ lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
556
+ upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
557
+ assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
558
+
559
+ return sld_indices
560
+
561
+ def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
562
+ prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
563
+ if "polished_params_array" in prediction_dict:
564
+ prediction_dict["polished_params_array"][sld_indices] += ambient_sld
565
+
409
566
  def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
410
567
  return LogLikelihood(
411
568
  q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
412
569
  )
570
+
571
+ def get_param_labels(self, **kwargs):
572
+ return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
573
+
574
+ @staticmethod
575
+ def _preprocess_input_data(
576
+ reflectivity_curve,
577
+ q_values,
578
+ sigmas=None,
579
+ q_resolution=None,
580
+ truncate_index_left=None,
581
+ truncate_index_right=None,
582
+ enable_error_bars_filtering=True,
583
+ filter_threshold=0.3,
584
+ filter_remove_singles=True,
585
+ filter_remove_consecutives=True,
586
+ filter_consecutive=3,
587
+ filter_q_start_trunc=0.1):
588
+
589
+ # Save originals for polishing
590
+ reflectivity_curve_original = reflectivity_curve.copy()
591
+ q_values_original = q_values.copy() if q_values is not None else None
592
+ q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
593
+ sigmas_original = sigmas.copy() if sigmas is not None else None
594
+
595
+ # Remove points with non-positive intensities
596
+ nonnegative_mask = reflectivity_curve > 0.0
597
+ reflectivity_curve = reflectivity_curve[nonnegative_mask]
598
+ q_values = q_values[nonnegative_mask]
599
+ if sigmas is not None:
600
+ sigmas = sigmas[nonnegative_mask]
601
+ if isinstance(q_resolution, np.ndarray):
602
+ q_resolution = q_resolution[nonnegative_mask]
603
+
604
+ # Truncate arrays
605
+ if truncate_index_left is not None or truncate_index_right is not None:
606
+ slice_obj = slice(truncate_index_left, truncate_index_right)
607
+ reflectivity_curve = reflectivity_curve[slice_obj]
608
+ q_values = q_values[slice_obj]
609
+ if sigmas is not None:
610
+ sigmas = sigmas[slice_obj]
611
+ if isinstance(q_resolution, np.ndarray):
612
+ q_resolution = q_resolution[slice_obj]
613
+
614
+ # Filter high-error points
615
+ if enable_error_bars_filtering and sigmas is not None:
616
+ valid_mask = get_filtering_mask(
617
+ q_values,
618
+ reflectivity_curve,
619
+ sigmas,
620
+ threshold=filter_threshold,
621
+ consecutive=filter_consecutive,
622
+ remove_singles=filter_remove_singles,
623
+ remove_consecutives=filter_remove_consecutives,
624
+ q_start_trunc=filter_q_start_trunc
625
+ )
626
+ reflectivity_curve = reflectivity_curve[valid_mask]
627
+ q_values = q_values[valid_mask]
628
+ sigmas = sigmas[valid_mask]
629
+ if isinstance(q_resolution, np.ndarray):
630
+ q_resolution = q_resolution[valid_mask]
631
+
632
+ return (q_values, reflectivity_curve, sigmas, q_resolution,
633
+ q_values_original, reflectivity_curve_original,
634
+ sigmas_original, q_resolution_original)
635
+
636
+ def interpolate_data_to_model_q(
637
+ self,
638
+ q_exp,
639
+ refl_exp,
640
+ sigmas_exp=None,
641
+ q_res_exp=None,
642
+ as_dict=False
643
+ ):
644
+ q_generator = self.trainer.loader.q_generator
645
+
646
+ def _pad(arr, pad_to, value=0.0):
647
+ if arr is None:
648
+ return None
649
+ return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
650
+
651
+ def _interp_or_keep(q_model, q_exp, arr):
652
+ """Interpolate arrays, keep floats or None unchanged."""
653
+ if arr is None:
654
+ return None
655
+ return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
656
+
657
+ def _pad_or_keep(arr, max_n):
658
+ """Pad arrays, keep floats or None unchanged."""
659
+ if arr is None:
660
+ return None
661
+ return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
662
+
663
+ def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
664
+ if as_dict:
665
+ result = {"q_model": q, "reflectivity": refl}
666
+ if sigmas is not None: result["sigmas"] = sigmas
667
+ if q_res is not None: result["q_resolution"] = q_res
668
+ if mask is not None: result["key_padding_mask"] = mask
669
+ return result
670
+ result = [q, refl]
671
+ if sigmas is not None: result.append(sigmas)
672
+ if q_res is not None: result.append(q_res)
673
+ if mask is not None: result.append(mask)
674
+ return tuple(result)
675
+
676
+ # ConstantQ
677
+ if isinstance(q_generator, ConstantQ):
678
+ q_model = q_generator.q.cpu().numpy()
679
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
680
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
681
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
682
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
683
+
684
+ # VariableQ
685
+ elif isinstance(q_generator, VariableQ):
686
+ if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
687
+ n_q_model = q_generator.n_q_range[0]
688
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
689
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
690
+ if self.trainer.loader.q_generator.mode == 'logspace':
691
+ q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
692
+ end=torch.log10(torch.tensor(q_max, device=self.device)),
693
+ steps=n_q_model, device=self.device).to('cpu')
694
+ logspace = True
695
+ else:
696
+ q_model = np.linspace(q_min, q_max, n_q_model)
697
+ logspace = False
698
+ else:
699
+ return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
700
+
701
+ refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
702
+ sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
703
+ q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
704
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
705
+
706
+ # MaskedVariableQ
707
+ elif isinstance(q_generator, MaskedVariableQ):
708
+ min_n, max_n = q_generator.n_q_range
709
+ n_exp = len(q_exp)
710
+
711
+ if min_n <= n_exp <= max_n:
712
+ # Pad only
713
+ q_model = _pad(q_exp, max_n, 0.0)
714
+ refl_out = _pad(refl_exp, max_n, 0.0)
715
+ sigmas_out = _pad_or_keep(sigmas_exp, max_n)
716
+ q_res_out = _pad_or_keep(q_res_exp, max_n)
717
+ key_padding_mask = np.zeros(max_n, dtype=bool)
718
+ key_padding_mask[:n_exp] = True
719
+
720
+ else:
721
+ # Interpolate + pad
722
+ n_interp = min(max(n_exp, min_n), max_n)
723
+ q_min = max(q_exp.min(), q_generator.q_min_range[0])
724
+ q_max = min(q_exp.max(), q_generator.q_max_range[1])
725
+ q_interp = np.linspace(q_min, q_max, n_interp)
726
+
727
+ refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
728
+ sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
729
+ q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
730
+
731
+ q_model = _pad(q_interp, max_n, 0.0)
732
+ refl_out = _pad(refl_interp, max_n, 0.0)
733
+ sigmas_out = _pad_or_keep(sigmas_interp, max_n)
734
+ q_res_out = _pad_or_keep(q_res_interp, max_n)
735
+ key_padding_mask = np.zeros(max_n, dtype=bool)
736
+ key_padding_mask[:n_interp] = True
737
+
738
+ return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
739
+
740
+ else:
741
+ raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
742
+
743
+ def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
744
+ assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
745
+ q = self.trainer.loader.q_generator.q.squeeze().float()
746
+ dq_max = (q[1] - q[0]) * dq_coef
747
+ q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
748
+
749
+ curve = to_t(curve).to(scaled_bounds)
750
+ shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
751
+
752
+ assert shifted_curves.shape == (num, q.shape[0])
753
+
754
+ scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
755
+ scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
756
+
757
+ with torch.no_grad():
758
+ self.trainer.model.eval()
759
+ scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
760
+ restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
761
+
762
+ best_param = get_best_mse_param(
763
+ restored_params,
764
+ self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
765
+ )
766
+ return best_param
767
+
768
+ def predict_using_widget(self, reflectivity_curve, **kwargs):
769
+ """
770
+ """
771
+
772
+ NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
773
+ param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
774
+ min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
775
+ max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
776
+ max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
777
+
778
+ print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
779
+
780
+ interval_widgets = []
781
+ for i in range(NUM_INTERVALS):
782
+ label = widgets.Label(value=f'{param_labels[i]}')
783
+ initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
784
+ slider = widgets.FloatRangeSlider(
785
+ value=[min_bounds[i], initial_max],
786
+ min=min_bounds[i],
787
+ max=max_bounds[i],
788
+ step=0.01,
789
+ layout=widgets.Layout(width='400px'),
790
+ style={'description_width': '60px'}
791
+ )
792
+
793
+ def validate_range(change, slider=slider, max_width=max_deltas[i]):
794
+ min_val, max_val = change['new']
795
+ if max_val - min_val > max_width:
796
+ old_min_val, old_max_val = change['old']
797
+ if abs(old_min_val - min_val) > abs(old_max_val - max_val):
798
+ max_val = min_val + max_width
799
+ else:
800
+ min_val = max_val - max_width
801
+ slider.value = [min_val, max_val]
802
+
803
+ slider.observe(validate_range, names='value')
804
+ interval_widgets.append((slider, widgets.HBox([label, slider])))
805
+
806
+ sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
807
+
808
+ output = widgets.Output()
809
+ predict_button = widgets.Button(description="Predict")
810
+ close_button = widgets.Button(description="Close Widget")
811
+
812
+ container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
813
+ display(container)
814
+
815
+ @output.capture(clear_output=True)
816
+ def on_predict_click(_):
817
+ if 'prior_bounds' in kwargs:
818
+ array_values = kwargs.pop('prior_bounds')
819
+ for i, (s, _) in enumerate(interval_widgets):
820
+ s.value = tuple(array_values[i])
821
+ else:
822
+ values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
823
+ array_values = np.array(values)
824
+
825
+ prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
826
+ prior_bounds=array_values,
827
+ **kwargs)
828
+ param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
829
+ print_prediction_results(prediction_result)
830
+
831
+ plot_prediction_results(
832
+ prediction_result,
833
+ q_exp=kwargs['q_values'],
834
+ curve_exp=reflectivity_curve,
835
+ )
836
+ self.widget_prediction_result = prediction_result
837
+
838
+ def on_close_click(_):
839
+ container.close()
840
+ print("Widget closed.")
841
+
842
+ predict_button.on_click(on_predict_click)
843
+ close_button.on_click(on_close_click)
844
+
845
+ def preprocess_and_predict_using_widget(self,
846
+ reflectivity_curve,
847
+ q_values=None,
848
+ sigmas=None,
849
+ q_resolution=None,
850
+ prior_bounds=None,
851
+ ambient_sld=None,
852
+ ):
853
+ """
854
+ Interactive widget around `preprocess_and_predict`
855
+ Results are stored in `self.widget_prediction_result`
856
+ """
857
+
858
+ if q_values is None:
859
+ raise ValueError("q_values must be provided for this widget.")
860
+
861
+ N = len(reflectivity_curve)
862
+
863
+ # ---------- Priors sliders ----------
864
+ param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
865
+ min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
866
+ max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
867
+ max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
868
+ NUM = len(param_labels)
869
+
870
+ sliders, rows = [], []
871
+ init_pb = np.array(prior_bounds) if prior_bounds is not None else None
872
+ for i in range(NUM):
873
+ init_min = float(init_pb[i, 0]) if init_pb is not None else float(min_bounds[i])
874
+ init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(min_bounds[i] + max_deltas[i], max_bounds[i]))
875
+ lab = widgets.Label(value=param_labels[i])
876
+ s = widgets.FloatRangeSlider(
877
+ value=[init_min, init_max],
878
+ min=float(min_bounds[i]),
879
+ max=float(max_bounds[i]),
880
+ step=0.01,
881
+ layout=widgets.Layout(width='420px'),
882
+ readout_format='.3f'
883
+ )
884
+ # Constrain slider widths
885
+ def _mk_validator(slider, max_width=float(max_deltas[i])):
886
+ def _validate(change):
887
+ a, b = change['new']
888
+ if b - a > max_width:
889
+ oa, ob = change['old']
890
+ if abs(oa - a) > abs(ob - b):
891
+ b = a + max_width
892
+ else:
893
+ a = b - max_width
894
+ slider.value = (a, b)
895
+ return _validate
896
+ s.observe(_mk_validator(s), names='value')
897
+ sliders.append(s)
898
+ rows.append(widgets.HBox([lab, s]))
899
+ priors_box = widgets.VBox([widgets.HTML("<b>Priors</b>")] + rows)
900
+
901
+ # ---------- Preprocess & Predict controls ----------
902
+ # Preprocessing
903
+ trunc_L = widgets.IntSlider(description='truncate left', min=0, max=max(0, N-1), step=1, value=0)
904
+ trunc_R = widgets.IntSlider(description='truncate right', min=1, max=N, step=1, value=N)
905
+ enable_filt = widgets.Checkbox(description='filter error bars', value=True)
906
+ thr = widgets.FloatSlider(description='filtering threshold', min=0.0, max=1.0, step=0.01, value=0.3)
907
+ rem_single = widgets.Checkbox(description='remove singles', value=True)
908
+ rem_cons = widgets.Checkbox(description='remove consecutives', value=True)
909
+ consec = widgets.IntSlider(description='num. consecutive', min=1, max=10, step=1, value=3)
910
+ qstart = widgets.FloatSlider(description='q_start_trunc', min=0.0, max=1.0, step=0.01, value=0.1)
911
+
912
+ # Polishing
913
+ polish = widgets.Checkbox(description='polish prediction', value=True)
914
+ use_sigmas_polish = widgets.Checkbox(description='use sigmas during polishing', value=True)
915
+
916
+ # Plotting
917
+ pred_show_yerr = widgets.Checkbox(description='show error bars', value=True)
918
+ pred_show_xerr = widgets.Checkbox(description='show q-resolution', value=False)
919
+ pred_logx = widgets.Checkbox(description='log x-axis', value=False)
920
+ plot_sld = widgets.Checkbox(description='plot SLD profile', value=True)
921
+ sld_pad_left = widgets.FloatText(description='SLD pad left', value=0.2, step=0.1)
922
+ sld_pad_right = widgets.FloatText(description='SLD pad right', value=1.1, step=0.1)
923
+
924
+ # Color pickers
925
+ exp_color_picker = widgets.ColorPicker(description='exp color', value='#0000FF') # blue
926
+ exp_errcolor_picker = widgets.ColorPicker(description='errbar color', value='#800080') # purple
927
+ pred_color_picker = widgets.ColorPicker(description='pred color', value='#FF0000') # red
928
+ pol_color_picker = widgets.ColorPicker(description='polished color', value='#FFA500') # orange
929
+ sld_pred_color_picker = widgets.ColorPicker(description='SLD pred color', value='#FF0000') # red
930
+ sld_pol_color_picker = widgets.ColorPicker(description='SLD pol color', value='#FFA500') # orange
931
+
932
+ # Compute toggles
933
+ calc_curve = widgets.Checkbox(description='calc curve', value=True)
934
+ calc_sld_pred = widgets.Checkbox(description='calc predicted SLD', value=True)
935
+ calc_sld_pol = widgets.Checkbox(description='calc polished SLD', value=True)
936
+
937
+ btn_predict = widgets.Button(description='Predict')
938
+ btn_close = widgets.Button(description='Close')
939
+
940
+ controls_box = widgets.VBox([
941
+ widgets.HTML("<b>Preprocess & Predict</b>"),
942
+ widgets.HTML("<i>Preprocessing</i>"),
943
+ widgets.HBox([trunc_L, trunc_R]),
944
+ widgets.HBox([enable_filt, rem_single, rem_cons]),
945
+ widgets.HBox([thr, consec, qstart]),
946
+ widgets.HTML("<i>Polishing</i>"),
947
+ widgets.HBox([polish, use_sigmas_polish]),
948
+ widgets.HTML("<i>Plotting</i>"),
949
+ widgets.HBox([pred_show_yerr, pred_show_xerr, pred_logx, plot_sld]),
950
+ widgets.HBox([sld_pad_left, sld_pad_right]),
951
+ widgets.HTML("<i>Colors</i>"),
952
+ widgets.HBox([exp_color_picker, exp_errcolor_picker]),
953
+ widgets.HBox([pred_color_picker, pol_color_picker]),
954
+ widgets.HBox([sld_pred_color_picker, sld_pol_color_picker]),
955
+ widgets.HTML("<i>Compute</i>"),
956
+ widgets.HBox([calc_curve, calc_sld_pred, calc_sld_pol]),
957
+ widgets.HBox([btn_predict, btn_close]),
958
+ ])
959
+
960
+ out_predict = widgets.Output()
961
+ container = widgets.VBox([priors_box, controls_box, out_predict])
962
+ display(container)
963
+
964
+ def _sync_trunc(_):
965
+ if trunc_L.value >= trunc_R.value:
966
+ trunc_L.value = max(0, trunc_R.value - 1)
967
+ trunc_L.observe(_sync_trunc, names='value')
968
+ trunc_R.observe(_sync_trunc, names='value')
969
+
970
+ def _current_priors():
971
+ return np.array([s.value for s in sliders], dtype=np.float32) # (param_dim, 2)
972
+
973
+ @out_predict.capture(clear_output=True)
974
+ def _on_predict(_):
975
+ out_predict.clear_output(wait=True)
976
+
977
+ res = self.preprocess_and_predict(
978
+ reflectivity_curve=reflectivity_curve,
979
+ q_values=q_values,
980
+ prior_bounds=_current_priors(),
981
+ sigmas=sigmas,
982
+ q_resolution=q_resolution,
983
+ ambient_sld=ambient_sld,
984
+ clip_prediction=True,
985
+ polish_prediction=polish.value,
986
+ use_sigmas_for_polishing=use_sigmas_polish.value,
987
+ calc_pred_curve=calc_curve.value,
988
+ calc_pred_sld_profile=(calc_sld_pred.value or plot_sld.value),
989
+ calc_polished_sld_profile=(calc_sld_pol.value or plot_sld.value),
990
+ sld_profile_padding_left=float(sld_pad_left.value),
991
+ sld_profile_padding_right=float(sld_pad_right.value),
992
+
993
+ truncate_index_left=trunc_L.value,
994
+ truncate_index_right=trunc_R.value,
995
+ enable_error_bars_filtering=enable_filt.value,
996
+ filter_threshold=thr.value,
997
+ filter_remove_singles=rem_single.value,
998
+ filter_remove_consecutives=rem_cons.value,
999
+ filter_consecutive=consec.value,
1000
+ filter_q_start_trunc=qstart.value,
1001
+ )
1002
+
1003
+ # Full experimental data as scatter
1004
+ q_exp_plot = q_values
1005
+ r_exp_plot = reflectivity_curve
1006
+ yerr_plot = (sigmas if pred_show_yerr.value else None)
1007
+ xerr_plot = (q_resolution if pred_show_xerr.value else None)
1008
+
1009
+ # Predicted curve only on the model region
1010
+ q_pred = res.get('q_plot_pred', None)
1011
+ r_pred = res.get('predicted_curve', None)
1012
+
1013
+ # Polished curve on the full experimental grid
1014
+ q_pol = q_values if ('polished_curve' in res) else None
1015
+ r_pol = res.get('polished_curve', None)
1016
+
1017
+ # SLD profiles
1018
+ z_sld = res.get('predicted_sld_xaxis', None)
1019
+ sld_pred = res.get('predicted_sld_profile', None)
1020
+ sld_pol = res.get('sld_profile_polished', None)
1021
+
1022
+ print_prediction_results(res)
1023
+
1024
+ plot_reflectivity(
1025
+ q_exp=q_exp_plot, r_exp=r_exp_plot,
1026
+ yerr=yerr_plot, xerr=xerr_plot,
1027
+ exp_style=('errorbar' if pred_show_yerr.value or pred_show_xerr.value else 'scatter'),
1028
+ exp_color=exp_color_picker.value,
1029
+ exp_errcolor=exp_errcolor_picker.value,
1030
+ q_pred=q_pred, r_pred=r_pred, pred_color=pred_color_picker.value,
1031
+ q_pol=q_pol, r_pol=r_pol, pol_color=pol_color_picker.value,
1032
+ z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
1033
+ sld_pred_color=sld_pred_color_picker.value,
1034
+ sld_pol_color=sld_pol_color_picker.value,
1035
+ plot_sld_profile=plot_sld.value,
1036
+ logx=pred_logx.value, logy=True,
1037
+ figsize=(12,6),
1038
+ legend=True
1039
+ )
1040
+
1041
+ self.widget_prediction_result = res
1042
+
1043
+ def _on_close(_):
1044
+ container.close()
1045
+ print("Widget closed.")
1046
+
1047
+ btn_predict.on_click(_on_predict)
1048
+ btn_close.on_click(_on_close)
413
1049
 
414
1050
  class InferenceModel(object):
415
1051
  def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,