reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,369 +1,369 @@
1
- from typing import Tuple, Dict, Type, List
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation.priors.base import PriorSampler
7
- from reflectorch.data_generation.priors.params import AbstractParams
8
- from reflectorch.data_generation.priors.no_constraints import (
9
- DEFAULT_DEVICE,
10
- DEFAULT_DTYPE,
11
- )
12
-
13
- from reflectorch.data_generation.priors.parametric_models import (
14
- MULTILAYER_MODELS,
15
- NuisanceParamsWrapper,
16
- ParametricModel,
17
- )
18
- from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
19
-
20
-
21
- class BasicParams(AbstractParams):
22
- """Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
23
-
24
- Args:
25
- parameters (Tensor): the values of the thin film parameters
26
- min_bounds (Tensor): the minimum subprior bounds of the parameters
27
- max_bounds (Tensor): the maximum subprior bounds of the parameters
28
- max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
29
- param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
30
- """
31
-
32
- __slots__ = (
33
- 'parameters',
34
- 'min_bounds',
35
- 'max_bounds',
36
- 'max_num_layers',
37
- 'param_model',
38
- )
39
- PARAM_NAMES = __slots__
40
- PARAM_MODEL_CLS: Type[ParametricModel]
41
- MAX_NUM_LAYERS: int = 30
42
-
43
- def __init__(self,
44
- parameters: Tensor,
45
- min_bounds: Tensor,
46
- max_bounds: Tensor,
47
- max_num_layers: int = None,
48
- param_model: ParametricModel = None,
49
- ):
50
-
51
- max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
52
- self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
53
- self.max_num_layers = max_num_layers
54
- self.parameters = parameters
55
- self.min_bounds = min_bounds
56
- self.max_bounds = max_bounds
57
-
58
- def get_param_labels(self, **kwargs) -> List[str]:
59
- """gets the parameter labels"""
60
- return self.param_model.get_param_labels(**kwargs)
61
-
62
- def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
63
- r"""computes the reflectivity curves directly from the parameters
64
-
65
- Args:
66
- q (Tensor): the q values
67
- log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
68
-
69
- Returns:
70
- Tensor: the simulated reflectivity curves
71
- """
72
- return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
73
-
74
- @property
75
- def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
76
- """gets the maximum number of layers"""
77
- return self.max_num_layers
78
-
79
- @property
80
- def num_params(self) -> int:
81
- """get the number of parameters (parameter dimensionality)"""
82
- return self.param_model.param_dim
83
-
84
- @property
85
- def thicknesses(self):
86
- """gets the thicknesses"""
87
- params = self.param_model.to_standard_params(self.parameters)
88
- return params['thickness']
89
-
90
- @property
91
- def roughnesses(self):
92
- """gets the roughnesses"""
93
- params = self.param_model.to_standard_params(self.parameters)
94
- return params['roughness']
95
-
96
- @property
97
- def slds(self):
98
- """gets the slds"""
99
- params = self.param_model.to_standard_params(self.parameters)
100
- return params['sld']
101
-
102
- @property
103
- def real_slds(self):
104
- """gets the real part of the slds"""
105
- params = self.param_model.to_standard_params(self.parameters)
106
- return params['sld'].real
107
-
108
- @property
109
- def imag_slds(self):
110
- """gets the imaginary part of the slds (only for complex dtypes)"""
111
- params = self.param_model.to_standard_params(self.parameters)
112
- return params['sld'].imag
113
-
114
- @staticmethod
115
- def rearrange_context_from_params(
116
- scaled_params: Tensor,
117
- context: Tensor,
118
- inference: bool = False,
119
- from_params: bool = False,
120
- ):
121
- if inference:
122
- if from_params:
123
- num_params = scaled_params.shape[-1] // 3
124
- scaled_params = scaled_params[:, num_params:]
125
- context = torch.cat([context, scaled_params], dim=-1)
126
- return context
127
-
128
- num_params = scaled_params.shape[-1] // 3
129
- assert num_params * 3 == scaled_params.shape[-1]
130
- scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
131
- context = torch.cat([context, bound_context], dim=-1)
132
- return scaled_params, context
133
-
134
- @staticmethod
135
- def restore_params_from_context(scaled_params: Tensor, context: Tensor):
136
- num_params = scaled_params.shape[-1]
137
- scaled_bounds = context[:, -2 * num_params:]
138
- scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
139
- return scaled_params
140
-
141
- def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
142
- """converts the instance of the class to a Pytorch tensor
143
-
144
- Args:
145
- add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
146
-
147
- Returns:
148
- Tensor: the Pytorch tensor obtained from the instance of the class
149
- """
150
- if not add_bounds:
151
- return self.parameters
152
- return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
153
-
154
- @classmethod
155
- def from_tensor(cls, params: Tensor, **kwargs):
156
- """initializes an instance of the class from a Pytorch tensor
157
-
158
- Args:
159
- params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
160
-
161
- Returns:
162
- BasicParams: the instance of the class
163
- """
164
- num_params = params.shape[-1] // 3
165
-
166
- params, min_bounds, max_bounds = torch.split(
167
- params, [num_params, num_params, num_params], dim=-1
168
- )
169
-
170
- return cls(
171
- params,
172
- min_bounds,
173
- max_bounds,
174
- **kwargs
175
- )
176
-
177
- def scale_with_q(self, q_ratio: float):
178
- """scales the parameters based on the q ratio
179
-
180
- Args:
181
- q_ratio (float): the scaling ratio
182
- """
183
- self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
184
- self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
185
- self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
186
-
187
-
188
- class SubpriorParametricSampler(PriorSampler, ScalerMixin):
189
- PARAM_CLS = BasicParams
190
-
191
- def __init__(self,
192
- param_ranges: Dict[str, Tuple[float, float]],
193
- bound_width_ranges: Dict[str, Tuple[float, float]],
194
- model_name: str,
195
- device: torch.device = DEFAULT_DEVICE,
196
- dtype: torch.dtype = DEFAULT_DTYPE,
197
- max_num_layers: int = 50,
198
- logdist: bool = False,
199
- scale_params_by_ranges = False,
200
- scaled_range: Tuple[float, float] = (-1., 1.),
201
- **kwargs
202
- ):
203
- """Prior sampler for the parameters of a parametric model and their subprior bounds
204
-
205
- Args:
206
- param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
207
- bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
208
- model_name (str): the name of the parametric model
209
- device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
210
- dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
211
- max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
212
- logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
213
- scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
214
- scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
215
- """
216
- self.scaled_range = scaled_range
217
-
218
- self.shift_param_config = kwargs.pop('shift_param_config', {})
219
-
220
- base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
221
- if any(self.shift_param_config.values()):
222
- self.param_model = NuisanceParamsWrapper(
223
- base_model=base_model,
224
- nuisance_params_config=self.shift_param_config,
225
- **kwargs,
226
- )
227
- else:
228
- self.param_model = base_model
229
-
230
- self.device = device
231
- self.dtype = dtype
232
- self.num_layers = max_num_layers
233
-
234
- self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
235
- self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
236
-
237
- self._param_dim = self.param_model.param_dim
238
- self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
239
- param_ranges, bound_width_ranges, device=device, dtype=dtype
240
- )
241
-
242
- self.param_ranges = param_ranges
243
- self.bound_width_ranges = bound_width_ranges
244
- self.model_name = model_name
245
- self.logdist = logdist
246
- self.scale_params_by_ranges = scale_params_by_ranges
247
-
248
- @property
249
- def max_num_layers(self) -> int:
250
- """gets the maximum number of layers"""
251
- return self.num_layers
252
-
253
- @property
254
- def param_dim(self) -> int:
255
- """get the number of parameters (parameter dimensionality)"""
256
- return self._param_dim
257
-
258
- def sample(self, batch_size: int) -> BasicParams:
259
- """sample a batch of parameters
260
-
261
- Args:
262
- batch_size (int): the batch size
263
-
264
- Returns:
265
- BasicParams: sampled parameters
266
- """
267
- params, min_bounds, max_bounds = self.param_model.sample(
268
- batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
269
- )
270
-
271
- params = BasicParams(
272
- parameters=params,
273
- min_bounds=min_bounds,
274
- max_bounds=max_bounds,
275
- max_num_layers=self.max_num_layers,
276
- param_model=self.param_model,
277
- )
278
-
279
- return params
280
-
281
- def scale_params(self, params: BasicParams) -> Tensor:
282
- """scale the parameters to a ML-friendly range
283
-
284
- Args:
285
- params (BasicParams): the parameters to be scaled
286
-
287
- Returns:
288
- Tensor: the scaled parameters
289
- """
290
- if self.scale_params_by_ranges:
291
- scaled_params = torch.cat([
292
- self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
293
- self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
294
- self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
295
- ], -1)
296
- return scaled_params
297
- else:
298
- scaled_params = torch.cat([
299
- self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
300
- self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
301
- self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
302
- ], -1)
303
- return scaled_params
304
-
305
- def restore_params(self, scaled_params: Tensor) -> BasicParams:
306
- """restore the parameters to their original range
307
-
308
- Args:
309
- scaled_params (Tensor): the scaled parameters
310
-
311
- Returns:
312
- BasicParams: the parameters restored to their original range
313
- """
314
- num_params = scaled_params.shape[-1] // 3
315
- scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
316
- scaled_params, num_params, -1
317
- )
318
- if self.scale_params_by_ranges:
319
- min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
320
- max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
321
- params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
322
- else:
323
- min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
324
- max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
325
- params = self._restore(scaled_params, min_bounds, max_bounds)
326
-
327
- return BasicParams(
328
- parameters=params,
329
- min_bounds=min_bounds,
330
- max_bounds=max_bounds,
331
- max_num_layers=self.max_num_layers,
332
- param_model=self.param_model,
333
- )
334
-
335
- def scale_bounds(self, bounds: Tensor) -> Tensor:
336
- return self._scale(bounds, self.min_bounds, self.max_bounds)
337
-
338
- def log_prob(self, params: BasicParams) -> Tensor:
339
- log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
340
- log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
341
- return log_prob
342
-
343
- def get_indices_within_domain(self, params: BasicParams) -> Tensor:
344
- return self.get_indices_within_bounds(params)
345
-
346
- def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
347
- return (
348
- torch.all(params.parameters >= params.min_bounds, -1) &
349
- torch.all(params.parameters <= params.max_bounds, -1)
350
- )
351
-
352
- def filter_params(self, params: BasicParams) -> BasicParams:
353
- indices = self.get_indices_within_domain(params)
354
- return params[indices]
355
-
356
- def clamp_params(
357
- self, params: BasicParams, inplace: bool = False
358
- ) -> BasicParams:
359
- if inplace:
360
- params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
361
- return params
362
-
363
- return BasicParams(
364
- parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
365
- min_bounds=params.min_bounds.clone(),
366
- max_bounds=params.max_bounds.clone(),
367
- max_num_layers=self.max_num_layers,
368
- param_model=self.param_model,
369
- )
1
+ from typing import Tuple, Dict, Type, List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.priors.base import PriorSampler
7
+ from reflectorch.data_generation.priors.params import AbstractParams
8
+ from reflectorch.data_generation.priors.no_constraints import (
9
+ DEFAULT_DEVICE,
10
+ DEFAULT_DTYPE,
11
+ )
12
+
13
+ from reflectorch.data_generation.priors.parametric_models import (
14
+ MULTILAYER_MODELS,
15
+ NuisanceParamsWrapper,
16
+ ParametricModel,
17
+ )
18
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
19
+
20
+
21
+ class BasicParams(AbstractParams):
22
+ """Parameter class compatible with different parameterizations of the SLD profile. It stores the parameters as well as their minimum and maximum subprior bounds.
23
+
24
+ Args:
25
+ parameters (Tensor): the values of the thin film parameters
26
+ min_bounds (Tensor): the minimum subprior bounds of the parameters
27
+ max_bounds (Tensor): the maximum subprior bounds of the parameters
28
+ max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to None.
29
+ param_model (ParametricModel, optional): the parametric model. Defaults to the box model parameterization with number of layers given by max_num_layers.
30
+ """
31
+
32
+ __slots__ = (
33
+ 'parameters',
34
+ 'min_bounds',
35
+ 'max_bounds',
36
+ 'max_num_layers',
37
+ 'param_model',
38
+ )
39
+ PARAM_NAMES = __slots__
40
+ PARAM_MODEL_CLS: Type[ParametricModel]
41
+ MAX_NUM_LAYERS: int = 30
42
+
43
+ def __init__(self,
44
+ parameters: Tensor,
45
+ min_bounds: Tensor,
46
+ max_bounds: Tensor,
47
+ max_num_layers: int = None,
48
+ param_model: ParametricModel = None,
49
+ ):
50
+
51
+ max_num_layers = max_num_layers or self.MAX_NUM_LAYERS
52
+ self.param_model = param_model or self.PARAM_MODEL_CLS(max_num_layers)
53
+ self.max_num_layers = max_num_layers
54
+ self.parameters = parameters
55
+ self.min_bounds = min_bounds
56
+ self.max_bounds = max_bounds
57
+
58
+ def get_param_labels(self, **kwargs) -> List[str]:
59
+ """gets the parameter labels"""
60
+ return self.param_model.get_param_labels(**kwargs)
61
+
62
+ def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
63
+ r"""computes the reflectivity curves directly from the parameters
64
+
65
+ Args:
66
+ q (Tensor): the q values
67
+ log (bool, optional): whether to apply logarithm to the curves. Defaults to False.
68
+
69
+ Returns:
70
+ Tensor: the simulated reflectivity curves
71
+ """
72
+ return self.param_model.reflectivity(q, self.parameters, log=log, **kwargs)
73
+
74
+ @property
75
+ def max_layer_num(self) -> int: # keep for back compatibility but TODO: unify api among different params
76
+ """gets the maximum number of layers"""
77
+ return self.max_num_layers
78
+
79
+ @property
80
+ def num_params(self) -> int:
81
+ """get the number of parameters (parameter dimensionality)"""
82
+ return self.param_model.param_dim
83
+
84
+ @property
85
+ def thicknesses(self):
86
+ """gets the thicknesses"""
87
+ params = self.param_model.to_standard_params(self.parameters)
88
+ return params['thickness']
89
+
90
+ @property
91
+ def roughnesses(self):
92
+ """gets the roughnesses"""
93
+ params = self.param_model.to_standard_params(self.parameters)
94
+ return params['roughness']
95
+
96
+ @property
97
+ def slds(self):
98
+ """gets the slds"""
99
+ params = self.param_model.to_standard_params(self.parameters)
100
+ return params['sld']
101
+
102
+ @property
103
+ def real_slds(self):
104
+ """gets the real part of the slds"""
105
+ params = self.param_model.to_standard_params(self.parameters)
106
+ return params['sld'].real
107
+
108
+ @property
109
+ def imag_slds(self):
110
+ """gets the imaginary part of the slds (only for complex dtypes)"""
111
+ params = self.param_model.to_standard_params(self.parameters)
112
+ return params['sld'].imag
113
+
114
+ @staticmethod
115
+ def rearrange_context_from_params(
116
+ scaled_params: Tensor,
117
+ context: Tensor,
118
+ inference: bool = False,
119
+ from_params: bool = False,
120
+ ):
121
+ if inference:
122
+ if from_params:
123
+ num_params = scaled_params.shape[-1] // 3
124
+ scaled_params = scaled_params[:, num_params:]
125
+ context = torch.cat([context, scaled_params], dim=-1)
126
+ return context
127
+
128
+ num_params = scaled_params.shape[-1] // 3
129
+ assert num_params * 3 == scaled_params.shape[-1]
130
+ scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
131
+ context = torch.cat([context, bound_context], dim=-1)
132
+ return scaled_params, context
133
+
134
+ @staticmethod
135
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
136
+ num_params = scaled_params.shape[-1]
137
+ scaled_bounds = context[:, -2 * num_params:]
138
+ scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
139
+ return scaled_params
140
+
141
+ def as_tensor(self, add_bounds: bool = True, **kwargs) -> Tensor:
142
+ """converts the instance of the class to a Pytorch tensor
143
+
144
+ Args:
145
+ add_bounds (bool, optional): whether to add the subprior bounds to the tensor. Defaults to True.
146
+
147
+ Returns:
148
+ Tensor: the Pytorch tensor obtained from the instance of the class
149
+ """
150
+ if not add_bounds:
151
+ return self.parameters
152
+ return torch.cat([self.parameters, self.min_bounds, self.max_bounds], -1)
153
+
154
+ @classmethod
155
+ def from_tensor(cls, params: Tensor, **kwargs):
156
+ """initializes an instance of the class from a Pytorch tensor
157
+
158
+ Args:
159
+ params (Tensor): Pytorch tensor containing the parameter values, min subprior bounds and max subprior bounds
160
+
161
+ Returns:
162
+ BasicParams: the instance of the class
163
+ """
164
+ num_params = params.shape[-1] // 3
165
+
166
+ params, min_bounds, max_bounds = torch.split(
167
+ params, [num_params, num_params, num_params], dim=-1
168
+ )
169
+
170
+ return cls(
171
+ params,
172
+ min_bounds,
173
+ max_bounds,
174
+ **kwargs
175
+ )
176
+
177
+ def scale_with_q(self, q_ratio: float):
178
+ """scales the parameters based on the q ratio
179
+
180
+ Args:
181
+ q_ratio (float): the scaling ratio
182
+ """
183
+ self.parameters = self.param_model.scale_with_q(self.parameters, q_ratio)
184
+ self.min_bounds = self.param_model.scale_with_q(self.min_bounds, q_ratio)
185
+ self.max_bounds = self.param_model.scale_with_q(self.max_bounds, q_ratio)
186
+
187
+
188
+ class SubpriorParametricSampler(PriorSampler, ScalerMixin):
189
+ PARAM_CLS = BasicParams
190
+
191
+ def __init__(self,
192
+ param_ranges: Dict[str, Tuple[float, float]],
193
+ bound_width_ranges: Dict[str, Tuple[float, float]],
194
+ model_name: str,
195
+ device: torch.device = DEFAULT_DEVICE,
196
+ dtype: torch.dtype = DEFAULT_DTYPE,
197
+ max_num_layers: int = 50,
198
+ logdist: bool = False,
199
+ scale_params_by_ranges = False,
200
+ scaled_range: Tuple[float, float] = (-1., 1.),
201
+ **kwargs
202
+ ):
203
+ """Prior sampler for the parameters of a parametric model and their subprior bounds
204
+
205
+ Args:
206
+ param_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with its range
207
+ bound_width_ranges (Dict[str, Tuple[float, float]]): dictionary containing the name of each type of parameter together with the range for sampling the widths of the subprior interval
208
+ model_name (str): the name of the parametric model
209
+ device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
210
+ dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
211
+ max_num_layers (int, optional): the maximum number of layers (for box model parameterizations it is the number of layers). Defaults to 50.
212
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
213
+ scale_params_by_ranges (bool, optional): if True the parameters are scaled with respect to their ranges instead of being scaled with respect to their prior bounds. Defaults to False.
214
+ scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
215
+ """
216
+ self.scaled_range = scaled_range
217
+
218
+ self.shift_param_config = kwargs.pop('shift_param_config', {})
219
+
220
+ base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
221
+ if any(self.shift_param_config.values()):
222
+ self.param_model = NuisanceParamsWrapper(
223
+ base_model=base_model,
224
+ nuisance_params_config=self.shift_param_config,
225
+ **kwargs,
226
+ )
227
+ else:
228
+ self.param_model = base_model
229
+
230
+ self.device = device
231
+ self.dtype = dtype
232
+ self.num_layers = max_num_layers
233
+
234
+ self.PARAM_CLS.PARAM_MODEL_CLS = MULTILAYER_MODELS[model_name]
235
+ self.PARAM_CLS.MAX_NUM_LAYERS = max_num_layers
236
+
237
+ self._param_dim = self.param_model.param_dim
238
+ self.min_bounds, self.max_bounds, self.min_delta, self.max_delta = self.param_model.init_bounds(
239
+ param_ranges, bound_width_ranges, device=device, dtype=dtype
240
+ )
241
+
242
+ self.param_ranges = param_ranges
243
+ self.bound_width_ranges = bound_width_ranges
244
+ self.model_name = model_name
245
+ self.logdist = logdist
246
+ self.scale_params_by_ranges = scale_params_by_ranges
247
+
248
+ @property
249
+ def max_num_layers(self) -> int:
250
+ """gets the maximum number of layers"""
251
+ return self.num_layers
252
+
253
+ @property
254
+ def param_dim(self) -> int:
255
+ """get the number of parameters (parameter dimensionality)"""
256
+ return self._param_dim
257
+
258
+ def sample(self, batch_size: int) -> BasicParams:
259
+ """sample a batch of parameters
260
+
261
+ Args:
262
+ batch_size (int): the batch size
263
+
264
+ Returns:
265
+ BasicParams: sampled parameters
266
+ """
267
+ params, min_bounds, max_bounds = self.param_model.sample(
268
+ batch_size, self.min_bounds, self.max_bounds, self.min_delta, self.max_delta
269
+ )
270
+
271
+ params = BasicParams(
272
+ parameters=params,
273
+ min_bounds=min_bounds,
274
+ max_bounds=max_bounds,
275
+ max_num_layers=self.max_num_layers,
276
+ param_model=self.param_model,
277
+ )
278
+
279
+ return params
280
+
281
+ def scale_params(self, params: BasicParams) -> Tensor:
282
+ """scale the parameters to a ML-friendly range
283
+
284
+ Args:
285
+ params (BasicParams): the parameters to be scaled
286
+
287
+ Returns:
288
+ Tensor: the scaled parameters
289
+ """
290
+ if self.scale_params_by_ranges:
291
+ scaled_params = torch.cat([
292
+ self._scale(params.parameters, self.min_bounds, self.max_bounds), #parameters and subprior bounds are scaled with respect to the parameter ranges
293
+ self._scale(params.min_bounds, self.min_bounds, self.max_bounds),
294
+ self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
295
+ ], -1)
296
+ return scaled_params
297
+ else:
298
+ scaled_params = torch.cat([
299
+ self._scale(params.parameters, params.min_bounds, params.max_bounds), #each parameter scaled with respect to its subprior bounds
300
+ self._scale(params.min_bounds, self.min_bounds, self.max_bounds), #the subprior bounds are scaled with respect to the parameter ranges
301
+ self._scale(params.max_bounds, self.min_bounds, self.max_bounds),
302
+ ], -1)
303
+ return scaled_params
304
+
305
+ def restore_params(self, scaled_params: Tensor) -> BasicParams:
306
+ """restore the parameters to their original range
307
+
308
+ Args:
309
+ scaled_params (Tensor): the scaled parameters
310
+
311
+ Returns:
312
+ BasicParams: the parameters restored to their original range
313
+ """
314
+ num_params = scaled_params.shape[-1] // 3
315
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
316
+ scaled_params, num_params, -1
317
+ )
318
+ if self.scale_params_by_ranges:
319
+ min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
320
+ max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
321
+ params = self._restore(scaled_params, self.min_bounds, self.max_bounds)
322
+ else:
323
+ min_bounds = self._restore(scaled_min_bounds, self.min_bounds, self.max_bounds)
324
+ max_bounds = self._restore(scaled_max_bounds, self.min_bounds, self.max_bounds)
325
+ params = self._restore(scaled_params, min_bounds, max_bounds)
326
+
327
+ return BasicParams(
328
+ parameters=params,
329
+ min_bounds=min_bounds,
330
+ max_bounds=max_bounds,
331
+ max_num_layers=self.max_num_layers,
332
+ param_model=self.param_model,
333
+ )
334
+
335
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
336
+ return self._scale(bounds, self.min_bounds, self.max_bounds)
337
+
338
+ def log_prob(self, params: BasicParams) -> Tensor:
339
+ log_prob = torch.zeros(params.batch_size, device=self.device, dtype=self.dtype)
340
+ log_prob[~self.get_indices_within_bounds(params)] = -float('inf')
341
+ return log_prob
342
+
343
+ def get_indices_within_domain(self, params: BasicParams) -> Tensor:
344
+ return self.get_indices_within_bounds(params)
345
+
346
+ def get_indices_within_bounds(self, params: BasicParams) -> Tensor:
347
+ return (
348
+ torch.all(params.parameters >= params.min_bounds, -1) &
349
+ torch.all(params.parameters <= params.max_bounds, -1)
350
+ )
351
+
352
+ def filter_params(self, params: BasicParams) -> BasicParams:
353
+ indices = self.get_indices_within_domain(params)
354
+ return params[indices]
355
+
356
+ def clamp_params(
357
+ self, params: BasicParams, inplace: bool = False
358
+ ) -> BasicParams:
359
+ if inplace:
360
+ params.parameters = torch.clamp_(params.parameters, params.min_bounds, params.max_bounds)
361
+ return params
362
+
363
+ return BasicParams(
364
+ parameters=torch.clamp(params.parameters, params.min_bounds, params.max_bounds),
365
+ min_bounds=params.min_bounds.clone(),
366
+ max_bounds=params.max_bounds.clone(),
367
+ max_num_layers=self.max_num_layers,
368
+ param_model=self.param_model,
369
+ )