reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,193 +1,193 @@
|
|
|
1
|
-
from typing import Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
from reflectorch.inference.inference_model import (
|
|
8
|
-
InferenceModel,
|
|
9
|
-
)
|
|
10
|
-
from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
|
|
11
|
-
|
|
12
|
-
from reflectorch.data_generation.priors import (
|
|
13
|
-
MultilayerStructureParams,
|
|
14
|
-
SimpleMultilayerSampler,
|
|
15
|
-
)
|
|
16
|
-
from reflectorch.inference.record_time import print_time
|
|
17
|
-
from reflectorch.inference.scipy_fitter import standard_refl_fit
|
|
18
|
-
from reflectorch.inference.multilayer_fitter import MultilayerFit
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class MultilayerInferenceModel(InferenceModel):
|
|
22
|
-
def predict(self,
|
|
23
|
-
intensity: np.ndarray,
|
|
24
|
-
scattering_angle: np.ndarray,
|
|
25
|
-
attenuation: np.ndarray,
|
|
26
|
-
priors: np.ndarray = None,
|
|
27
|
-
preprocessing_parameters: dict = None,
|
|
28
|
-
polish: bool = True,
|
|
29
|
-
use_raw_q: bool = False,
|
|
30
|
-
**kwargs
|
|
31
|
-
) -> dict:
|
|
32
|
-
|
|
33
|
-
with print_time("everything"):
|
|
34
|
-
with print_time("preprocess"):
|
|
35
|
-
preprocessed_dict = self.preprocess(
|
|
36
|
-
intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
preprocessed_curve = preprocessed_dict["curve_interp"]
|
|
40
|
-
|
|
41
|
-
raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
|
|
42
|
-
|
|
43
|
-
with print_time("predict_from_preprocessed_curve"):
|
|
44
|
-
preprocessed_dict.update(self.predict_from_preprocessed_curve(
|
|
45
|
-
preprocessed_curve, priors,
|
|
46
|
-
raw_curve=(raw_curve if use_raw_q else None),
|
|
47
|
-
raw_q=raw_q,
|
|
48
|
-
polish=polish,
|
|
49
|
-
use_raw_q=use_raw_q,
|
|
50
|
-
**kwargs
|
|
51
|
-
))
|
|
52
|
-
|
|
53
|
-
return preprocessed_dict
|
|
54
|
-
|
|
55
|
-
def predict_from_preprocessed_curve(self,
|
|
56
|
-
curve: np.ndarray,
|
|
57
|
-
priors: np.ndarray = None, *, # ignore the priors so far
|
|
58
|
-
polish: bool = True,
|
|
59
|
-
raw_curve: np.ndarray = None,
|
|
60
|
-
raw_q: np.ndarray = None,
|
|
61
|
-
clip_prediction: bool = True,
|
|
62
|
-
use_raw_q: bool = False,
|
|
63
|
-
use_sampler: bool = False,
|
|
64
|
-
fitted_time_limit: float = 3.,
|
|
65
|
-
sampler_rel_bounds: float = 0.3,
|
|
66
|
-
polish_with_abeles: bool = False,
|
|
67
|
-
**kwargs
|
|
68
|
-
) -> dict:
|
|
69
|
-
|
|
70
|
-
scaled_curve = self._scale_curve(curve)
|
|
71
|
-
|
|
72
|
-
predicted_params, parametrized = self._simple_prediction(scaled_curve)
|
|
73
|
-
|
|
74
|
-
if use_sampler:
|
|
75
|
-
parametrized: Tensor = self._sampler_solution(
|
|
76
|
-
curve, parametrized,
|
|
77
|
-
time_limit=fitted_time_limit,
|
|
78
|
-
rel_bounds=sampler_rel_bounds,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
init_raw_q = raw_q
|
|
82
|
-
|
|
83
|
-
if raw_curve is None:
|
|
84
|
-
raw_curve = curve
|
|
85
|
-
raw_q = self.q.squeeze().cpu().numpy()
|
|
86
|
-
raw_q_t = self.q
|
|
87
|
-
else:
|
|
88
|
-
raw_q_t = torch.from_numpy(raw_q).to(self.q)
|
|
89
|
-
|
|
90
|
-
# if q_ratio != 1.:
|
|
91
|
-
# predicted_params.scale_with_q(q_ratio)
|
|
92
|
-
# raw_q = raw_q * q_ratio
|
|
93
|
-
# raw_q_t = raw_q_t * q_ratio
|
|
94
|
-
|
|
95
|
-
prediction_dict = {
|
|
96
|
-
"params": parametrized.squeeze().cpu().numpy(),
|
|
97
|
-
"param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
|
|
98
|
-
"curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
# sld_x_axis, sld_profile, _ = get_density_profiles(
|
|
102
|
-
# predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
|
|
103
|
-
# )
|
|
104
|
-
#
|
|
105
|
-
# prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
|
|
106
|
-
# prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
|
|
107
|
-
|
|
108
|
-
if polish:
|
|
109
|
-
prediction_dict.update(self._polish_prediction(
|
|
110
|
-
raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
|
|
111
|
-
))
|
|
112
|
-
if polish_with_abeles:
|
|
113
|
-
prediction_dict.update(self._polish_prediction(
|
|
114
|
-
raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
|
|
115
|
-
))
|
|
116
|
-
|
|
117
|
-
return prediction_dict
|
|
118
|
-
|
|
119
|
-
def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
120
|
-
with torch.no_grad():
|
|
121
|
-
self.trainer.model.eval()
|
|
122
|
-
scaled_params = self.trainer.model(scaled_curve)
|
|
123
|
-
|
|
124
|
-
predicted_params, parametrized = self._restore_predicted_params(scaled_params)
|
|
125
|
-
return predicted_params, parametrized
|
|
126
|
-
|
|
127
|
-
def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
128
|
-
parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
|
|
129
|
-
predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
|
|
130
|
-
return predicted_params, parametrized
|
|
131
|
-
|
|
132
|
-
@print_time
|
|
133
|
-
def _sampler_solution(
|
|
134
|
-
self,
|
|
135
|
-
curve: Tensor or np.ndarray,
|
|
136
|
-
predicted_params: Tensor,
|
|
137
|
-
batch_size: int = 2 ** 13,
|
|
138
|
-
time_limit: float = 3.,
|
|
139
|
-
rel_bounds: float = 0.3,
|
|
140
|
-
) -> Tensor:
|
|
141
|
-
|
|
142
|
-
fit_obj = MultilayerFit.from_prediction(
|
|
143
|
-
predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
|
|
144
|
-
batch_size=batch_size, rel_bounds=rel_bounds,
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
fit_obj.run_fixed_time(time_limit)
|
|
148
|
-
|
|
149
|
-
best_params = fit_obj.get_best_solution()
|
|
150
|
-
|
|
151
|
-
return best_params
|
|
152
|
-
|
|
153
|
-
@property
|
|
154
|
-
def _prior_sampler(self) -> SimpleMultilayerSampler:
|
|
155
|
-
return self.trainer.loader.prior_sampler
|
|
156
|
-
|
|
157
|
-
@print_time
|
|
158
|
-
def _polish_prediction(self,
|
|
159
|
-
q: np.ndarray,
|
|
160
|
-
curve: np.ndarray,
|
|
161
|
-
predicted_params: Tensor,
|
|
162
|
-
q_values: np.ndarray,
|
|
163
|
-
use_kinematical: bool = True,
|
|
164
|
-
) -> dict:
|
|
165
|
-
|
|
166
|
-
params = predicted_params.squeeze().cpu().numpy()
|
|
167
|
-
polished_params_dict = {}
|
|
168
|
-
|
|
169
|
-
if use_kinematical:
|
|
170
|
-
refl_generator = kinematical_approximation_np
|
|
171
|
-
else:
|
|
172
|
-
refl_generator = abeles_np
|
|
173
|
-
|
|
174
|
-
try:
|
|
175
|
-
polished_params_arr, curve_polished = standard_refl_fit(
|
|
176
|
-
q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
|
|
177
|
-
refl_generator=refl_generator,
|
|
178
|
-
bounds=self._prior_sampler.get_np_bounds(),
|
|
179
|
-
)
|
|
180
|
-
params = self._prior_sampler.restore_np_params(polished_params_arr)
|
|
181
|
-
if q_values is None:
|
|
182
|
-
q_values = q
|
|
183
|
-
curve_polished = abeles_np(q_values, **params)
|
|
184
|
-
|
|
185
|
-
except Exception as err:
|
|
186
|
-
self.log.exception(err)
|
|
187
|
-
polished_params_arr = params
|
|
188
|
-
curve_polished = np.zeros_like(q)
|
|
189
|
-
|
|
190
|
-
polished_params_dict['params_polished'] = polished_params_arr
|
|
191
|
-
polished_params_dict['curve_polished'] = curve_polished
|
|
192
|
-
|
|
193
|
-
return polished_params_dict
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from reflectorch.inference.inference_model import (
|
|
8
|
+
InferenceModel,
|
|
9
|
+
)
|
|
10
|
+
from reflectorch.data_generation.reflectivity import kinematical_approximation_np, abeles_np
|
|
11
|
+
|
|
12
|
+
from reflectorch.data_generation.priors import (
|
|
13
|
+
MultilayerStructureParams,
|
|
14
|
+
SimpleMultilayerSampler,
|
|
15
|
+
)
|
|
16
|
+
from reflectorch.inference.record_time import print_time
|
|
17
|
+
from reflectorch.inference.scipy_fitter import standard_refl_fit
|
|
18
|
+
from reflectorch.inference.multilayer_fitter import MultilayerFit
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MultilayerInferenceModel(InferenceModel):
|
|
22
|
+
def predict(self,
|
|
23
|
+
intensity: np.ndarray,
|
|
24
|
+
scattering_angle: np.ndarray,
|
|
25
|
+
attenuation: np.ndarray,
|
|
26
|
+
priors: np.ndarray = None,
|
|
27
|
+
preprocessing_parameters: dict = None,
|
|
28
|
+
polish: bool = True,
|
|
29
|
+
use_raw_q: bool = False,
|
|
30
|
+
**kwargs
|
|
31
|
+
) -> dict:
|
|
32
|
+
|
|
33
|
+
with print_time("everything"):
|
|
34
|
+
with print_time("preprocess"):
|
|
35
|
+
preprocessed_dict = self.preprocess(
|
|
36
|
+
intensity, scattering_angle, attenuation, **(preprocessing_parameters or {})
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
preprocessed_curve = preprocessed_dict["curve_interp"]
|
|
40
|
+
|
|
41
|
+
raw_curve, raw_q = preprocessed_dict["curve"], preprocessed_dict["q_values"]
|
|
42
|
+
|
|
43
|
+
with print_time("predict_from_preprocessed_curve"):
|
|
44
|
+
preprocessed_dict.update(self.predict_from_preprocessed_curve(
|
|
45
|
+
preprocessed_curve, priors,
|
|
46
|
+
raw_curve=(raw_curve if use_raw_q else None),
|
|
47
|
+
raw_q=raw_q,
|
|
48
|
+
polish=polish,
|
|
49
|
+
use_raw_q=use_raw_q,
|
|
50
|
+
**kwargs
|
|
51
|
+
))
|
|
52
|
+
|
|
53
|
+
return preprocessed_dict
|
|
54
|
+
|
|
55
|
+
def predict_from_preprocessed_curve(self,
|
|
56
|
+
curve: np.ndarray,
|
|
57
|
+
priors: np.ndarray = None, *, # ignore the priors so far
|
|
58
|
+
polish: bool = True,
|
|
59
|
+
raw_curve: np.ndarray = None,
|
|
60
|
+
raw_q: np.ndarray = None,
|
|
61
|
+
clip_prediction: bool = True,
|
|
62
|
+
use_raw_q: bool = False,
|
|
63
|
+
use_sampler: bool = False,
|
|
64
|
+
fitted_time_limit: float = 3.,
|
|
65
|
+
sampler_rel_bounds: float = 0.3,
|
|
66
|
+
polish_with_abeles: bool = False,
|
|
67
|
+
**kwargs
|
|
68
|
+
) -> dict:
|
|
69
|
+
|
|
70
|
+
scaled_curve = self._scale_curve(curve)
|
|
71
|
+
|
|
72
|
+
predicted_params, parametrized = self._simple_prediction(scaled_curve)
|
|
73
|
+
|
|
74
|
+
if use_sampler:
|
|
75
|
+
parametrized: Tensor = self._sampler_solution(
|
|
76
|
+
curve, parametrized,
|
|
77
|
+
time_limit=fitted_time_limit,
|
|
78
|
+
rel_bounds=sampler_rel_bounds,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
init_raw_q = raw_q
|
|
82
|
+
|
|
83
|
+
if raw_curve is None:
|
|
84
|
+
raw_curve = curve
|
|
85
|
+
raw_q = self.q.squeeze().cpu().numpy()
|
|
86
|
+
raw_q_t = self.q
|
|
87
|
+
else:
|
|
88
|
+
raw_q_t = torch.from_numpy(raw_q).to(self.q)
|
|
89
|
+
|
|
90
|
+
# if q_ratio != 1.:
|
|
91
|
+
# predicted_params.scale_with_q(q_ratio)
|
|
92
|
+
# raw_q = raw_q * q_ratio
|
|
93
|
+
# raw_q_t = raw_q_t * q_ratio
|
|
94
|
+
|
|
95
|
+
prediction_dict = {
|
|
96
|
+
"params": parametrized.squeeze().cpu().numpy(),
|
|
97
|
+
"param_names": list(self._prior_sampler.multilayer_model.PARAMETER_NAMES),
|
|
98
|
+
"curve_predicted": predicted_params.reflectivity(raw_q_t).squeeze().cpu().numpy()
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# sld_x_axis, sld_profile, _ = get_density_profiles(
|
|
102
|
+
# predicted_params.thicknesses, predicted_params.roughnesses, predicted_params.slds, num=1024,
|
|
103
|
+
# )
|
|
104
|
+
#
|
|
105
|
+
# prediction_dict['sld_profile'] = sld_profile.squeeze().cpu().numpy()
|
|
106
|
+
# prediction_dict['sld_x_axis'] = sld_x_axis.squeeze().cpu().numpy()
|
|
107
|
+
|
|
108
|
+
if polish:
|
|
109
|
+
prediction_dict.update(self._polish_prediction(
|
|
110
|
+
raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=True
|
|
111
|
+
))
|
|
112
|
+
if polish_with_abeles:
|
|
113
|
+
prediction_dict.update(self._polish_prediction(
|
|
114
|
+
raw_q, raw_curve, parametrized, q_values=init_raw_q, use_kinematical=False
|
|
115
|
+
))
|
|
116
|
+
|
|
117
|
+
return prediction_dict
|
|
118
|
+
|
|
119
|
+
def _simple_prediction(self, scaled_curve) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
self.trainer.model.eval()
|
|
122
|
+
scaled_params = self.trainer.model(scaled_curve)
|
|
123
|
+
|
|
124
|
+
predicted_params, parametrized = self._restore_predicted_params(scaled_params)
|
|
125
|
+
return predicted_params, parametrized
|
|
126
|
+
|
|
127
|
+
def _restore_predicted_params(self, scaled_params: Tensor) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
128
|
+
parametrized = self._prior_sampler.restore_params2parametrized(scaled_params)
|
|
129
|
+
predicted_params: MultilayerStructureParams = self._prior_sampler.restore_params(scaled_params)
|
|
130
|
+
return predicted_params, parametrized
|
|
131
|
+
|
|
132
|
+
@print_time
|
|
133
|
+
def _sampler_solution(
|
|
134
|
+
self,
|
|
135
|
+
curve: Tensor or np.ndarray,
|
|
136
|
+
predicted_params: Tensor,
|
|
137
|
+
batch_size: int = 2 ** 13,
|
|
138
|
+
time_limit: float = 3.,
|
|
139
|
+
rel_bounds: float = 0.3,
|
|
140
|
+
) -> Tensor:
|
|
141
|
+
|
|
142
|
+
fit_obj = MultilayerFit.from_prediction(
|
|
143
|
+
predicted_params, self._prior_sampler, self.q, torch.as_tensor(curve).to(self.q),
|
|
144
|
+
batch_size=batch_size, rel_bounds=rel_bounds,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
fit_obj.run_fixed_time(time_limit)
|
|
148
|
+
|
|
149
|
+
best_params = fit_obj.get_best_solution()
|
|
150
|
+
|
|
151
|
+
return best_params
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def _prior_sampler(self) -> SimpleMultilayerSampler:
|
|
155
|
+
return self.trainer.loader.prior_sampler
|
|
156
|
+
|
|
157
|
+
@print_time
|
|
158
|
+
def _polish_prediction(self,
|
|
159
|
+
q: np.ndarray,
|
|
160
|
+
curve: np.ndarray,
|
|
161
|
+
predicted_params: Tensor,
|
|
162
|
+
q_values: np.ndarray,
|
|
163
|
+
use_kinematical: bool = True,
|
|
164
|
+
) -> dict:
|
|
165
|
+
|
|
166
|
+
params = predicted_params.squeeze().cpu().numpy()
|
|
167
|
+
polished_params_dict = {}
|
|
168
|
+
|
|
169
|
+
if use_kinematical:
|
|
170
|
+
refl_generator = kinematical_approximation_np
|
|
171
|
+
else:
|
|
172
|
+
refl_generator = abeles_np
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
polished_params_arr, curve_polished = standard_refl_fit(
|
|
176
|
+
q, curve, params, restore_params_func=self._prior_sampler.restore_np_params,
|
|
177
|
+
refl_generator=refl_generator,
|
|
178
|
+
bounds=self._prior_sampler.get_np_bounds(),
|
|
179
|
+
)
|
|
180
|
+
params = self._prior_sampler.restore_np_params(polished_params_arr)
|
|
181
|
+
if q_values is None:
|
|
182
|
+
q_values = q
|
|
183
|
+
curve_polished = abeles_np(q_values, **params)
|
|
184
|
+
|
|
185
|
+
except Exception as err:
|
|
186
|
+
self.log.exception(err)
|
|
187
|
+
polished_params_arr = params
|
|
188
|
+
curve_polished = np.zeros_like(q)
|
|
189
|
+
|
|
190
|
+
polished_params_dict['params_polished'] = polished_params_arr
|
|
191
|
+
polished_params_dict['curve_polished'] = curve_polished
|
|
192
|
+
|
|
193
|
+
return polished_params_dict
|