reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (39) hide show
  1. reflectorch/data_generation/__init__.py +2 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +90 -15
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +31 -11
  8. reflectorch/data_generation/reflectivity/__init__.py +56 -14
  9. reflectorch/data_generation/reflectivity/abeles.py +31 -16
  10. reflectorch/data_generation/reflectivity/kinematical.py +5 -6
  11. reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
  12. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  13. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  14. reflectorch/data_generation/smearing.py +42 -11
  15. reflectorch/data_generation/utils.py +92 -18
  16. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  17. reflectorch/inference/inference_model.py +220 -105
  18. reflectorch/inference/plotting.py +98 -0
  19. reflectorch/inference/scipy_fitter.py +84 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +122 -23
  26. reflectorch/models/__init__.py +1 -1
  27. reflectorch/models/encoders/__init__.py +0 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/networks/__init__.py +2 -0
  31. reflectorch/models/networks/mlp_networks.py +324 -152
  32. reflectorch/models/networks/residual_net.py +31 -5
  33. reflectorch/runs/train.py +0 -1
  34. reflectorch/runs/utils.py +43 -9
  35. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
  36. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
  37. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
  38. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
  39. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,11 @@ from IPython.display import display
12
12
  from huggingface_hub import hf_hub_download
13
13
 
14
14
  from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
15
+ from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
15
16
  from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
16
17
  from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
18
+ from reflectorch.inference.plotting import plot_prediction_results
19
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
17
20
  from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
