reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +90 -15
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +31 -11
- reflectorch/data_generation/reflectivity/__init__.py +56 -14
- reflectorch/data_generation/reflectivity/abeles.py +31 -16
- reflectorch/data_generation/reflectivity/kinematical.py +5 -6
- reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +92 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +220 -105
- reflectorch/inference/plotting.py +98 -0
- reflectorch/inference/scipy_fitter.py +84 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +122 -23
- reflectorch/models/__init__.py +1 -1
- reflectorch/models/encoders/__init__.py +0 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +324 -152
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +43 -9
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -12,8 +12,11 @@ from IPython.display import display
|
|
|
12
12
|
from huggingface_hub import hf_hub_download
|
|
13
13
|
|
|
14
14
|
from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
|
|
15
|
+
from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
|
|
15
16
|
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
|
|
16
17
|
from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
|
|
18
|
+
from reflectorch.inference.plotting import plot_prediction_results
|
|
19
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
17
20
|
from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
|
|
18
21
|
from reflectorch.runs.utils import (
|
|
19
22
|
get_trainer_by_name, train_from_config
|
|
@@ -23,7 +26,7 @@ from reflectorch.ml.trainers import PointEstimatorTrainer
|
|
|
23
26
|
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
24
27
|
|
|
25
28
|
from reflectorch.inference.preprocess_exp import StandardPreprocessing
|
|
26
|
-
from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
|
|
29
|
+
from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
|
|
27
30
|
from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
|
|
28
31
|
from reflectorch.inference.record_time import print_time
|
|
29
32
|
from reflectorch.utils import to_t
|
|
@@ -103,7 +106,9 @@ class EasyInferenceModel(object):
|
|
|
103
106
|
self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
|
|
104
107
|
self.trainer.model.eval()
|
|
105
108
|
|
|
106
|
-
|
|
109
|
+
param_model = self.trainer.loader.prior_sampler.param_model
|
|
110
|
+
param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
|
|
111
|
+
print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
|
|
107
112
|
print("Parameter types and total ranges:")
|
|
108
113
|
for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
|
|
109
114
|
print(f"- {param}: {range_}")
|
|
@@ -115,23 +120,51 @@ class EasyInferenceModel(object):
|
|
|
115
120
|
q_min = self.trainer.loader.q_generator.q[0].item()
|
|
116
121
|
q_max = self.trainer.loader.q_generator.q[-1].item()
|
|
117
122
|
n_q = self.trainer.loader.q_generator.q.shape[0]
|
|
118
|
-
print(f'The model was trained on curves discretized at {n_q} uniform points between
|
|
123
|
+
print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
|
|
119
124
|
elif isinstance(self.trainer.loader.q_generator, VariableQ):
|
|
120
125
|
q_min_range = self.trainer.loader.q_generator.q_min_range
|
|
121
126
|
q_max_range = self.trainer.loader.q_generator.q_max_range
|
|
122
127
|
n_q_range = self.trainer.loader.q_generator.n_q_range
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
128
|
+
if n_q_range[0] == n_q_range[1]:
|
|
129
|
+
n_q_fixed = n_q_range[0]
|
|
130
|
+
print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
|
|
131
|
+
f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
132
|
+
else:
|
|
133
|
+
print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
|
|
134
|
+
f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
135
|
+
|
|
136
|
+
if self.trainer.loader.smearing is not None:
|
|
137
|
+
q_res_min = self.trainer.loader.smearing.sigma_min
|
|
138
|
+
q_res_max = self.trainer.loader.smearing.sigma_max
|
|
139
|
+
if self.trainer.loader.smearing.constant_dq == False:
|
|
140
|
+
print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
|
|
141
|
+
elif self.trainer.loader.smearing.constant_dq == True:
|
|
142
|
+
print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
|
|
143
|
+
|
|
144
|
+
additional_inputs = ["prior bounds"]
|
|
145
|
+
if self.trainer.train_with_q_input:
|
|
146
|
+
additional_inputs.append("q values")
|
|
147
|
+
if self.trainer.condition_on_q_resolutions:
|
|
148
|
+
additional_inputs.append("the resolution dq/q")
|
|
149
|
+
if additional_inputs:
|
|
150
|
+
inputs_str = ", ".join(additional_inputs)
|
|
151
|
+
print(f"The following quantities are additional inputs to the network: {inputs_str}.")
|
|
152
|
+
|
|
153
|
+
def predict(self,
|
|
154
|
+
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
126
155
|
q_values: Union[np.ndarray, Tensor] = None,
|
|
127
156
|
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
157
|
+
q_resolution: Union[float, np.ndarray] = None,
|
|
158
|
+
ambient_sld: float = None,
|
|
128
159
|
clip_prediction: bool = False,
|
|
129
160
|
polish_prediction: bool = False,
|
|
161
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
130
162
|
fit_growth: bool = False,
|
|
131
163
|
max_d_change: float = 5.,
|
|
132
164
|
use_q_shift: bool = False,
|
|
133
165
|
calc_pred_curve: bool = True,
|
|
134
166
|
calc_pred_sld_profile: bool = False,
|
|
167
|
+
calc_polished_sld_profile: bool = False,
|
|
135
168
|
):
|
|
136
169
|
"""Predict the thin film parameters
|
|
137
170
|
|
|
@@ -139,13 +172,17 @@ class EasyInferenceModel(object):
|
|
|
139
172
|
reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
|
|
140
173
|
q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
|
|
141
174
|
prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
|
|
175
|
+
q_resolution (Union[float, np.ndarray], optional): the instrumental resolution. Either as a float with meaning dq/q for linear smearing or as a numpy array with meaning dq for pointwise smearing.
|
|
176
|
+
ambient_sld (float, optional): the SLD of the ambient medium (fronting), if different from air.
|
|
142
177
|
clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
|
|
143
178
|
polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
|
|
179
|
+
polishing_kwargs_reflectivity (dict): extra arguments for the reflectivity function used during polishing.
|
|
144
180
|
fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
|
|
145
181
|
max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
|
|
146
182
|
use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
|
|
147
183
|
calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
|
|
148
184
|
calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
|
|
185
|
+
calc_polished_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the polished parameters. Defaults to False.
|
|
149
186
|
|
|
150
187
|
Returns:
|
|
151
188
|
dict: dictionary containing the predictions
|
|
@@ -153,7 +190,22 @@ class EasyInferenceModel(object):
|
|
|
153
190
|
|
|
154
191
|
scaled_curve = self._scale_curve(reflectivity_curve)
|
|
155
192
|
prior_bounds = np.array(prior_bounds)
|
|
156
|
-
|
|
193
|
+
|
|
194
|
+
if ambient_sld:
|
|
195
|
+
n_layers = self.trainer.loader.prior_sampler.max_num_layers
|
|
196
|
+
sld_indices = slice(2*n_layers+1, 3*n_layers+2)
|
|
197
|
+
prior_bounds[sld_indices, ...] -= ambient_sld
|
|
198
|
+
training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
|
|
199
|
+
training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
|
|
200
|
+
lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
|
|
201
|
+
upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
|
|
202
|
+
assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
206
|
+
except ValueError as e:
|
|
207
|
+
print(str(e))
|
|
208
|
+
return None
|
|
157
209
|
|
|
158
210
|
if not self.trainer.train_with_q_input:
|
|
159
211
|
q_values = self.trainer.loader.q_generator.q
|
|
@@ -166,11 +218,29 @@ class EasyInferenceModel(object):
|
|
|
166
218
|
else:
|
|
167
219
|
with torch.no_grad():
|
|
168
220
|
self.trainer.model.eval()
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
221
|
+
|
|
222
|
+
scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
|
|
223
|
+
|
|
224
|
+
if q_resolution is not None:
|
|
225
|
+
q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
|
|
226
|
+
if isinstance(q_resolution, float):
|
|
227
|
+
unscaled_q_resolutions = q_resolution_tensor
|
|
228
|
+
else:
|
|
229
|
+
unscaled_q_resolutions = (q_resolution_tensor / q_values).mean(dim=-1, keepdim=True)
|
|
230
|
+
scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
|
|
231
|
+
scaled_conditioning_params = scaled_q_resolutions
|
|
232
|
+
if polishing_kwargs_reflectivity is None:
|
|
233
|
+
polishing_kwargs_reflectivity = {'dq': q_resolution}
|
|
172
234
|
else:
|
|
173
|
-
|
|
235
|
+
q_resolution_tensor = None
|
|
236
|
+
scaled_conditioning_params = None
|
|
237
|
+
|
|
238
|
+
scaled_predicted_params = self.trainer.model(
|
|
239
|
+
curves=scaled_curve,
|
|
240
|
+
bounds=scaled_prior_bounds,
|
|
241
|
+
q_values=scaled_q_values,
|
|
242
|
+
conditioning_params = scaled_conditioning_params,
|
|
243
|
+
)
|
|
174
244
|
|
|
175
245
|
predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
176
246
|
|
|
@@ -184,19 +254,22 @@ class EasyInferenceModel(object):
|
|
|
184
254
|
}
|
|
185
255
|
|
|
186
256
|
if calc_pred_curve:
|
|
187
|
-
predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
|
|
257
|
+
predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
|
|
188
258
|
prediction_dict[ "predicted_curve"] = predicted_curve
|
|
189
259
|
|
|
260
|
+
ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld)).to(predicted_params.thicknesses.device) if ambient_sld is not None else None
|
|
190
261
|
if calc_pred_sld_profile:
|
|
191
262
|
predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
|
|
192
|
-
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
|
|
263
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, ambient_sld_tensor, num=1024,
|
|
193
264
|
)
|
|
194
265
|
prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
|
|
195
266
|
prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
|
|
196
267
|
else:
|
|
197
268
|
predicted_sld_xaxis = None
|
|
198
269
|
|
|
199
|
-
if polish_prediction:
|
|
270
|
+
if polish_prediction:
|
|
271
|
+
if ambient_sld_tensor:
|
|
272
|
+
ambient_sld_tensor = ambient_sld_tensor.cpu()
|
|
200
273
|
polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
|
|
201
274
|
curve = reflectivity_curve,
|
|
202
275
|
predicted_params = predicted_params,
|
|
@@ -204,19 +277,25 @@ class EasyInferenceModel(object):
|
|
|
204
277
|
fit_growth = fit_growth,
|
|
205
278
|
max_d_change = max_d_change,
|
|
206
279
|
calc_polished_curve = calc_pred_curve,
|
|
207
|
-
calc_polished_sld_profile =
|
|
280
|
+
calc_polished_sld_profile = calc_polished_sld_profile,
|
|
281
|
+
ambient_sld_tensor=ambient_sld_tensor,
|
|
208
282
|
sld_x_axis = predicted_sld_xaxis,
|
|
283
|
+
polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
|
|
209
284
|
)
|
|
210
285
|
prediction_dict.update(polished_dict)
|
|
211
286
|
|
|
212
287
|
if fit_growth and "polished_params_array" in prediction_dict:
|
|
213
288
|
prediction_dict["param_names"].append("max_d_change")
|
|
214
289
|
|
|
290
|
+
if ambient_sld: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object
|
|
291
|
+
prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
|
|
292
|
+
if "polished_params_array" in prediction_dict:
|
|
293
|
+
prediction_dict["polished_params_array"][sld_indices] += ambient_sld
|
|
294
|
+
|
|
215
295
|
return prediction_dict
|
|
216
296
|
|
|
217
|
-
|
|
218
|
-
"""
|
|
219
|
-
The other arguments are the same as for the ``predict`` method.
|
|
297
|
+
def predict_using_widget(self, reflectivity_curve, **kwargs):
|
|
298
|
+
"""
|
|
220
299
|
"""
|
|
221
300
|
|
|
222
301
|
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
@@ -225,76 +304,74 @@ class EasyInferenceModel(object):
|
|
|
225
304
|
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
226
305
|
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
227
306
|
|
|
228
|
-
print(f'
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
layout=widgets.Layout(width='400px'),
|
|
243
|
-
style={'description_width': '60px'}
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
247
|
-
min_val, max_val = change['new']
|
|
248
|
-
if max_val - min_val > max_width:
|
|
249
|
-
if change['name'] == 'value':
|
|
250
|
-
if change['old'][0] != min_val:
|
|
251
|
-
max_val = min_val + max_width
|
|
252
|
-
else:
|
|
253
|
-
min_val = max_val - max_width
|
|
254
|
-
slider.value = [min_val, max_val]
|
|
255
|
-
|
|
256
|
-
slider.observe(validate_range, names='value')
|
|
257
|
-
|
|
258
|
-
interval_row = widgets.HBox([interval_label, slider])
|
|
259
|
-
intervals.append((slider, interval_row))
|
|
260
|
-
return intervals
|
|
261
|
-
|
|
262
|
-
interval_widgets = create_interval_widgets(NUM_INTERVALS)
|
|
263
|
-
interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
|
|
264
|
-
display(interval_box)
|
|
265
|
-
|
|
266
|
-
button = widgets.Button(description="Make prediction")
|
|
267
|
-
display(button)
|
|
268
|
-
|
|
269
|
-
prediction_result = None
|
|
270
|
-
|
|
271
|
-
def store_values(b, future):
|
|
272
|
-
print("Debug: Button clicked")
|
|
273
|
-
values = []
|
|
274
|
-
for slider, _ in interval_widgets:
|
|
275
|
-
values.append((slider.value[0], slider.value[1]))
|
|
276
|
-
array_values = np.array(values)
|
|
277
|
-
|
|
278
|
-
nonlocal prediction_result
|
|
279
|
-
prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
|
|
280
|
-
print(prediction_result["predicted_params_array"])
|
|
281
|
-
|
|
282
|
-
print("Prediction completed. Closing widget.")
|
|
283
|
-
|
|
284
|
-
for child in interval_box.children:
|
|
285
|
-
child.close()
|
|
286
|
-
button.close()
|
|
287
|
-
|
|
288
|
-
future.set_result(prediction_result)
|
|
289
|
-
|
|
290
|
-
button.on_click(store_values)
|
|
291
|
-
|
|
307
|
+
print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
|
|
308
|
+
|
|
309
|
+
interval_widgets = []
|
|
310
|
+
for i in range(NUM_INTERVALS):
|
|
311
|
+
label = widgets.Label(value=f'{param_labels[i]}')
|
|
312
|
+
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
313
|
+
slider = widgets.FloatRangeSlider(
|
|
314
|
+
value=[min_bounds[i], initial_max],
|
|
315
|
+
min=min_bounds[i],
|
|
316
|
+
max=max_bounds[i],
|
|
317
|
+
step=0.01,
|
|
318
|
+
layout=widgets.Layout(width='400px'),
|
|
319
|
+
style={'description_width': '60px'}
|
|
320
|
+
)
|
|
292
321
|
|
|
293
|
-
|
|
322
|
+
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
323
|
+
min_val, max_val = change['new']
|
|
324
|
+
if max_val - min_val > max_width:
|
|
325
|
+
old_min_val, old_max_val = change['old']
|
|
326
|
+
if abs(old_min_val - min_val) > abs(old_max_val - max_val):
|
|
327
|
+
max_val = min_val + max_width
|
|
328
|
+
else:
|
|
329
|
+
min_val = max_val - max_width
|
|
330
|
+
slider.value = [min_val, max_val]
|
|
331
|
+
|
|
332
|
+
slider.observe(validate_range, names='value')
|
|
333
|
+
interval_widgets.append((slider, widgets.HBox([label, slider])))
|
|
334
|
+
|
|
335
|
+
sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
|
|
336
|
+
|
|
337
|
+
output = widgets.Output()
|
|
338
|
+
predict_button = widgets.Button(description="Predict")
|
|
339
|
+
close_button = widgets.Button(description="Close Widget")
|
|
340
|
+
|
|
341
|
+
container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
|
|
342
|
+
display(container)
|
|
343
|
+
|
|
344
|
+
@output.capture(clear_output=True)
|
|
345
|
+
def on_predict_click(_):
|
|
346
|
+
if 'prior_bounds' in kwargs:
|
|
347
|
+
array_values = kwargs.pop('prior_bounds')
|
|
348
|
+
for i, (s, _) in enumerate(interval_widgets):
|
|
349
|
+
s.value = tuple(array_values[i])
|
|
350
|
+
else:
|
|
351
|
+
values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
|
|
352
|
+
array_values = np.array(values)
|
|
353
|
+
|
|
354
|
+
prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
|
|
355
|
+
prior_bounds=array_values,
|
|
356
|
+
**kwargs)
|
|
357
|
+
param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
358
|
+
for param_name, pred_param_val in zip(param_names, prediction_result["predicted_params_array"]):
|
|
359
|
+
print(f'{param_name.ljust(14)} : {pred_param_val:.2f}')
|
|
360
|
+
|
|
361
|
+
plot_prediction_results(
|
|
362
|
+
prediction_result,
|
|
363
|
+
q_exp=kwargs['q_values'],
|
|
364
|
+
curve_exp=reflectivity_curve,
|
|
365
|
+
q_model=kwargs['q_values'],
|
|
366
|
+
)
|
|
367
|
+
self.prediction_result = prediction_result
|
|
294
368
|
|
|
295
|
-
|
|
369
|
+
def on_close_click(_):
|
|
370
|
+
container.close()
|
|
371
|
+
print("Widget closed.")
|
|
296
372
|
|
|
297
|
-
|
|
373
|
+
predict_button.on_click(on_predict_click)
|
|
374
|
+
close_button.on_click(on_close_click)
|
|
298
375
|
|
|
299
376
|
|
|
300
377
|
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
@@ -328,18 +405,17 @@ class EasyInferenceModel(object):
|
|
|
328
405
|
predicted_params: BasicParams,
|
|
329
406
|
priors: np.ndarray,
|
|
330
407
|
sld_x_axis,
|
|
408
|
+
ambient_sld_tensor: Tensor = None,
|
|
331
409
|
fit_growth: bool = False,
|
|
332
410
|
max_d_change: float = 5.,
|
|
333
411
|
calc_polished_curve: bool = True,
|
|
334
412
|
calc_polished_sld_profile: bool = False,
|
|
413
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
335
414
|
) -> dict:
|
|
336
|
-
params =
|
|
337
|
-
predicted_params.thicknesses.squeeze(),
|
|
338
|
-
predicted_params.roughnesses.squeeze(),
|
|
339
|
-
predicted_params.slds.squeeze()
|
|
340
|
-
]).cpu().numpy()
|
|
415
|
+
params = predicted_params.parameters.squeeze().cpu().numpy()
|
|
341
416
|
|
|
342
417
|
polished_params_dict = {}
|
|
418
|
+
polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
|
|
343
419
|
|
|
344
420
|
try:
|
|
345
421
|
if fit_growth:
|
|
@@ -354,20 +430,25 @@ class EasyInferenceModel(object):
|
|
|
354
430
|
torch.from_numpy(polished_params_arr[:-1][None]),
|
|
355
431
|
torch.from_numpy(priors.T[0][None]),
|
|
356
432
|
torch.from_numpy(priors.T[1][None]),
|
|
357
|
-
|
|
433
|
+
self.trainer.loader.prior_sampler.max_num_layers,
|
|
434
|
+
self.trainer.loader.prior_sampler.param_model
|
|
358
435
|
)
|
|
359
436
|
else:
|
|
360
|
-
polished_params_arr, curve_polished =
|
|
437
|
+
polished_params_arr, curve_polished = refl_fit(
|
|
361
438
|
q = q,
|
|
362
439
|
curve = curve,
|
|
363
440
|
init_params = params,
|
|
364
|
-
bounds=priors.T
|
|
441
|
+
bounds=priors.T,
|
|
442
|
+
prior_sampler=self.trainer.loader.prior_sampler,
|
|
443
|
+
reflectivity_kwargs=polishing_kwargs_reflectivity,
|
|
444
|
+
)
|
|
365
445
|
polished_params = BasicParams(
|
|
366
446
|
torch.from_numpy(polished_params_arr[None]),
|
|
367
447
|
torch.from_numpy(priors.T[0][None]),
|
|
368
448
|
torch.from_numpy(priors.T[1][None]),
|
|
369
|
-
|
|
370
|
-
|
|
449
|
+
self.trainer.loader.prior_sampler.max_num_layers,
|
|
450
|
+
self.trainer.loader.prior_sampler.param_model
|
|
451
|
+
)
|
|
371
452
|
except Exception as err:
|
|
372
453
|
polished_params = predicted_params
|
|
373
454
|
polished_params_arr = get_prediction_array(polished_params)
|
|
@@ -379,9 +460,9 @@ class EasyInferenceModel(object):
|
|
|
379
460
|
|
|
380
461
|
if calc_polished_sld_profile:
|
|
381
462
|
_, sld_profile_polished, _ = get_density_profiles(
|
|
382
|
-
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
|
|
463
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, ambient_sld_tensor, z_axis=sld_x_axis.cpu(),
|
|
383
464
|
)
|
|
384
|
-
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().
|
|
465
|
+
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().numpy()
|
|
385
466
|
|
|
386
467
|
return polished_params_dict
|
|
387
468
|
|
|
@@ -393,16 +474,50 @@ class EasyInferenceModel(object):
|
|
|
393
474
|
return scaled_curve
|
|
394
475
|
|
|
395
476
|
def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
477
|
+
try:
|
|
478
|
+
prior_bounds = torch.tensor(prior_bounds)
|
|
479
|
+
prior_bounds = prior_bounds.to(self.device).T
|
|
480
|
+
min_bounds, max_bounds = prior_bounds[:, None]
|
|
481
|
+
|
|
482
|
+
scaled_bounds = torch.cat([
|
|
483
|
+
self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
|
|
484
|
+
self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
|
|
485
|
+
], -1)
|
|
486
|
+
|
|
487
|
+
return scaled_bounds.float()
|
|
488
|
+
|
|
489
|
+
except RuntimeError as e:
|
|
490
|
+
expected_param_dim = self.trainer.loader.prior_sampler.param_dim
|
|
491
|
+
actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
|
|
492
|
+
|
|
493
|
+
msg = (
|
|
494
|
+
f"\n **Parameter dimension mismatch during inference!**\n"
|
|
495
|
+
f"- Model expects **{expected_param_dim}** parameters.\n"
|
|
496
|
+
f"- You provided **{actual_param_dim}** prior bounds.\n\n"
|
|
497
|
+
f"💡This often occurs when:\n"
|
|
498
|
+
f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
|
|
499
|
+
f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
|
|
500
|
+
f"- The number of layers or parameterization type differs from the one used during training.\n\n"
|
|
501
|
+
f" Check the configuration or the summary of expected parameters."
|
|
502
|
+
)
|
|
503
|
+
raise ValueError(msg) from e
|
|
504
|
+
|
|
505
|
+
def interpolate_data_to_model_q(self, q_exp, curve_exp):
|
|
506
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
507
|
+
q_model = self.trainer.loader.q_generator.q.cpu().numpy()
|
|
508
|
+
elif isinstance(self.trainer.loader.q_generator, VariableQ):
|
|
509
|
+
if self.trainer.loader.q_generator.n_q_range[0] == self.trainer.loader.q_generator.n_q_range[1]:
|
|
510
|
+
n_q_model = self.trainer.loader.q_generator.n_q_range[0]
|
|
511
|
+
q_model_min = max(q_exp.min(), self.trainer.loader.q_generator.q_min_range[0])
|
|
512
|
+
q_model_max = min(q_exp.max(), self.trainer.loader.q_generator.q_max_range[1])
|
|
513
|
+
q_model = np.linspace(q_model_min, q_model_max, n_q_model)
|
|
514
|
+
else:
|
|
515
|
+
q_model = q_exp
|
|
516
|
+
exp_curve_interp = curve_exp
|
|
399
517
|
|
|
400
|
-
|
|
401
|
-
self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
|
|
402
|
-
self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
|
|
403
|
-
], -1)
|
|
518
|
+
exp_curve_interp = interp_reflectivity(q_model, q_exp, curve_exp)
|
|
404
519
|
|
|
405
|
-
return
|
|
520
|
+
return q_model, exp_curve_interp
|
|
406
521
|
|
|
407
522
|
def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
408
523
|
return LogLikelihood(
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from matplotlib import pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def plot_prediction_results(
|
|
6
|
+
prediction_dict: dict,
|
|
7
|
+
q_exp: np.ndarray = None,
|
|
8
|
+
curve_exp: np.ndarray = None,
|
|
9
|
+
sigmas_exp: np.ndarray = None,
|
|
10
|
+
q_model: np.ndarray = None,
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Plot the experimental curve (with optional error bars), the predicted
|
|
14
|
+
and polished curves, and also the predicted/polished SLD profiles.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
prediction_dict (dict): Dictionary containing 'predicted_curve',
|
|
18
|
+
'predicted_sld_profile', 'predicted_sld_xaxis',
|
|
19
|
+
and optionally 'polished_curve', 'sld_profile_polished'.
|
|
20
|
+
q_exp (ndarray, optional): Experimental q-values.
|
|
21
|
+
curve_exp (ndarray, optional): Experimental reflectivity curve.
|
|
22
|
+
sigmas_exp (ndarray, optional): Error bars of the experimental reflectivity.
|
|
23
|
+
q_model (ndarray, optional): The q-values on which prediction_dict's reflectivity
|
|
24
|
+
was computed (e.g. from EasyInferenceModel.interpolate_data_to_model_q).
|
|
25
|
+
|
|
26
|
+
Example usage:
|
|
27
|
+
prediction_dict = model.predict(...)
|
|
28
|
+
plot_prediction_results(
|
|
29
|
+
prediction_dict,
|
|
30
|
+
q_exp=q_exp,
|
|
31
|
+
curve_exp=curve_exp,
|
|
32
|
+
sigmas_exp=sigmas_exp,
|
|
33
|
+
q_model=q_model
|
|
34
|
+
)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
|
|
38
|
+
|
|
39
|
+
# --- Left plot: Reflectivity curves ---
|
|
40
|
+
ax[0].set_yscale('log')
|
|
41
|
+
ax[0].set_xlabel('q [$Å^{-1}$]', fontsize=20)
|
|
42
|
+
ax[0].set_ylabel('R(q)', fontsize=20)
|
|
43
|
+
ax[0].tick_params(axis='both', which='major', labelsize=15)
|
|
44
|
+
ax[0].tick_params(axis='both', which='minor', labelsize=15)
|
|
45
|
+
|
|
46
|
+
# Optionally set major y ticks (log scale)
|
|
47
|
+
y_tick_locations = [10 ** (-2 * i) for i in range(6)]
|
|
48
|
+
ax[0].yaxis.set_major_locator(plt.FixedLocator(y_tick_locations))
|
|
49
|
+
|
|
50
|
+
# Plot experimental data with error bars (if provided)
|
|
51
|
+
if q_exp is not None and curve_exp is not None:
|
|
52
|
+
el = ax[0].errorbar(
|
|
53
|
+
q_exp, curve_exp, yerr=sigmas_exp,
|
|
54
|
+
xerr=None, c='b', ecolor='purple', elinewidth=1,
|
|
55
|
+
marker='o', linestyle='none', markersize=3,
|
|
56
|
+
label='exp. curve', zorder=1
|
|
57
|
+
)
|
|
58
|
+
# Change the color of error bar lines (optional)
|
|
59
|
+
elines = el.get_children()
|
|
60
|
+
if len(elines) > 1:
|
|
61
|
+
elines[1].set_color('purple')
|
|
62
|
+
|
|
63
|
+
# Plot predicted curve
|
|
64
|
+
if 'predicted_curve' in prediction_dict and q_model is not None:
|
|
65
|
+
ax[0].plot(q_model, prediction_dict['predicted_curve'], c='red', lw=2, label='pred. curve')
|
|
66
|
+
|
|
67
|
+
# Plot polished curve (if present)
|
|
68
|
+
if 'polished_curve' in prediction_dict and q_model is not None:
|
|
69
|
+
ax[0].plot(q_model, prediction_dict['polished_curve'], c='orange', ls='--', lw=2, label='polished pred. curve')
|
|
70
|
+
|
|
71
|
+
ax[0].legend(fontsize=12)
|
|
72
|
+
|
|
73
|
+
# --- Right plot: SLD profiles ---
|
|
74
|
+
ax[1].set_xlabel('z [$Å$]', fontsize=20)
|
|
75
|
+
ax[1].set_ylabel('SLD [$10^{-6} Å^{-2}$]', fontsize=20)
|
|
76
|
+
ax[1].tick_params(axis='both', which='major', labelsize=15)
|
|
77
|
+
ax[1].tick_params(axis='both', which='minor', labelsize=15)
|
|
78
|
+
|
|
79
|
+
# Predicted SLD
|
|
80
|
+
if 'predicted_sld_xaxis' in prediction_dict and 'predicted_sld_profile' in prediction_dict:
|
|
81
|
+
ax[1].plot(
|
|
82
|
+
prediction_dict['predicted_sld_xaxis'],
|
|
83
|
+
prediction_dict['predicted_sld_profile'],
|
|
84
|
+
c='red', label='pred. sld'
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Polished SLD
|
|
88
|
+
if 'sld_profile_polished' in prediction_dict and 'predicted_sld_xaxis' in prediction_dict:
|
|
89
|
+
ax[1].plot(
|
|
90
|
+
prediction_dict['predicted_sld_xaxis'],
|
|
91
|
+
prediction_dict['sld_profile_polished'],
|
|
92
|
+
c='orange', ls='--', label='polished sld'
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
ax[1].legend(fontsize=12)
|
|
96
|
+
|
|
97
|
+
plt.tight_layout()
|
|
98
|
+
plt.show()
|