reflectorch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +23 -0
- reflectorch/data_generation/__init__.py +130 -0
- reflectorch/data_generation/dataset.py +196 -0
- reflectorch/data_generation/likelihoods.py +86 -0
- reflectorch/data_generation/noise.py +371 -0
- reflectorch/data_generation/priors/__init__.py +66 -0
- reflectorch/data_generation/priors/base.py +61 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
- reflectorch/data_generation/priors/independent_priors.py +201 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +110 -0
- reflectorch/data_generation/priors/no_constraints.py +212 -0
- reflectorch/data_generation/priors/parametric_models.py +767 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
- reflectorch/data_generation/priors/params.py +258 -0
- reflectorch/data_generation/priors/sampler_strategies.py +306 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +377 -0
- reflectorch/data_generation/priors/utils.py +124 -0
- reflectorch/data_generation/process_data.py +47 -0
- reflectorch/data_generation/q_generator.py +232 -0
- reflectorch/data_generation/reflectivity/__init__.py +56 -0
- reflectorch/data_generation/reflectivity/abeles.py +81 -0
- reflectorch/data_generation/reflectivity/kinematical.py +58 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +123 -0
- reflectorch/data_generation/scale_curves.py +118 -0
- reflectorch/data_generation/smearing.py +67 -0
- reflectorch/data_generation/utils.py +154 -0
- reflectorch/extensions/__init__.py +6 -0
- reflectorch/extensions/jupyter/__init__.py +12 -0
- reflectorch/extensions/jupyter/callbacks.py +40 -0
- reflectorch/extensions/matplotlib/__init__.py +11 -0
- reflectorch/extensions/matplotlib/losses.py +38 -0
- reflectorch/inference/__init__.py +22 -0
- reflectorch/inference/inference_model.py +734 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +16 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +171 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +37 -0
- reflectorch/ml/basic_trainer.py +286 -0
- reflectorch/ml/callbacks.py +86 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +38 -0
- reflectorch/ml/schedulers.py +246 -0
- reflectorch/ml/trainers.py +126 -0
- reflectorch/ml/utils.py +9 -0
- reflectorch/models/__init__.py +22 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +27 -0
- reflectorch/models/encoders/conv_encoder.py +211 -0
- reflectorch/models/encoders/conv_res_net.py +119 -0
- reflectorch/models/encoders/fno.py +127 -0
- reflectorch/models/encoders/transformers.py +56 -0
- reflectorch/models/networks/__init__.py +18 -0
- reflectorch/models/networks/mlp_networks.py +256 -0
- reflectorch/models/networks/residual_net.py +131 -0
- reflectorch/paths.py +33 -0
- reflectorch/runs/__init__.py +35 -0
- reflectorch/runs/config.py +31 -0
- reflectorch/runs/slurm_utils.py +99 -0
- reflectorch/runs/train.py +85 -0
- reflectorch/runs/utils.py +300 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +74 -0
- reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
- reflectorch-1.0.0.dist-info/METADATA +115 -0
- reflectorch-1.0.0.dist-info/RECORD +83 -0
- reflectorch-1.0.0.dist-info/WHEEL +5 -0
- reflectorch-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,734 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
from typing import List, Tuple, Union
|
|
10
|
+
import ipywidgets as widgets
|
|
11
|
+
from IPython.display import display
|
|
12
|
+
from huggingface_hub import hf_hub_download
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors import Params, BasicParams, ExpUniformSubPriorSampler, UniformSubPriorParams
|
|
15
|
+
from reflectorch.data_generation.q_generator import ConstantQ, VariableQ
|
|
16
|
+
from reflectorch.data_generation.utils import get_density_profiles, get_param_labels
|
|
17
|
+
from reflectorch.paths import CONFIG_DIR, ROOT_DIR, SAVED_MODELS_DIR
|
|
18
|
+
from reflectorch.runs.utils import (
|
|
19
|
+
get_trainer_by_name, train_from_config
|
|
20
|
+
)
|
|
21
|
+
from reflectorch.runs.config import load_config
|
|
22
|
+
from reflectorch.ml.trainers import PointEstimatorTrainer
|
|
23
|
+
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
24
|
+
|
|
25
|
+
from reflectorch.inference.preprocess_exp import StandardPreprocessing
|
|
26
|
+
from reflectorch.inference.scipy_fitter import standard_refl_fit, get_fit_with_growth
|
|
27
|
+
from reflectorch.inference.sampler_solution import simple_sampler_solution, get_best_mse_param
|
|
28
|
+
from reflectorch.inference.record_time import print_time
|
|
29
|
+
from reflectorch.utils import to_t
|
|
30
|
+
|
|
31
|
+
class EasyInferenceModel(object):
|
|
32
|
+
"""Facilitates the inference process using pretrained models
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
|
|
36
|
+
model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
|
|
37
|
+
root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
|
|
38
|
+
repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
|
|
39
|
+
trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
|
|
40
|
+
device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
|
|
41
|
+
"""
|
|
42
|
+
def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, repo_id: str = 'valentinsingularity/reflectivity',
|
|
43
|
+
trainer: PointEstimatorTrainer = None, device='cuda'):
|
|
44
|
+
self.config_name = config_name
|
|
45
|
+
self.model_name = model_name
|
|
46
|
+
self.root_dir = root_dir
|
|
47
|
+
self.repo_id = repo_id
|
|
48
|
+
self.trainer = trainer
|
|
49
|
+
self.device = device
|
|
50
|
+
|
|
51
|
+
if trainer is None and self.config_name is not None:
|
|
52
|
+
self.load_model(self.config_name, self.model_name, self.root_dir)
|
|
53
|
+
|
|
54
|
+
self.prediction_result = None
|
|
55
|
+
|
|
56
|
+
def load_model(self, config_name: str, model_name: str, root_dir: str) -> None:
|
|
57
|
+
"""Loads a model for inference
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
|
|
61
|
+
model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`.
|
|
62
|
+
root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
|
|
63
|
+
"""
|
|
64
|
+
if self.config_name == config_name and self.trainer is not None:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
if not config_name.endswith('.yaml'):
|
|
68
|
+
config_name_no_extension = config_name
|
|
69
|
+
self.config_name = config_name_no_extension + '.yaml'
|
|
70
|
+
else:
|
|
71
|
+
config_name_no_extension = config_name[:-5]
|
|
72
|
+
self.config_name = config_name
|
|
73
|
+
|
|
74
|
+
self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
|
|
75
|
+
self.model_name = model_name or 'model_' + config_name_no_extension + '.pt'
|
|
76
|
+
if not self.model_name.endswith('.pt'):
|
|
77
|
+
self.model_name += '.pt'
|
|
78
|
+
self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
|
|
79
|
+
|
|
80
|
+
config_path = Path(self.config_dir) / self.config_name
|
|
81
|
+
if config_path.exists():
|
|
82
|
+
print(f"Configuration file `{config_path}` found locally.")
|
|
83
|
+
else:
|
|
84
|
+
print(f"Configuration file `{config_path}` not found locally.")
|
|
85
|
+
if self.repo_id is None:
|
|
86
|
+
raise ValueError("repo_id must be provided to download files from Huggingface.")
|
|
87
|
+
print("Downloading from Huggingface...")
|
|
88
|
+
hf_hub_download(repo_id=self.repo_id, subfolder='configs', filename=self.config_name, local_dir=config_path.parents[1])
|
|
89
|
+
|
|
90
|
+
model_path = Path(self.model_dir) / self.model_name
|
|
91
|
+
if model_path.exists():
|
|
92
|
+
print(f"Weights file `{model_path}` found locally.")
|
|
93
|
+
else:
|
|
94
|
+
print(f"Weights file `{model_path}` not found locally.")
|
|
95
|
+
if self.repo_id is None:
|
|
96
|
+
raise ValueError("repo_id must be provided to download files from Huggingface.")
|
|
97
|
+
print("Downloading from Huggingface...")
|
|
98
|
+
hf_hub_download(repo_id=self.repo_id, subfolder='saved_models', filename=self.model_name, local_dir=model_path.parents[1])
|
|
99
|
+
|
|
100
|
+
self.trainer = get_trainer_by_name(config_name=config_name, config_dir=self.config_dir, model_path=model_path, load_weights=True, inference_device = self.device)
|
|
101
|
+
self.trainer.model.eval()
|
|
102
|
+
|
|
103
|
+
print(f'The model corresponds to a `{self.trainer.loader.prior_sampler.param_model.NAME}` parameterization with {self.trainer.loader.prior_sampler.max_num_layers} layers ({self.trainer.loader.prior_sampler.param_dim} predicted parameters)')
|
|
104
|
+
print("Parameter types and total ranges:")
|
|
105
|
+
for param, range_ in self.trainer.loader.prior_sampler.param_ranges.items():
|
|
106
|
+
print(f"- {param}: {range_}")
|
|
107
|
+
print("Allowed widths of the prior bound intervals (max-min):")
|
|
108
|
+
for param, range_ in self.trainer.loader.prior_sampler.bound_width_ranges.items():
|
|
109
|
+
print(f"- {param}: {range_}")
|
|
110
|
+
|
|
111
|
+
if isinstance(self.trainer.loader.q_generator, ConstantQ):
|
|
112
|
+
q_min = self.trainer.loader.q_generator.q[0].item()
|
|
113
|
+
q_max = self.trainer.loader.q_generator.q[-1].item()
|
|
114
|
+
n_q = self.trainer.loader.q_generator.q.shape[0]
|
|
115
|
+
print(f'The model was trained on curves discretized at {n_q} uniform points between between q_min={q_min} and q_max={q_max}')
|
|
116
|
+
elif isinstance(self.trainer.loader.q_generator, VariableQ):
|
|
117
|
+
q_min_range = self.trainer.loader.q_generator.q_min_range
|
|
118
|
+
q_max_range = self.trainer.loader.q_generator.q_max_range
|
|
119
|
+
n_q_range = self.trainer.loader.q_generator.n_q_range
|
|
120
|
+
print(f'The model was trained on curves discretized at a number between {n_q_range[0]} and {n_q_range[1]} of uniform points between between q_min in [{q_min_range[0]}, {q_min_range[1]}] and q_max in [{q_max_range[0]}, {q_max_range[1]}]')
|
|
121
|
+
|
|
122
|
+
def predict(self, reflectivity_curve: Union[np.ndarray, Tensor],
|
|
123
|
+
q_values: Union[np.ndarray, Tensor] = None,
|
|
124
|
+
prior_bounds: Union[np.ndarray, List[Tuple]] = None,
|
|
125
|
+
clip_prediction: bool = False,
|
|
126
|
+
polish_prediction: bool = False,
|
|
127
|
+
fit_growth: bool = False,
|
|
128
|
+
max_d_change: float = 5.,
|
|
129
|
+
use_q_shift: bool = False,
|
|
130
|
+
calc_pred_curve: bool = True,
|
|
131
|
+
calc_pred_sld_profile: bool = False,
|
|
132
|
+
):
|
|
133
|
+
"""Predict the thin film parameters
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
reflectivity_curve (Union[np.ndarray, Tensor]): The reflectivity curve (which has been already preprocessed, normalized and interpolated).
|
|
137
|
+
q_values (Union[np.ndarray, Tensor], optional): The momentum transfer (q) values for the reflectivity curve (in units of inverse angstroms).
|
|
138
|
+
prior_bounds (Union[np.ndarray, List[Tuple]], optional): the prior bounds for the thin film parameters.
|
|
139
|
+
clip_prediction (bool, optional): If ``True``, the values of the predicted parameters are clipped to not be outside the interval set by the prior bounds. Defaults to False.
|
|
140
|
+
polish_prediction (bool, optional): If ``True``, the neural network predictions are further polished using a simple least mean squares (LMS) fit. Only for the standard box-model parameterization. Defaults to False.
|
|
141
|
+
fit_growth (bool, optional): If ``True``, an additional parameters is introduced during the LMS polishing to account for the change in the thickness of the upper layer during the in-situ measurement of the reflectivity curve (a linear growth is assumed). Defaults to False.
|
|
142
|
+
max_d_change (float): The maximum possible change in the thickness of the upper layer during the in-situ measurement, relevant when polish_prediction and fit_growth are True. Defaults to 5.
|
|
143
|
+
use_q_shift: If ``True``, the prediction is performed for a batch of slightly shifted versions of the input curve and the best result is returned, which is meant to mitigate the influence of imperfect sample alignment, as introduced in Greco et al. (only for models with fixed q-discretization). Defaults to False.
|
|
144
|
+
calc_pred_curve (bool, optional): Whether to calculate the curve corresponding to the predicted parameters. Defaults to True.
|
|
145
|
+
calc_pred_sld_profile (bool, optional): Whether to calculate the SLD profile corresponding to the predicted parameters. Defaults to False.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
dict: dictionary containing the predictions
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
scaled_curve = self._scale_curve(reflectivity_curve)
|
|
152
|
+
prior_bounds = np.array(prior_bounds)
|
|
153
|
+
scaled_prior_bounds = self._scale_prior_bounds(prior_bounds)
|
|
154
|
+
|
|
155
|
+
if not self.trainer.train_with_q_input:
|
|
156
|
+
q_values = self.trainer.loader.q_generator.q
|
|
157
|
+
else:
|
|
158
|
+
q_values = torch.atleast_2d(to_t(q_values)).to(scaled_curve)
|
|
159
|
+
|
|
160
|
+
if use_q_shift and not self.trainer.train_with_q_input:
|
|
161
|
+
predicted_params = self._qshift_prediction(reflectivity_curve, scaled_prior_bounds, num = 1024, dq_coef = 1.)
|
|
162
|
+
|
|
163
|
+
else:
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
self.trainer.model.eval()
|
|
166
|
+
if self.trainer.train_with_q_input:
|
|
167
|
+
scaled_q = self.trainer.loader.q_generator.scale_q(q_values).float()
|
|
168
|
+
scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds, scaled_q)
|
|
169
|
+
else:
|
|
170
|
+
scaled_predicted_params = self.trainer.model(scaled_curve, scaled_prior_bounds)
|
|
171
|
+
|
|
172
|
+
predicted_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
173
|
+
|
|
174
|
+
if clip_prediction:
|
|
175
|
+
predicted_params = self.trainer.loader.prior_sampler.clamp_params(predicted_params)
|
|
176
|
+
|
|
177
|
+
prediction_dict = {
|
|
178
|
+
"predicted_params_object": predicted_params,
|
|
179
|
+
"predicted_params_array": predicted_params.parameters.squeeze().cpu().numpy(),
|
|
180
|
+
"param_names" : self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if calc_pred_curve:
|
|
184
|
+
predicted_curve = predicted_params.reflectivity(q_values).squeeze().cpu().numpy()
|
|
185
|
+
prediction_dict[ "predicted_curve"] = predicted_curve
|
|
186
|
+
|
|
187
|
+
if calc_pred_sld_profile:
|
|
188
|
+
predicted_sld_xaxis, predicted_sld_profile, _ = get_density_profiles(
|
|
189
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
|
|
190
|
+
)
|
|
191
|
+
prediction_dict['predicted_sld_profile'] = predicted_sld_profile.squeeze().cpu().numpy()
|
|
192
|
+
prediction_dict['predicted_sld_xaxis'] = predicted_sld_xaxis.squeeze().cpu().numpy()
|
|
193
|
+
else:
|
|
194
|
+
predicted_sld_xaxis = None
|
|
195
|
+
|
|
196
|
+
if polish_prediction: #only for standard box-model parameterization
|
|
197
|
+
polished_dict = self._polish_prediction(q = q_values.squeeze().cpu().numpy(),
|
|
198
|
+
curve = reflectivity_curve,
|
|
199
|
+
predicted_params = predicted_params,
|
|
200
|
+
priors = np.array(prior_bounds),
|
|
201
|
+
fit_growth = fit_growth,
|
|
202
|
+
max_d_change = max_d_change,
|
|
203
|
+
calc_polished_curve = calc_pred_curve,
|
|
204
|
+
calc_polished_sld_profile = False,
|
|
205
|
+
sld_x_axis = predicted_sld_xaxis,
|
|
206
|
+
)
|
|
207
|
+
prediction_dict.update(polished_dict)
|
|
208
|
+
|
|
209
|
+
if fit_growth and "polished_params_array" in prediction_dict:
|
|
210
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
211
|
+
|
|
212
|
+
return prediction_dict
|
|
213
|
+
|
|
214
|
+
async def predict_using_widget(self, reflectivity_curve: Union[np.ndarray, torch.Tensor], **kwargs):
|
|
215
|
+
"""Use an interactive Python widget for specifying the prior bounds before the prediction (works only in a Jupyter notebook).
|
|
216
|
+
The other arguments are the same as for the ``predict`` method.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
NUM_INTERVALS = self.trainer.loader.prior_sampler.param_dim
|
|
220
|
+
param_labels = self.trainer.loader.prior_sampler.param_model.get_param_labels()
|
|
221
|
+
min_bounds = self.trainer.loader.prior_sampler.min_bounds.cpu().numpy().flatten()
|
|
222
|
+
max_bounds = self.trainer.loader.prior_sampler.max_bounds.cpu().numpy().flatten()
|
|
223
|
+
max_deltas = self.trainer.loader.prior_sampler.max_delta.cpu().numpy().flatten()
|
|
224
|
+
|
|
225
|
+
print(f'Parameter ranges: {self.trainer.loader.prior_sampler.param_ranges}')
|
|
226
|
+
print(f'Allowed widths of the prior bound intervals (max-min): {self.trainer.loader.prior_sampler.bound_width_ranges}')
|
|
227
|
+
print(f'Please fill in the values of the minimum and maximum prior bound for each parameter and press the button!')
|
|
228
|
+
|
|
229
|
+
def create_interval_widgets(n):
|
|
230
|
+
intervals = []
|
|
231
|
+
for i in range(n):
|
|
232
|
+
interval_label = widgets.Label(value=f'{param_labels[i]}')
|
|
233
|
+
initial_max = min(max_bounds[i], min_bounds[i] + max_deltas[i])
|
|
234
|
+
slider = widgets.FloatRangeSlider(
|
|
235
|
+
value=[min_bounds[i], initial_max],
|
|
236
|
+
min=min_bounds[i],
|
|
237
|
+
max=max_bounds[i],
|
|
238
|
+
step=0.01,
|
|
239
|
+
layout=widgets.Layout(width='400px'),
|
|
240
|
+
style={'description_width': '60px'}
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def validate_range(change, slider=slider, max_width=max_deltas[i]):
|
|
244
|
+
min_val, max_val = change['new']
|
|
245
|
+
if max_val - min_val > max_width:
|
|
246
|
+
if change['name'] == 'value':
|
|
247
|
+
if change['old'][0] != min_val:
|
|
248
|
+
max_val = min_val + max_width
|
|
249
|
+
else:
|
|
250
|
+
min_val = max_val - max_width
|
|
251
|
+
slider.value = [min_val, max_val]
|
|
252
|
+
|
|
253
|
+
slider.observe(validate_range, names='value')
|
|
254
|
+
|
|
255
|
+
interval_row = widgets.HBox([interval_label, slider])
|
|
256
|
+
intervals.append((slider, interval_row))
|
|
257
|
+
return intervals
|
|
258
|
+
|
|
259
|
+
interval_widgets = create_interval_widgets(NUM_INTERVALS)
|
|
260
|
+
interval_box = widgets.VBox([widget[1] for widget in interval_widgets])
|
|
261
|
+
display(interval_box)
|
|
262
|
+
|
|
263
|
+
button = widgets.Button(description="Make prediction")
|
|
264
|
+
display(button)
|
|
265
|
+
|
|
266
|
+
prediction_result = None
|
|
267
|
+
|
|
268
|
+
def store_values(b, future):
|
|
269
|
+
print("Debug: Button clicked")
|
|
270
|
+
values = []
|
|
271
|
+
for slider, _ in interval_widgets:
|
|
272
|
+
values.append((slider.value[0], slider.value[1]))
|
|
273
|
+
array_values = np.array(values)
|
|
274
|
+
|
|
275
|
+
nonlocal prediction_result
|
|
276
|
+
prediction_result = self.predict(reflectivity_curve=reflectivity_curve, prior_bounds=array_values, **kwargs)
|
|
277
|
+
print(prediction_result["predicted_params_array"])
|
|
278
|
+
|
|
279
|
+
print("Prediction completed. Closing widget.")
|
|
280
|
+
|
|
281
|
+
for child in interval_box.children:
|
|
282
|
+
child.close()
|
|
283
|
+
button.close()
|
|
284
|
+
|
|
285
|
+
future.set_result(prediction_result)
|
|
286
|
+
|
|
287
|
+
button.on_click(store_values)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
future = asyncio.Future()
|
|
291
|
+
|
|
292
|
+
button.on_click(lambda b: store_values(b, future))
|
|
293
|
+
|
|
294
|
+
return await future
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> BasicParams:
|
|
298
|
+
assert isinstance(self.trainer.loader.q_generator, ConstantQ), "Prediction with q shifts available only for models with fixed discretization"
|
|
299
|
+
q = self.trainer.loader.q_generator.q.squeeze().float()
|
|
300
|
+
dq_max = (q[1] - q[0]) * dq_coef
|
|
301
|
+
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
302
|
+
|
|
303
|
+
curve = to_t(curve).to(scaled_bounds)
|
|
304
|
+
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
305
|
+
|
|
306
|
+
assert shifted_curves.shape == (num, q.shape[0])
|
|
307
|
+
|
|
308
|
+
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
309
|
+
scaled_prior_bounds = torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)
|
|
310
|
+
|
|
311
|
+
with torch.no_grad():
|
|
312
|
+
self.trainer.model.eval()
|
|
313
|
+
scaled_predicted_params = self.trainer.model(scaled_curves, scaled_prior_bounds)
|
|
314
|
+
restored_params = self.trainer.loader.prior_sampler.restore_params(torch.cat([scaled_predicted_params, scaled_prior_bounds], dim=-1))
|
|
315
|
+
|
|
316
|
+
best_param = get_best_mse_param(
|
|
317
|
+
restored_params,
|
|
318
|
+
self._get_likelihood(q=self.trainer.loader.q_generator.q, curve=curve),
|
|
319
|
+
)
|
|
320
|
+
return best_param
|
|
321
|
+
|
|
322
|
+
def _polish_prediction(self,
|
|
323
|
+
q: np.ndarray,
|
|
324
|
+
curve: np.ndarray,
|
|
325
|
+
predicted_params: BasicParams,
|
|
326
|
+
priors: np.ndarray,
|
|
327
|
+
sld_x_axis,
|
|
328
|
+
fit_growth: bool = False,
|
|
329
|
+
max_d_change: float = 5.,
|
|
330
|
+
calc_polished_curve: bool = True,
|
|
331
|
+
calc_polished_sld_profile: bool = False,
|
|
332
|
+
) -> dict:
|
|
333
|
+
params = torch.cat([
|
|
334
|
+
predicted_params.thicknesses.squeeze(),
|
|
335
|
+
predicted_params.roughnesses.squeeze(),
|
|
336
|
+
predicted_params.slds.squeeze()
|
|
337
|
+
]).cpu().numpy()
|
|
338
|
+
|
|
339
|
+
polished_params_dict = {}
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
if fit_growth:
|
|
343
|
+
polished_params_arr, curve_polished = get_fit_with_growth(
|
|
344
|
+
q = q,
|
|
345
|
+
curve = curve,
|
|
346
|
+
init_params = params,
|
|
347
|
+
bounds = priors.T,
|
|
348
|
+
max_d_change = max_d_change,
|
|
349
|
+
)
|
|
350
|
+
polished_params = BasicParams(
|
|
351
|
+
torch.from_numpy(polished_params_arr[:-1][None]),
|
|
352
|
+
torch.from_numpy(priors.T[0][None]),
|
|
353
|
+
torch.from_numpy(priors.T[1][None]),
|
|
354
|
+
#self.trainer.loader.prior_sampler.param_model
|
|
355
|
+
)
|
|
356
|
+
else:
|
|
357
|
+
polished_params_arr, curve_polished = standard_refl_fit(
|
|
358
|
+
q = q,
|
|
359
|
+
curve = curve,
|
|
360
|
+
init_params = params,
|
|
361
|
+
bounds=priors.T)
|
|
362
|
+
polished_params = BasicParams(
|
|
363
|
+
torch.from_numpy(polished_params_arr[None]),
|
|
364
|
+
torch.from_numpy(priors.T[0][None]),
|
|
365
|
+
torch.from_numpy(priors.T[1][None]),
|
|
366
|
+
#self.trainer.loader.prior_sampler.param_model
|
|
367
|
+
)
|
|
368
|
+
except Exception as err:
|
|
369
|
+
polished_params = predicted_params
|
|
370
|
+
polished_params_arr = get_prediction_array(polished_params)
|
|
371
|
+
curve_polished = np.zeros_like(q)
|
|
372
|
+
|
|
373
|
+
polished_params_dict['polished_params_array'] = polished_params_arr
|
|
374
|
+
if calc_polished_curve:
|
|
375
|
+
polished_params_dict['polished_curve'] = curve_polished
|
|
376
|
+
|
|
377
|
+
if calc_polished_sld_profile:
|
|
378
|
+
_, sld_profile_polished, _ = get_density_profiles(
|
|
379
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
|
|
380
|
+
)
|
|
381
|
+
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
|
|
382
|
+
|
|
383
|
+
return polished_params_dict
|
|
384
|
+
|
|
385
|
+
def _scale_curve(self, curve: Union[np.ndarray, Tensor]):
|
|
386
|
+
if not isinstance(curve, Tensor):
|
|
387
|
+
curve = torch.from_numpy(curve).float()
|
|
388
|
+
curve = torch.atleast_2d(curve).to(self.device)
|
|
389
|
+
scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
|
|
390
|
+
return scaled_curve
|
|
391
|
+
|
|
392
|
+
def _scale_prior_bounds(self, prior_bounds: List[Tuple]):
|
|
393
|
+
prior_bounds = torch.tensor(prior_bounds)
|
|
394
|
+
prior_bounds = prior_bounds.to(self.device).T
|
|
395
|
+
min_bounds, max_bounds = prior_bounds[:, None]
|
|
396
|
+
|
|
397
|
+
scaled_bounds = torch.cat([
|
|
398
|
+
self.trainer.loader.prior_sampler.scale_bounds(min_bounds),
|
|
399
|
+
self.trainer.loader.prior_sampler.scale_bounds(max_bounds)
|
|
400
|
+
], -1)
|
|
401
|
+
|
|
402
|
+
return scaled_bounds.float()
|
|
403
|
+
|
|
404
|
+
def _get_likelihood(self, q, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
405
|
+
return LogLikelihood(
|
|
406
|
+
q, curve, self.trainer.loader.prior_sampler, curve * rel_err + abs_err
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
class InferenceModel(object):
|
|
410
|
+
def __init__(self, name: str = None, trainer: PointEstimatorTrainer = None, preprocessing_parameters: dict = None,
|
|
411
|
+
num_sampling: int = 2 ** 13):
|
|
412
|
+
self.log = logging.getLogger(__name__)
|
|
413
|
+
self.model_name = name
|
|
414
|
+
self.trainer = trainer
|
|
415
|
+
self.q = None
|
|
416
|
+
self.preprocessing = StandardPreprocessing(**(preprocessing_parameters or {}))
|
|
417
|
+
self._sampling_num = num_sampling
|
|
418
|
+
|
|
419
|
+
if trainer is None and self.model_name is not None:
|
|
420
|
+
self.load_model(self.model_name)
|
|
421
|
+
elif trainer is not None:
|
|
422
|
+
self._set_trainer(trainer, preprocessing_parameters)
|
|
423
|
+
|
|
424
|
+
### API methods ###
|
|
425
|
+
|
|
426
|
+
def load_model(self, name: str) -> None:
|
|
427
|
+
self.log.debug(f"loading model {name}")
|
|
428
|
+
if self.model_name == name and self.trainer is not None:
|
|
429
|
+
return
|
|
430
|
+
self.model_name = name
|
|
431
|
+
self._set_trainer(get_trainer_by_name(name))
|
|
432
|
+
self.log.info(f"Model {name} is loaded.")
|
|
433
|
+
|
|
434
|
+
def train_model(self, name: str):
|
|
435
|
+
self.model_name = name
|
|
436
|
+
self.trainer = train_from_config(load_config(name))
|
|
437
|
+
|
|
438
|
+
def set_preprocessing_parameters(self, **kwargs) -> None:
|
|
439
|
+
self.preprocessing.set_parameters(**kwargs)
|
|
440
|
+
|
|
441
|
+
def preprocess(self,
|
|
442
|
+
intensity: np.ndarray,
|
|
443
|
+
scattering_angle: np.ndarray,
|
|
444
|
+
attenuation: np.ndarray,
|
|
445
|
+
update_params: bool = False,
|
|
446
|
+
**kwargs) -> dict:
|
|
447
|
+
if update_params:
|
|
448
|
+
self.preprocessing.set_parameters(**kwargs)
|
|
449
|
+
preprocessed_dict = self.preprocessing(intensity, scattering_angle, attenuation, **kwargs)
|
|
450
|
+
return preprocessed_dict
|
|
451
|
+
|
|
452
|
+
def predict(self,
|
|
453
|
+
intensity: np.ndarray,
|
|
454
|
+
scattering_angle: np.ndarray,
|
|
455
|
+
attenuation: np.ndarray,
|
|
456
|
+
priors: np.ndarray,
|
|
457
|
+
preprocessing_parameters: dict = None,
|
|
458
|
+
polish: bool = True,
|
|
459
|
+
use_sampler: bool = False,
|
|
460
|
+
use_q_shift: bool = True,
|
|
461
|
+
max_d_change: float = 5.,
|
|
462
|
+
fit_growth: bool = True,
|
|
463
|
+
) -> dict:
|
|
464
|
+
|
|
465
|
+
with print_time("everything"):
|
|
466
|
+
with print_time("preprocess"):
|
|
467
|
+
preprocessed_dict = self.preprocess(
|
|
468
|
+
intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
preprocessed_curve = preprocessed_dict["curve_interp"]
|
|
472
|
+
raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
|
|
473
|
+
q_ratio = preprocessed_dict["q_ratio"]
|
|
474
|
+
|
|
475
|
+
with print_time("predict_from_preprocessed_curve"):
|
|
476
|
+
preprocessed_dict.update(self.predict_from_preprocessed_curve(
|
|
477
|
+
preprocessed_curve, priors, raw_curve=raw_curve, raw_q=raw_q, polish=polish, q_ratio=q_ratio,
|
|
478
|
+
use_sampler=use_sampler, use_q_shift=use_q_shift, max_d_change=max_d_change,
|
|
479
|
+
fit_growth=fit_growth,
|
|
480
|
+
))
|
|
481
|
+
|
|
482
|
+
return preprocessed_dict
|
|
483
|
+
|
|
484
|
+
def predict_from_preprocessed_curve(self,
|
|
485
|
+
curve: np.ndarray,
|
|
486
|
+
priors: np.ndarray, *,
|
|
487
|
+
polish: bool = True,
|
|
488
|
+
raw_curve: np.ndarray = None,
|
|
489
|
+
raw_q: np.ndarray = None,
|
|
490
|
+
clip_prediction: bool = True,
|
|
491
|
+
q_ratio: float = 1.,
|
|
492
|
+
use_sampler: bool = False,
|
|
493
|
+
use_q_shift: bool = True,
|
|
494
|
+
max_d_change: float = 5.,
|
|
495
|
+
fit_growth: bool = True,
|
|
496
|
+
) -> dict:
|
|
497
|
+
|
|
498
|
+
scaled_curve = self._scale_curve(curve)
|
|
499
|
+
scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
|
|
500
|
+
|
|
501
|
+
if not use_q_shift:
|
|
502
|
+
predicted_params: UniformSubPriorParams = self._simple_prediction(scaled_curve, scaled_bounds)
|
|
503
|
+
else:
|
|
504
|
+
predicted_params: UniformSubPriorParams = self._qshift_prediction(curve, scaled_bounds)
|
|
505
|
+
|
|
506
|
+
if use_sampler:
|
|
507
|
+
predicted_params: UniformSubPriorParams = self._sampler_solution(
|
|
508
|
+
curve, predicted_params,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if clip_prediction:
|
|
512
|
+
predicted_params = self._prior_sampler.clamp_params(predicted_params)
|
|
513
|
+
|
|
514
|
+
if raw_curve is None:
|
|
515
|
+
raw_curve = curve
|
|
516
|
+
if raw_q is None:
|
|
517
|
+
raw_q = self.q.squeeze().cpu().numpy()
|
|
518
|
+
raw_q_t = self.q
|
|
519
|
+
else:
|
|
520
|
+
raw_q_t = torch.from_numpy(raw_q).to(self.q)
|
|
521
|
+
|
|
522
|
+
if q_ratio != 1.:
|
|
523
|
+
predicted_params.scale_with_q(q_ratio)
|
|
524
|
+
raw_q = raw_q * q_ratio
|
|
525
|
+
raw_q_t = raw_q_t * q_ratio
|
|
526
|
+
|
|
527
|
+
prediction_dict = {
|
|
528
|
+
"params": get_prediction_array(predicted_params),
|
|
529
|
+
"param_names": get_param_labels(
|
|
530
|
+
predicted_params.max_layer_num,
|
|
531
|
+
thickness_name='d',
|
|
532
|
+
roughness_name='sigma',
|
|
533
|
+
sld_name='rho',
|
|
534
|
+
),
|
|
535
|
+
"curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
sld_x_axis, sld_profile, _ = get_density_profiles(
|
|
539
|
+
predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
|
|
543
|
+
prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
|
|
544
|
+
|
|
545
|
+
if polish:
|
|
546
|
+
prediction_dict.update(self._polish_prediction(
|
|
547
|
+
raw_q, raw_curve, predicted_params, priors, sld_x_axis,
|
|
548
|
+
max_d_change=max_d_change, fit_growth=fit_growth,
|
|
549
|
+
))
|
|
550
|
+
|
|
551
|
+
if fit_growth and "params_polished" in prediction_dict:
|
|
552
|
+
prediction_dict["param_names"].append("max_d_change")
|
|
553
|
+
|
|
554
|
+
return prediction_dict
|
|
555
|
+
|
|
556
|
+
### some shortcut methods for data processing ###
|
|
557
|
+
|
|
558
|
+
def _simple_prediction(self, scaled_curve, scaled_bounds) -> UniformSubPriorParams:
|
|
559
|
+
context = torch.cat([scaled_curve, scaled_bounds], -1)
|
|
560
|
+
|
|
561
|
+
with torch.no_grad():
|
|
562
|
+
self.trainer.model.eval()
|
|
563
|
+
scaled_params = self.trainer.model(context)
|
|
564
|
+
|
|
565
|
+
predicted_params: UniformSubPriorParams = self._restore_predicted_params(scaled_params, context)
|
|
566
|
+
return predicted_params
|
|
567
|
+
|
|
568
|
+
@print_time
|
|
569
|
+
def _qshift_prediction(self, curve, scaled_bounds, num: int = 1000, dq_coef: float = 1.) -> UniformSubPriorParams:
|
|
570
|
+
q = self.q.squeeze().float()
|
|
571
|
+
curve = to_t(curve).to(q)
|
|
572
|
+
dq_max = (q[1] - q[0]) * dq_coef
|
|
573
|
+
q_shifts = torch.linspace(-dq_max, dq_max, num).to(q)
|
|
574
|
+
shifted_curves = _qshift_interp(q.squeeze(), curve, q_shifts)
|
|
575
|
+
|
|
576
|
+
assert shifted_curves.shape == (num, q.shape[0])
|
|
577
|
+
|
|
578
|
+
scaled_curves = self.trainer.loader.curves_scaler.scale(shifted_curves)
|
|
579
|
+
context = torch.cat([scaled_curves, torch.atleast_2d(scaled_bounds).expand(scaled_curves.shape[0], -1)], -1)
|
|
580
|
+
|
|
581
|
+
with torch.no_grad():
|
|
582
|
+
self.trainer.model.eval()
|
|
583
|
+
scaled_params = self.trainer.model(context)
|
|
584
|
+
restored_params = self._restore_predicted_params(scaled_params, context)
|
|
585
|
+
|
|
586
|
+
best_param = get_best_mse_param(
|
|
587
|
+
restored_params,
|
|
588
|
+
self._get_likelihood(curve),
|
|
589
|
+
)
|
|
590
|
+
return best_param
|
|
591
|
+
|
|
592
|
+
@print_time
|
|
593
|
+
def _polish_prediction(self,
|
|
594
|
+
q: np.ndarray,
|
|
595
|
+
curve: np.ndarray,
|
|
596
|
+
predicted_params: Params,
|
|
597
|
+
priors: np.ndarray,
|
|
598
|
+
sld_x_axis,
|
|
599
|
+
fit_growth: bool = True,
|
|
600
|
+
max_d_change: float = 5.,
|
|
601
|
+
) -> dict:
|
|
602
|
+
params = torch.cat([
|
|
603
|
+
predicted_params.thicknesses.squeeze(),
|
|
604
|
+
predicted_params.roughnesses.squeeze(),
|
|
605
|
+
predicted_params.slds.squeeze()
|
|
606
|
+
]).cpu().numpy()
|
|
607
|
+
|
|
608
|
+
polished_params_dict = {}
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
if fit_growth:
|
|
612
|
+
polished_params_arr, curve_polished = get_fit_with_growth(
|
|
613
|
+
q, curve, params, bounds=priors.T,
|
|
614
|
+
max_d_change=max_d_change,
|
|
615
|
+
)
|
|
616
|
+
polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[:-1][None]).to(self.q))
|
|
617
|
+
else:
|
|
618
|
+
polished_params_arr, curve_polished = standard_refl_fit(q, curve, params, bounds=priors.T)
|
|
619
|
+
polished_params = Params.from_tensor(torch.from_numpy(polished_params_arr[None]).to(self.q))
|
|
620
|
+
except Exception as err:
|
|
621
|
+
self.log.exception(err)
|
|
622
|
+
polished_params = predicted_params
|
|
623
|
+
polished_params_arr = get_prediction_array(polished_params)
|
|
624
|
+
curve_polished = np.zeros_like(q)
|
|
625
|
+
|
|
626
|
+
polished_params_dict['params_polished'] = polished_params_arr
|
|
627
|
+
polished_params_dict['curve_polished'] = curve_polished
|
|
628
|
+
|
|
629
|
+
sld_x_axis_polished, sld_profile_polished, _ = get_density_profiles(
|
|
630
|
+
polished_params.thicknesses, polished_params.roughnesses, polished_params.slds, z_axis=sld_x_axis,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
polished_params_dict['sld_profile_polished'] = sld_profile_polished.squeeze().cpu().numpy()
|
|
634
|
+
|
|
635
|
+
return polished_params_dict
|
|
636
|
+
|
|
637
|
+
def _restore_predicted_params(self, scaled_params: Tensor, context: Tensor) -> UniformSubPriorParams:
|
|
638
|
+
predicted_params: UniformSubPriorParams = self.trainer.loader.prior_sampler.restore_params(
|
|
639
|
+
self.trainer.loader.prior_sampler.PARAM_CLS.restore_params_from_context(scaled_params, context)
|
|
640
|
+
)
|
|
641
|
+
return predicted_params
|
|
642
|
+
|
|
643
|
+
def _input2context(self, curve: np.ndarray, priors: np.ndarray, q_ratio: float = 1.):
|
|
644
|
+
scaled_curve = self._scale_curve(curve)
|
|
645
|
+
scaled_bounds, min_bounds, max_bounds = self._scale_priors(priors, q_ratio)
|
|
646
|
+
scaled_input = torch.cat([scaled_curve, scaled_bounds], -1)
|
|
647
|
+
return scaled_input, min_bounds, max_bounds
|
|
648
|
+
|
|
649
|
+
def _scale_curve(self, curve: np.ndarray or Tensor):
|
|
650
|
+
if not isinstance(curve, Tensor):
|
|
651
|
+
curve = torch.from_numpy(curve).float()
|
|
652
|
+
curve = torch.atleast_2d(curve).to(self.q)
|
|
653
|
+
scaled_curve = self.trainer.loader.curves_scaler.scale(curve)
|
|
654
|
+
return scaled_curve.float()
|
|
655
|
+
|
|
656
|
+
def _scale_priors(self, priors: np.ndarray or Tensor, q_ratio: float = 1.):
|
|
657
|
+
if not isinstance(priors, Tensor):
|
|
658
|
+
priors = torch.from_numpy(priors)
|
|
659
|
+
|
|
660
|
+
priors = priors.float().clone()
|
|
661
|
+
|
|
662
|
+
priors = priors.to(self.q).T
|
|
663
|
+
priors = self._prior_sampler.scale_bounds_with_q(priors, 1 / q_ratio)
|
|
664
|
+
priors = self._prior_sampler.clamp_bounds(priors)
|
|
665
|
+
|
|
666
|
+
min_bounds, max_bounds = priors[:, None].to(self.q)
|
|
667
|
+
prior_sampler = self._prior_sampler
|
|
668
|
+
scaled_bounds = torch.cat([
|
|
669
|
+
prior_sampler.scale_bounds(min_bounds), prior_sampler.scale_bounds(max_bounds)
|
|
670
|
+
], -1)
|
|
671
|
+
return scaled_bounds.float(), min_bounds, max_bounds
|
|
672
|
+
|
|
673
|
+
@property
|
|
674
|
+
def _prior_sampler(self) -> ExpUniformSubPriorSampler:
|
|
675
|
+
return self.trainer.loader.prior_sampler
|
|
676
|
+
|
|
677
|
+
def _set_trainer(self, trainer, preprocessing_parameters: dict = None):
|
|
678
|
+
self.trainer = trainer
|
|
679
|
+
self.trainer.model.eval()
|
|
680
|
+
self._update_preprocessing(preprocessing_parameters)
|
|
681
|
+
|
|
682
|
+
def _update_preprocessing(self, preprocessing_parameters: dict = None):
|
|
683
|
+
self.log.debug(f"setting preprocessing_parameters {preprocessing_parameters}.")
|
|
684
|
+
self.q = self.trainer.loader.q_generator.q
|
|
685
|
+
self.preprocessing = StandardPreprocessing(
|
|
686
|
+
self.q.cpu().squeeze().numpy(),
|
|
687
|
+
**(preprocessing_parameters or {})
|
|
688
|
+
)
|
|
689
|
+
self.log.info(f"preprocessing params are set: {preprocessing_parameters}.")
|
|
690
|
+
|
|
691
|
+
@print_time
|
|
692
|
+
def _sampler_solution(
|
|
693
|
+
self,
|
|
694
|
+
curve: Tensor or np.ndarray,
|
|
695
|
+
predicted_params: UniformSubPriorParams,
|
|
696
|
+
) -> UniformSubPriorParams:
|
|
697
|
+
|
|
698
|
+
if not isinstance(curve, Tensor):
|
|
699
|
+
curve = torch.from_numpy(curve).float()
|
|
700
|
+
curve = curve.to(self.q)
|
|
701
|
+
|
|
702
|
+
refined_params = simple_sampler_solution(
|
|
703
|
+
self._get_likelihood(curve),
|
|
704
|
+
predicted_params,
|
|
705
|
+
self._prior_sampler.min_bounds,
|
|
706
|
+
self._prior_sampler.max_bounds,
|
|
707
|
+
num=self._sampling_num, coef=0.1,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
return refined_params
|
|
711
|
+
|
|
712
|
+
def _get_likelihood(self, curve, rel_err: float = 0.1, abs_err: float = 1e-12):
|
|
713
|
+
return LogLikelihood(
|
|
714
|
+
self.q, curve, self._prior_sampler, curve * rel_err + abs_err
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def get_prediction_array(params: BasicParams) -> np.ndarray:
|
|
719
|
+
predict_arr = torch.cat([
|
|
720
|
+
params.thicknesses.squeeze(),
|
|
721
|
+
params.roughnesses.squeeze(),
|
|
722
|
+
params.slds.squeeze(),
|
|
723
|
+
]).cpu().numpy()
|
|
724
|
+
|
|
725
|
+
return predict_arr
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def _qshift_interp(q, r, q_shifts):
|
|
729
|
+
qs = q[None] + q_shifts[:, None]
|
|
730
|
+
eps = torch.finfo(r.dtype).eps
|
|
731
|
+
ind = torch.searchsorted(q[None].expand_as(qs).contiguous(), qs.contiguous())
|
|
732
|
+
ind = torch.clamp(ind - 1, 0, q.shape[0] - 2)
|
|
733
|
+
slopes = (r[1:] - r[:-1]) / (eps + (q[1:] - q[:-1]))
|
|
734
|
+
return r[ind] + slopes[ind] * (qs - q[ind])
|