18
21
  from reflectorch.runs.utils import (
19
22
  get_trainer_by_name, train_from_config
@@ -23,7 +26,7 @@ from reflectorch.ml.trainers import PointEstimatorTrainer
23
26
  from reflectorch.data_generation.likelihoods import LogLikelihood
24
27
 
25
28
  from reflectorch.inference.preprocess_exp import StandardPreprocessing
26
- from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
29
+ from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
27
30
  from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
28
31
  from reflectorch.inference.record_time import print_time
29
32
  from reflectorch.utils import to_t
@@ -103,7 +106,9 @@ class EasyInferenceModel(object):
103
106
  self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
104
107
  self.trainer.model.eval()
105
108
 
106
- print(f'The model corresponds to a `{self.trainer.loader.prior_sampler.param_model.NAME}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
109
+ param_model = self.trainer.loader.prior_sampler.param_model
110
+ param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
111
+ print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
107
112
  print("Parameter types and total ranges:")
108
113
  for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
109
114
  print(f"- {param}: {range_}")
@@ -115,23 +120,51 @@ class EasyInferenceModel(object):
115
120
  q_min = self.trainer.loader.q_generator.q[0].item()
116
121
  q_max = self.trainer.loader.q_generator.q[-1].item()
117
122
  n_q = self.trainer.loader.q_generator.q.shape[0]
118
- print(f'The model was trained on curves discretized at {n_q} uniform points between between q_min={q_min} and q_max={q_max}')
123
+ print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
119
124
  elif isinstance(self.trainer.loader.q_generator, VariableQ):
120
125
  q_min_range = self.trainer.loader.q_generator.q_min_range
121
126
  q_max_range = self.trainer.loader.q_generator.q_max_range
122
127
  n_q_range = self.trainer.loader.q_generator.n_q_range
123
- print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} of uniform points between between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
124
-
125
- def predict(self, reflectivity_curve: Union[np.ndarray, Tensor],
128
+ if n_q_range[0] == n_q_range[1]:
129
+ n_q_fixed = n_q_range[0]
130
+ print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
131
+ f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
132
+ else:
133
+ print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
134
+ f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
135
+
136
+ if self.trainer.loader.smearing is not None:
137
+ q_res_min = self.trainer.loader.smearing.sigma_min
138
+ q_res_max = self.trainer.loader.smearing.sigma_max
139
+ if self.trainer.loader.smearing.constant_dq == False:
140
+ print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
141
+ elif self.trainer.loader.smearing.constant_dq == True:
142
+ print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
143
+
144
+ additional_inputs = ["prior bounds"]
145
+ if self.trainer.train_with_q_input:
146
+ additional_inputs.append("q values")
147
+ if self.trainer.condition_on_q_resolutions:
148
+ additional_inputs.append("the resolution dq/q")
149
+ if additional_inputs:
150
+ inputs_str = ", ".join(additional_inputs)
151
+ print(f"The following quantities are additional inputs to the network: {inputs_str}.")
152
+
153
+ def predict(self,
154
+ reflectivity_curve: Union[np.ndarray, Tensor],
126
155
  q_values: Union[np.ndarray, Tensor] = None,
127
156
  prior_bounds: Union[np.ndarray, List[Tuple]] = None,
157
+ q_resolution: Union[float, np.ndarray] = None,
158
+ ambient_sld: float = None,
128
159
  clip_prediction: bool = False,
129
160
  polish_prediction: bool = False,
161
+ polishing_kwargs_reflectivity: dict = None,
130
162
  fit_growth: bool = False,
131
163
  max_d_change: float = 5.,
132
164
  use_q_shift: bool = False,
133
165
  calc_pred_curve: bool = True,
134
166
  calc_pred_sld_profile: bool = False,
167
+ calc_polished_sld_profile: bool = False,
135
168
  ):
136
169
  """Predict the thin film parameters
137
170
 
@@ -139,13 +172,17 @@ class EasyInferenceModel(object):
139
172
  reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
140
173
  q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
141
174
  prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
175
+ q_resolution (Union[float, np.ndarray], optional): the instrumental resolution. Either as a float with meaning dq/q for linear smearing or as a numpy array with meaning dq for pointwise smearing.
176
+ ambient_sld (float, optional): the SLD of the ambient medium (fronting), if different from air.
142
177
  clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
143
178
  polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
179
+ polishing_kwargs_reflectivity (dict): extra arguments for the reflectivity function used during polishing.
144
180
  fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
145
181
  max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
146
182
  use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
147
183
  calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
148
184
  calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
185
+ calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
149
186
 
150
187
  Returns:
151
188
  dict: dictionary containing the predictions
@@ -153,7 +190,22 @@ class EasyInferenceModel(object):
153
190
 
154
191
  scaled_curve = self._scale_curve(reflectivity_curve)
155
192
  prior_bounds = np.array(prior_bounds)
156
- scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
193
+
194
+ if ambient_sld:
195
+ n_layers = self.trainer.loader.prior_sampler.max_num_layers
196
+ sld_indices = slice(2*n_layers+1, 3*n_layers+2)
197
+ prior_bounds[sld_indices, ...] -= ambient_sld
198
+ training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
199
+ training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
200
+ lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
201
+ upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
202
+ assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
203
+
204
+ try:
205
+ scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
206
+ except ValueError as e:
207
+ print(str(e))
208
+ return None
157
209
 
158
210
  if not self.trainer.train_with_q_input:
159
211
  q_values = self.trainer.loader.q_generator.q
@@ -166,11 +218,29 @@ class EasyInferenceModel(object):
166
218
  else:
167
219
  with torch.no_grad():
168
220
  self.trainer.model.eval()
169
- if self.trainer.train_with_q_input:
170
- scaled_q = self.trainer.loader.q_generator.scale_q(q_values).float()
171
- scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds, scaled_q)
221
+
222
+ scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
223
+
224
+ if q_resolution is not None:
225
+ q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
226
+ if isinstance(q_resolution, float):
227
+ unscaled_q_resolutions = q_resolution_tensor
228
+ else:
229
+ unscaled_q_resolutions = (q_resolution_tensor / q_values).mean(dim=-1, keepdim=True)
230
+ scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
231
+ scaled_conditioning_params = scaled_q_resolutions
232
+ if polishing_kwargs_reflectivity is None:
233
+ polishing_kwargs_reflectivity = {'dq': q_resolution}
172
234
  else:
173
- scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds)
235
+ q_resolution_tensor = None
236
+ scaled_conditioning_params = None
237
+
238
+ scaled_predicted_params = self.trainer.model(
239
+ curves=scaled_curve,
240
+ bounds=scaled_prior_bounds,
241
+ q_values=scaled_q_values,
242
+ conditioning_params = scaled_conditioning_params,
243
+ )
174
244
 
175
245
  predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
176
246
 
@@ -184,19 +254,22 @@ class EasyInferenceModel(object):
184
254
  }
185
255
 
186
256
  if calc_pred_curve:
187
- predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
257
+ predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
188
258
  prediction_dict[ "predicted_curve"] = predicted_curve
189
259
 
260
+ ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld)).to(predicted_params.thicknesses.device) if ambient_sld is not None else None
190
261
  if calc_pred_sld_profile:
191
262
  predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
192
- predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
263
+ predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, ambient_sld_tensor, num=1024,
193
264
  )
194
265
  prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
195
266
  prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
196
267
  else:
197
268
  predicted_sld_xaxis = None
198
269
 
199
- if polish_prediction: #only for standard box-model parameterization
270
+ if polish_prediction:
271
+ if ambient_sld_tensor:
272
+ ambient_sld_tensor = ambient_sld_tensor.cpu()
200
273
  polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
201
274
  curve = reflectivity_curve,
202
275
  predicted_params = predicted_params,
@@ -204,19 +277,25 @@ class EasyInferenceModel(object):
204
277
  fit_growth = fit_growth,
205
278
  max_d_change = max_d_change,
206
279
  calc_polished_curve = calc_pred_curve,
207
- calc_polished_sld_profile = False,
280
+ calc_polished_sld_profile = calc_polished_sld_profile,
281
+ ambient_sld_tensor=ambient_sld_tensor,
208
282
  sld_x_axis = predicted_sld_xaxis,
283
+ polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
209
284
  )
210
285
  prediction_dict.update(polished_dict)
211
286
 
212
287
  if fit_growth and "polished_params_array" in prediction_dict:
213
288
  prediction_dict["param_names"].append("max_d_change")
214
289
 
290
+ if ambient_sld: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object
291
+ prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
292
+ if "polished_params_array" in prediction_dict:
293
+ prediction_dict["polished_params_array"][sld_indices] += ambient_sld
294
+
215
295
  return prediction_dict
216
296
 
217
- async def predict_using_widget(self, reflectivity_curve: Union[np.ndarray, torch.Tensor], **kwargs):
218
- """Use an interactive Python widget for specifying the prior bounds before the prediction (works only in a Jupyter notebook).
219
- The other arguments are the same as for the ``predict`` method.
297
+ def predict_using_widget(self, reflectivity_curve, **kwargs):
298
+ """
220
299
  """
221
300
 
222
301
  NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
@@ -225,76 +304,74 @@ class EasyInferenceModel(object):
225
304
  max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
226
305
  max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
227
306
 
228
- print(f'Parameter ranges: {self.trainer.loader.prior_sampler.param_ranges}')
229
- print(f'Allowed widths of the prior bound intervals (max-min): {self.trainer.loader.prior_sampler.bound_width_ranges}')
230
- print(f'Please fill in the values of the minimum and maximum prior bound for each parameter and press the button!')
231
-
232
- def create_interval_widgets(n):
233
- intervals = []
234
- for i in range(n):
235
- interval_label = widgets.Label(value=f'{param_labels[i]}')
236
- initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
237
- slider = widgets.FloatRangeSlider(
238
- value=[min_bounds[i], initial_max],
239
- min=min_bounds[i],
240
- max=max_bounds[i],
241
- step=0.01,
242
- layout=widgets.Layout(width='400px'),
243
- style={'description_width': '60px'}
244
- )
245
-
246
- def validate_range(change, slider=slider, max_width=max_deltas[i]):
247
- min_val, max_val = change['new']
248
- if max_val - min_val > max_width:
249
- if change['name'] == 'value':
250
- if change['old'][0] != min_val:
251
- max_val = min_val + max_width
252
- else:
253
- min_val = max_val - max_width
254
- slider.value = [min_val, max_val]
255
-
256
- slider.observe(validate_range, names='value')
257
-
258
- interval_row = widgets.HBox([interval_label, slider])
259
- intervals.append((slider, interval_row))
260
- return intervals
261
-
262
- interval_widgets = create_interval_widgets(NUM_INTERVALS)
263
- interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
264
- display(interval_box)
265
-
266
- button = widgets.Button(description="Make prediction")
267
- display(button)
268
-
269
- prediction_result = None
270
-
271
- def store_values(b, future):
272
- print("Debug: Button clicked")
273
- values = []
274
- for slider, _ in interval_widgets:
275
- values.append((slider.value[0], slider.value[1]))
276
- array_values = np.array(values)
277
-
278
- nonlocal prediction_result
279
- prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
280
- print(prediction_result["predicted_params_array"])
281
-
282
- print("Prediction completed. Closing widget.")
283
-
284
- for child in interval_box.children:
285
- child.close()
286
- button.close()
287
-
288
- future.set_result(prediction_result)
289
-
290
- button.on_click(store_values)
291
-
307
+ print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
308
+
309
+ interval_widgets = []
310
+ for i in range(NUM_INTERVALS):
311
+ label = widgets.Label(value=f'{param_labels[i]}')
312
+ initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
313
+ slider = widgets.FloatRangeSlider(
314
+ value=[min_bounds[i], initial_max],
315
+ min=min_bounds[i],
316
+ max=max_bounds[i],
317
+ step=0.01,
318
+ layout=widgets.Layout(width='400px'),
319
+ style={'description_width': '60px'}
320
+ )
292
321
 
293
- future = asyncio.Future()
322
+ def validate_range(change, slider=slider, max_width=max_deltas[i]):
323
+ min_val, max_val = change['new']
324
+ if max_val - min_val > max_width:
325
+ old_min_val, old_max_val = change['old']
326
+ if abs(old_min_val - min_val) > abs(old_max_val - max_val):
327
+ max_val = min_val + max_width
328
+ else:
329
+ min_val = max_val - max_width
330
+ slider.value = [min_val, max_val]
331
+
332
+ slider.observe(validate_range, names='value')
333
+ interval_widgets.append((slider, widgets.HBox([label, slider])))
334
+
335
+ sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
336
+
337
+ output = widgets.Output()
338
+ predict_button = widgets.Button(description="Predict")
339
+ close_button = widgets.Button(description="Close Widget")
340
+
341
+ container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
342
+ display(container)
343
+
344
+ @output.capture(clear_output=True)
345
+ def on_predict_click(_):
346
+ if 'prior_bounds' in kwargs:
347
+ array_values = kwargs.pop('prior_bounds')
348
+ for i, (s, _) in enumerate(interval_widgets):
349
+ s.value = tuple(array_values[i])
350
+ else:
351
+ values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
352
+ array_values = np.array(values)
353
+
354
+ prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
355
+ prior_bounds=array_values,
356
+ **kwargs)
357
+ param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
358
+ for param_name, pred_param_val in zip(param_names, prediction_result["predicted_params_array"]):
359
+ print(f'{param_name.ljust(14)} : {pred_param_val:.2f}')
360
+
361
+ plot_prediction_results(
362
+ prediction_result,
363
+ q_exp=kwargs['q_values'],
364
+ curve_exp=reflectivity_curve,
365
+ q_model=kwargs['q_values'],
366
+ )
367
+ self.prediction_result = prediction_result
294
368
 
295
- button.on_click(lambda b: store_values(b, future))
369
+ def on_close_click(_):
370
+ container.close()
371
+ print("Widget closed.")
296
372
 
297
- return await future
373
+ predict_button.on_click(on_predict_click)
374
+ close_button.on_click(on_close_click)
298
375
 
299
376
 
300
377
  def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
@@ -328,18 +405,17 @@ class EasyInferenceModel(object):
328
405
  predicted_params: BasicParams,
329
406
  priors: np.ndarray,
330
407
  sld_x_axis,
408
+ ambient_sld_tensor: Tensor = None,
331
409
  fit_growth: bool = False,
332
410
  max_d_change: float = 5.,
333
411
  calc_polished_curve: bool = True,
334
412
  calc_polished_sld_profile: bool = False,
413
+ polishing_kwargs_reflectivity: dict = None,
335
414
  ) -> dict:
