reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,252 +1,252 @@
|
|
|
1
|
-
from typing import List, Tuple
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
|
|
6
|
-
from reflectorch.data_generation.utils import get_d_rhos, get_param_labels
|
|
7
|
-
from reflectorch.data_generation.reflectivity import reflectivity
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
"Params",
|
|
11
|
-
"AbstractParams",
|
|
12
|
-
]
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class AbstractParams(object):
|
|
16
|
-
"""Base class for parameters"""
|
|
17
|
-
PARAM_NAMES: Tuple[str, ...]
|
|
18
|
-
|
|
19
|
-
@staticmethod
|
|
20
|
-
def rearrange_context_from_params(
|
|
21
|
-
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
22
|
-
):
|
|
23
|
-
if inference:
|
|
24
|
-
return context
|
|
25
|
-
return scaled_params, context
|
|
26
|
-
|
|
27
|
-
@staticmethod
|
|
28
|
-
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
29
|
-
return scaled_params
|
|
30
|
-
|
|
31
|
-
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
32
|
-
"""computes the reflectivity curves directly from the parameters"""
|
|
33
|
-
raise NotImplementedError
|
|
34
|
-
|
|
35
|
-
def __iter__(self):
|
|
36
|
-
for name in self.PARAM_NAMES:
|
|
37
|
-
yield getattr(self, name)
|
|
38
|
-
|
|
39
|
-
def to_(self, tgt):
|
|
40
|
-
for name, arr in zip(self.PARAM_NAMES, self):
|
|
41
|
-
setattr(self, name, _to(arr, tgt))
|
|
42
|
-
|
|
43
|
-
def to(self, tgt):
|
|
44
|
-
"""performs Pytorch Tensor dtype and/or device conversion"""
|
|
45
|
-
return self.__class__(*[_to(arr, tgt) for arr in self])
|
|
46
|
-
|
|
47
|
-
def cuda(self):
|
|
48
|
-
"""moves the parameters to the GPU"""
|
|
49
|
-
return self.to('cuda')
|
|
50
|
-
|
|
51
|
-
def cpu(self):
|
|
52
|
-
"""moves the parameters to the CPU"""
|
|
53
|
-
return self.to('cpu')
|
|
54
|
-
|
|
55
|
-
def __getitem__(self, item) -> 'AbstractParams':
|
|
56
|
-
return self.__class__(*[
|
|
57
|
-
arr.__getitem__(item) if isinstance(arr, Tensor) else arr for arr in self
|
|
58
|
-
])
|
|
59
|
-
|
|
60
|
-
def __setitem__(self, key, other):
|
|
61
|
-
if not isinstance(other, AbstractParams):
|
|
62
|
-
raise ValueError
|
|
63
|
-
|
|
64
|
-
for param, other_param in zip(self, other):
|
|
65
|
-
if isinstance(param, Tensor):
|
|
66
|
-
param[key] = other_param
|
|
67
|
-
|
|
68
|
-
def __add__(self, other):
|
|
69
|
-
if not isinstance(other, AbstractParams):
|
|
70
|
-
raise NotImplemented
|
|
71
|
-
|
|
72
|
-
return self.__class__(*[
|
|
73
|
-
torch.cat([param, other_param], 0)
|
|
74
|
-
if isinstance(param, Tensor) else param
|
|
75
|
-
for param, other_param in zip(self, other)
|
|
76
|
-
])
|
|
77
|
-
|
|
78
|
-
def __eq__(self, other):
|
|
79
|
-
if not isinstance(other, AbstractParams):
|
|
80
|
-
raise NotImplemented
|
|
81
|
-
|
|
82
|
-
return all([torch.allclose(param, other_param) for param, other_param in zip(self, other)])
|
|
83
|
-
|
|
84
|
-
@classmethod
|
|
85
|
-
def cat(cls, *params: 'AbstractParams'):
|
|
86
|
-
return cls(
|
|
87
|
-
*[torch.cat([getattr(p, name) for p in params], 0)
|
|
88
|
-
if isinstance(getattr(params[0], name), Tensor) else
|
|
89
|
-
getattr(params[0], name)
|
|
90
|
-
for name in cls.PARAM_NAMES]
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
@property
|
|
94
|
-
def _ref_tensor(self):
|
|
95
|
-
return getattr(self, self.PARAM_NAMES[0])
|
|
96
|
-
|
|
97
|
-
def scale_with_q(self, q_ratio: float):
|
|
98
|
-
raise NotImplementedError
|
|
99
|
-
|
|
100
|
-
@property
|
|
101
|
-
def max_layer_num(self) -> int:
|
|
102
|
-
"""gets the number of layers"""
|
|
103
|
-
return self._ref_tensor.shape[-1]
|
|
104
|
-
|
|
105
|
-
@property
|
|
106
|
-
def batch_size(self) -> int:
|
|
107
|
-
"""gets the batch size"""
|
|
108
|
-
return self._ref_tensor.shape[0]
|
|
109
|
-
|
|
110
|
-
@property
|
|
111
|
-
def device(self):
|
|
112
|
-
"""gets the Pytorch device"""
|
|
113
|
-
return self._ref_tensor.device
|
|
114
|
-
|
|
115
|
-
@property
|
|
116
|
-
def dtype(self):
|
|
117
|
-
"""gets the Pytorch data type"""
|
|
118
|
-
return self._ref_tensor.dtype
|
|
119
|
-
|
|
120
|
-
@property
|
|
121
|
-
def d_rhos(self):
|
|
122
|
-
raise NotImplementedError
|
|
123
|
-
|
|
124
|
-
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
125
|
-
"""converts the instance of the class to a Pytorch tensor"""
|
|
126
|
-
raise NotImplementedError
|
|
127
|
-
|
|
128
|
-
@classmethod
|
|
129
|
-
def from_tensor(cls, params: Tensor):
|
|
130
|
-
"""initializes an instance of the class from a Pytorch tensor"""
|
|
131
|
-
raise NotImplementedError
|
|
132
|
-
|
|
133
|
-
@property
|
|
134
|
-
def num_params(self) -> int:
|
|
135
|
-
"""get the number of parameters"""
|
|
136
|
-
return self.layers_num2size(self.max_layer_num)
|
|
137
|
-
|
|
138
|
-
@staticmethod
|
|
139
|
-
def size2layers_num(size: int) -> int:
|
|
140
|
-
"""converts the number of parameters to the number of layers"""
|
|
141
|
-
raise NotImplementedError
|
|
142
|
-
|
|
143
|
-
@staticmethod
|
|
144
|
-
def layers_num2size(layers_num: int) -> int:
|
|
145
|
-
"""converts the number of layers to the number of parameters"""
|
|
146
|
-
raise NotImplementedError
|
|
147
|
-
|
|
148
|
-
def get_param_labels(self) -> List[str]:
|
|
149
|
-
"""gets the parameter labels"""
|
|
150
|
-
raise NotImplementedError
|
|
151
|
-
|
|
152
|
-
def __repr__(self):
|
|
153
|
-
return f'{self.__class__.__name__}(' \
|
|
154
|
-
f'batch_size={self.batch_size}, ' \
|
|
155
|
-
f'max_layer_num={self.max_layer_num}, ' \
|
|
156
|
-
f'device={str(self.device)})'
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
def _to(arr, dest):
|
|
160
|
-
if hasattr(arr, 'to'):
|
|
161
|
-
arr = arr.to(dest)
|
|
162
|
-
return arr
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
class Params(AbstractParams):
|
|
166
|
-
"""Parameter class for thickness, roughness and sld parameters
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
thicknesses (Tensor): batch of thicknesses (top to bottom)
|
|
170
|
-
roughnesses (Tensor): batch of roughnesses (top to bottom)
|
|
171
|
-
slds (Tensor): batch of slds (top to bottom)
|
|
172
|
-
"""
|
|
173
|
-
MIN_THICKNESS: float = 0.5
|
|
174
|
-
|
|
175
|
-
__slots__ = ('thicknesses', 'roughnesses', 'slds')
|
|
176
|
-
PARAM_NAMES = __slots__
|
|
177
|
-
|
|
178
|
-
def __init__(self, thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
179
|
-
|
|
180
|
-
self.thicknesses = thicknesses
|
|
181
|
-
self.roughnesses = roughnesses
|
|
182
|
-
self.slds = slds
|
|
183
|
-
|
|
184
|
-
@staticmethod
|
|
185
|
-
def rearrange_context_from_params(
|
|
186
|
-
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
187
|
-
):
|
|
188
|
-
if inference:
|
|
189
|
-
return context
|
|
190
|
-
return scaled_params, context
|
|
191
|
-
|
|
192
|
-
@staticmethod
|
|
193
|
-
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
194
|
-
return scaled_params
|
|
195
|
-
|
|
196
|
-
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
197
|
-
"""computes the reflectivity curves directly from the parameters
|
|
198
|
-
|
|
199
|
-
Args:
|
|
200
|
-
q (Tensor): the q values
|
|
201
|
-
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
202
|
-
|
|
203
|
-
Returns:
|
|
204
|
-
Tensor: the simulated reflectivity curves
|
|
205
|
-
"""
|
|
206
|
-
return reflectivity(q, self.thicknesses, self.roughnesses, self.slds, log=log, **kwargs)
|
|
207
|
-
|
|
208
|
-
def scale_with_q(self, q_ratio: float):
|
|
209
|
-
"""scales the parameters based on the q ratio
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
q_ratio (float): the scaling ratio
|
|
213
|
-
"""
|
|
214
|
-
self.thicknesses /= q_ratio
|
|
215
|
-
self.roughnesses /= q_ratio
|
|
216
|
-
self.slds *= q_ratio ** 2
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def d_rhos(self):
|
|
220
|
-
"""computes the differences in SLD values of the neighboring layers"""
|
|
221
|
-
return get_d_rhos(self.slds)
|
|
222
|
-
|
|
223
|
-
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
224
|
-
"""converts the instance of the class to a Pytorch tensor"""
|
|
225
|
-
if use_drho:
|
|
226
|
-
return torch.cat([self.thicknesses, self.roughnesses, self.d_rhos], -1)
|
|
227
|
-
else:
|
|
228
|
-
return torch.cat([self.thicknesses, self.roughnesses, self.slds], -1)
|
|
229
|
-
|
|
230
|
-
@classmethod
|
|
231
|
-
def from_tensor(cls, params: Tensor):
|
|
232
|
-
"""initializes an instance of the class from a Pytorch tensor containing the values of the parameters"""
|
|
233
|
-
layers_num = (params.shape[-1] - 2) // 3
|
|
234
|
-
|
|
235
|
-
thicknesses, roughnesses, slds = torch.split(params, [layers_num, layers_num + 1, layers_num + 1], dim=-1)
|
|
236
|
-
|
|
237
|
-
return cls(thicknesses, roughnesses, slds)
|
|
238
|
-
|
|
239
|
-
@staticmethod
|
|
240
|
-
def size2layers_num(size: int) -> int:
|
|
241
|
-
"""converts the number of parameters to the number of layers"""
|
|
242
|
-
return (size - 2) // 3
|
|
243
|
-
|
|
244
|
-
@staticmethod
|
|
245
|
-
def layers_num2size(layers_num: int) -> int:
|
|
246
|
-
"""converts the number of layers to the number of parameters"""
|
|
247
|
-
return layers_num * 3 + 2
|
|
248
|
-
|
|
249
|
-
def get_param_labels(self, **kwargs) -> List[str]:
|
|
250
|
-
"""gets the parameter labels, the layers are numbered from the bottom to the top
|
|
251
|
-
(i.e. opposite to the order in the Tensors)"""
|
|
252
|
-
return get_param_labels(self.max_layer_num, **kwargs)
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.utils import get_d_rhos, get_param_labels
|
|
7
|
+
from reflectorch.data_generation.reflectivity import reflectivity
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Params",
|
|
11
|
+
"AbstractParams",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AbstractParams(object):
|
|
16
|
+
"""Base class for parameters"""
|
|
17
|
+
PARAM_NAMES: Tuple[str, ...]
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def rearrange_context_from_params(
|
|
21
|
+
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
22
|
+
):
|
|
23
|
+
if inference:
|
|
24
|
+
return context
|
|
25
|
+
return scaled_params, context
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
29
|
+
return scaled_params
|
|
30
|
+
|
|
31
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
32
|
+
"""computes the reflectivity curves directly from the parameters"""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def __iter__(self):
|
|
36
|
+
for name in self.PARAM_NAMES:
|
|
37
|
+
yield getattr(self, name)
|
|
38
|
+
|
|
39
|
+
def to_(self, tgt):
|
|
40
|
+
for name, arr in zip(self.PARAM_NAMES, self):
|
|
41
|
+
setattr(self, name, _to(arr, tgt))
|
|
42
|
+
|
|
43
|
+
def to(self, tgt):
|
|
44
|
+
"""performs Pytorch Tensor dtype and/or device conversion"""
|
|
45
|
+
return self.__class__(*[_to(arr, tgt) for arr in self])
|
|
46
|
+
|
|
47
|
+
def cuda(self):
|
|
48
|
+
"""moves the parameters to the GPU"""
|
|
49
|
+
return self.to('cuda')
|
|
50
|
+
|
|
51
|
+
def cpu(self):
|
|
52
|
+
"""moves the parameters to the CPU"""
|
|
53
|
+
return self.to('cpu')
|
|
54
|
+
|
|
55
|
+
def __getitem__(self, item) -> 'AbstractParams':
|
|
56
|
+
return self.__class__(*[
|
|
57
|
+
arr.__getitem__(item) if isinstance(arr, Tensor) else arr for arr in self
|
|
58
|
+
])
|
|
59
|
+
|
|
60
|
+
def __setitem__(self, key, other):
|
|
61
|
+
if not isinstance(other, AbstractParams):
|
|
62
|
+
raise ValueError
|
|
63
|
+
|
|
64
|
+
for param, other_param in zip(self, other):
|
|
65
|
+
if isinstance(param, Tensor):
|
|
66
|
+
param[key] = other_param
|
|
67
|
+
|
|
68
|
+
def __add__(self, other):
|
|
69
|
+
if not isinstance(other, AbstractParams):
|
|
70
|
+
raise NotImplemented
|
|
71
|
+
|
|
72
|
+
return self.__class__(*[
|
|
73
|
+
torch.cat([param, other_param], 0)
|
|
74
|
+
if isinstance(param, Tensor) else param
|
|
75
|
+
for param, other_param in zip(self, other)
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
def __eq__(self, other):
|
|
79
|
+
if not isinstance(other, AbstractParams):
|
|
80
|
+
raise NotImplemented
|
|
81
|
+
|
|
82
|
+
return all([torch.allclose(param, other_param) for param, other_param in zip(self, other)])
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def cat(cls, *params: 'AbstractParams'):
|
|
86
|
+
return cls(
|
|
87
|
+
*[torch.cat([getattr(p, name) for p in params], 0)
|
|
88
|
+
if isinstance(getattr(params[0], name), Tensor) else
|
|
89
|
+
getattr(params[0], name)
|
|
90
|
+
for name in cls.PARAM_NAMES]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def _ref_tensor(self):
|
|
95
|
+
return getattr(self, self.PARAM_NAMES[0])
|
|
96
|
+
|
|
97
|
+
def scale_with_q(self, q_ratio: float):
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def max_layer_num(self) -> int:
|
|
102
|
+
"""gets the number of layers"""
|
|
103
|
+
return self._ref_tensor.shape[-1]
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def batch_size(self) -> int:
|
|
107
|
+
"""gets the batch size"""
|
|
108
|
+
return self._ref_tensor.shape[0]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def device(self):
|
|
112
|
+
"""gets the Pytorch device"""
|
|
113
|
+
return self._ref_tensor.device
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def dtype(self):
|
|
117
|
+
"""gets the Pytorch data type"""
|
|
118
|
+
return self._ref_tensor.dtype
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def d_rhos(self):
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
125
|
+
"""converts the instance of the class to a Pytorch tensor"""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_tensor(cls, params: Tensor):
|
|
130
|
+
"""initializes an instance of the class from a Pytorch tensor"""
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def num_params(self) -> int:
|
|
135
|
+
"""get the number of parameters"""
|
|
136
|
+
return self.layers_num2size(self.max_layer_num)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def size2layers_num(size: int) -> int:
|
|
140
|
+
"""converts the number of parameters to the number of layers"""
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def layers_num2size(layers_num: int) -> int:
|
|
145
|
+
"""converts the number of layers to the number of parameters"""
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
def get_param_labels(self) -> List[str]:
|
|
149
|
+
"""gets the parameter labels"""
|
|
150
|
+
raise NotImplementedError
|
|
151
|
+
|
|
152
|
+
def __repr__(self):
|
|
153
|
+
return f'{self.__class__.__name__}(' \
|
|
154
|
+
f'batch_size={self.batch_size}, ' \
|
|
155
|
+
f'max_layer_num={self.max_layer_num}, ' \
|
|
156
|
+
f'device={str(self.device)})'
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _to(arr, dest):
|
|
160
|
+
if hasattr(arr, 'to'):
|
|
161
|
+
arr = arr.to(dest)
|
|
162
|
+
return arr
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class Params(AbstractParams):
|
|
166
|
+
"""Parameter class for thickness, roughness and sld parameters
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
thicknesses (Tensor): batch of thicknesses (top to bottom)
|
|
170
|
+
roughnesses (Tensor): batch of roughnesses (top to bottom)
|
|
171
|
+
slds (Tensor): batch of slds (top to bottom)
|
|
172
|
+
"""
|
|
173
|
+
MIN_THICKNESS: float = 0.5
|
|
174
|
+
|
|
175
|
+
__slots__ = ('thicknesses', 'roughnesses', 'slds')
|
|
176
|
+
PARAM_NAMES = __slots__
|
|
177
|
+
|
|
178
|
+
def __init__(self, thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
179
|
+
|
|
180
|
+
self.thicknesses = thicknesses
|
|
181
|
+
self.roughnesses = roughnesses
|
|
182
|
+
self.slds = slds
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def rearrange_context_from_params(
|
|
186
|
+
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
187
|
+
):
|
|
188
|
+
if inference:
|
|
189
|
+
return context
|
|
190
|
+
return scaled_params, context
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
194
|
+
return scaled_params
|
|
195
|
+
|
|
196
|
+
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
197
|
+
"""computes the reflectivity curves directly from the parameters
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
q (Tensor): the q values
|
|
201
|
+
log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Tensor: the simulated reflectivity curves
|
|
205
|
+
"""
|
|
206
|
+
return reflectivity(q, self.thicknesses, self.roughnesses, self.slds, log=log, **kwargs)
|
|
207
|
+
|
|
208
|
+
def scale_with_q(self, q_ratio: float):
|
|
209
|
+
"""scales the parameters based on the q ratio
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
q_ratio (float): the scaling ratio
|
|
213
|
+
"""
|
|
214
|
+
self.thicknesses /= q_ratio
|
|
215
|
+
self.roughnesses /= q_ratio
|
|
216
|
+
self.slds *= q_ratio ** 2
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def d_rhos(self):
|
|
220
|
+
"""computes the differences in SLD values of the neighboring layers"""
|
|
221
|
+
return get_d_rhos(self.slds)
|
|
222
|
+
|
|
223
|
+
def as_tensor(self, use_drho: bool = False) -> Tensor:
|
|
224
|
+
"""converts the instance of the class to a Pytorch tensor"""
|
|
225
|
+
if use_drho:
|
|
226
|
+
return torch.cat([self.thicknesses, self.roughnesses, self.d_rhos], -1)
|
|
227
|
+
else:
|
|
228
|
+
return torch.cat([self.thicknesses, self.roughnesses, self.slds], -1)
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def from_tensor(cls, params: Tensor):
|
|
232
|
+
"""initializes an instance of the class from a Pytorch tensor containing the values of the parameters"""
|
|
233
|
+
layers_num = (params.shape[-1] - 2) // 3
|
|
234
|
+
|
|
235
|
+
thicknesses, roughnesses, slds = torch.split(params, [layers_num, layers_num + 1, layers_num + 1], dim=-1)
|
|
236
|
+
|
|
237
|
+
return cls(thicknesses, roughnesses, slds)
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def size2layers_num(size: int) -> int:
|
|
241
|
+
"""converts the number of parameters to the number of layers"""
|
|
242
|
+
return (size - 2) // 3
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def layers_num2size(layers_num: int) -> int:
|
|
246
|
+
"""converts the number of layers to the number of parameters"""
|
|
247
|
+
return layers_num * 3 + 2
|
|
248
|
+
|
|
249
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
250
|
+
"""gets the parameter labels, the layers are numbered from the bottom to the top
|
|
251
|
+
(i.e. opposite to the order in the Tensors)"""
|
|
252
|
+
return get_param_labels(self.max_layer_num, **kwargs)
|