reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,369 +1,369 @@
|
|
|
1
|
-
from typing import Tuple, Dict, Type, List
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
|
|
6
|
-
from reflectorch.data_generation.priors.base import PriorSampler
|
|
7
|
-
from reflectorch.data_generation.priors.params import AbstractParams
|
|
8
|
-
from reflectorch.data_generation.priors.no_constraints import (
|
|
9
|
-
DEFAULT_DEVICE,
|
|
10
|
-
DEFAULT_DTYPE,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
from reflectorch.data_generation.priors.parametric_models import (
|
|
14
|
-
MULTILAYER_MODELS,
|
|
15
|
-
NuisanceParamsWrapper,
|
|
16
|
-
ParametricModel,
|
|
17
|
-
)
|
|
18
|
-
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class BasicParams(AbstractParams):
|
|
22
|
-
"""Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
parameters (Tensor): the values of the thin film parameters
|
|
26
|
-
min_bounds (Tensor): the minimum subprior bounds of the parameters
|
|
27
|
-
max_bounds (Tensor): the maximum subprior bounds of the parameters
|
|
28
|
-
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
|
|
29
|
-
param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
__slots__ = (
|
|
33
|
-
'parameters',
|
|
34
|
-
'min_bounds',
|
|
35
|
-
'max_bounds',
|
|
36
|
-
'max_num_layers',
|
|
37
|
-
'param_model',
|
|
38
|
-
)
|
|
39
|
-
PARAM_NAMES = __slots__
|
|
40
|
-
PARAM_MODEL_CLS: Type[ParametricModel]
|
|
41
|
-
MAX_NUM_LAYERS: int = 30
|
|
42
|
-
|
|
43
|
-
def __init__(self,
|
|
44
|
-
parameters: Tensor,
|
|
45
|
-
min_bounds: Tensor,
|
|
46
|
-
max_bounds: Tensor,
|
|
47
|
-
max_num_layers: int = None,
|
|
48
|
-
param_model: ParametricModel = None,
|
|
49
|
-
):
|
|
50
|
-
|
|
51
|
-
max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
|
|
52
|
-
self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
|
|
53
|
-
self.max_num_layers = max_num_layers
|
|
54
|
-
self.parameters = parameters
|
|
55
|
-
self.min_bounds = min_bounds
|
|
56
|
-
self.max_bounds = max_bounds
|
|
57
|
-
|
|
58
|
-
def get_param_labels(self, **kwargs) -> List[str]:
|
|
59
|
-
"""gets the parameter labels"""
|
|
60
|
-
return self.param_model.get_param_labels(**kwargs)
|
|
61
|
-
|
|
62
|
-
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
63
|
-
r"""computes the reflectivity curves directly from the parameters
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
q (Tensor): the q values
|
|
67
|
-
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
Tensor: the simulated reflectivity curves
|
|
71
|
-
"""
|
|
72
|
-
return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
|
|
73
|
-
|
|
74
|
-
@property
|
|
75
|
-
def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
|
|
76
|
-
"""gets the maximum number of layers"""
|
|
77
|
-
return self.max_num_layers
|
|
78
|
-
|
|
79
|
-
@property
|
|
80
|
-
def num_params(self) -> int:
|
|
81
|
-
"""get the number of parameters (parameter dimensionality)"""
|
|
82
|
-
return self.param_model.param_dim
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
def thicknesses(self):
|
|
86
|
-
"""gets the thicknesses"""
|
|
87
|
-
params = self.param_model.to_standard_params(self.parameters)
|
|
88
|
-
return params['thickness']
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
def roughnesses(self):
|
|
92
|
-
"""gets the roughnesses"""
|
|
93
|
-
params = self.param_model.to_standard_params(self.parameters)
|
|
94
|
-
return params['roughness']
|
|
95
|
-
|
|
96
|
-
@property
|
|
97
|
-
def slds(self):
|
|
98
|
-
"""gets the slds"""
|
|
99
|
-
params = self.param_model.to_standard_params(self.parameters)
|
|
100
|
-
return params['sld']
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def real_slds(self):
|
|
104
|
-
"""gets the real part of the slds"""
|
|
105
|
-
params = self.param_model.to_standard_params(self.parameters)
|
|
106
|
-
return params['sld'].real
|
|
107
|
-
|
|
108
|
-
@property
|
|
109
|
-
def imag_slds(self):
|
|
110
|
-
"""gets the imaginary part of the slds (only for complex dtypes)"""
|
|
111
|
-
params = self.param_model.to_standard_params(self.parameters)
|
|
112
|
-
return params['sld'].imag
|
|
113
|
-
|
|
114
|
-
@staticmethod
|
|
115
|
-
def rearrange_context_from_params(
|
|
116
|
-
scaled_params: Tensor,
|
|
117
|
-
context: Tensor,
|
|
118
|
-
inference: bool = False,
|
|
119
|
-
from_params: bool = False,
|
|
120
|
-
):
|
|
121
|
-
if inference:
|
|
122
|
-
if from_params:
|
|
123
|
-
num_params = scaled_params.shape[-1] // 3
|
|
124
|
-
scaled_params = scaled_params[:, num_params:]
|
|
125
|
-
context = torch.cat([context, scaled_params], dim=-1)
|
|
126
|
-
return context
|
|
127
|
-
|
|
128
|
-
num_params = scaled_params.shape[-1] // 3
|
|
129
|
-
assert num_params * 3 == scaled_params.shape[-1]
|
|
130
|
-
scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
131
|
-
context = torch.cat([context, bound_context], dim=-1)
|
|
132
|
-
return scaled_params, context
|
|
133
|
-
|
|
134
|
-
@staticmethod
|
|
135
|
-
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
136
|
-
num_params = scaled_params.shape[-1]
|
|
137
|
-
scaled_bounds = context[:, -2 * num_params:]
|
|
138
|
-
scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
|
|
139
|
-
return scaled_params
|
|
140
|
-
|
|
141
|
-
def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
|
|
142
|
-
"""converts the instance of the class to a Pytorch tensor
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
Tensor: the Pytorch tensor obtained from the instance of the class
|
|
149
|
-
"""
|
|
150
|
-
if not add_bounds:
|
|
151
|
-
return self.parameters
|
|
152
|
-
return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
|
|
153
|
-
|
|
154
|
-
@classmethod
|
|
155
|
-
def from_tensor(cls, params: Tensor, **kwargs):
|
|
156
|
-
"""initializes an instance of the class from a Pytorch tensor
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
|
|
160
|
-
|
|
161
|
-
Returns:
|
|
162
|
-
BasicParams: the instance of the class
|
|
163
|
-
"""
|
|
164
|
-
num_params = params.shape[-1] // 3
|
|
165
|
-
|
|
166
|
-
params, min_bounds, max_bounds = torch.split(
|
|
167
|
-
params, [num_params, num_params, num_params], dim=-1
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
return cls(
|
|
171
|
-
params,
|
|
172
|
-
min_bounds,
|
|
173
|
-
max_bounds,
|
|
174
|
-
**kwargs
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
def scale_with_q(self, q_ratio: float):
|
|
178
|
-
"""scales the parameters based on the q ratio
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
q_ratio (float): the scaling ratio
|
|
182
|
-
"""
|
|
183
|
-
self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
|
|
184
|
-
self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
|
|
185
|
-
self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
class SubpriorParametricSampler(PriorSampler, ScalerMixin):
|
|
189
|
-
PARAM_CLS = BasicParams
|
|
190
|
-
|
|
191
|
-
def __init__(self,
|
|
192
|
-
param_ranges: Dict[str, Tuple[float, float]],
|
|
193
|
-
bound_width_ranges: Dict[str, Tuple[float, float]],
|
|
194
|
-
model_name: str,
|
|
195
|
-
device: torch.device = DEFAULT_DEVICE,
|
|
196
|
-
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
197
|
-
max_num_layers: int = 50,
|
|
198
|
-
logdist: bool = False,
|
|
199
|
-
scale_params_by_ranges = False,
|
|
200
|
-
scaled_range: Tuple[float, float] = (-1., 1.),
|
|
201
|
-
**kwargs
|
|
202
|
-
):
|
|
203
|
-
"""Prior sampler for the parameters of a parametric model and their subprior bounds
|
|
204
|
-
|
|
205
|
-
Args:
|
|
206
|
-
param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
|
|
207
|
-
bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
|
|
208
|
-
model_name (str): the name of the parametric model
|
|
209
|
-
device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
210
|
-
dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
211
|
-
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
|
|
212
|
-
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
213
|
-
scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
|
|
214
|
-
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
|
|
215
|
-
"""
|
|
216
|
-
self.scaled_range = scaled_range
|
|
217
|
-
|
|
218
|
-
self.shift_param_config = kwargs.pop('shift_param_config', {})
|
|
219
|
-
|
|
220
|
-
base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
|
|
221
|
-
if any(self.shift_param_config.values()):
|
|
222
|
-
self.param_model = NuisanceParamsWrapper(
|
|
223
|
-
base_model=base_model,
|
|
224
|
-
nuisance_params_config=self.shift_param_config,
|
|
225
|
-
**kwargs,
|
|
226
|
-
)
|
|
227
|
-
else:
|
|
228
|
-
self.param_model = base_model
|
|
229
|
-
|
|
230
|
-
self.device = device
|
|
231
|
-
self.dtype = dtype
|
|
232
|
-
self.num_layers = max_num_layers
|
|
233
|
-
|
|
234
|
-
self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
|
|
235
|
-
self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
|
|
236
|
-
|
|
237
|
-
self._param_dim = self.param_model.param_dim
|
|
238
|
-
self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
|
|
239
|
-
param_ranges, bound_width_ranges, device=device, dtype=dtype
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
self.param_ranges = param_ranges
|
|
243
|
-
self.bound_width_ranges = bound_width_ranges
|
|
244
|
-
self.model_name = model_name
|
|
245
|
-
self.logdist = logdist
|
|
246
|
-
self.scale_params_by_ranges = scale_params_by_ranges
|
|
247
|
-
|
|
248
|
-
@property
|
|
249
|
-
def max_num_layers(self) -> int:
|
|
250
|
-
"""gets the maximum number of layers"""
|
|
251
|
-
return self.num_layers
|
|
252
|
-
|
|
253
|
-
@property
|
|
254
|
-
def param_dim(self) -> int:
|
|
255
|
-
"""get the number of parameters (parameter dimensionality)"""
|
|
256
|
-
return self._param_dim
|
|
257
|
-
|
|
258
|
-
def sample(self, batch_size: int) -> BasicParams:
|
|
259
|
-
"""sample a batch of parameters
|
|
260
|
-
|
|
261
|
-
Args:
|
|
262
|
-
batch_size (int): the batch size
|
|
263
|
-
|
|
264
|
-
Returns:
|
|
265
|
-
BasicParams: sampled parameters
|
|
266
|
-
"""
|
|
267
|
-
params, min_bounds, max_bounds = self.param_model.sample(
|
|
268
|
-
batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
params = BasicParams(
|
|
272
|
-
parameters=params,
|
|
273
|
-
min_bounds=min_bounds,
|
|
274
|
-
max_bounds=max_bounds,
|
|
275
|
-
max_num_layers=self.max_num_layers,
|
|
276
|
-
param_model=self.param_model,
|
|
277
|
-
)
|
|
278
|
-
|
|
279
|
-
return params
|
|
280
|
-
|
|
281
|
-
def scale_params(self, params: BasicParams) -> Tensor:
|
|
282
|
-
"""scale the parameters to a ML-friendly range
|
|
283
|
-
|
|
284
|
-
Args:
|
|
285
|
-
params (BasicParams): the parameters to be scaled
|
|
286
|
-
|
|
287
|
-
Returns:
|
|
288
|
-
Tensor: the scaled parameters
|
|
289
|
-
"""
|
|
290
|
-
if self.scale_params_by_ranges:
|
|
291
|
-
scaled_params = torch.cat([
|
|
292
|
-
self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
|
|
293
|
-
self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
|
|
294
|
-
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
295
|
-
], -1)
|
|
296
|
-
return scaled_params
|
|
297
|
-
else:
|
|
298
|
-
scaled_params = torch.cat([
|
|
299
|
-
self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
|
|
300
|
-
self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
|
|
301
|
-
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
302
|
-
], -1)
|
|
303
|
-
return scaled_params
|
|
304
|
-
|
|
305
|
-
def restore_params(self, scaled_params: Tensor) -> BasicParams:
|
|
306
|
-
"""restore the parameters to their original range
|
|
307
|
-
|
|
308
|
-
Args:
|
|
309
|
-
scaled_params (Tensor): the scaled parameters
|
|
310
|
-
|
|
311
|
-
Returns:
|
|
312
|
-
BasicParams: the parameters restored to their original range
|
|
313
|
-
"""
|
|
314
|
-
num_params = scaled_params.shape[-1] // 3
|
|
315
|
-
scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
|
|
316
|
-
scaled_params, num_params, -1
|
|
317
|
-
)
|
|
318
|
-
if self.scale_params_by_ranges:
|
|
319
|
-
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
320
|
-
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
321
|
-
params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
|
|
322
|
-
else:
|
|
323
|
-
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
324
|
-
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
325
|
-
params = self._restore(scaled_params, min_bounds, max_bounds)
|
|
326
|
-
|
|
327
|
-
return BasicParams(
|
|
328
|
-
parameters=params,
|
|
329
|
-
min_bounds=min_bounds,
|
|
330
|
-
max_bounds=max_bounds,
|
|
331
|
-
max_num_layers=self.max_num_layers,
|
|
332
|
-
param_model=self.param_model,
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
def scale_bounds(self, bounds: Tensor) -> Tensor:
|
|
336
|
-
return self._scale(bounds, self.min_bounds, self.max_bounds)
|
|
337
|
-
|
|
338
|
-
def log_prob(self, params: BasicParams) -> Tensor:
|
|
339
|
-
log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
|
|
340
|
-
log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
|
|
341
|
-
return log_prob
|
|
342
|
-
|
|
343
|
-
def get_indices_within_domain(self, params: BasicParams) -> Tensor:
|
|
344
|
-
return self.get_indices_within_bounds(params)
|
|
345
|
-
|
|
346
|
-
def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
|
|
347
|
-
return (
|
|
348
|
-
torch.all(params.parameters >= params.min_bounds, -1) &
|
|
349
|
-
torch.all(params.parameters <= params.max_bounds, -1)
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
def filter_params(self, params: BasicParams) -> BasicParams:
|
|
353
|
-
indices = self.get_indices_within_domain(params)
|
|
354
|
-
return params[indices]
|
|
355
|
-
|
|
356
|
-
def clamp_params(
|
|
357
|
-
self, params: BasicParams, inplace: bool = False
|
|
358
|
-
) -> BasicParams:
|
|
359
|
-
if inplace:
|
|
360
|
-
params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
|
|
361
|
-
return params
|
|
362
|
-
|
|
363
|
-
return BasicParams(
|
|
364
|
-
parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
|
|
365
|
-
min_bounds=params.min_bounds.clone(),
|
|
366
|
-
max_bounds=params.max_bounds.clone(),
|
|
367
|
-
max_num_layers=self.max_num_layers,
|
|
368
|
-
param_model=self.param_model,
|
|
369
|
-
)
|
|
1
|
+
from typing import Tuple, Dict, Type, List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
7
|
+
from reflectorch.data_generation.priors.params import AbstractParams
|
|
8
|
+
from reflectorch.data_generation.priors.no_constraints import (
|
|
9
|
+
DEFAULT_DEVICE,
|
|
10
|
+
DEFAULT_DTYPE,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from reflectorch.data_generation.priors.parametric_models import (
|
|
14
|
+
MULTILAYER_MODELS,
|
|
15
|
+
NuisanceParamsWrapper,
|
|
16
|
+
ParametricModel,
|
|
17
|
+
)
|
|
18
|
+
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BasicParams(AbstractParams):
|
|
22
|
+
"""Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
parameters (Tensor): the values of the thin film parameters
|
|
26
|
+
min_bounds (Tensor): the minimum subprior bounds of the parameters
|
|
27
|
+
max_bounds (Tensor): the maximum subprior bounds of the parameters
|
|
28
|
+
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
|
|
29
|
+
param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
__slots__ = (
|
|
33
|
+
'parameters',
|
|
34
|
+
'min_bounds',
|
|
35
|
+
'max_bounds',
|
|
36
|
+
'max_num_layers',
|
|
37
|
+
'param_model',
|
|
38
|
+
)
|
|
39
|
+
PARAM_NAMES = __slots__
|
|
40
|
+
PARAM_MODEL_CLS: Type[ParametricModel]
|
|
41
|
+
MAX_NUM_LAYERS: int = 30
|
|
42
|
+
|
|
43
|
+
def __init__(self,
|
|
44
|
+
parameters: Tensor,
|
|
45
|
+
min_bounds: Tensor,
|
|
46
|
+
max_bounds: Tensor,
|
|
47
|
+
max_num_layers: int = None,
|
|
48
|
+
param_model: ParametricModel = None,
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
|
|
52
|
+
self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
|
|
53
|
+
self.max_num_layers = max_num_layers
|
|
54
|
+
self.parameters = parameters
|
|
55
|
+
self.min_bounds = min_bounds
|
|
56
|
+
self.max_bounds = max_bounds
|
|
57
|
+
|
|
58
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
59
|
+
"""gets the parameter labels"""
|
|
60
|
+
return self.param_model.get_param_labels(**kwargs)
|
|
61
|
+
|
|
62
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
63
|
+
r"""computes the reflectivity curves directly from the parameters
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
q (Tensor): the q values
|
|
67
|
+
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tensor: the simulated reflectivity curves
|
|
71
|
+
"""
|
|
72
|
+
return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
|
|
76
|
+
"""gets the maximum number of layers"""
|
|
77
|
+
return self.max_num_layers
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def num_params(self) -> int:
|
|
81
|
+
"""get the number of parameters (parameter dimensionality)"""
|
|
82
|
+
return self.param_model.param_dim
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def thicknesses(self):
|
|
86
|
+
"""gets the thicknesses"""
|
|
87
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
88
|
+
return params['thickness']
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def roughnesses(self):
|
|
92
|
+
"""gets the roughnesses"""
|
|
93
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
94
|
+
return params['roughness']
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def slds(self):
|
|
98
|
+
"""gets the slds"""
|
|
99
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
100
|
+
return params['sld']
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def real_slds(self):
|
|
104
|
+
"""gets the real part of the slds"""
|
|
105
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
106
|
+
return params['sld'].real
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def imag_slds(self):
|
|
110
|
+
"""gets the imaginary part of the slds (only for complex dtypes)"""
|
|
111
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
112
|
+
return params['sld'].imag
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def rearrange_context_from_params(
|
|
116
|
+
scaled_params: Tensor,
|
|
117
|
+
context: Tensor,
|
|
118
|
+
inference: bool = False,
|
|
119
|
+
from_params: bool = False,
|
|
120
|
+
):
|
|
121
|
+
if inference:
|
|
122
|
+
if from_params:
|
|
123
|
+
num_params = scaled_params.shape[-1] // 3
|
|
124
|
+
scaled_params = scaled_params[:, num_params:]
|
|
125
|
+
context = torch.cat([context, scaled_params], dim=-1)
|
|
126
|
+
return context
|
|
127
|
+
|
|
128
|
+
num_params = scaled_params.shape[-1] // 3
|
|
129
|
+
assert num_params * 3 == scaled_params.shape[-1]
|
|
130
|
+
scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
131
|
+
context = torch.cat([context, bound_context], dim=-1)
|
|
132
|
+
return scaled_params, context
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
136
|
+
num_params = scaled_params.shape[-1]
|
|
137
|
+
scaled_bounds = context[:, -2 * num_params:]
|
|
138
|
+
scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
|
|
139
|
+
return scaled_params
|
|
140
|
+
|
|
141
|
+
def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
|
|
142
|
+
"""converts the instance of the class to a Pytorch tensor
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tensor: the Pytorch tensor obtained from the instance of the class
|
|
149
|
+
"""
|
|
150
|
+
if not add_bounds:
|
|
151
|
+
return self.parameters
|
|
152
|
+
return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def from_tensor(cls, params: Tensor, **kwargs):
|
|
156
|
+
"""initializes an instance of the class from a Pytorch tensor
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
BasicParams: the instance of the class
|
|
163
|
+
"""
|
|
164
|
+
num_params = params.shape[-1] // 3
|
|
165
|
+
|
|
166
|
+
params, min_bounds, max_bounds = torch.split(
|
|
167
|
+
params, [num_params, num_params, num_params], dim=-1
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return cls(
|
|
171
|
+
params,
|
|
172
|
+
min_bounds,
|
|
173
|
+
max_bounds,
|
|
174
|
+
**kwargs
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def scale_with_q(self, q_ratio: float):
|
|
178
|
+
"""scales the parameters based on the q ratio
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
q_ratio (float): the scaling ratio
|
|
182
|
+
"""
|
|
183
|
+
self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
|
|
184
|
+
self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
|
|
185
|
+
self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class SubpriorParametricSampler(PriorSampler, ScalerMixin):
|
|
189
|
+
PARAM_CLS = BasicParams
|
|
190
|
+
|
|
191
|
+
def __init__(self,
|
|
192
|
+
param_ranges: Dict[str, Tuple[float, float]],
|
|
193
|
+
bound_width_ranges: Dict[str, Tuple[float, float]],
|
|
194
|
+
model_name: str,
|
|
195
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
196
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
197
|
+
max_num_layers: int = 50,
|
|
198
|
+
logdist: bool = False,
|
|
199
|
+
scale_params_by_ranges = False,
|
|
200
|
+
scaled_range: Tuple[float, float] = (-1., 1.),
|
|
201
|
+
**kwargs
|
|
202
|
+
):
|
|
203
|
+
"""Prior sampler for the parameters of a parametric model and their subprior bounds
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
|
|
207
|
+
bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
|
|
208
|
+
model_name (str): the name of the parametric model
|
|
209
|
+
device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
210
|
+
dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
211
|
+
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
|
|
212
|
+
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
213
|
+
scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
|
|
214
|
+
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
|
|
215
|
+
"""
|
|
216
|
+
self.scaled_range = scaled_range
|
|
217
|
+
|
|
218
|
+
self.shift_param_config = kwargs.pop('shift_param_config', {})
|
|
219
|
+
|
|
220
|
+
base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
|
|
221
|
+
if any(self.shift_param_config.values()):
|
|
222
|
+
self.param_model = NuisanceParamsWrapper(
|
|
223
|
+
base_model=base_model,
|
|
224
|
+
nuisance_params_config=self.shift_param_config,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
self.param_model = base_model
|
|
229
|
+
|
|
230
|
+
self.device = device
|
|
231
|
+
self.dtype = dtype
|
|
232
|
+
self.num_layers = max_num_layers
|
|
233
|
+
|
|
234
|
+
self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
|
|
235
|
+
self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
|
|
236
|
+
|
|
237
|
+
self._param_dim = self.param_model.param_dim
|
|
238
|
+
self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
|
|
239
|
+
param_ranges, bound_width_ranges, device=device, dtype=dtype
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
self.param_ranges = param_ranges
|
|
243
|
+
self.bound_width_ranges = bound_width_ranges
|
|
244
|
+
self.model_name = model_name
|
|
245
|
+
self.logdist = logdist
|
|
246
|
+
self.scale_params_by_ranges = scale_params_by_ranges
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def max_num_layers(self) -> int:
|
|
250
|
+
"""gets the maximum number of layers"""
|
|
251
|
+
return self.num_layers
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def param_dim(self) -> int:
|
|
255
|
+
"""get the number of parameters (parameter dimensionality)"""
|
|
256
|
+
return self._param_dim
|
|
257
|
+
|
|
258
|
+
def sample(self, batch_size: int) -> BasicParams:
|
|
259
|
+
"""sample a batch of parameters
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
batch_size (int): the batch size
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
BasicParams: sampled parameters
|
|
266
|
+
"""
|
|
267
|
+
params, min_bounds, max_bounds = self.param_model.sample(
|
|
268
|
+
batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
params = BasicParams(
|
|
272
|
+
parameters=params,
|
|
273
|
+
min_bounds=min_bounds,
|
|
274
|
+
max_bounds=max_bounds,
|
|
275
|
+
max_num_layers=self.max_num_layers,
|
|
276
|
+
param_model=self.param_model,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return params
|
|
280
|
+
|
|
281
|
+
def scale_params(self, params: BasicParams) -> Tensor:
|
|
282
|
+
"""scale the parameters to a ML-friendly range
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
params (BasicParams): the parameters to be scaled
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Tensor: the scaled parameters
|
|
289
|
+
"""
|
|
290
|
+
if self.scale_params_by_ranges:
|
|
291
|
+
scaled_params = torch.cat([
|
|
292
|
+
self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
|
|
293
|
+
self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
|
|
294
|
+
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
295
|
+
], -1)
|
|
296
|
+
return scaled_params
|
|
297
|
+
else:
|
|
298
|
+
scaled_params = torch.cat([
|
|
299
|
+
self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
|
|
300
|
+
self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
|
|
301
|
+
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
302
|
+
], -1)
|
|
303
|
+
return scaled_params
|
|
304
|
+
|
|
305
|
+
def restore_params(self, scaled_params: Tensor) -> BasicParams:
|
|
306
|
+
"""restore the parameters to their original range
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
scaled_params (Tensor): the scaled parameters
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
BasicParams: the parameters restored to their original range
|
|
313
|
+
"""
|
|
314
|
+
num_params = scaled_params.shape[-1] // 3
|
|
315
|
+
scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
|
|
316
|
+
scaled_params, num_params, -1
|
|
317
|
+
)
|
|
318
|
+
if self.scale_params_by_ranges:
|
|
319
|
+
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
320
|
+
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
321
|
+
params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
|
|
322
|
+
else:
|
|
323
|
+
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
324
|
+
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
325
|
+
params = self._restore(scaled_params, min_bounds, max_bounds)
|
|
326
|
+
|
|
327
|
+
return BasicParams(
|
|
328
|
+
parameters=params,
|
|
329
|
+
min_bounds=min_bounds,
|
|
330
|
+
max_bounds=max_bounds,
|
|
331
|
+
max_num_layers=self.max_num_layers,
|
|
332
|
+
param_model=self.param_model,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def scale_bounds(self, bounds: Tensor) -> Tensor:
|
|
336
|
+
return self._scale(bounds, self.min_bounds, self.max_bounds)
|
|
337
|
+
|
|
338
|
+
def log_prob(self, params: BasicParams) -> Tensor:
|
|
339
|
+
log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
|
|
340
|
+
log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
|
|
341
|
+
return log_prob
|
|
342
|
+
|
|
343
|
+
def get_indices_within_domain(self, params: BasicParams) -> Tensor:
|
|
344
|
+
return self.get_indices_within_bounds(params)
|
|
345
|
+
|
|
346
|
+
def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
|
|
347
|
+
return (
|
|
348
|
+
torch.all(params.parameters >= params.min_bounds, -1) &
|
|
349
|
+
torch.all(params.parameters <= params.max_bounds, -1)
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def filter_params(self, params: BasicParams) -> BasicParams:
|
|
353
|
+
indices = self.get_indices_within_domain(params)
|
|
354
|
+
return params[indices]
|
|
355
|
+
|
|
356
|
+
def clamp_params(
|
|
357
|
+
self, params: BasicParams, inplace: bool = False
|
|
358
|
+
) -> BasicParams:
|
|
359
|
+
if inplace:
|
|
360
|
+
params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
|
|
361
|
+
return params
|
|
362
|
+
|
|
363
|
+
return BasicParams(
|
|
364
|
+
parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
|
|
365
|
+
min_bounds=params.min_bounds.clone(),
|
|
366
|
+
max_bounds=params.max_bounds.clone(),
|
|
367
|
+
max_num_layers=self.max_num_layers,
|
|
368
|
+
param_model=self.param_model,
|
|
369
|
+
)
|