336
- params = torch.cat([
337
- predicted_params.thicknesses.squeeze(),
338
- predicted_params.roughnesses.squeeze(),
339
- predicted_params.slds.squeeze()
340
- ]).cpu().numpy()
415
+ params = predicted_params.parameters.squeeze().cpu().numpy()
341
416
 
342
417
  polished_params_dict = {}
418
+ polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
343
419
 
344
420
  try:
345
421
  if fit_growth:
@@ -354,20 +430,25 @@ class EasyInferenceModel(object):
354
430
  torch.from_numpy(polished_params_arr[:-1][None]),
355
431
  torch.from_numpy(priors.T[0][None]),
356
432
  torch.from_numpy(priors.T[1][None]),
357
- #self.trainer.loader.prior_sampler.param_model
433
+ self.trainer.loader.prior_sampler.max_num_layers,
434
+ self.trainer.loader.prior_sampler.param_model
358
435
  )
359
436
  else:
360
- polished_params_arr, curve_polished = standard_refl_fit(
437
+ polished_params_arr, curve_polished = refl_fit(
361
438
  q = q,
362
439
  curve = curve,
363
440
  init_params = params,
364
- bounds=priors.T)
441
+ bounds=priors.T,
442
+ prior_sampler=self.trainer.loader.prior_sampler,
443
+ reflectivity_kwargs=polishing_kwargs_reflectivity,
444
+ )
365
445
  polished_params = BasicParams(
366
446
  torch.from_numpy(polished_params_arr[None]),
367
447
  torch.from_numpy(priors.T[0][None]),
368
448
  torch.from_numpy(priors.T[1][None]),
369
- #self.trainer.loader.prior_sampler.param_model
370
- )
449
+ self.trainer.loader.prior_sampler.max_num_layers,
450
+ self.trainer.loader.prior_sampler.param_model
451
+ )
371
452
  except Exception as err:
