reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +4 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +91 -16
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +97 -43
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +93 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +795 -159
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +517 -0
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +98 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +131 -23
- reflectorch/models/__init__.py +2 -1
- reflectorch/models/encoders/__init__.py +2 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +331 -153
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +48 -11
- reflectorch/utils.py +30 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -12,8 +12,10 @@ from IPython.display import display
|
|
|
12
12
|
from huggingface_hub import hf_hub_download
|
|
13
13
|
|
|
14
14
|
from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
|
|
15
|
-
from reflectorch.data_generation.
|
|
15
|
+
from reflectorch.data_generation.priors.parametric_models import NuisanceParamsWrapper
|
|
16
|
+
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ, MaskedVariableQ
|
|
16
17
|
from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
|
|
18
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
17
19
|
from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
|
|
18
20
|
from reflectorch.runs.utils import (
|
|
19
21
|
get_trainer_by_name, train_from_config
|
|
@@ -23,10 +25,11 @@ from reflectorch.ml.trainers import PointEstimatorTrainer
|
|
|
23
25
|
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
24
26
|
|
|
25
27
|
from reflectorch.inference.preprocess_exp import StandardPreprocessing
|
|
26
|
-
from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
|
|
28
|
+
from reflectorch.inference.scipy_fitter import standard_refl_fit, refl_fit, get_fit_with_growth
|
|
27
29
|
from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
|
|
28
30
|
from reflectorch.inference.record_time import print_time
|
|
29
|
-
from reflectorch.
|
|
31
|
+
from reflectorch.inference.plotting import plot_reflectivity, plot_prediction_results, print_prediction_results
|
|
32
|
+
from reflectorch.utils import get_filtering_mask, to_t
|
|
30
33
|
|
|
31
34
|
class EasyInferenceModel(object):
|
|
32
35
|
"""Facilitates the inference process using pretrained models
|
|
@@ -103,7 +106,9 @@ class EasyInferenceModel(object):
|
|
|
103
106
|
self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
|
|
104
107
|
self.trainer.model.eval()
|
|
105
108
|
|
|
106
|
-
|
|
109
|
+
param_model = self.trainer.loader.prior_sampler.param_model
|
|
110
|
+
param_model_name = param_model.base_model.NAME if isinstance(param_model, NuisanceParamsWrapper) else param_model.NAME
|
|
111
|
+
print(f'The model corresponds to a `{param_model_name}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
|
|
107
112
|
print("Parameter types and total ranges:")
|
|
108
113
|
for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
|
|
109
114
|
print(f"- {param}: {range_}")
|
|
@@ -115,23 +120,188 @@ class EasyInferenceModel(object):
|
|
|
115
120
|
q_min = self.trainer.loader.q_generator.q[0].item()
|
|
116
121
|
q_max = self.trainer.loader.q_generator.q[-1].item()
|
|
117
122
|
n_q = self.trainer.loader.q_generator.q.shape[0]
|
|
118
|
-
print(f'The model was trained on curves discretized at {n_q} uniform points between
|
|
123
|
+
print(f'The model was trained on curves discretized at {n_q} uniform points between q_min={q_min} and q_max={q_max}')
|
|
119
124
|
elif isinstance(self.trainer.loader.q_generator, VariableQ):
|
|
120
125
|
q_min_range = self.trainer.loader.q_generator.q_min_range
|
|
121
126
|
q_max_range = self.trainer.loader.q_generator.q_max_range
|
|
122
127
|
n_q_range = self.trainer.loader.q_generator.n_q_range
|
|
123
|
-
|
|
128
|
+
if n_q_range[0] == n_q_range[1]:
|
|
129
|
+
n_q_fixed = n_q_range[0]
|
|
130
|
+
print(f'The model was trained on curves discretized at exactly {n_q_fixed} uniform points, '
|
|
131
|
+
f'between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
132
|
+
else:
|
|
133
|
+
print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} '
|
|
134
|
+
f'of uniform points between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
135
|
+
|
|
136
|
+
if self.trainer.loader.smearing is not None:
|
|
137
|
+
q_res_min = self.trainer.loader.smearing.sigma_min
|
|
138
|
+
q_res_max = self.trainer.loader.smearing.sigma_max
|
|
139
|
+
if self.trainer.loader.smearing.constant_dq == False:
|
|
140
|
+
print(f"The model was trained with linear resolution smearing (dq/q) in the range [{q_res_min}, {q_res_max}]")
|
|
141
|
+
elif self.trainer.loader.smearing.constant_dq == True:
|
|
142
|
+
print(f"The model was trained with constant resolution smearing in the range [{q_res_min}, {q_res_max}]")
|
|
143
|
+
|
|
144
|
+
additional_inputs = ["prior bounds"]
|
|
145
|
+
if self.trainer.train_with_q_input:
|
|
146
|
+
additional_inputs.append("q values")
|
|
147
|
+
if self.trainer.condition_on_q_resolutions:
|
|
148
|
+
additional_inputs.append("the resolution dq/q")
|
|
149
|
+
if additional_inputs:
|
|
150
|
+
inputs_str = ", ".join(additional_inputs)
|
|
151
|
+
print(f"The following quantities are additional inputs to the network: {inputs_str}.")
|
|
152
|
+
|
|
153
|
+
def preprocess_and_predict(self,
|
|
154
|
+
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
155
|
+
q_values: Union[np.ndarray, Tensor] = None,
|
|
156
|
+
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
157
|
+
sigmas: Union[np.ndarray, Tensor] = None,
|
|
158
|
+
q_resolution: float = None,
|
|
159
|
+
ambient_sld: float = None,
|
|
160
|
+
clip_prediction: bool = True,
|
|
161
|
+
polish_prediction: bool = False,
|
|
162
|
+
polishing_method: str = 'trf',
|
|
163
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
164
|
+
use_sigmas_for_polishing: bool = False,
|
|
165
|
+
polishing_max_nfev: int = None,
|
|
166
|
+
fit_growth: bool = False,
|
|
167
|
+
max_d_change: float = 5.,
|
|
168
|
+
calc_pred_curve: bool = True,
|
|
169
|
+
calc_pred_sld_profile: bool = False,
|
|
170
|
+
calc_polished_sld_profile: bool = False,
|
|
171
|
+
sld_profile_padding_left: float = 0.2,
|
|
172
|
+
sld_profile_padding_right: float = 1.1,
|
|
173
|
+
kwargs_param_labels: dict = {},
|
|
174
|
+
|
|
175
|
+
truncate_index_left: int = None,
|
|
176
|
+
truncate_index_right: int = None,
|
|
177
|
+
enable_error_bars_filtering: bool = True,
|
|
178
|
+
filter_threshold=0.3,
|
|
179
|
+
filter_remove_singles=True,
|
|
180
|
+
filter_remove_consecutives=True,
|
|
181
|
+
filter_consecutive=3,
|
|
182
|
+
filter_q_start_trunc=0.1,
|
|
183
|
+
):
|
|
184
|
+
|
|
185
|
+
## Preprocess the data for inference (remove negative intensities, truncation, filer out points with high error bars)
|
|
186
|
+
(q_values, reflectivity_curve, sigmas, q_resolution,
|
|
187
|
+
q_values_original, reflectivity_curve_original, sigmas_original, q_resolution_original) = self._preprocess_input_data(
|
|
188
|
+
reflectivity_curve=reflectivity_curve,
|
|
189
|
+
q_values=q_values,
|
|
190
|
+
sigmas=sigmas,
|
|
191
|
+
q_resolution=q_resolution,
|
|
192
|
+
truncate_index_left=truncate_index_left,
|
|
193
|
+
truncate_index_right=truncate_index_right,
|
|
194
|
+
enable_error_bars_filtering=enable_error_bars_filtering,
|
|
195
|
+
filter_threshold=filter_threshold,
|
|
196
|
+
filter_remove_singles=filter_remove_singles,
|
|
197
|
+
filter_remove_consecutives=filter_remove_consecutives,
|
|
198
|
+
filter_consecutive=filter_consecutive,
|
|
199
|
+
filter_q_start_trunc=filter_q_start_trunc,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
### Interpolate the experimental data if needed by the embedding network
|
|
203
|
+
interp_data = self.interpolate_data_to_model_q(
|
|
204
|
+
q_exp=q_values,
|
|
205
|
+
refl_exp=reflectivity_curve,
|
|
206
|
+
sigmas_exp=sigmas,
|
|
207
|
+
q_res_exp=q_resolution,
|
|
208
|
+
as_dict=True
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
q_model = interp_data["q_model"]
|
|
212
|
+
reflectivity_curve_interp = interp_data["reflectivity"]
|
|
213
|
+
sigmas_interp = interp_data.get("sigmas")
|
|
214
|
+
q_resolution_interp = interp_data.get("q_resolution")
|
|
215
|
+
key_padding_mask = interp_data.get("key_padding_mask")
|
|
216
|
+
|
|
217
|
+
### Make the prediction
|
|
218
|
+
prediction_dict = self.predict(
|
|
219
|
+
reflectivity_curve=reflectivity_curve_interp,
|
|
220
|
+
q_values=q_model,
|
|
221
|
+
sigmas=sigmas_interp,
|
|
222
|
+
q_resolution=q_resolution_interp,
|
|
223
|
+
key_padding_mask=key_padding_mask,
|
|
224
|
+
prior_bounds=prior_bounds,
|
|
225
|
+
ambient_sld=ambient_sld,
|
|
226
|
+
clip_prediction=clip_prediction,
|
|
227
|
+
polish_prediction=False, ###do the polishing outside the predict method on the full data
|
|
228
|
+
supress_sld_amb_back_shift=True, ###do not shift back the slds by the ambient yet
|
|
229
|
+
calc_pred_curve=calc_pred_curve,
|
|
230
|
+
calc_pred_sld_profile=calc_pred_sld_profile,
|
|
231
|
+
sld_profile_padding_left=sld_profile_padding_left,
|
|
232
|
+
sld_profile_padding_right=sld_profile_padding_right,
|
|
233
|
+
kwargs_param_labels=kwargs_param_labels,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
### Save interpolated data
|
|
237
|
+
prediction_dict['q_model'] = q_model
|
|
238
|
+
prediction_dict['reflectivity_curve_interp'] = reflectivity_curve_interp
|
|
239
|
+
if q_resolution_interp is not None:
|
|
240
|
+
prediction_dict['q_resolution_interp'] = q_resolution_interp
|
|
241
|
+
if sigmas_interp is not None:
|
|
242
|
+
prediction_dict['sigmas_interp'] = sigmas_interp
|
|
243
|
+
if key_padding_mask is not None:
|
|
244
|
+
prediction_dict['key_padding_mask'] = key_padding_mask
|
|
245
|
+
|
|
246
|
+
### Perform polishing on the original data
|
|
247
|
+
if polish_prediction:
|
|
248
|
+
polishing_kwargs = polishing_kwargs_reflectivity or {}
|
|
249
|
+
polishing_kwargs.setdefault('dq', q_resolution_original)
|
|
250
|
+
|
|
251
|
+
prior_bounds = np.array(prior_bounds)
|
|
252
|
+
if ambient_sld:
|
|
253
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
254
|
+
|
|
255
|
+
polished_dict = self._polish_prediction(
|
|
256
|
+
q=q_values_original,
|
|
257
|
+
curve=reflectivity_curve_original,
|
|
258
|
+
predicted_params=prediction_dict['predicted_params_object'],
|
|
259
|
+
priors=prior_bounds,
|
|
260
|
+
ambient_sld_tensor=torch.atleast_2d(torch.as_tensor(ambient_sld)).to(self.device) if ambient_sld is not None else None,
|
|
261
|
+
calc_polished_sld_profile=calc_polished_sld_profile,
|
|
262
|
+
sld_x_axis=torch.from_numpy(prediction_dict['predicted_sld_xaxis']),
|
|
263
|
+
polishing_kwargs_reflectivity = polishing_kwargs,
|
|
264
|
+
error_bars=sigmas_original if use_sigmas_for_polishing else None,
|
|
265
|
+
polishing_method=polishing_method,
|
|
266
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
267
|
+
fit_growth=fit_growth,
|
|
268
|
+
max_d_change=max_d_change,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
prediction_dict.update(polished_dict)
|
|
272
|
+
if fit_growth and "polished_params_array" in prediction_dict:
|
|
273
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
274
|
+
|
|
275
|
+
### Shift back the slds for nonzero ambient
|
|
276
|
+
if ambient_sld:
|
|
277
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
124
278
|
|
|
125
|
-
|
|
279
|
+
return prediction_dict
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def predict(self,
|
|
283
|
+
reflectivity_curve: Union[np.ndarray, Tensor],
|
|
126
284
|
q_values: Union[np.ndarray, Tensor] = None,
|
|
127
285
|
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
128
|
-
|
|
129
|
-
|
|
286
|
+
sigmas: Union[np.ndarray, Tensor] = None,
|
|
287
|
+
key_padding_mask: Union[np.ndarray, Tensor] = None,
|
|
288
|
+
q_resolution: float = None,
|
|
289
|
+
ambient_sld: float = None,
|
|
290
|
+
clip_prediction: bool = True,
|
|
291
|
+
polish_prediction: bool = False,
|
|
292
|
+
polishing_method: str = 'trf',
|
|
293
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
294
|
+
polishing_max_nfev: int = None,
|
|
130
295
|
fit_growth: bool = False,
|
|
131
296
|
max_d_change: float = 5.,
|
|
132
297
|
use_q_shift: bool = False,
|
|
133
298
|
calc_pred_curve: bool = True,
|
|
134
299
|
calc_pred_sld_profile: bool = False,
|
|
300
|
+
calc_polished_sld_profile: bool = False,
|
|
301
|
+
sld_profile_padding_left: float = 0.2,
|
|
302
|
+
sld_profile_padding_right: float = 1.1,
|
|
303
|
+
supress_sld_amb_back_shift: bool = False,
|
|
304
|
+
kwargs_param_labels: dict = {},
|
|
135
305
|
):
|
|
136
306
|
"""Predict the thin film parameters
|
|
137
307
|
|
|
@@ -153,24 +323,51 @@ class EasyInferenceModel(object):
|
|
|
153
323
|
|
|
154
324
|
scaled_curve = self._scale_curve(reflectivity_curve)
|
|
155
325
|
prior_bounds = np.array(prior_bounds)
|
|
326
|
+
|
|
327
|
+
if ambient_sld:
|
|
328
|
+
sld_indices = self._shift_slds_by_ambient(prior_bounds, ambient_sld)
|
|
329
|
+
|
|
156
330
|
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
157
331
|
|
|
158
|
-
if
|
|
332
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
159
333
|
q_values = self.trainer.loader.q_generator.q
|
|
160
334
|
else:
|
|
161
335
|
q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
|
|
162
336
|
|
|
337
|
+
scaled_q_values = self.trainer.loader.q_generator.scale_q(q_values).to(torch.float32) if self.trainer.train_with_q_input else None
|
|
338
|
+
|
|
339
|
+
if q_resolution is not None:
|
|
340
|
+
q_resolution_tensor = torch.atleast_2d(torch.as_tensor(q_resolution)).to(scaled_curve)
|
|
341
|
+
if isinstance(q_resolution, float):
|
|
342
|
+
unscaled_q_resolutions = q_resolution_tensor
|
|
343
|
+
else:
|
|
344
|
+
unscaled_q_resolutions = (q_resolution_tensor / q_values).nanmean(dim=-1, keepdim=True) ##when q_values is padded with 0s, there will be nan at the padded positions
|
|
345
|
+
scaled_q_resolutions = self.trainer.loader.smearing.scale_resolutions(unscaled_q_resolutions) if self.trainer.condition_on_q_resolutions else None
|
|
346
|
+
scaled_conditioning_params = scaled_q_resolutions
|
|
347
|
+
if polishing_kwargs_reflectivity is None:
|
|
348
|
+
polishing_kwargs_reflectivity = {'dq': q_resolution}
|
|
349
|
+
else:
|
|
350
|
+
q_resolution_tensor = None
|
|
351
|
+
scaled_conditioning_params = None
|
|
352
|
+
|
|
353
|
+
if key_padding_mask is not None:
|
|
354
|
+
key_padding_mask = torch.as_tensor(key_padding_mask, device=self.device)
|
|
355
|
+
key_padding_mask = key_padding_mask.unsqueeze(0) if key_padding_mask.dim() == 1 else key_padding_mask
|
|
356
|
+
|
|
163
357
|
if use_q_shift and not self.trainer.train_with_q_input:
|
|
164
358
|
predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
|
|
165
|
-
|
|
166
359
|
else:
|
|
167
360
|
with torch.no_grad():
|
|
168
361
|
self.trainer.model.eval()
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
362
|
+
|
|
363
|
+
scaled_predicted_params = self.trainer.model(
|
|
364
|
+
curves=scaled_curve,
|
|
365
|
+
bounds=scaled_prior_bounds,
|
|
366
|
+
q_values=scaled_q_values,
|
|
367
|
+
conditioning_params = scaled_conditioning_params,
|
|
368
|
+
key_padding_mask = key_padding_mask,
|
|
369
|
+
unscaled_q_values = q_values,
|
|
370
|
+
)
|
|
174
371
|
|
|
175
372
|
predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
176
373
|
|
|
@@ -180,166 +377,80 @@ class EasyInferenceModel(object):
|
|
|
180
377
|
prediction_dict = {
|
|
181
378
|
"predicted_params_object": predicted_params,
|
|
182
379
|
"predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
|
|
183
|
-
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
380
|
+
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs_param_labels)
|
|
184
381
|
}
|
|
382
|
+
|
|
383
|
+
key_padding_mask = None if key_padding_mask is None else key_padding_mask.squeeze().cpu().numpy()
|
|
185
384
|
|
|
186
385
|
if calc_pred_curve:
|
|
187
|
-
predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
|
|
188
|
-
prediction_dict[ "predicted_curve"] = predicted_curve
|
|
386
|
+
predicted_curve = predicted_params.reflectivity(q=q_values, dq=q_resolution_tensor).squeeze().cpu().numpy()
|
|
387
|
+
prediction_dict[ "predicted_curve"] = predicted_curve if key_padding_mask is None else predicted_curve[key_padding_mask]
|
|
189
388
|
|
|
389
|
+
ambient_sld_tensor = torch.atleast_2d(torch.as_tensor(ambient_sld, device=self.device)) if ambient_sld is not None else None
|
|
190
390
|
if calc_pred_sld_profile:
|
|
191
391
|
predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
|
|
192
|
-
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds,
|
|
392
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
393
|
+
num=1024, padding_left=sld_profile_padding_left, padding_right=sld_profile_padding_right,
|
|
193
394
|
)
|
|
194
395
|
prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
|
|
195
396
|
prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
|
|
196
397
|
else:
|
|
197
398
|
predicted_sld_xaxis = None
|
|
399
|
+
|
|
400
|
+
refl_curve_polish = reflectivity_curve if key_padding_mask is None else reflectivity_curve[key_padding_mask]
|
|
401
|
+
q_polish = q_values.squeeze().cpu().numpy() if key_padding_mask is None else q_values.squeeze().cpu().numpy()[key_padding_mask]
|
|
402
|
+
prediction_dict['q_plot_pred'] = q_polish
|
|
198
403
|
|
|
199
|
-
if polish_prediction:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
404
|
+
if polish_prediction:
|
|
405
|
+
if ambient_sld_tensor:
|
|
406
|
+
ambient_sld_tensor = ambient_sld_tensor.cpu()
|
|
407
|
+
|
|
408
|
+
polished_dict = self._polish_prediction(
|
|
409
|
+
q = q_polish,
|
|
410
|
+
curve = refl_curve_polish,
|
|
411
|
+
predicted_params = predicted_params,
|
|
412
|
+
priors = np.array(prior_bounds),
|
|
413
|
+
error_bars = sigmas,
|
|
414
|
+
fit_growth = fit_growth,
|
|
415
|
+
max_d_change = max_d_change,
|
|
416
|
+
calc_polished_curve = calc_pred_curve,
|
|
417
|
+
calc_polished_sld_profile = calc_polished_sld_profile,
|
|
418
|
+
ambient_sld_tensor=ambient_sld_tensor,
|
|
419
|
+
sld_x_axis = predicted_sld_xaxis,
|
|
420
|
+
polishing_method=polishing_method,
|
|
421
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
422
|
+
polishing_kwargs_reflectivity=polishing_kwargs_reflectivity,
|
|
423
|
+
)
|
|
210
424
|
prediction_dict.update(polished_dict)
|
|
211
425
|
|
|
212
426
|
if fit_growth and "polished_params_array" in prediction_dict:
|
|
213
427
|
prediction_dict["param_names"].append("max_d_change")
|
|
214
428
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
async def predict_using_widget(self, reflectivity_curve: Union[np.ndarray, torch.Tensor], **kwargs):
|
|
218
|
-
"""Use an interactive Python widget for specifying the prior bounds before the prediction (works only in a Jupyter notebook).
|
|
219
|
-
The other arguments are the same as for the ``predict`` method.
|
|
220
|
-
"""
|
|
221
|
-
|
|
222
|
-
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
223
|
-
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
224
|
-
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
225
|
-
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
226
|
-
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
227
|
-
|
|
228
|
-
print(f'Parameter ranges: {self.trainer.loader.prior_sampler.param_ranges}')
|
|
229
|
-
print(f'Allowed widths of the prior bound intervals (max-min): {self.trainer.loader.prior_sampler.bound_width_ranges}')
|
|
230
|
-
print(f'Please fill in the values of the minimum and maximum prior bound for each parameter and press the button!')
|
|
231
|
-
|
|
232
|
-
def create_interval_widgets(n):
|
|
233
|
-
intervals = []
|
|
234
|
-
for i in range(n):
|
|
235
|
-
interval_label = widgets.Label(value=f'{param_labels[i]}')
|
|
236
|
-
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
237
|
-
slider = widgets.FloatRangeSlider(
|
|
238
|
-
value=[min_bounds[i], initial_max],
|
|
239
|
-
min=min_bounds[i],
|
|
240
|
-
max=max_bounds[i],
|
|
241
|
-
step=0.01,
|
|
242
|
-
layout=widgets.Layout(width='400px'),
|
|
243
|
-
style={'description_width': '60px'}
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
247
|
-
min_val, max_val = change['new']
|
|
248
|
-
if max_val - min_val > max_width:
|
|
249
|
-
if change['name'] == 'value':
|
|
250
|
-
if change['old'][0] != min_val:
|
|
251
|
-
max_val = min_val + max_width
|
|
252
|
-
else:
|
|
253
|
-
min_val = max_val - max_width
|
|
254
|
-
slider.value = [min_val, max_val]
|
|
255
|
-
|
|
256
|
-
slider.observe(validate_range, names='value')
|
|
257
|
-
|
|
258
|
-
interval_row = widgets.HBox([interval_label, slider])
|
|
259
|
-
intervals.append((slider, interval_row))
|
|
260
|
-
return intervals
|
|
261
|
-
|
|
262
|
-
interval_widgets = create_interval_widgets(NUM_INTERVALS)
|
|
263
|
-
interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
|
|
264
|
-
display(interval_box)
|
|
265
|
-
|
|
266
|
-
button = widgets.Button(description="Make prediction")
|
|
267
|
-
display(button)
|
|
268
|
-
|
|
269
|
-
prediction_result = None
|
|
270
|
-
|
|
271
|
-
def store_values(b, future):
|
|
272
|
-
print("Debug: Button clicked")
|
|
273
|
-
values = []
|
|
274
|
-
for slider, _ in interval_widgets:
|
|
275
|
-
values.append((slider.value[0], slider.value[1]))
|
|
276
|
-
array_values = np.array(values)
|
|
277
|
-
|
|
278
|
-
nonlocal prediction_result
|
|
279
|
-
prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
|
|
280
|
-
print(prediction_result["predicted_params_array"])
|
|
281
|
-
|
|
282
|
-
print("Prediction completed. Closing widget.")
|
|
283
|
-
|
|
284
|
-
for child in interval_box.children:
|
|
285
|
-
child.close()
|
|
286
|
-
button.close()
|
|
287
|
-
|
|
288
|
-
future.set_result(prediction_result)
|
|
429
|
+
if ambient_sld and not supress_sld_amb_back_shift: #Note: the SLD shift will only be reflected in predicted_params_array but not in predicted_params_object; supress_sld_amb_back_shift is required for the 'preprocess_and_predict' method
|
|
430
|
+
self._restore_slds_after_ambient_shift(prediction_dict, sld_indices, ambient_sld)
|
|
289
431
|
|
|
290
|
-
|
|
432
|
+
return prediction_dict
|
|
291
433
|
|
|
292
|
-
|
|
293
|
-
future = asyncio.Future()
|
|
294
|
-
|
|
295
|
-
button.on_click(lambda b: store_values(b, future))
|
|
296
|
-
|
|
297
|
-
return await future
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
301
|
-
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
302
|
-
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
303
|
-
dq_max = (q[1] - q[0]) * dq_coef
|
|
304
|
-
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
305
|
-
|
|
306
|
-
curve = to_t(curve).to(scaled_bounds)
|
|
307
|
-
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
308
|
-
|
|
309
|
-
assert shifted_curves.shape == (num, q.shape[0])
|
|
310
|
-
|
|
311
|
-
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
312
|
-
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
313
|
-
|
|
314
|
-
with torch.no_grad():
|
|
315
|
-
self.trainer.model.eval()
|
|
316
|
-
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
317
|
-
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
318
|
-
|
|
319
|
-
best_param = get_best_mse_param(
|
|
320
|
-
restored_params,
|
|
321
|
-
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
322
|
-
)
|
|
323
|
-
return best_param
|
|
324
|
-
|
|
325
434
|
def _polish_prediction(self,
|
|
326
435
|
q: np.ndarray,
|
|
327
436
|
curve: np.ndarray,
|
|
328
437
|
predicted_params: BasicParams,
|
|
329
438
|
priors: np.ndarray,
|
|
330
439
|
sld_x_axis,
|
|
440
|
+
ambient_sld_tensor: Tensor = None,
|
|
331
441
|
fit_growth: bool = False,
|
|
332
442
|
max_d_change: float = 5.,
|
|
333
443
|
calc_polished_curve: bool = True,
|
|
334
444
|
calc_polished_sld_profile: bool = False,
|
|
445
|
+
error_bars: np.ndarray = None,
|
|
446
|
+
polishing_method: str = 'trf',
|
|
447
|
+
polishing_max_nfev: int = None,
|
|
448
|
+
polishing_kwargs_reflectivity: dict = None,
|
|
335
449
|
) -> dict:
|
|
336
|
-
params =
|
|
337
|
-
predicted_params.thicknesses.squeeze(),
|
|
338
|
-
predicted_params.roughnesses.squeeze(),
|
|
339
|
-
predicted_params.slds.squeeze()
|
|
340
|
-
]).cpu().numpy()
|
|
450
|
+
params = predicted_params.parameters.squeeze().cpu().numpy()
|
|
341
451
|
|
|
342
452
|
polished_params_dict = {}
|
|
453
|
+
polishing_kwargs_reflectivity = polishing_kwargs_reflectivity or {}
|
|
343
454
|
|
|
344
455
|
try:
|
|
345
456
|
if fit_growth:
|
|
@@ -358,18 +469,24 @@ class EasyInferenceModel(object):
|
|
|
358
469
|
self.trainer.loader.prior_sampler.param_model
|
|
359
470
|
)
|
|
360
471
|
else:
|
|
361
|
-
polished_params_arr, curve_polished =
|
|
472
|
+
polished_params_arr, curve_polished = refl_fit(
|
|
362
473
|
q = q,
|
|
363
474
|
curve = curve,
|
|
364
475
|
init_params = params,
|
|
365
|
-
bounds=priors.T
|
|
476
|
+
bounds=priors.T,
|
|
477
|
+
prior_sampler=self.trainer.loader.prior_sampler,
|
|
478
|
+
error_bars=error_bars,
|
|
479
|
+
method=polishing_method,
|
|
480
|
+
polishing_max_nfev=polishing_max_nfev,
|
|
481
|
+
reflectivity_kwargs=polishing_kwargs_reflectivity,
|
|
482
|
+
)
|
|
366
483
|
polished_params = BasicParams(
|
|
367
484
|
torch.from_numpy(polished_params_arr[None]),
|
|
368
485
|
torch.from_numpy(priors.T[0][None]),
|
|
369
486
|
torch.from_numpy(priors.T[1][None]),
|
|
370
487
|
self.trainer.loader.prior_sampler.max_num_layers,
|
|
371
488
|
self.trainer.loader.prior_sampler.param_model
|
|
372
|
-
|
|
489
|
+
)
|
|
373
490
|
except Exception as err:
|
|
374
491
|
polished_params = predicted_params
|
|
375
492
|
polished_params_arr = get_prediction_array(polished_params)
|
|
@@ -379,9 +496,14 @@ class EasyInferenceModel(object):
|
|
|
379
496
|
if calc_polished_curve:
|
|
380
497
|
polished_params_dict['polished_curve'] = curve_polished
|
|
381
498
|
|
|
499
|
+
if ambient_sld_tensor is not None:
|
|
500
|
+
ambient_sld_tensor = ambient_sld_tensor.to(polished_params.slds.device)
|
|
501
|
+
|
|
502
|
+
|
|
382
503
|
if calc_polished_sld_profile:
|
|
383
504
|
_, sld_profile_polished, _ = get_density_profiles(
|
|
384
|
-
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds,
|
|
505
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds + (ambient_sld_tensor or 0), ambient_sld_tensor,
|
|
506
|
+
z_axis=sld_x_axis.to(polished_params.slds.device),
|
|
385
507
|
)
|
|
386
508
|
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
|
|
387
509
|
|
|
@@ -390,26 +512,540 @@ class EasyInferenceModel(object):
|
|
|
390
512
|
def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
|
|
391
513
|
if not isinstance(curve, Tensor):
|
|
392
514
|
curve = torch.from_numpy(curve).float()
|
|
393
|
-
curve =
|
|
515
|
+
curve = curve.unsqueeze(0).to(self.device)
|
|
394
516
|
scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
|
|
395
517
|
return scaled_curve
|
|
396
518
|
|
|
397
519
|
def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
520
|
+
try:
|
|
521
|
+
prior_bounds = torch.tensor(prior_bounds)
|
|
522
|
+
prior_bounds = prior_bounds.to(self.device).T
|
|
523
|
+
min_bounds, max_bounds = prior_bounds[:, None]
|
|
524
|
+
|
|
525
|
+
scaled_bounds = torch.cat([
|
|
526
|
+
self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
|
|
527
|
+
self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
|
|
528
|
+
], -1)
|
|
529
|
+
|
|
530
|
+
return scaled_bounds.float()
|
|
531
|
+
|
|
532
|
+
except RuntimeError as e:
|
|
533
|
+
expected_param_dim = self.trainer.loader.prior_sampler.param_dim
|
|
534
|
+
actual_param_dim = prior_bounds.shape[1] if prior_bounds.ndim == 2 else len(prior_bounds)
|
|
535
|
+
|
|
536
|
+
msg = (
|
|
537
|
+
f"\n **Parameter dimension mismatch during inference!**\n"
|
|
538
|
+
f"- Model expects **{expected_param_dim}** parameters.\n"
|
|
539
|
+
f"- You provided **{actual_param_dim}** prior bounds.\n\n"
|
|
540
|
+
f"💡This often occurs when:\n"
|
|
541
|
+
f"- The model was trained with additional nuisance parameters like `r_scale`, `q_shift`, or `log10_background`,\n"
|
|
542
|
+
f" but they were not included in the `prior_bounds` passed to `.predict()`.\n"
|
|
543
|
+
f"- The number of layers or parameterization type differs from the one used during training.\n\n"
|
|
544
|
+
f" Check the configuration or the summary of expected parameters."
|
|
545
|
+
)
|
|
546
|
+
raise ValueError(msg) from e
|
|
408
547
|
|
|
548
|
+
def _shift_slds_by_ambient(self, prior_bounds: np.ndarray, ambient_sld: float):
|
|
549
|
+
n_layers = self.trainer.loader.prior_sampler.max_num_layers
|
|
550
|
+
sld_indices = slice(2*n_layers+1, 3*n_layers+2)
|
|
551
|
+
prior_bounds[sld_indices, ...] -= ambient_sld
|
|
552
|
+
|
|
553
|
+
training_min_bounds = self.trainer.loader.prior_sampler.min_bounds.squeeze().cpu().numpy()
|
|
554
|
+
training_max_bounds = self.trainer.loader.prior_sampler.max_bounds.squeeze().cpu().numpy()
|
|
555
|
+
lower_bound_check = (prior_bounds[sld_indices, 0] >= training_min_bounds[sld_indices]).all()
|
|
556
|
+
upper_bound_check = (prior_bounds[sld_indices, 1] <= training_max_bounds[sld_indices]).all()
|
|
557
|
+
assert lower_bound_check and upper_bound_check, "Shifting the layer SLDs by the ambient SLD exceeded the training ranges."
|
|
558
|
+
|
|
559
|
+
return sld_indices
|
|
560
|
+
|
|
561
|
+
def _restore_slds_after_ambient_shift(self, prediction_dict, sld_indices, ambient_sld):
|
|
562
|
+
prediction_dict["predicted_params_array"][sld_indices] += ambient_sld
|
|
563
|
+
if "polished_params_array" in prediction_dict:
|
|
564
|
+
prediction_dict["polished_params_array"][sld_indices] += ambient_sld
|
|
565
|
+
|
|
409
566
|
def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
410
567
|
return LogLikelihood(
|
|
411
568
|
q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
|
|
412
569
|
)
|
|
570
|
+
|
|
571
|
+
def get_param_labels(self, **kwargs):
|
|
572
|
+
return self.trainer.loader.prior_sampler.param_model.get_param_labels(**kwargs)
|
|
573
|
+
|
|
574
|
+
@staticmethod
|
|
575
|
+
def _preprocess_input_data(
|
|
576
|
+
reflectivity_curve,
|
|
577
|
+
q_values,
|
|
578
|
+
sigmas=None,
|
|
579
|
+
q_resolution=None,
|
|
580
|
+
truncate_index_left=None,
|
|
581
|
+
truncate_index_right=None,
|
|
582
|
+
enable_error_bars_filtering=True,
|
|
583
|
+
filter_threshold=0.3,
|
|
584
|
+
filter_remove_singles=True,
|
|
585
|
+
filter_remove_consecutives=True,
|
|
586
|
+
filter_consecutive=3,
|
|
587
|
+
filter_q_start_trunc=0.1):
|
|
588
|
+
|
|
589
|
+
# Save originals for polishing
|
|
590
|
+
reflectivity_curve_original = reflectivity_curve.copy()
|
|
591
|
+
q_values_original = q_values.copy() if q_values is not None else None
|
|
592
|
+
q_resolution_original = q_resolution.copy() if isinstance(q_resolution, np.ndarray) else q_resolution
|
|
593
|
+
sigmas_original = sigmas.copy() if sigmas is not None else None
|
|
594
|
+
|
|
595
|
+
# Remove points with non-positive intensities
|
|
596
|
+
nonnegative_mask = reflectivity_curve > 0.0
|
|
597
|
+
reflectivity_curve = reflectivity_curve[nonnegative_mask]
|
|
598
|
+
q_values = q_values[nonnegative_mask]
|
|
599
|
+
if sigmas is not None:
|
|
600
|
+
sigmas = sigmas[nonnegative_mask]
|
|
601
|
+
if isinstance(q_resolution, np.ndarray):
|
|
602
|
+
q_resolution = q_resolution[nonnegative_mask]
|
|
603
|
+
|
|
604
|
+
# Truncate arrays
|
|
605
|
+
if truncate_index_left is not None or truncate_index_right is not None:
|
|
606
|
+
slice_obj = slice(truncate_index_left, truncate_index_right)
|
|
607
|
+
reflectivity_curve = reflectivity_curve[slice_obj]
|
|
608
|
+
q_values = q_values[slice_obj]
|
|
609
|
+
if sigmas is not None:
|
|
610
|
+
sigmas = sigmas[slice_obj]
|
|
611
|
+
if isinstance(q_resolution, np.ndarray):
|
|
612
|
+
q_resolution = q_resolution[slice_obj]
|
|
613
|
+
|
|
614
|
+
# Filter high-error points
|
|
615
|
+
if enable_error_bars_filtering and sigmas is not None:
|
|
616
|
+
valid_mask = get_filtering_mask(
|
|
617
|
+
q_values,
|
|
618
|
+
reflectivity_curve,
|
|
619
|
+
sigmas,
|
|
620
|
+
threshold=filter_threshold,
|
|
621
|
+
consecutive=filter_consecutive,
|
|
622
|
+
remove_singles=filter_remove_singles,
|
|
623
|
+
remove_consecutives=filter_remove_consecutives,
|
|
624
|
+
q_start_trunc=filter_q_start_trunc
|
|
625
|
+
)
|
|
626
|
+
reflectivity_curve = reflectivity_curve[valid_mask]
|
|
627
|
+
q_values = q_values[valid_mask]
|
|
628
|
+
sigmas = sigmas[valid_mask]
|
|
629
|
+
if isinstance(q_resolution, np.ndarray):
|
|
630
|
+
q_resolution = q_resolution[valid_mask]
|
|
631
|
+
|
|
632
|
+
return (q_values, reflectivity_curve, sigmas, q_resolution,
|
|
633
|
+
q_values_original, reflectivity_curve_original,
|
|
634
|
+
sigmas_original, q_resolution_original)
|
|
635
|
+
|
|
636
|
+
def interpolate_data_to_model_q(
|
|
637
|
+
self,
|
|
638
|
+
q_exp,
|
|
639
|
+
refl_exp,
|
|
640
|
+
sigmas_exp=None,
|
|
641
|
+
q_res_exp=None,
|
|
642
|
+
as_dict=False
|
|
643
|
+
):
|
|
644
|
+
q_generator = self.trainer.loader.q_generator
|
|
645
|
+
|
|
646
|
+
def _pad(arr, pad_to, value=0.0):
|
|
647
|
+
if arr is None:
|
|
648
|
+
return None
|
|
649
|
+
return np.pad(arr, (0, pad_to - len(arr)), constant_values=value)
|
|
650
|
+
|
|
651
|
+
def _interp_or_keep(q_model, q_exp, arr):
|
|
652
|
+
"""Interpolate arrays, keep floats or None unchanged."""
|
|
653
|
+
if arr is None:
|
|
654
|
+
return None
|
|
655
|
+
return np.interp(q_model, q_exp, arr) if isinstance(arr, np.ndarray) else arr
|
|
656
|
+
|
|
657
|
+
def _pad_or_keep(arr, max_n):
|
|
658
|
+
"""Pad arrays, keep floats or None unchanged."""
|
|
659
|
+
if arr is None:
|
|
660
|
+
return None
|
|
661
|
+
return _pad(arr, max_n, 0.0) if isinstance(arr, np.ndarray) else arr
|
|
662
|
+
|
|
663
|
+
def _prepare_return(q, refl, sigmas=None, q_res=None, mask=None, as_dict=False):
|
|
664
|
+
if as_dict:
|
|
665
|
+
result = {"q_model": q, "reflectivity": refl}
|
|
666
|
+
if sigmas is not None: result["sigmas"] = sigmas
|
|
667
|
+
if q_res is not None: result["q_resolution"] = q_res
|
|
668
|
+
if mask is not None: result["key_padding_mask"] = mask
|
|
669
|
+
return result
|
|
670
|
+
result = [q, refl]
|
|
671
|
+
if sigmas is not None: result.append(sigmas)
|
|
672
|
+
if q_res is not None: result.append(q_res)
|
|
673
|
+
if mask is not None: result.append(mask)
|
|
674
|
+
return tuple(result)
|
|
675
|
+
|
|
676
|
+
# ConstantQ
|
|
677
|
+
if isinstance(q_generator, ConstantQ):
|
|
678
|
+
q_model = q_generator.q.cpu().numpy()
|
|
679
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp)
|
|
680
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
681
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
682
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
683
|
+
|
|
684
|
+
# VariableQ
|
|
685
|
+
elif isinstance(q_generator, VariableQ):
|
|
686
|
+
if q_generator.n_q_range[0] == q_generator.n_q_range[1]:
|
|
687
|
+
n_q_model = q_generator.n_q_range[0]
|
|
688
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
689
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
690
|
+
if self.trainer.loader.q_generator.mode == 'logspace':
|
|
691
|
+
q_model = torch.logspace(start=torch.log10(torch.tensor(q_min, device=self.device)),
|
|
692
|
+
end=torch.log10(torch.tensor(q_max, device=self.device)),
|
|
693
|
+
steps=n_q_model, device=self.device).to('cpu')
|
|
694
|
+
logspace = True
|
|
695
|
+
else:
|
|
696
|
+
q_model = np.linspace(q_min, q_max, n_q_model)
|
|
697
|
+
logspace = False
|
|
698
|
+
else:
|
|
699
|
+
return _prepare_return(q_exp, refl_exp, sigmas_exp, q_res_exp, None, as_dict)
|
|
700
|
+
|
|
701
|
+
refl_out = interp_reflectivity(q_model, q_exp, refl_exp, logspace=logspace)
|
|
702
|
+
sigmas_out = _interp_or_keep(q_model, q_exp, sigmas_exp)
|
|
703
|
+
q_res_out = _interp_or_keep(q_model, q_exp, q_res_exp)
|
|
704
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, None, as_dict)
|
|
705
|
+
|
|
706
|
+
# MaskedVariableQ
|
|
707
|
+
elif isinstance(q_generator, MaskedVariableQ):
|
|
708
|
+
min_n, max_n = q_generator.n_q_range
|
|
709
|
+
n_exp = len(q_exp)
|
|
710
|
+
|
|
711
|
+
if min_n <= n_exp <= max_n:
|
|
712
|
+
# Pad only
|
|
713
|
+
q_model = _pad(q_exp, max_n, 0.0)
|
|
714
|
+
refl_out = _pad(refl_exp, max_n, 0.0)
|
|
715
|
+
sigmas_out = _pad_or_keep(sigmas_exp, max_n)
|
|
716
|
+
q_res_out = _pad_or_keep(q_res_exp, max_n)
|
|
717
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
718
|
+
key_padding_mask[:n_exp] = True
|
|
719
|
+
|
|
720
|
+
else:
|
|
721
|
+
# Interpolate + pad
|
|
722
|
+
n_interp = min(max(n_exp, min_n), max_n)
|
|
723
|
+
q_min = max(q_exp.min(), q_generator.q_min_range[0])
|
|
724
|
+
q_max = min(q_exp.max(), q_generator.q_max_range[1])
|
|
725
|
+
q_interp = np.linspace(q_min, q_max, n_interp)
|
|
726
|
+
|
|
727
|
+
refl_interp = interp_reflectivity(q_interp, q_exp, refl_exp)
|
|
728
|
+
sigmas_interp = _interp_or_keep(q_interp, q_exp, sigmas_exp)
|
|
729
|
+
q_res_interp = _interp_or_keep(q_interp, q_exp, q_res_exp)
|
|
730
|
+
|
|
731
|
+
q_model = _pad(q_interp, max_n, 0.0)
|
|
732
|
+
refl_out = _pad(refl_interp, max_n, 0.0)
|
|
733
|
+
sigmas_out = _pad_or_keep(sigmas_interp, max_n)
|
|
734
|
+
q_res_out = _pad_or_keep(q_res_interp, max_n)
|
|
735
|
+
key_padding_mask = np.zeros(max_n, dtype=bool)
|
|
736
|
+
key_padding_mask[:n_interp] = True
|
|
737
|
+
|
|
738
|
+
return _prepare_return(q_model, refl_out, sigmas_out, q_res_out, key_padding_mask, as_dict)
|
|
739
|
+
|
|
740
|
+
else:
|
|
741
|
+
raise TypeError(f"Unsupported QGenerator type: {type(q_generator)}")
|
|
742
|
+
|
|
743
|
+
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
744
|
+
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
745
|
+
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
746
|
+
dq_max = (q[1] - q[0]) * dq_coef
|
|
747
|
+
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
748
|
+
|
|
749
|
+
curve = to_t(curve).to(scaled_bounds)
|
|
750
|
+
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
751
|
+
|
|
752
|
+
assert shifted_curves.shape == (num, q.shape[0])
|
|
753
|
+
|
|
754
|
+
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
755
|
+
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
756
|
+
|
|
757
|
+
with torch.no_grad():
|
|
758
|
+
self.trainer.model.eval()
|
|
759
|
+
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
760
|
+
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
761
|
+
|
|
762
|
+
best_param = get_best_mse_param(
|
|
763
|
+
restored_params,
|
|
764
|
+
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
765
|
+
)
|
|
766
|
+
return best_param
|
|
767
|
+
|
|
768
|
+
def predict_using_widget(self, reflectivity_curve, **kwargs):
|
|
769
|
+
"""
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
773
|
+
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
774
|
+
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
775
|
+
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
776
|
+
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
777
|
+
|
|
778
|
+
print(f'Adjust the sliders for each parameter and press "Predict". Repeat as desired. Press "Close Widget" to finish.')
|
|
779
|
+
|
|
780
|
+
interval_widgets = []
|
|
781
|
+
for i in range(NUM_INTERVALS):
|
|
782
|
+
label = widgets.Label(value=f'{param_labels[i]}')
|
|
783
|
+
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
784
|
+
slider = widgets.FloatRangeSlider(
|
|
785
|
+
value=[min_bounds[i], initial_max],
|
|
786
|
+
min=min_bounds[i],
|
|
787
|
+
max=max_bounds[i],
|
|
788
|
+
step=0.01,
|
|
789
|
+
layout=widgets.Layout(width='400px'),
|
|
790
|
+
style={'description_width': '60px'}
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
794
|
+
min_val, max_val = change['new']
|
|
795
|
+
if max_val - min_val > max_width:
|
|
796
|
+
old_min_val, old_max_val = change['old']
|
|
797
|
+
if abs(old_min_val - min_val) > abs(old_max_val - max_val):
|
|
798
|
+
max_val = min_val + max_width
|
|
799
|
+
else:
|
|
800
|
+
min_val = max_val - max_width
|
|
801
|
+
slider.value = [min_val, max_val]
|
|
802
|
+
|
|
803
|
+
slider.observe(validate_range, names='value')
|
|
804
|
+
interval_widgets.append((slider, widgets.HBox([label, slider])))
|
|
805
|
+
|
|
806
|
+
sliders_box = widgets.VBox([iw[1] for iw in interval_widgets])
|
|
807
|
+
|
|
808
|
+
output = widgets.Output()
|
|
809
|
+
predict_button = widgets.Button(description="Predict")
|
|
810
|
+
close_button = widgets.Button(description="Close Widget")
|
|
811
|
+
|
|
812
|
+
container = widgets.VBox([sliders_box, widgets.HBox([predict_button, close_button]), output])
|
|
813
|
+
display(container)
|
|
814
|
+
|
|
815
|
+
@output.capture(clear_output=True)
|
|
816
|
+
def on_predict_click(_):
|
|
817
|
+
if 'prior_bounds' in kwargs:
|
|
818
|
+
array_values = kwargs.pop('prior_bounds')
|
|
819
|
+
for i, (s, _) in enumerate(interval_widgets):
|
|
820
|
+
s.value = tuple(array_values[i])
|
|
821
|
+
else:
|
|
822
|
+
values = [(s.value[0], s.value[1]) for s, _ in interval_widgets]
|
|
823
|
+
array_values = np.array(values)
|
|
824
|
+
|
|
825
|
+
prediction_result = self.predict(reflectivity_curve=reflectivity_curve,
|
|
826
|
+
prior_bounds=array_values,
|
|
827
|
+
**kwargs)
|
|
828
|
+
param_names = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
829
|
+
print_prediction_results(prediction_result)
|
|
830
|
+
|
|
831
|
+
plot_prediction_results(
|
|
832
|
+
prediction_result,
|
|
833
|
+
q_exp=kwargs['q_values'],
|
|
834
|
+
curve_exp=reflectivity_curve,
|
|
835
|
+
)
|
|
836
|
+
self.widget_prediction_result = prediction_result
|
|
837
|
+
|
|
838
|
+
def on_close_click(_):
|
|
839
|
+
container.close()
|
|
840
|
+
print("Widget closed.")
|
|
841
|
+
|
|
842
|
+
predict_button.on_click(on_predict_click)
|
|
843
|
+
close_button.on_click(on_close_click)
|
|
844
|
+
|
|
845
|
+
def preprocess_and_predict_using_widget(self,
|
|
846
|
+
reflectivity_curve,
|
|
847
|
+
q_values=None,
|
|
848
|
+
sigmas=None,
|
|
849
|
+
q_resolution=None,
|
|
850
|
+
prior_bounds=None,
|
|
851
|
+
ambient_sld=None,
|
|
852
|
+
):
|
|
853
|
+
"""
|
|
854
|
+
Interactive widget around `preprocess_and_predict`
|
|
855
|
+
Results are stored in `self.widget_prediction_result`
|
|
856
|
+
"""
|
|
857
|
+
|
|
858
|
+
if q_values is None:
|
|
859
|
+
raise ValueError("q_values must be provided for this widget.")
|
|
860
|
+
|
|
861
|
+
N = len(reflectivity_curve)
|
|
862
|
+
|
|
863
|
+
# ---------- Priors sliders ----------
|
|
864
|
+
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
865
|
+
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
866
|
+
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
867
|
+
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
868
|
+
NUM = len(param_labels)
|
|
869
|
+
|
|
870
|
+
sliders, rows = [], []
|
|
871
|
+
init_pb = np.array(prior_bounds) if prior_bounds is not None else None
|
|
872
|
+
for i in range(NUM):
|
|
873
|
+
init_min = float(init_pb[i, 0]) if init_pb is not None else float(min_bounds[i])
|
|
874
|
+
init_max = float(init_pb[i, 1]) if init_pb is not None else float(min(min_bounds[i] + max_deltas[i], max_bounds[i]))
|
|
875
|
+
lab = widgets.Label(value=param_labels[i])
|
|
876
|
+
s = widgets.FloatRangeSlider(
|
|
877
|
+
value=[init_min, init_max],
|
|
878
|
+
min=float(min_bounds[i]),
|
|
879
|
+
max=float(max_bounds[i]),
|
|
880
|
+
step=0.01,
|
|
881
|
+
layout=widgets.Layout(width='420px'),
|
|
882
|
+
readout_format='.3f'
|
|
883
|
+
)
|
|
884
|
+
# Constrain slider widths
|
|
885
|
+
def _mk_validator(slider, max_width=float(max_deltas[i])):
|
|
886
|
+
def _validate(change):
|
|
887
|
+
a, b = change['new']
|
|
888
|
+
if b - a > max_width:
|
|
889
|
+
oa, ob = change['old']
|
|
890
|
+
if abs(oa - a) > abs(ob - b):
|
|
891
|
+
b = a + max_width
|
|
892
|
+
else:
|
|
893
|
+
a = b - max_width
|
|
894
|
+
slider.value = (a, b)
|
|
895
|
+
return _validate
|
|
896
|
+
s.observe(_mk_validator(s), names='value')
|
|
897
|
+
sliders.append(s)
|
|
898
|
+
rows.append(widgets.HBox([lab, s]))
|
|
899
|
+
priors_box = widgets.VBox([widgets.HTML("<b>Priors</b>")] + rows)
|
|
900
|
+
|
|
901
|
+
# ---------- Preprocess & Predict controls ----------
|
|
902
|
+
# Preprocessing
|
|
903
|
+
trunc_L = widgets.IntSlider(description='truncate left', min=0, max=max(0, N-1), step=1, value=0)
|
|
904
|
+
trunc_R = widgets.IntSlider(description='truncate right', min=1, max=N, step=1, value=N)
|
|
905
|
+
enable_filt = widgets.Checkbox(description='filter error bars', value=True)
|
|
906
|
+
thr = widgets.FloatSlider(description='filtering threshold', min=0.0, max=1.0, step=0.01, value=0.3)
|
|
907
|
+
rem_single = widgets.Checkbox(description='remove singles', value=True)
|
|
908
|
+
rem_cons = widgets.Checkbox(description='remove consecutives', value=True)
|
|
909
|
+
consec = widgets.IntSlider(description='num. consecutive', min=1, max=10, step=1, value=3)
|
|
910
|
+
qstart = widgets.FloatSlider(description='q_start_trunc', min=0.0, max=1.0, step=0.01, value=0.1)
|
|
911
|
+
|
|
912
|
+
# Polishing
|
|
913
|
+
polish = widgets.Checkbox(description='polish prediction', value=True)
|
|
914
|
+
use_sigmas_polish = widgets.Checkbox(description='use sigmas during polishing', value=True)
|
|
915
|
+
|
|
916
|
+
# Plotting
|
|
917
|
+
pred_show_yerr = widgets.Checkbox(description='show error bars', value=True)
|
|
918
|
+
pred_show_xerr = widgets.Checkbox(description='show q-resolution', value=False)
|
|
919
|
+
pred_logx = widgets.Checkbox(description='log x-axis', value=False)
|
|
920
|
+
plot_sld = widgets.Checkbox(description='plot SLD profile', value=True)
|
|
921
|
+
sld_pad_left = widgets.FloatText(description='SLD pad left', value=0.2, step=0.1)
|
|
922
|
+
sld_pad_right = widgets.FloatText(description='SLD pad right', value=1.1, step=0.1)
|
|
923
|
+
|
|
924
|
+
# Color pickers
|
|
925
|
+
exp_color_picker = widgets.ColorPicker(description='exp color', value='#0000FF') # blue
|
|
926
|
+
exp_errcolor_picker = widgets.ColorPicker(description='errbar color', value='#800080') # purple
|
|
927
|
+
pred_color_picker = widgets.ColorPicker(description='pred color', value='#FF0000') # red
|
|
928
|
+
pol_color_picker = widgets.ColorPicker(description='polished color', value='#FFA500') # orange
|
|
929
|
+
sld_pred_color_picker = widgets.ColorPicker(description='SLD pred color', value='#FF0000') # red
|
|
930
|
+
sld_pol_color_picker = widgets.ColorPicker(description='SLD pol color', value='#FFA500') # orange
|
|
931
|
+
|
|
932
|
+
# Compute toggles
|
|
933
|
+
calc_curve = widgets.Checkbox(description='calc curve', value=True)
|
|
934
|
+
calc_sld_pred = widgets.Checkbox(description='calc predicted SLD', value=True)
|
|
935
|
+
calc_sld_pol = widgets.Checkbox(description='calc polished SLD', value=True)
|
|
936
|
+
|
|
937
|
+
btn_predict = widgets.Button(description='Predict')
|
|
938
|
+
btn_close = widgets.Button(description='Close')
|
|
939
|
+
|
|
940
|
+
controls_box = widgets.VBox([
|
|
941
|
+
widgets.HTML("<b>Preprocess & Predict</b>"),
|
|
942
|
+
widgets.HTML("<i>Preprocessing</i>"),
|
|
943
|
+
widgets.HBox([trunc_L, trunc_R]),
|
|
944
|
+
widgets.HBox([enable_filt, rem_single, rem_cons]),
|
|
945
|
+
widgets.HBox([thr, consec, qstart]),
|
|
946
|
+
widgets.HTML("<i>Polishing</i>"),
|
|
947
|
+
widgets.HBox([polish, use_sigmas_polish]),
|
|
948
|
+
widgets.HTML("<i>Plotting</i>"),
|
|
949
|
+
widgets.HBox([pred_show_yerr, pred_show_xerr, pred_logx, plot_sld]),
|
|
950
|
+
widgets.HBox([sld_pad_left, sld_pad_right]),
|
|
951
|
+
widgets.HTML("<i>Colors</i>"),
|
|
952
|
+
widgets.HBox([exp_color_picker, exp_errcolor_picker]),
|
|
953
|
+
widgets.HBox([pred_color_picker, pol_color_picker]),
|
|
954
|
+
widgets.HBox([sld_pred_color_picker, sld_pol_color_picker]),
|
|
955
|
+
widgets.HTML("<i>Compute</i>"),
|
|
956
|
+
widgets.HBox([calc_curve, calc_sld_pred, calc_sld_pol]),
|
|
957
|
+
widgets.HBox([btn_predict, btn_close]),
|
|
958
|
+
])
|
|
959
|
+
|
|
960
|
+
out_predict = widgets.Output()
|
|
961
|
+
container = widgets.VBox([priors_box, controls_box, out_predict])
|
|
962
|
+
display(container)
|
|
963
|
+
|
|
964
|
+
def _sync_trunc(_):
|
|
965
|
+
if trunc_L.value >= trunc_R.value:
|
|
966
|
+
trunc_L.value = max(0, trunc_R.value - 1)
|
|
967
|
+
trunc_L.observe(_sync_trunc, names='value')
|
|
968
|
+
trunc_R.observe(_sync_trunc, names='value')
|
|
969
|
+
|
|
970
|
+
def _current_priors():
|
|
971
|
+
return np.array([s.value for s in sliders], dtype=np.float32) # (param_dim, 2)
|
|
972
|
+
|
|
973
|
+
@out_predict.capture(clear_output=True)
|
|
974
|
+
def _on_predict(_):
|
|
975
|
+
out_predict.clear_output(wait=True)
|
|
976
|
+
|
|
977
|
+
res = self.preprocess_and_predict(
|
|
978
|
+
reflectivity_curve=reflectivity_curve,
|
|
979
|
+
q_values=q_values,
|
|
980
|
+
prior_bounds=_current_priors(),
|
|
981
|
+
sigmas=sigmas,
|
|
982
|
+
q_resolution=q_resolution,
|
|
983
|
+
ambient_sld=ambient_sld,
|
|
984
|
+
clip_prediction=True,
|
|
985
|
+
polish_prediction=polish.value,
|
|
986
|
+
use_sigmas_for_polishing=use_sigmas_polish.value,
|
|
987
|
+
calc_pred_curve=calc_curve.value,
|
|
988
|
+
calc_pred_sld_profile=(calc_sld_pred.value or plot_sld.value),
|
|
989
|
+
calc_polished_sld_profile=(calc_sld_pol.value or plot_sld.value),
|
|
990
|
+
sld_profile_padding_left=float(sld_pad_left.value),
|
|
991
|
+
sld_profile_padding_right=float(sld_pad_right.value),
|
|
992
|
+
|
|
993
|
+
truncate_index_left=trunc_L.value,
|
|
994
|
+
truncate_index_right=trunc_R.value,
|
|
995
|
+
enable_error_bars_filtering=enable_filt.value,
|
|
996
|
+
filter_threshold=thr.value,
|
|
997
|
+
filter_remove_singles=rem_single.value,
|
|
998
|
+
filter_remove_consecutives=rem_cons.value,
|
|
999
|
+
filter_consecutive=consec.value,
|
|
1000
|
+
filter_q_start_trunc=qstart.value,
|
|
1001
|
+
)
|
|
1002
|
+
|
|
1003
|
+
# Full experimental data as scatter
|
|
1004
|
+
q_exp_plot = q_values
|
|
1005
|
+
r_exp_plot = reflectivity_curve
|
|
1006
|
+
yerr_plot = (sigmas if pred_show_yerr.value else None)
|
|
1007
|
+
xerr_plot = (q_resolution if pred_show_xerr.value else None)
|
|
1008
|
+
|
|
1009
|
+
# Predicted curve only on the model region
|
|
1010
|
+
q_pred = res.get('q_plot_pred', None)
|
|
1011
|
+
r_pred = res.get('predicted_curve', None)
|
|
1012
|
+
|
|
1013
|
+
# Polished curve on the full experimental grid
|
|
1014
|
+
q_pol = q_values if ('polished_curve' in res) else None
|
|
1015
|
+
r_pol = res.get('polished_curve', None)
|
|
1016
|
+
|
|
1017
|
+
# SLD profiles
|
|
1018
|
+
z_sld = res.get('predicted_sld_xaxis', None)
|
|
1019
|
+
sld_pred = res.get('predicted_sld_profile', None)
|
|
1020
|
+
sld_pol = res.get('sld_profile_polished', None)
|
|
1021
|
+
|
|
1022
|
+
print_prediction_results(res)
|
|
1023
|
+
|
|
1024
|
+
plot_reflectivity(
|
|
1025
|
+
q_exp=q_exp_plot, r_exp=r_exp_plot,
|
|
1026
|
+
yerr=yerr_plot, xerr=xerr_plot,
|
|
1027
|
+
exp_style=('errorbar' if pred_show_yerr.value or pred_show_xerr.value else 'scatter'),
|
|
1028
|
+
exp_color=exp_color_picker.value,
|
|
1029
|
+
exp_errcolor=exp_errcolor_picker.value,
|
|
1030
|
+
q_pred=q_pred, r_pred=r_pred, pred_color=pred_color_picker.value,
|
|
1031
|
+
q_pol=q_pol, r_pol=r_pol, pol_color=pol_color_picker.value,
|
|
1032
|
+
z_sld=z_sld, sld_pred=sld_pred, sld_pol=sld_pol,
|
|
1033
|
+
sld_pred_color=sld_pred_color_picker.value,
|
|
1034
|
+
sld_pol_color=sld_pol_color_picker.value,
|
|
1035
|
+
plot_sld_profile=plot_sld.value,
|
|
1036
|
+
logx=pred_logx.value, logy=True,
|
|
1037
|
+
figsize=(12,6),
|
|
1038
|
+
legend=True
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
self.widget_prediction_result = res
|
|
1042
|
+
|
|
1043
|
+
def _on_close(_):
|
|
1044
|
+
container.close()
|
|
1045
|
+
print("Widget closed.")
|
|
1046
|
+
|
|
1047
|
+
btn_predict.on_click(_on_predict)
|
|
1048
|
+
btn_close.on_click(_on_close)
|
|
413
1049
|
|
|
414
1050
|
class InferenceModel(object):
|
|
415
1051
|
def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,
|