reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
from typing import Tuple, Dict, Type, List
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
7
|
+
from reflectorch.data_generation.priors.params import AbstractParams
|
|
8
|
+
from reflectorch.data_generation.priors.no_constraints import (
|
|
9
|
+
DEFAULT_DEVICE,
|
|
10
|
+
DEFAULT_DTYPE,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from reflectorch.data_generation.priors.parametric_models import (
|
|
14
|
+
MULTILAYER_MODELS,
|
|
15
|
+
NuisanceParamsWrapper,
|
|
16
|
+
ParametricModel,
|
|
17
|
+
)
|
|
18
|
+
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BasicParams(AbstractParams):
|
|
22
|
+
"""Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
parameters (Tensor): the values of the thin film parameters
|
|
26
|
+
min_bounds (Tensor): the minimum subprior bounds of the parameters
|
|
27
|
+
max_bounds (Tensor): the maximum subprior bounds of the parameters
|
|
28
|
+
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
|
|
29
|
+
param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
__slots__ = (
|
|
33
|
+
'parameters',
|
|
34
|
+
'min_bounds',
|
|
35
|
+
'max_bounds',
|
|
36
|
+
'max_num_layers',
|
|
37
|
+
'param_model',
|
|
38
|
+
)
|
|
39
|
+
PARAM_NAMES = __slots__
|
|
40
|
+
PARAM_MODEL_CLS: Type[ParametricModel]
|
|
41
|
+
MAX_NUM_LAYERS: int = 30
|
|
42
|
+
|
|
43
|
+
def __init__(self,
|
|
44
|
+
parameters: Tensor,
|
|
45
|
+
min_bounds: Tensor,
|
|
46
|
+
max_bounds: Tensor,
|
|
47
|
+
max_num_layers: int = None,
|
|
48
|
+
param_model: ParametricModel = None,
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
|
|
52
|
+
self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
|
|
53
|
+
self.max_num_layers = max_num_layers
|
|
54
|
+
self.parameters = parameters
|
|
55
|
+
self.min_bounds = min_bounds
|
|
56
|
+
self.max_bounds = max_bounds
|
|
57
|
+
|
|
58
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
59
|
+
"""gets the parameter labels"""
|
|
60
|
+
return self.param_model.get_param_labels(**kwargs)
|
|
61
|
+
|
|
62
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
63
|
+
r"""computes the reflectivity curves directly from the parameters
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
q (Tensor): the q values
|
|
67
|
+
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tensor: the simulated reflectivity curves
|
|
71
|
+
"""
|
|
72
|
+
return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
|
|
76
|
+
"""gets the maximum number of layers"""
|
|
77
|
+
return self.max_num_layers
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def num_params(self) -> int:
|
|
81
|
+
"""get the number of parameters (parameter dimensionality)"""
|
|
82
|
+
return self.param_model.param_dim
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def thicknesses(self):
|
|
86
|
+
"""gets the thicknesses"""
|
|
87
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
88
|
+
return params['thickness']
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def roughnesses(self):
|
|
92
|
+
"""gets the roughnesses"""
|
|
93
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
94
|
+
return params['roughness']
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def slds(self):
|
|
98
|
+
"""gets the slds"""
|
|
99
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
100
|
+
return params['sld']
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def real_slds(self):
|
|
104
|
+
"""gets the real part of the slds"""
|
|
105
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
106
|
+
return params['sld'].real
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def imag_slds(self):
|
|
110
|
+
"""gets the imaginary part of the slds (only for complex dtypes)"""
|
|
111
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
112
|
+
return params['sld'].imag
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def rearrange_context_from_params(
|
|
116
|
+
scaled_params: Tensor,
|
|
117
|
+
context: Tensor,
|
|
118
|
+
inference: bool = False,
|
|
119
|
+
from_params: bool = False,
|
|
120
|
+
):
|
|
121
|
+
if inference:
|
|
122
|
+
if from_params:
|
|
123
|
+
num_params = scaled_params.shape[-1] // 3
|
|
124
|
+
scaled_params = scaled_params[:, num_params:]
|
|
125
|
+
context = torch.cat([context, scaled_params], dim=-1)
|
|
126
|
+
return context
|
|
127
|
+
|
|
128
|
+
num_params = scaled_params.shape[-1] // 3
|
|
129
|
+
assert num_params * 3 == scaled_params.shape[-1]
|
|
130
|
+
scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
131
|
+
context = torch.cat([context, bound_context], dim=-1)
|
|
132
|
+
return scaled_params, context
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
136
|
+
num_params = scaled_params.shape[-1]
|
|
137
|
+
scaled_bounds = context[:, -2 * num_params:]
|
|
138
|
+
scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
|
|
139
|
+
return scaled_params
|
|
140
|
+
|
|
141
|
+
def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
|
|
142
|
+
"""converts the instance of the class to a Pytorch tensor
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tensor: the Pytorch tensor obtained from the instance of the class
|
|
149
|
+
"""
|
|
150
|
+
if not add_bounds:
|
|
151
|
+
return self.parameters
|
|
152
|
+
return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def from_tensor(cls, params: Tensor, **kwargs):
|
|
156
|
+
"""initializes an instance of the class from a Pytorch tensor
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
BasicParams: the instance of the class
|
|
163
|
+
"""
|
|
164
|
+
num_params = params.shape[-1] // 3
|
|
165
|
+
|
|
166
|
+
params, min_bounds, max_bounds = torch.split(
|
|
167
|
+
params, [num_params, num_params, num_params], dim=-1
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return cls(
|
|
171
|
+
params,
|
|
172
|
+
min_bounds,
|
|
173
|
+
max_bounds,
|
|
174
|
+
**kwargs
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def scale_with_q(self, q_ratio: float):
|
|
178
|
+
"""scales the parameters based on the q ratio
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
q_ratio (float): the scaling ratio
|
|
182
|
+
"""
|
|
183
|
+
self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
|
|
184
|
+
self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
|
|
185
|
+
self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class SubpriorParametricSampler(PriorSampler, ScalerMixin):
|
|
189
|
+
PARAM_CLS = BasicParams
|
|
190
|
+
|
|
191
|
+
def __init__(self,
|
|
192
|
+
param_ranges: Dict[str, Tuple[float, float]],
|
|
193
|
+
bound_width_ranges: Dict[str, Tuple[float, float]],
|
|
194
|
+
model_name: str,
|
|
195
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
196
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
197
|
+
max_num_layers: int = 50,
|
|
198
|
+
logdist: bool = False,
|
|
199
|
+
scale_params_by_ranges = False,
|
|
200
|
+
scaled_range: Tuple[float, float] = (-1., 1.),
|
|
201
|
+
**kwargs
|
|
202
|
+
):
|
|
203
|
+
"""Prior sampler for the parameters of a parametric model and their subprior bounds
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
|
|
207
|
+
bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
|
|
208
|
+
model_name (str): the name of the parametric model
|
|
209
|
+
device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
210
|
+
dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
211
|
+
max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
|
|
212
|
+
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
213
|
+
scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
|
|
214
|
+
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
|
|
215
|
+
"""
|
|
216
|
+
self.scaled_range = scaled_range
|
|
217
|
+
|
|
218
|
+
self.shift_param_config = kwargs.pop('shift_param_config', {})
|
|
219
|
+
|
|
220
|
+
base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
|
|
221
|
+
if any(self.shift_param_config.values()):
|
|
222
|
+
self.param_model = NuisanceParamsWrapper(
|
|
223
|
+
base_model=base_model,
|
|
224
|
+
nuisance_params_config=self.shift_param_config,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
self.param_model = base_model
|
|
229
|
+
|
|
230
|
+
self.device = device
|
|
231
|
+
self.dtype = dtype
|
|
232
|
+
self.num_layers = max_num_layers
|
|
233
|
+
|
|
234
|
+
self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
|
|
235
|
+
self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
|
|
236
|
+
|
|
237
|
+
self._param_dim = self.param_model.param_dim
|
|
238
|
+
self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
|
|
239
|
+
param_ranges, bound_width_ranges, device=device, dtype=dtype
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
self.param_ranges = param_ranges
|
|
243
|
+
self.bound_width_ranges = bound_width_ranges
|
|
244
|
+
self.model_name = model_name
|
|
245
|
+
self.logdist = logdist
|
|
246
|
+
self.scale_params_by_ranges = scale_params_by_ranges
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def max_num_layers(self) -> int:
|
|
250
|
+
"""gets the maximum number of layers"""
|
|
251
|
+
return self.num_layers
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def param_dim(self) -> int:
|
|
255
|
+
"""get the number of parameters (parameter dimensionality)"""
|
|
256
|
+
return self._param_dim
|
|
257
|
+
|
|
258
|
+
def sample(self, batch_size: int) -> BasicParams:
|
|
259
|
+
"""sample a batch of parameters
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
batch_size (int): the batch size
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
BasicParams: sampled parameters
|
|
266
|
+
"""
|
|
267
|
+
params, min_bounds, max_bounds = self.param_model.sample(
|
|
268
|
+
batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
params = BasicParams(
|
|
272
|
+
parameters=params,
|
|
273
|
+
min_bounds=min_bounds,
|
|
274
|
+
max_bounds=max_bounds,
|
|
275
|
+
max_num_layers=self.max_num_layers,
|
|
276
|
+
param_model=self.param_model,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return params
|
|
280
|
+
|
|
281
|
+
def scale_params(self, params: BasicParams) -> Tensor:
|
|
282
|
+
"""scale the parameters to a ML-friendly range
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
params (BasicParams): the parameters to be scaled
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Tensor: the scaled parameters
|
|
289
|
+
"""
|
|
290
|
+
if self.scale_params_by_ranges:
|
|
291
|
+
scaled_params = torch.cat([
|
|
292
|
+
self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
|
|
293
|
+
self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
|
|
294
|
+
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
295
|
+
], -1)
|
|
296
|
+
return scaled_params
|
|
297
|
+
else:
|
|
298
|
+
scaled_params = torch.cat([
|
|
299
|
+
self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
|
|
300
|
+
self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
|
|
301
|
+
self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
|
|
302
|
+
], -1)
|
|
303
|
+
return scaled_params
|
|
304
|
+
|
|
305
|
+
def restore_params(self, scaled_params: Tensor) -> BasicParams:
|
|
306
|
+
"""restore the parameters to their original range
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
scaled_params (Tensor): the scaled parameters
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
BasicParams: the parameters restored to their original range
|
|
313
|
+
"""
|
|
314
|
+
num_params = scaled_params.shape[-1] // 3
|
|
315
|
+
scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
|
|
316
|
+
scaled_params, num_params, -1
|
|
317
|
+
)
|
|
318
|
+
if self.scale_params_by_ranges:
|
|
319
|
+
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
320
|
+
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
321
|
+
params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
|
|
322
|
+
else:
|
|
323
|
+
min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
|
|
324
|
+
max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
|
|
325
|
+
params = self._restore(scaled_params, min_bounds, max_bounds)
|
|
326
|
+
|
|
327
|
+
return BasicParams(
|
|
328
|
+
parameters=params,
|
|
329
|
+
min_bounds=min_bounds,
|
|
330
|
+
max_bounds=max_bounds,
|
|
331
|
+
max_num_layers=self.max_num_layers,
|
|
332
|
+
param_model=self.param_model,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def scale_bounds(self, bounds: Tensor) -> Tensor:
|
|
336
|
+
return self._scale(bounds, self.min_bounds, self.max_bounds)
|
|
337
|
+
|
|
338
|
+
def log_prob(self, params: BasicParams) -> Tensor:
|
|
339
|
+
log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
|
|
340
|
+
log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
|
|
341
|
+
return log_prob
|
|
342
|
+
|
|
343
|
+
def get_indices_within_domain(self, params: BasicParams) -> Tensor:
|
|
344
|
+
return self.get_indices_within_bounds(params)
|
|
345
|
+
|
|
346
|
+
def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
|
|
347
|
+
return (
|
|
348
|
+
torch.all(params.parameters >= params.min_bounds, -1) &
|
|
349
|
+
torch.all(params.parameters <= params.max_bounds, -1)
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
def filter_params(self, params: BasicParams) -> BasicParams:
|
|
353
|
+
indices = self.get_indices_within_domain(params)
|
|
354
|
+
return params[indices]
|
|
355
|
+
|
|
356
|
+
def clamp_params(
|
|
357
|
+
self, params: BasicParams, inplace: bool = False
|
|
358
|
+
) -> BasicParams:
|
|
359
|
+
if inplace:
|
|
360
|
+
params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
|
|
361
|
+
return params
|
|
362
|
+
|
|
363
|
+
return BasicParams(
|
|
364
|
+
parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
|
|
365
|
+
min_bounds=params.min_bounds.clone(),
|
|
366
|
+
max_bounds=params.max_bounds.clone(),
|
|
367
|
+
max_num_layers=self.max_num_layers,
|
|
368
|
+
param_model=self.param_model,
|
|
369
|
+
)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.utils import get_d_rhos, get_param_labels
|
|
7
|
+
from reflectorch.data_generation.reflectivity import reflectivity
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Params",
|
|
11
|
+
"AbstractParams",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AbstractParams(object):
|
|
16
|
+
"""Base class for parameters"""
|
|
17
|
+
PARAM_NAMES: Tuple[str, ...]
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def rearrange_context_from_params(
|
|
21
|
+
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
22
|
+
):
|
|
23
|
+
if inference:
|
|
24
|
+
return context
|
|
25
|
+
return scaled_params, context
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
29
|
+
return scaled_params
|
|
30
|
+
|
|
31
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
32
|
+
"""computes the reflectivity curves directly from the parameters"""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def __iter__(self):
|
|
36
|
+
for name in self.PARAM_NAMES:
|
|
37
|
+
yield getattr(self, name)
|
|
38
|
+
|
|
39
|
+
def to_(self, tgt):
|
|
40
|
+
for name, arr in zip(self.PARAM_NAMES, self):
|
|
41
|
+
setattr(self, name, _to(arr, tgt))
|
|
42
|
+
|
|
43
|
+
def to(self, tgt):
|
|
44
|
+
"""performs Pytorch Tensor dtype and/or device conversion"""
|
|
45
|
+
return self.__class__(*[_to(arr, tgt) for arr in self])
|
|
46
|
+
|
|
47
|
+
def cuda(self):
|
|
48
|
+
"""moves the parameters to the GPU"""
|
|
49
|
+
return self.to('cuda')
|
|
50
|
+
|
|
51
|
+
def cpu(self):
|
|
52
|
+
"""moves the parameters to the CPU"""
|
|
53
|
+
return self.to('cpu')
|
|
54
|
+
|
|
55
|
+
def __getitem__(self, item) -> 'AbstractParams':
|
|
56
|
+
return self.__class__(*[
|
|
57
|
+
arr.__getitem__(item) if isinstance(arr, Tensor) else arr for arr in self
|
|
58
|
+
])
|
|
59
|
+
|
|
60
|
+
def __setitem__(self, key, other):
|
|
61
|
+
if not isinstance(other, AbstractParams):
|
|
62
|
+
raise ValueError
|
|
63
|
+
|
|
64
|
+
for param, other_param in zip(self, other):
|
|
65
|
+
if isinstance(param, Tensor):
|
|
66
|
+
param[key] = other_param
|
|
67
|
+
|
|
68
|
+
def __add__(self, other):
|
|
69
|
+
if not isinstance(other, AbstractParams):
|
|
70
|
+
raise NotImplemented
|
|
71
|
+
|
|
72
|
+
return self.__class__(*[
|
|
73
|
+
torch.cat([param, other_param], 0)
|
|
74
|
+
if isinstance(param, Tensor) else param
|
|
75
|
+
for param, other_param in zip(self, other)
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
def __eq__(self, other):
|
|
79
|
+
if not isinstance(other, AbstractParams):
|
|
80
|
+
raise NotImplemented
|
|
81
|
+
|
|
82
|
+
return all([torch.allclose(param, other_param) for param, other_param in zip(self, other)])
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def cat(cls, *params: 'AbstractParams'):
|
|
86
|
+
return cls(
|
|
87
|
+
*[torch.cat([getattr(p, name) for p in params], 0)
|
|
88
|
+
if isinstance(getattr(params[0], name), Tensor) else
|
|
89
|
+
getattr(params[0], name)
|
|
90
|
+
for name in cls.PARAM_NAMES]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def _ref_tensor(self):
|
|
95
|
+
return getattr(self, self.PARAM_NAMES[0])
|
|
96
|
+
|
|
97
|
+
def scale_with_q(self, q_ratio: float):
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def max_layer_num(self) -> int:
|
|
102
|
+
"""gets the number of layers"""
|
|
103
|
+
return self._ref_tensor.shape[-1]
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def batch_size(self) -> int:
|
|
107
|
+
"""gets the batch size"""
|
|
108
|
+
return self._ref_tensor.shape[0]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def device(self):
|
|
112
|
+
"""gets the Pytorch device"""
|
|
113
|
+
return self._ref_tensor.device
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def dtype(self):
|
|
117
|
+
"""gets the Pytorch data type"""
|
|
118
|
+
return self._ref_tensor.dtype
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def d_rhos(self):
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
125
|
+
"""converts the instance of the class to a Pytorch tensor"""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_tensor(cls, params: Tensor):
|
|
130
|
+
"""initializes an instance of the class from a Pytorch tensor"""
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def num_params(self) -> int:
|
|
135
|
+
"""get the number of parameters"""
|
|
136
|
+
return self.layers_num2size(self.max_layer_num)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def size2layers_num(size: int) -> int:
|
|
140
|
+
"""converts the number of parameters to the number of layers"""
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def layers_num2size(layers_num: int) -> int:
|
|
145
|
+
"""converts the number of layers to the number of parameters"""
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
def get_param_labels(self) -> List[str]:
|
|
149
|
+
"""gets the parameter labels"""
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def __repr__(self):
|
|
153
|
+
return f'{self.__class__.__name__}(' \
|
|
154
|
+
f'batch_size={self.batch_size}, ' \
|
|
155
|
+
f'max_layer_num={self.max_layer_num}, ' \
|
|
156
|
+
f'device={str(self.device)})'
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _to(arr, dest):
|
|
160
|
+
if hasattr(arr, 'to'):
|
|
161
|
+
arr = arr.to(dest)
|
|
162
|
+
return arr
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class Params(AbstractParams):
|
|
166
|
+
"""Parameter class for thickness, roughness and sld parameters
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
thicknesses (Tensor): batch of thicknesses (top to bottom)
|
|
170
|
+
roughnesses (Tensor): batch of roughnesses (top to bottom)
|
|
171
|
+
slds (Tensor): batch of slds (top to bottom)
|
|
172
|
+
"""
|
|
173
|
+
MIN_THICKNESS: float = 0.5
|
|
174
|
+
|
|
175
|
+
__slots__ = ('thicknesses', 'roughnesses', 'slds')
|
|
176
|
+
PARAM_NAMES = __slots__
|
|
177
|
+
|
|
178
|
+
def __init__(self, thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
179
|
+
|
|
180
|
+
self.thicknesses = thicknesses
|
|
181
|
+
self.roughnesses = roughnesses
|
|
182
|
+
self.slds = slds
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def rearrange_context_from_params(
|
|
186
|
+
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
187
|
+
):
|
|
188
|
+
if inference:
|
|
189
|
+
return context
|
|
190
|
+
return scaled_params, context
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
194
|
+
return scaled_params
|
|
195
|
+
|
|
196
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
197
|
+
"""computes the reflectivity curves directly from the parameters
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
q (Tensor): the q values
|
|
201
|
+
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Tensor: the simulated reflectivity curves
|
|
205
|
+
"""
|
|
206
|
+
return reflectivity(q, self.thicknesses, self.roughnesses, self.slds, log=log, **kwargs)
|
|
207
|
+
|
|
208
|
+
def scale_with_q(self, q_ratio: float):
|
|
209
|
+
"""scales the parameters based on the q ratio
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
q_ratio (float): the scaling ratio
|
|
213
|
+
"""
|
|
214
|
+
self.thicknesses /= q_ratio
|
|
215
|
+
self.roughnesses /= q_ratio
|
|
216
|
+
self.slds *= q_ratio ** 2
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def d_rhos(self):
|
|
220
|
+
"""computes the differences in SLD values of the neighboring layers"""
|
|
221
|
+
return get_d_rhos(self.slds)
|
|
222
|
+
|
|
223
|
+
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
224
|
+
"""converts the instance of the class to a Pytorch tensor"""
|
|
225
|
+
if use_drho:
|
|
226
|
+
return torch.cat([self.thicknesses, self.roughnesses, self.d_rhos], -1)
|
|
227
|
+
else:
|
|
228
|
+
return torch.cat([self.thicknesses, self.roughnesses, self.slds], -1)
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def from_tensor(cls, params: Tensor):
|
|
232
|
+
"""initializes an instance of the class from a Pytorch tensor containing the values of the parameters"""
|
|
233
|
+
layers_num = (params.shape[-1] - 2) // 3
|
|
234
|
+
|
|
235
|
+
thicknesses, roughnesses, slds = torch.split(params, [layers_num, layers_num + 1, layers_num + 1], dim=-1)
|
|
236
|
+
|
|
237
|
+
return cls(thicknesses, roughnesses, slds)
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def size2layers_num(size: int) -> int:
|
|
241
|
+
"""converts the number of parameters to the number of layers"""
|
|
242
|
+
return (size - 2) // 3
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def layers_num2size(layers_num: int) -> int:
|
|
246
|
+
"""converts the number of layers to the number of parameters"""
|
|
247
|
+
return layers_num * 3 + 2
|
|
248
|
+
|
|
249
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
250
|
+
"""gets the parameter labels, the layers are numbered from the bottom to the top
|
|
251
|
+
(i.e. opposite to the order in the Tensors)"""
|
|
252
|
+
return get_param_labels(self.max_layer_num, **kwargs)
|