reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Dict, Type, List
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation.priors.base import PriorSampler
13
+ from reflectorch.data_generation.priors.params import AbstractParams
14
+ from reflectorch.data_generation.priors.no_constraints import (
15
+ DEFAULT_DEVICE,
16
+ DEFAULT_DTYPE,
17
+ )
18
+
19
+ from reflectorch.data_generation.priors.parametric_models import (
20
+ MULTILAYER_MODELS,
21
+ ParametricModel,
22
+ )
23
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
24
+
25
+
26
+ class BasicParams(AbstractParams):
27
+ """Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
28
+
29
+ Args:
30
+ parameters (Tensor): the values of the thin film parameters
31
+ min_bounds (Tensor): the minimum subprior bounds of the parameters
32
+ max_bounds (Tensor): the maximum subprior bounds of the parameters
33
+ max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
34
+ param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
35
+ """
36
+
37
+ __slots__ = (
38
+ 'parameters',
39
+ 'min_bounds',
40
+ 'max_bounds',
41
+ 'max_num_layers',
42
+ 'param_model',
43
+ )
44
+ PARAM_NAMES = __slots__
45
+ PARAM_MODEL_CLS: Type[ParametricModel]
46
+ MAX_NUM_LAYERS: int = 30
47
+
48
+ def __init__(self,
49
+ parameters: Tensor,
50
+ min_bounds: Tensor,
51
+ max_bounds: Tensor,
52
+ max_num_layers: int = None,
53
+ param_model: ParametricModel = None,
54
+ ):
55
+
56
+ max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
57
+ self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
58
+ self.max_num_layers = max_num_layers
59
+ self.parameters = parameters
60
+ self.min_bounds = min_bounds
61
+ self.max_bounds = max_bounds
62
+
63
+ def get_param_labels(self) -> List[str]:
64
+ """gets the parameter labels"""
65
+ return self.param_model.get_param_labels()
66
+
67
+ def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
68
+ r"""computes the reflectivity curves directly from the parameters
69
+
70
+ Args:
71
+ q (Tensor): the q values
72
+ log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
73
+
74
+ Returns:
75
+ Tensor: the simulated reflectivity curves
76
+ """
77
+ return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
78
+
79
+ @property
80
+ def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
81
+ """gets the maximum number of layers"""
82
+ return self.max_num_layers
83
+
84
+ @property
85
+ def num_params(self) -> int:
86
+ """get the number of parameters (parameter dimensionality)"""
87
+ return self.param_model.param_dim
88
+
89
+ @property
90
+ def thicknesses(self):
91
+ """gets the thicknesses"""
92
+ params = self.param_model.to_standard_params(self.parameters)
93
+ return params['thickness']
94
+
95
+ @property
96
+ def roughnesses(self):
97
+ """gets the roughnesses"""
98
+ params = self.param_model.to_standard_params(self.parameters)
99
+ return params['roughness']
100
+
101
+ @property
102
+ def slds(self):
103
+ """gets the slds"""
104
+ params = self.param_model.to_standard_params(self.parameters)
105
+ return params['sld']
106
+
107
+ @staticmethod
108
+ def rearrange_context_from_params(
109
+ scaled_params: Tensor,
110
+ context: Tensor,
111
+ inference: bool = False,
112
+ from_params: bool = False,
113
+ ):
114
+ if inference:
115
+ if from_params:
116
+ num_params = scaled_params.shape[-1] // 3
117
+ scaled_params = scaled_params[:, num_params:]
118
+ context = torch.cat([context, scaled_params], dim=-1)
119
+ return context
120
+
121
+ num_params = scaled_params.shape[-1] // 3
122
+ assert num_params * 3 == scaled_params.shape[-1]
123
+ scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
124
+ context = torch.cat([context, bound_context], dim=-1)
125
+ return scaled_params, context
126
+
127
+ @staticmethod
128
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
129
+ num_params = scaled_params.shape[-1]
130
+ scaled_bounds = context[:, -2 * num_params:]
131
+ scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
132
+ return scaled_params
133
+
134
+ def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
135
+ """converts the instance of the class to a Pytorch tensor
136
+
137
+ Args:
138
+ add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
139
+
140
+ Returns:
141
+ Tensor: the Pytorch tensor obtained from the instance of the class
142
+ """
143
+ if not add_bounds:
144
+ return self.parameters
145
+ return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
146
+
147
+ @classmethod
148
+ def from_tensor(cls, params: Tensor, **kwargs):
149
+ """initializes an instance of the class from a Pytorch tensor
150
+
151
+ Args:
152
+ params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
153
+
154
+ Returns:
155
+ BasicParams: the instance of the class
156
+ """
157
+ num_params = params.shape[-1] // 3
158
+
159
+ params, min_bounds, max_bounds = torch.split(
160
+ params, [num_params, num_params, num_params], dim=-1
161
+ )
162
+
163
+ return cls(
164
+ params,
165
+ min_bounds,
166
+ max_bounds,
167
+ **kwargs
168
+ )
169
+
170
+ def scale_with_q(self, q_ratio: float):
171
+ """scales the parameters based on the q ratio
172
+
173
+ Args:
174
+ q_ratio (float): the scaling ratio
175
+ """
176
+ self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
177
+ self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
178
+ self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
179
+
180
+
181
+ class SubpriorParametricSampler(PriorSampler, ScalerMixin):
182
+ PARAM_CLS = BasicParams
183
+
184
+ def __init__(self,
185
+ param_ranges: Dict[str, Tuple[float, float]],
186
+ bound_width_ranges: Dict[str, Tuple[float, float]],
187
+ model_name: str,
188
+ device: torch.device = DEFAULT_DEVICE,
189
+ dtype: torch.dtype = DEFAULT_DTYPE,
190
+ max_num_layers: int = 50,
191
+ logdist: bool = False,
192
+ scale_params_by_ranges = False,
193
+ scaled_range: Tuple[float, float] = (-1., 1.),
194
+ **kwargs
195
+ ):
196
+ """Prior sampler for the parameters of a parametric model and their subprior bounds
197
+
198
+ Args:
199
+ param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
200
+ bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
201
+ model_name (str): the name of the parametric model
202
+ device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
203
+ dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
204
+ max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
205
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
206
+ scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
207
+ scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
208
+ """
209
+ self.scaled_range = scaled_range
210
+ self.param_model: ParametricModel = MULTILAYER_MODELS[model_name](
211
+ max_num_layers,
212
+ logdist=logdist,
213
+ **kwargs
214
+ )
215
+ self.device = device
216
+ self.dtype = dtype
217
+ self.num_layers = max_num_layers
218
+
219
+ self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
220
+ self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
221
+
222
+ self._param_dim = self.param_model.param_dim
223
+ self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
224
+ param_ranges, bound_width_ranges, device=device, dtype=dtype
225
+ )
226
+
227
+ self.param_ranges = param_ranges
228
+ self.bound_width_ranges = bound_width_ranges
229
+ self.model_name = model_name
230
+ self.logdist = logdist
231
+ self.scale_params_by_ranges = scale_params_by_ranges
232
+
233
+ @property
234
+ def max_num_layers(self) -> int:
235
+ """gets the maximum number of layers"""
236
+ return self.num_layers
237
+
238
+ @property
239
+ def param_dim(self) -> int:
240
+ """get the number of parameters (parameter dimensionality)"""
241
+ return self._param_dim
242
+
243
+ def sample(self, batch_size: int) -> BasicParams:
244
+ """sample a batch of parameters
245
+
246
+ Args:
247
+ batch_size (int): the batch size
248
+
249
+ Returns:
250
+ BasicParams: sampled parameters
251
+ """
252
+ params, min_bounds, max_bounds = self.param_model.sample(
253
+ batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
254
+ )
255
+
256
+ params = BasicParams(
257
+ parameters=params,
258
+ min_bounds=min_bounds,
259
+ max_bounds=max_bounds,
260
+ max_num_layers=self.max_num_layers,
261
+ param_model=self.param_model,
262
+ )
263
+
264
+ return params
265
+
266
+ def scale_params(self, params: BasicParams) -> Tensor:
267
+ """scale the parameters to a ML-friendly range
268
+
269
+ Args:
270
+ params (BasicParams): the parameters to be scaled
271
+
272
+ Returns:
273
+ Tensor: the scaled parameters
274
+ """
275
+ if self.scale_params_by_ranges:
276
+ scaled_params = torch.cat([
277
+ self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
278
+ self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
279
+ self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
280
+ ], -1)
281
+ return scaled_params
282
+ else:
283
+ scaled_params = torch.cat([
284
+ self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
285
+ self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
286
+ self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
287
+ ], -1)
288
+ return scaled_params
289
+
290
+ def restore_params(self, scaled_params: Tensor) -> BasicParams:
291
+ """restore the parameters to their original range
292
+
293
+ Args:
294
+ scaled_params (Tensor): the scaled parameters
295
+
296
+ Returns:
297
+ BasicParams: the parameters restored to their original range
298
+ """
299
+ num_params = scaled_params.shape[-1] // 3
300
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
301
+ scaled_params, num_params, -1
302
+ )
303
+ if self.scale_params_by_ranges:
304
+ min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
305
+ max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
306
+ params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
307
+ else:
308
+ min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
309
+ max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
310
+ params = self._restore(scaled_params, min_bounds, max_bounds)
311
+
312
+ return BasicParams(
313
+ parameters=params,
314
+ min_bounds=min_bounds,
315
+ max_bounds=max_bounds,
316
+ max_num_layers=self.max_num_layers,
317
+ param_model=self.param_model,
318
+ )
319
+
320
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
321
+ return self._scale(bounds, self.min_bounds, self.max_bounds)
322
+
323
+ def log_prob(self, params: BasicParams) -> Tensor:
324
+ log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
325
+ log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
326
+ return log_prob
327
+
328
+ def get_indices_within_domain(self, params: BasicParams) -> Tensor:
329
+ return self.get_indices_within_bounds(params)
330
+
331
+ def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
332
+ return (
333
+ torch.all(params.parameters >= params.min_bounds, -1) &
334
+ torch.all(params.parameters <= params.max_bounds, -1)
335
+ )
336
+
337
+ def filter_params(self, params: BasicParams) -> BasicParams:
338
+ indices = self.get_indices_within_domain(params)
339
+ return params[indices]
340
+
341
+ def clamp_params(
342
+ self, params: BasicParams, inplace: bool = False
343
+ ) -> BasicParams:
344
+ if inplace:
345
+ params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
346
+ return params
347
+
348
+ return BasicParams(
349
+ parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
350
+ min_bounds=params.min_bounds.clone(),
351
+ max_bounds=params.max_bounds.clone(),
352
+ max_num_layers=self.max_num_layers,
353
+ param_model=self.param_model,
354
+ )
@@ -0,0 +1,258 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation.utils import get_d_rhos, get_param_labels
13
+ from reflectorch.data_generation.reflectivity import reflectivity
14
+
15
+ __all__ = [
16
+ "Params",
17
+ "AbstractParams",
18
+ ]
19
+
20
+
21
+ class AbstractParams(object):
22
+ """Base class for parameters"""
23
+ PARAM_NAMES: Tuple[str, ...]
24
+
25
+ @staticmethod
26
+ def rearrange_context_from_params(
27
+ scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
28
+ ):
29
+ if inference:
30
+ return context
31
+ return scaled_params, context
32
+
33
+ @staticmethod
34
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
35
+ return scaled_params
36
+
37
+ def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
38
+ """computes the reflectivity curves directly from the parameters"""
39
+ raise NotImplementedError
40
+
41
+ def __iter__(self):
42
+ for name in self.PARAM_NAMES:
43
+ yield getattr(self, name)
44
+
45
+ def to_(self, tgt):
46
+ for name, arr in zip(self.PARAM_NAMES, self):
47
+ setattr(self, name, _to(arr, tgt))
48
+
49
+ def to(self, tgt):
50
+ """performs Pytorch Tensor dtype and/or device conversion"""
51
+ return self.__class__(*[_to(arr, tgt) for arr in self])
52
+
53
+ def cuda(self):
54
+ """moves the parameters to the GPU"""
55
+ return self.to('cuda')
56
+
57
+ def cpu(self):
58
+ """moves the parameters to the CPU"""
59
+ return self.to('cpu')
60
+
61
+ def __getitem__(self, item) -> 'AbstractParams':
62
+ return self.__class__(*[
63
+ arr.__getitem__(item) if isinstance(arr, Tensor) else arr for arr in self
64
+ ])
65
+
66
+ def __setitem__(self, key, other):
67
+ if not isinstance(other, AbstractParams):
68
+ raise ValueError
69
+
70
+ for param, other_param in zip(self, other):
71
+ if isinstance(param, Tensor):
72
+ param[key] = other_param
73
+
74
+ def __add__(self, other):
75
+ if not isinstance(other, AbstractParams):
76
+ raise NotImplemented
77
+
78
+ return self.__class__(*[
79
+ torch.cat([param, other_param], 0)
80
+ if isinstance(param, Tensor) else param
81
+ for param, other_param in zip(self, other)
82
+ ])
83
+
84
+ def __eq__(self, other):
85
+ if not isinstance(other, AbstractParams):
86
+ raise NotImplemented
87
+
88
+ return all([torch.allclose(param, other_param) for param, other_param in zip(self, other)])
89
+
90
+ @classmethod
91
+ def cat(cls, *params: 'AbstractParams'):
92
+ return cls(
93
+ *[torch.cat([getattr(p, name) for p in params], 0)
94
+ if isinstance(getattr(params[0], name), Tensor) else
95
+ getattr(params[0], name)
96
+ for name in cls.PARAM_NAMES]
97
+ )
98
+
99
+ @property
100
+ def _ref_tensor(self):
101
+ return getattr(self, self.PARAM_NAMES[0])
102
+
103
+ def scale_with_q(self, q_ratio: float):
104
+ raise NotImplementedError
105
+
106
+ @property
107
+ def max_layer_num(self) -> int:
108
+ """gets the number of layers"""
109
+ return self._ref_tensor.shape[-1]
110
+
111
+ @property
112
+ def batch_size(self) -> int:
113
+ """gets the batch size"""
114
+ return self._ref_tensor.shape[0]
115
+
116
+ @property
117
+ def device(self):
118
+ """gets the Pytorch device"""
119
+ return self._ref_tensor.device
120
+
121
+ @property
122
+ def dtype(self):
123
+ """gets the Pytorch data type"""
124
+ return self._ref_tensor.dtype
125
+
126
+ @property
127
+ def d_rhos(self):
128
+ raise NotImplementedError
129
+
130
+ def as_tensor(self, use_drho: bool = False) -> Tensor:
131
+ """converts the instance of the class to a Pytorch tensor"""
132
+ raise NotImplementedError
133
+
134
+ @classmethod
135
+ def from_tensor(cls, params: Tensor):
136
+ """initializes an instance of the class from a Pytorch tensor"""
137
+ raise NotImplementedError
138
+
139
+ @property
140
+ def num_params(self) -> int:
141
+ """get the number of parameters"""
142
+ return self.layers_num2size(self.max_layer_num)
143
+
144
+ @staticmethod
145
+ def size2layers_num(size: int) -> int:
146
+ """converts the number of parameters to the number of layers"""
147
+ raise NotImplementedError
148
+
149
+ @staticmethod
150
+ def layers_num2size(layers_num: int) -> int:
151
+ """converts the number of layers to the number of parameters"""
152
+ raise NotImplementedError
153
+
154
+ def get_param_labels(self) -> List[str]:
155
+ """gets the parameter labels"""
156
+ raise NotImplementedError
157
+
158
+ def __repr__(self):
159
+ return f'{self.__class__.__name__}(' \
160
+ f'batch_size={self.batch_size}, ' \
161
+ f'max_layer_num={self.max_layer_num}, ' \
162
+ f'device={str(self.device)})'
163
+
164
+
165
+ def _to(arr, dest):
166
+ if hasattr(arr, 'to'):
167
+ arr = arr.to(dest)
168
+ return arr
169
+
170
+
171
+ class Params(AbstractParams):
172
+ """Parameter class for thickness, roughness and sld parameters
173
+
174
+ Args:
175
+ thicknesses (Tensor): batch of thicknesses (top to bottom)
176
+ roughnesses (Tensor): batch of roughnesses (top to bottom)
177
+ slds (Tensor): batch of slds (top to bottom)
178
+ """
179
+ MIN_THICKNESS: float = 0.5
180
+
181
+ __slots__ = ('thicknesses', 'roughnesses', 'slds')
182
+ PARAM_NAMES = __slots__
183
+
184
+ def __init__(self, thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
185
+
186
+ self.thicknesses = thicknesses
187
+ self.roughnesses = roughnesses
188
+ self.slds = slds
189
+
190
+ @staticmethod
191
+ def rearrange_context_from_params(
192
+ scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
193
+ ):
194
+ if inference:
195
+ return context
196
+ return scaled_params, context
197
+
198
+ @staticmethod
199
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
200
+ return scaled_params
201
+
202
+ def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
203
+ """computes the reflectivity curves directly from the parameters
204
+
205
+ Args:
206
+ q (Tensor): the q values
207
+ log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
208
+
209
+ Returns:
210
+ Tensor: the simulated reflectivity curves
211
+ """
212
+ return reflectivity(q, self.thicknesses, self.roughnesses, self.slds, log=log, **kwargs)
213
+
214
+ def scale_with_q(self, q_ratio: float):
215
+ """scales the parameters based on the q ratio
216
+
217
+ Args:
218
+ q_ratio (float): the scaling ratio
219
+ """
220
+ self.thicknesses /= q_ratio
221
+ self.roughnesses /= q_ratio
222
+ self.slds *= q_ratio ** 2
223
+
224
+ @property
225
+ def d_rhos(self):
226
+ """computes the differences in SLD values of the neighboring layers"""
227
+ return get_d_rhos(self.slds)
228
+
229
+ def as_tensor(self, use_drho: bool = False) -> Tensor:
230
+ """converts the instance of the class to a Pytorch tensor"""
231
+ if use_drho:
232
+ return torch.cat([self.thicknesses, self.roughnesses, self.d_rhos], -1)
233
+ else:
234
+ return torch.cat([self.thicknesses, self.roughnesses, self.slds], -1)
235
+
236
+ @classmethod
237
+ def from_tensor(cls, params: Tensor):
238
+ """initializes an instance of the class from a Pytorch tensor containing the values of the parameters"""
239
+ layers_num = (params.shape[-1] - 2) // 3
240
+
241
+ thicknesses, roughnesses, slds = torch.split(params, [layers_num, layers_num + 1, layers_num + 1], dim=-1)
242
+
243
+ return cls(thicknesses, roughnesses, slds)
244
+
245
+ @staticmethod
246
+ def size2layers_num(size: int) -> int:
247
+ """converts the number of parameters to the number of layers"""
248
+ return (size - 2) // 3
249
+
250
+ @staticmethod
251
+ def layers_num2size(layers_num: int) -> int:
252
+ """converts the number of layers to the number of parameters"""
253
+ return layers_num * 3 + 2
254
+
255
+ def get_param_labels(self, **kwargs) -> List[str]:
256
+ """gets the parameter labels, the layers are numbered from the bottom to the top
257
+ (i.e. opposite to the order in the Tensors)"""
258
+ return get_param_labels(self.max_layer_num, **kwargs)