372
453
  polished_params = predicted_params
373
454
  polished_params_arr = get_prediction_array(polished_params)
@@ -379,9 +460,9 @@ class EasyInferenceModel(object):
379
460
 
380
461
  if calc_polished_sld_profile:
381
462
  _, sld_profile_polished, _ = get_density_profiles(
382
- polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
463
+ polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, ambient_sld_tensor, z_axis=sld_x_axis.cpu(),
383
464
  )
384
- polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
465
+ polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().numpy()
385
466
 
386
467
  return polished_params_dict
387
468
 
@@ -393,16 +474,50 @@ class EasyInferenceModel(object):
393
474
  return scaled_curve
394
475
 
395
476
  def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
396
- prior_bounds = torch.tensor(prior_bounds)
397
- prior_bounds = prior_bounds.to(self.device).T
398
- min_bounds, max_bounds = prior_bounds[:, None]
477
+ try:
478
+ prior_bounds = torch.tensor(prior_bounds)
479
+ prior_bounds = prior_bounds.to(self.device).T
480
+ min_bounds, max_bounds = prior_bounds[:, None]
481
+
482
+ scaled_bounds = torch.cat([
483
+ self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
484
+ self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
485
+ ], -1)
486
+
487
+ return scaled_bounds.float()
488
+
489
+ except RuntimeError as e:
490
+ expected_param_dim = self.trainer.loader.prior_sampler.param_dim
491
+ actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
492
+
493
+ msg = (
494
+ f"\n **Parameter dimension mismatch during inference!**\n"
495
+ f"- Model expects **{expected_param_dim}** parameters.\n"
496
+ f"- You provided **{actual_param_dim}** prior bounds.\n\n"
497
+ f"💡This often occurs when:\n"
498
+ f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
499
+ f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
500
+ f"- The number of layers or parameterization type differs from the one used during training.\n\n"
501
+ f" Check the configuration or the summary of expected parameters."
502
+ )
503
+ raise ValueError(msg) from e
504
+
505
+ def interpolate_data_to_model_q(self, q_exp, curve_exp):
506
+ if isinstance(self.trainer.loader.q_generator, ConstantQ):
507
+ q_model = self.trainer.loader.q_generator.q.cpu().numpy()
508
+ elif isinstance(self.trainer.loader.q_generator, VariableQ):
509
+ if self.trainer.loader.q_generator.n_q_range[0] == self.trainer.loader.q_generator.n_q_range[1]:
510
+ n_q_model = self.trainer.loader.q_generator.n_q_range[0]
511
+ q_model_min = max(q_exp.min(), self.trainer.loader.q_generator.q_min_range[0])
512
+ q_model_max = min(q_exp.max(), self.trainer.loader.q_generator.q_max_range[1])
513
+ q_model = np.linspace(q_model_min, q_model_max, n_q_model)
514
+ else:
515
+ q_model = q_exp
516
+ exp_curve_interp = curve_exp
399
517
 
400
- scaled_bounds = torch.cat([
401
- self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
402
- self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
403
- ], -1)
518
+ exp_curve_interp = interp_reflectivity(q_model, q_exp, curve_exp)
404
519
 
405
- return scaled_bounds.float()
520
+ return q_model, exp_curve_interp
406
521
 
407
522
  def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
408
523
  return LogLikelihood(
@@ -0,0 +1,98 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+
4
+
5
+ def plot_prediction_results(
6
+ prediction_dict: dict,
7
+ q_exp: np.ndarray = None,
8
+ curve_exp: np.ndarray = None,
9
+ sigmas_exp: np.ndarray = None,
10
+ q_model: np.ndarray = None,
11
+ ):
12
+ """
13
+ Plot the experimental curve (with optional error bars), the predicted
14
+ and polished curves, and also the predicted/polished SLD profiles.
15
+
16
+ Args:
17
+ prediction_dict (dict): Dictionary containing 'predicted_curve',
18
+ 'predicted_sld_profile', 'predicted_sld_xaxis',
19
+ and optionally 'polished_curve', 'sld_profile_polished'.
20
+ q_exp (ndarray, optional): Experimental q-values.
21
+ curve_exp (ndarray, optional): Experimental reflectivity curve.
22
+ sigmas_exp (ndarray, optional): Error bars of the experimental reflectivity.
23
+ q_model (ndarray, optional): The q-values on which prediction_dict's reflectivity
24
+ was computed (e.g. from EasyInferenceModel.interpolate_data_to_model_q).
25
+
26
+ Example usage:
27
+ prediction_dict = model.predict(...)
28
+ plot_prediction_results(
29
+ prediction_dict,
30
+ q_exp=q_exp,
31
+ curve_exp=curve_exp,
32
+ sigmas_exp=sigmas_exp,
33
+ q_model=q_model
34
+ )
35
+ """
36
+
37
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
38
+
39
+ # --- Left plot: Reflectivity curves ---
40
+ ax[0].set_yscale('log')
41
+ ax[0].set_xlabel('q [$Å^{-1}$]', fontsize=20)
42
+ ax[0].set_ylabel('R(q)', fontsize=20)
43
+ ax[0].tick_params(axis='both', which='major', labelsize=15)
44
+ ax[0].tick_params(axis='both', which='minor', labelsize=15)
45
+
46
+ # Optionally set major y ticks (log scale)
47
+ y_tick_locations = [10 ** (-2 * i) for i in range(6)]
48
+ ax[0].yaxis.set_major_locator(plt.FixedLocator(y_tick_locations))
49
+
50
+ # Plot experimental data with error bars (if provided)
51
+ if q_exp is not None and curve_exp is not None:
52
+ el = ax[0].errorbar(
53
+ q_exp, curve_exp, yerr=sigmas_exp,
54
+ xerr=None, c='b', ecolor='purple', elinewidth=1,
55
+ marker='o', linestyle='none', markersize=3,
56
+ label='exp. curve', zorder=1
57
+ )
58
+ # Change the color of error bar lines (optional)
59
+ elines = el.get_children()
60
+ if len(elines) > 1:
61
+ elines[1].set_color('purple')
62
+
63
+ # Plot predicted curve
64
+ if 'predicted_curve' in prediction_dict and q_model is not None:
65
+ ax[0].plot(q_model, prediction_dict['predicted_curve'], c='red', lw=2, label='pred. curve')
66
+
67
+ # Plot polished curve (if present)
68
+ if 'polished_curve' in prediction_dict and q_model is not None:
69
+ ax[0].plot(q_model, prediction_dict['polished_curve'], c='orange', ls='--', lw=2, label='polished pred. curve')
70
+
71
+ ax[0].legend(fontsize=12)
72
+
73
+ # --- Right plot: SLD profiles ---
74
+ ax[1].set_xlabel('z [$Å$]', fontsize=20)
75
+ ax[1].set_ylabel('SLD [$10^{-6} Å^{-2}$]', fontsize=20)
76
+ ax[1].tick_params(axis='both', which='major', labelsize=15)
77
+ ax[1].tick_params(axis='both', which='minor', labelsize=15)
78
+
79
+ # Predicted SLD
80
+ if 'predicted_sld_xaxis' in prediction_dict and 'predicted_sld_profile' in prediction_dict:
81
+ ax[1].plot(
82
+ prediction_dict['predicted_sld_xaxis'],
83
+ prediction_dict['predicted_sld_profile'],
84
+ c='red', label='pred. sld'
85
+ )
86
+
87
+ # Polished SLD
88
+ if 'sld_profile_polished' in prediction_dict and 'predicted_sld_xaxis' in prediction_dict:
89
+ ax[1].plot(
90
+ prediction_dict['predicted_sld_xaxis'],
91
+ prediction_dict['sld_profile_polished'],
92
+ c='orange', ls='--', label='polished sld'
93
+ )
94
+
95
+ ax[1].legend(fontsize=12)
96
+
97
+ plt.tight_layout()
98
+ plt.show()