reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,842 +1,842 @@
1
- from typing import Tuple, Dict, List
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation.reflectivity import (
7
- reflectivity,
8
- abeles_memory_eff,
9
- kinematical_approximation,
10
- )
11
- from reflectorch.data_generation.utils import (
12
- get_param_labels,
13
- )
14
- from reflectorch.data_generation.priors.sampler_strategies import (
15
- SamplerStrategy,
16
- BasicSamplerStrategy,
17
- ConstrainedRoughnessSamplerStrategy,
18
- ConstrainedRoughnessAndImgSldSamplerStrategy,
19
- )
20
-
21
- __all__ = [
22
- "MULTILAYER_MODELS",
23
- "ParametricModel",
24
- ]
25
-
26
-
27
- class ParametricModel(object):
28
- """Base class for parameterizations of the SLD profile.
29
-
30
- Args:
31
- max_num_layers (int): the number of layers
32
- """
33
- NAME: str = ''
34
- PARAMETER_NAMES: Tuple[str, ...]
35
-
36
- def __init__(self, max_num_layers: int, **kwargs):
37
- self.max_num_layers = max_num_layers
38
- self._sampler_strategy = self._init_sampler_strategy(**kwargs)
39
-
40
- def _init_sampler_strategy(self, nuisance_params_dim: int = 0, **kwargs):
41
- return BasicSamplerStrategy(**kwargs)
42
-
43
- @property
44
- def param_dim(self) -> int:
45
- """get the number of parameters
46
-
47
- Returns:
48
- int:
49
- """
50
- return len(self.PARAMETER_NAMES)
51
-
52
- @property
53
- def sampler_strategy(self) -> SamplerStrategy:
54
- """get the sampler strategy
55
-
56
- Returns:
57
- SamplerStrategy:
58
- """
59
- return self._sampler_strategy
60
-
61
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
62
- """computes the reflectivity curves
63
-
64
- Args:
65
- q: the reciprocal space (q) positions
66
- parametrized_model (Tensor): the values of the parameters
67
-
68
- Returns:
69
- Tensor: the computed reflectivity curves
70
- """
71
- params = self.to_standard_params(parametrized_model)
72
- return reflectivity(q, **params, **kwargs)
73
-
74
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
75
- raise NotImplementedError
76
-
77
- def from_standard_params(self, params: dict) -> Tensor:
78
- raise NotImplementedError
79
-
80
- def scale_with_q(self, parametrized_model: Tensor, q_ratio: float) -> Tensor:
81
- raise NotImplementedError
82
-
83
- def init_bounds(self,
84
- param_ranges: Dict[str, Tuple[float, float]],
85
- bound_width_ranges: Dict[str, Tuple[float, float]],
86
- device=None,
87
- dtype=None,
88
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
89
- """initializes arrays storing individually the upper and lower bounds from the dictionaries of parameter and bound width ranges
90
-
91
- Args:
92
- param_ranges (Dict[str, Tuple[float, float]]): parameter ranges
93
- bound_width_ranges (Dict[str, Tuple[float, float]]): bound width ranges
94
- device (optional): the Pytorch device. Defaults to None.
95
- dtype (optional): the Pytorch datatype. Defaults to None.
96
-
97
- Returns:
98
- Tuple[Tensor, Tensor, Tensor, Tensor]:
99
- """
100
- ordered_bounds = [param_ranges[k] for k in self.PARAMETER_NAMES]
101
- delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES]
102
-
103
- min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
104
- min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
105
-
106
- return min_bounds, max_bounds, min_deltas, max_deltas
107
-
108
- def get_param_labels(self, **kwargs) -> List[str]:
109
- """get the list with the name of the parameters
110
-
111
- Returns:
112
- List[str]:
113
- """
114
- return list(self.PARAMETER_NAMES)
115
-
116
- def sample(self, batch_size: int,
117
- total_min_bounds: Tensor,
118
- total_max_bounds: Tensor,
119
- total_min_delta: Tensor,
120
- total_max_delta: Tensor,
121
- ):
122
- """samples the parameter values and their prior bounds
123
-
124
- Args:
125
- batch_size (int): the batch size
126
- total_min_bounds (Tensor): lower bounds of the parameter ranges
127
- total_max_bounds (Tensor): upper bounds of the parameter ranges
128
- total_min_delta (Tensor): lower widths of the subprior intervals
129
- total_max_delta (Tensor): upper widths of the subprior intervals
130
-
131
- Returns:
132
- Tensor: sampled parameters
133
- """
134
- return self.sampler_strategy.sample(
135
- batch_size,
136
- total_min_bounds,
137
- total_max_bounds,
138
- total_min_delta,
139
- total_max_delta,
140
- )
141
-
142
-
143
- class StandardModel(ParametricModel):
144
- """Parameterization for the standard box model. The parameters are the thicknesses, roughnesses and real sld values of the layers."""
145
- NAME = 'standard_model'
146
-
147
- PARAMETER_NAMES = (
148
- "thicknesses",
149
- "roughnesses",
150
- "slds",
151
- )
152
-
153
- @property
154
- def param_dim(self) -> int:
155
- return 3 * self.max_num_layers + 2
156
-
157
- def _init_sampler_strategy(self,
158
- constrained_roughness: bool = True,
159
- max_thickness_share: float = 0.5,
160
- nuisance_params_dim: int = 0,
161
- **kwargs):
162
- if constrained_roughness:
163
- num_params = self.param_dim + nuisance_params_dim
164
- thickness_mask = torch.zeros(num_params, dtype=torch.bool)
165
- roughness_mask = torch.zeros(num_params, dtype=torch.bool)
166
- thickness_mask[:self.max_num_layers] = True
167
- roughness_mask[self.max_num_layers:2 * self.max_num_layers + 1] = True
168
- return ConstrainedRoughnessSamplerStrategy(
169
- thickness_mask, roughness_mask,
170
- max_thickness_share=max_thickness_share,
171
- **kwargs
172
- )
173
- else:
174
- return BasicSamplerStrategy(**kwargs)
175
-
176
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
177
- return self._params2dict(parametrized_model)
178
-
179
- def init_bounds(self,
180
- param_ranges: Dict[str, Tuple[float, float]],
181
- bound_width_ranges: Dict[str, Tuple[float, float]],
182
- device=None,
183
- dtype=None,
184
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
185
-
186
- other_ranges = [param_ranges[k] for k in self.PARAMETER_NAMES[3:]]
187
- other_delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES[3:]]
188
-
189
- ordered_bounds = (
190
- [param_ranges["thicknesses"]] * self.max_num_layers +
191
- [param_ranges["roughnesses"]] * (self.max_num_layers + 1) +
192
- [param_ranges["slds"]] * (self.max_num_layers + 1) +
193
- other_ranges
194
- )
195
- delta_bounds = (
196
- [bound_width_ranges["thicknesses"]] * self.max_num_layers +
197
- [bound_width_ranges["roughnesses"]] * (self.max_num_layers + 1) +
198
- [bound_width_ranges["slds"]] * (self.max_num_layers + 1) +
199
- other_delta_bounds
200
- )
201
-
202
- min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
203
- min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
204
-
205
- return min_bounds, max_bounds, min_deltas, max_deltas
206
-
207
- def get_param_labels(self, **kwargs) -> List[str]:
208
- return get_param_labels(self.max_num_layers, **kwargs)
209
-
210
- @staticmethod
211
- def _params2dict(parametrized_model: Tensor):
212
- num_params = parametrized_model.shape[-1]
213
- num_layers = (num_params - 2) // 3
214
- assert num_layers * 3 + 2 == num_params
215
-
216
- d, sigma, sld = torch.split(
217
- parametrized_model, [num_layers, num_layers + 1, num_layers + 1], -1
218
- )
219
- params = dict(
220
- thickness=d,
221
- roughness=sigma,
222
- sld=sld
223
- )
224
-
225
- return params
226
-
227
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
228
- return reflectivity(
229
- q, **self._params2dict(parametrized_model), **kwargs
230
- )
231
-
232
-
233
- class ModelWithAbsorption(StandardModel):
234
- """Parameterization for the box model in which the imaginary sld values of the layers are additional parameters."""
235
- NAME = 'model_with_absorption'
236
-
237
- PARAMETER_NAMES = (
238
- "thicknesses",
239
- "roughnesses",
240
- "slds",
241
- "islds",
242
- )
243
-
244
- @property
245
- def param_dim(self) -> int:
246
- return 4 * self.max_num_layers + 3
247
-
248
- def _init_sampler_strategy(self,
249
- constrained_roughness: bool = True,
250
- constrained_isld: bool = True,
251
- max_thickness_share: float = 0.5,
252
- max_sld_share: float = 0.2,
253
- nuisance_params_dim: int = 0,
254
- **kwargs):
255
- if constrained_roughness:
256
- num_params = self.param_dim + nuisance_params_dim
257
- thickness_mask = torch.zeros(num_params, dtype=torch.bool)
258
- roughness_mask = torch.zeros(num_params, dtype=torch.bool)
259
- thickness_mask[:self.max_num_layers] = True
260
- roughness_mask[self.max_num_layers:2 * self.max_num_layers + 1] = True
261
-
262
- if constrained_isld:
263
- sld_mask = torch.zeros(num_params, dtype=torch.bool)
264
- isld_mask = torch.zeros(num_params, dtype=torch.bool)
265
- sld_mask[2 * self.max_num_layers + 1:3 * self.max_num_layers + 2] = True
266
- isld_mask[3 * self.max_num_layers + 2:4 * self.max_num_layers + 3] = True
267
- return ConstrainedRoughnessAndImgSldSamplerStrategy(
268
- thickness_mask, roughness_mask, sld_mask, isld_mask,
269
- max_thickness_share=max_thickness_share, max_sld_share=max_sld_share,
270
- **kwargs
271
- )
272
- else:
273
- return ConstrainedRoughnessSamplerStrategy(
274
- thickness_mask, roughness_mask,
275
- max_thickness_share=max_thickness_share,
276
- **kwargs
277
- )
278
- else:
279
- return BasicSamplerStrategy(**kwargs)
280
-
281
- def init_bounds(self,
282
- param_ranges: Dict[str, Tuple[float, float]],
283
- bound_width_ranges: Dict[str, Tuple[float, float]],
284
- device=None,
285
- dtype=None,
286
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
287
- other_ranges = [param_ranges[k] for k in self.PARAMETER_NAMES[4:]]
288
- other_delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES[4:]]
289
-
290
- ordered_bounds = (
291
- [param_ranges["thicknesses"]] * self.max_num_layers +
292
- [param_ranges["roughnesses"]] * (self.max_num_layers + 1) +
293
- [param_ranges["slds"]] * (self.max_num_layers + 1) +
294
- [param_ranges["islds"]] * (self.max_num_layers + 1) +
295
- other_ranges
296
- )
297
- delta_bounds = (
298
- [bound_width_ranges["thicknesses"]] * self.max_num_layers +
299
- [bound_width_ranges["roughnesses"]] * (self.max_num_layers + 1) +
300
- [bound_width_ranges["slds"]] * (self.max_num_layers + 1) +
301
- [bound_width_ranges["islds"]] * (self.max_num_layers + 1) +
302
- other_delta_bounds
303
- )
304
-
305
- min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
306
- min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
307
-
308
- return min_bounds, max_bounds, min_deltas, max_deltas
309
-
310
- def get_param_labels(self, **kwargs) -> List[str]:
311
- return get_param_labels(self.max_num_layers, parameterization_type='absorption', **kwargs)
312
-
313
- @staticmethod
314
- def _params2dict(parametrized_model: Tensor):
315
- num_params = parametrized_model.shape[-1]
316
- num_layers = (num_params - 3) // 4
317
- assert num_layers * 4 + 3 == num_params
318
-
319
- d, sigma, sld, isld = torch.split(
320
- parametrized_model, [num_layers, num_layers + 1, num_layers + 1, num_layers + 1], -1
321
- )
322
- params = dict(
323
- thickness=d,
324
- roughness=sigma,
325
- sld=sld + 1j * isld
326
- )
327
-
328
- return params
329
-
330
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
331
- return reflectivity(
332
- q, **self._params2dict(parametrized_model), **kwargs
333
- )
334
-
335
-
336
- class ModelWithShifts(StandardModel):
337
- """Variant of the standard box model parameterization in which two additional parameters are considered: the shift in the q positions (additive) and the shift in
338
- intensity (multiplicative, or additive in log domain)."""
339
- NAME = 'model_with_shifts'
340
-
341
- PARAMETER_NAMES = (
342
- "thicknesses",
343
- "roughnesses",
344
- "slds",
345
- "q_shift",
346
- "norm_shift",
347
- )
348
-
349
- @property
350
- def param_dim(self) -> int:
351
- return 3 * self.max_num_layers + 4
352
-
353
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
354
- params = self._params2dict(parametrized_model)
355
- params.pop('q_shift')
356
- params.pop('norm_shift')
357
-
358
- return params
359
-
360
- def get_param_labels(self, **kwargs) -> List[str]:
361
- return get_param_labels(self.max_num_layers, **kwargs) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
362
-
363
-
364
- @staticmethod
365
- def _params2dict(parametrized_model: Tensor):
366
- num_params = parametrized_model.shape[-1]
367
- num_layers = (num_params - 4) // 3
368
- assert num_layers * 3 + 4 == num_params
369
-
370
- d, sigma, sld, q_shift, norm_shift = torch.split(
371
- parametrized_model, [num_layers, num_layers + 1, num_layers + 1, 1, 1], -1
372
- )
373
- params = dict(
374
- thickness=d,
375
- roughness=sigma,
376
- sld=sld,
377
- q_shift=q_shift,
378
- norm_shift=norm_shift,
379
- )
380
-
381
- return params
382
-
383
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
384
- return reflectivity_with_shifts(
385
- q, **self._params2dict(parametrized_model), **kwargs
386
- )
387
-
388
- def reflectivity_with_shifts(q, thickness, roughness, sld, q_shift, norm_shift, **kwargs):
389
- q = torch.atleast_2d(q) + q_shift
390
- return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
391
-
392
- class NoFresnelModel(StandardModel):
393
- NAME = 'no_fresnel_model'
394
-
395
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
396
- return kinematical_approximation(
397
- q, **self._params2dict(parametrized_model), apply_fresnel=False, **kwargs
398
- )
399
-
400
-
401
- class BasicMultilayerModel1(ParametricModel):
402
- NAME = 'repeating_multilayer_v1'
403
-
404
- PARAMETER_NAMES = (
405
- "d_full_rel",
406
- "rel_sigmas",
407
- "d_block",
408
- "s_block_rel",
409
- "r_block",
410
- "dr",
411
- "d3_rel",
412
- "s3_rel",
413
- "r3",
414
- "d_sio2",
415
- "s_sio2",
416
- "s_si",
417
- "r_sio2",
418
- "r_si",
419
- )
420
-
421
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
422
- return multilayer_model1(parametrized_model, self.max_num_layers)
423
-
424
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
425
- params = self.to_standard_params(parametrized_model)
426
- return reflectivity(q, abeles_func=abeles_memory_eff, **params, **kwargs)
427
-
428
-
429
- class BasicMultilayerModel2(BasicMultilayerModel1):
430
- NAME = 'repeating_multilayer_v2'
431
-
432
- PARAMETER_NAMES = (
433
- "d_full_rel",
434
- "rel_sigmas",
435
- "dr_sigmoid_rel_pos",
436
- "dr_sigmoid_rel_width",
437
- "d_block",
438
- "s_block_rel",
439
- "r_block",
440
- "dr",
441
- "d3_rel",
442
- "s3_rel",
443
- "r3",
444
- "d_sio2",
445
- "s_sio2",
446
- "s_si",
447
- "r_sio2",
448
- "r_si",
449
- )
450
-
451
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
452
- return multilayer_model2(parametrized_model, self.max_num_layers)
453
-
454
-
455
- class BasicMultilayerModel3(BasicMultilayerModel1):
456
- """Parameterization for a thin film composed of repeating identical monolayers, each monolayer consisting of two boxes with distinct SLDs.
457
- A sigmoid envelope modulating the SLD profile of the monolayers defines the film thickness and the roughness at the top interface.
458
- A second sigmoid envelope can be used to modulate the amplitude of the monolayer SLDs as a function of the displacement from the position of the first sigmoid.
459
- These two sigmoids allow one to model a thin film that is coherently ordered up to a certain coherent thickness and gets incoherently ordered or amorphous toward the top of the film.
460
- In addition, a layer between the substrate and the multilayer (”phase layer”) is introduced to account for the interface structure,
461
- which does not necessarily have to be identical to the multilayer period.
462
- """
463
-
464
- NAME = 'repeating_multilayer_v3'
465
-
466
- PARAMETER_NAMES = (
467
- "d_full_rel",
468
- "rel_sigmas",
469
- "dr_sigmoid_rel_pos",
470
- "dr_sigmoid_rel_width",
471
- "d_block1_rel",
472
- "d_block",
473
- "s_block_rel",
474
- "r_block",
475
- "dr",
476
- "d3_rel",
477
- "s3_rel",
478
- "r3",
479
- "d_sio2",
480
- "s_sio2",
481
- "s_si",
482
- "r_sio2",
483
- "r_si",
484
- )
485
-
486
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
487
- return multilayer_model3(parametrized_model, self.max_num_layers)
488
-
489
-
490
- class MultilayerModel1WithShifts(BasicMultilayerModel1):
491
- NAME = 'repeating_multilayer_v1_with_shifts'
492
-
493
- PARAMETER_NAMES = (
494
- "d_full_rel",
495
- "rel_sigmas",
496
- "d_block",
497
- "s_block_rel",
498
- "r_block",
499
- "dr",
500
- "d3_rel",
501
- "s3_rel",
502
- "r3",
503
- "d_sio2",
504
- "s_sio2",
505
- "s_si",
506
- "r_sio2",
507
- "r_si",
508
- "q_shift",
509
- "norm_shift",
510
- )
511
-
512
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
513
- q_shift, norm_shift = parametrized_model[..., -2:].T[..., None]
514
- return reflectivity_with_shifts(
515
- q, q_shift=q_shift, norm_shift=norm_shift, abeles_func=abeles_memory_eff,
516
- **self.to_standard_params(parametrized_model), **kwargs
517
- )
518
-
519
-
520
- class MultilayerModel3WithShifts(BasicMultilayerModel3):
521
- NAME = 'repeating_multilayer_v3_with_shifts'
522
-
523
- PARAMETER_NAMES = (
524
- "d_full_rel",
525
- "rel_sigmas",
526
- "dr_sigmoid_rel_pos",
527
- "dr_sigmoid_rel_width",
528
- "d_block1_rel",
529
- "d_block",
530
- "s_block_rel",
531
- "r_block",
532
- "dr",
533
- "d3_rel",
534
- "s3_rel",
535
- "r3",
536
- "d_sio2",
537
- "s_sio2",
538
- "s_si",
539
- "r_sio2",
540
- "r_si",
541
- "q_shift",
542
- "norm_shift",
543
- )
544
-
545
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
546
- q_shift, norm_shift = parametrized_model[..., -2:].T[..., None]
547
- return reflectivity_with_shifts(
548
- q, q_shift=q_shift, norm_shift=norm_shift, abeles_func=abeles_memory_eff,
549
- **self.to_standard_params(parametrized_model), **kwargs
550
- )
551
-
552
-
553
- MULTILAYER_MODELS = {
554
- 'standard_model': StandardModel,
555
- 'model_with_absorption': ModelWithAbsorption,
556
- 'model_with_shifts': ModelWithShifts,
557
- 'no_fresnel_model': NoFresnelModel,
558
- 'repeating_multilayer_v1': BasicMultilayerModel1,
559
- 'repeating_multilayer_v2': BasicMultilayerModel2,
560
- 'repeating_multilayer_v3': BasicMultilayerModel3,
561
- 'repeating_multilayer_v1_with_shifts': MultilayerModel1WithShifts,
562
- 'repeating_multilayer_v3_with_shifts': MultilayerModel3WithShifts,
563
- }
564
-
565
-
566
- def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 30) -> dict:
567
- n = d_full_rel_max
568
-
569
- (
570
- d_full_rel,
571
- rel_sigmas,
572
- d_block,
573
- s_block_rel,
574
- r_block,
575
- dr,
576
- d3_rel,
577
- s3_rel,
578
- r3,
579
- d_sio2,
580
- s_sio2,
581
- s_si,
582
- r_sio2,
583
- r_si,
584
- *_,
585
- ) = parametrized_model.T
586
-
587
- batch_size = parametrized_model.shape[0]
588
-
589
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
590
-
591
- r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
592
-
593
- r_block = r_block[:, None].repeat(1, n)
594
- dr = dr[:, None].repeat(1, n)
595
-
596
- sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
597
-
598
- sld_blocks = r_modulations * sld_blocks
599
-
600
- d3 = d3_rel * d_block
601
-
602
- thicknesses = torch.cat(
603
- [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
604
- )
605
-
606
- s_block = s_block_rel * d_block
607
-
608
- roughnesses = torch.cat(
609
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
610
- )
611
-
612
- slds = torch.cat(
613
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
614
- )
615
-
616
- params = dict(
617
- thickness=thicknesses,
618
- roughness=roughnesses,
619
- sld=slds
620
- )
621
- return params
622
-
623
-
624
- def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 30) -> dict:
625
- n = d_full_rel_max
626
-
627
- (
628
- d_full_rel,
629
- rel_sigmas,
630
- dr_sigmoid_rel_pos,
631
- dr_sigmoid_rel_width,
632
- d_block,
633
- s_block_rel,
634
- r_block,
635
- dr,
636
- d3_rel,
637
- s3_rel,
638
- r3,
639
- d_sio2,
640
- s_sio2,
641
- s_si,
642
- r_sio2,
643
- r_si,
644
- *_,
645
- ) = parametrized_model.T
646
-
647
- batch_size = parametrized_model.shape[0]
648
-
649
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
650
-
651
- r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
652
-
653
- r_block = r_block[:, None].repeat(1, n)
654
- dr = dr[:, None].repeat(1, n)
655
-
656
- dr_positions = r_positions[:, ::2]
657
-
658
- dr_modulations = torch.sigmoid(
659
- -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
660
- )
661
-
662
- dr = dr * dr_modulations
663
-
664
- sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
665
-
666
- sld_blocks = r_modulations * sld_blocks
667
-
668
- d3 = d3_rel * d_block
669
-
670
- thicknesses = torch.cat(
671
- [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
672
- )
673
-
674
- s_block = s_block_rel * d_block
675
-
676
- roughnesses = torch.cat(
677
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
678
- )
679
-
680
- slds = torch.cat(
681
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
682
- )
683
-
684
- params = dict(
685
- thickness=thicknesses,
686
- roughness=roughnesses,
687
- sld=slds
688
- )
689
- return params
690
-
691
-
692
- def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
693
- n = d_full_rel_max
694
-
695
- (
696
- d_full_rel,
697
- rel_sigmas,
698
- dr_sigmoid_rel_pos,
699
- dr_sigmoid_rel_width,
700
- d_block1_rel,
701
- d_block,
702
- s_block_rel,
703
- r_block,
704
- dr,
705
- d3_rel,
706
- s3_rel,
707
- r3,
708
- d_sio2,
709
- s_sio2,
710
- s_si,
711
- r_sio2,
712
- r_si,
713
- *_,
714
- ) = parametrized_model.T
715
-
716
- batch_size = parametrized_model.shape[0]
717
-
718
- r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
719
-
720
- r_modulations = torch.sigmoid(
721
- -(
722
- r_positions - 2 * d_full_rel[..., None]
723
- ) / rel_sigmas[..., None]
724
- )
725
-
726
- dr_positions = r_positions[:, ::2]
727
-
728
- dr_modulations = dr[..., None] * (1 - torch.sigmoid(
729
- -(
730
- dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
731
- ) / dr_sigmoid_rel_width[..., None]
732
- ))
733
-
734
- r_block = r_block[..., None].repeat(1, n)
735
- dr = dr[..., None].repeat(1, n)
736
-
737
- sld_blocks = torch.stack(
738
- [
739
- r_block + dr_modulations * (1 - d_block1_rel[..., None]),
740
- r_block + dr - dr_modulations * d_block1_rel[..., None]
741
- ], -1).flatten(1)
742
-
743
- sld_blocks = r_modulations * sld_blocks
744
-
745
- d3 = d3_rel * d_block
746
-
747
- d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
748
-
749
- thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
750
-
751
- thicknesses = torch.cat(
752
- [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
753
- )
754
-
755
- s_block = s_block_rel * d_block
756
-
757
- roughnesses = torch.cat(
758
- [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
759
- )
760
-
761
- slds = torch.cat(
762
- [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
763
- )
764
-
765
- params = dict(
766
- thickness=thicknesses,
767
- roughness=roughnesses,
768
- sld=slds
769
- )
770
- return params
771
-
772
-
773
- class NuisanceParamsWrapper(ParametricModel):
774
- """
775
- Wraps a base model (e.g. StandardModel) to add nuisance parameters, allowing independent enabling/disabling.
776
-
777
- Args:
778
- base_model (ParametricModel): The base parametric model.
779
- nuisance_params_config (Dict[str, bool]): Dictionary where keys are parameter names
780
- and values are `True` (enable) or `False` (disable).
781
- """
782
-
783
- def __init__(self, base_model: ParametricModel, nuisance_params_config: Dict[str, bool] = None, **kwargs):
784
- self.base_model = base_model
785
- self.nuisance_params_config = nuisance_params_config or {}
786
-
787
- self.enabled_nuisance_params = [name for name, is_enabled in self.nuisance_params_config.items() if is_enabled]
788
-
789
- self.PARAMETER_NAMES = self.base_model.PARAMETER_NAMES + tuple(self.enabled_nuisance_params)
790
- self._param_dim = self.base_model.param_dim + len(self.enabled_nuisance_params)
791
-
792
- super().__init__(base_model.max_num_layers, **kwargs)
793
-
794
- def _init_sampler_strategy(self, **kwargs):
795
- return self.base_model._init_sampler_strategy(nuisance_params_dim=len(self.enabled_nuisance_params), **kwargs)
796
-
797
- @property
798
- def param_dim(self) -> int:
799
- return self._param_dim
800
-
801
- def to_standard_params(self, parametrized_model: Tensor) -> dict:
802
- """Extracts base model parameters only."""
803
- base_dim = self.base_model.param_dim
804
- base_part = parametrized_model[..., :base_dim]
805
- return self.base_model.to_standard_params(base_part)
806
-
807
- def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
808
- """Computes reflectivity with optional nuisance parameter shifts."""
809
- base_dim = self.base_model.param_dim
810
- base_params = parametrized_model[..., :base_dim]
811
- nuisance_part = parametrized_model[..., base_dim:]
812
-
813
- nuisance_dict = {param: nuisance_part[..., i].unsqueeze(-1) for i, param in enumerate(self.enabled_nuisance_params)}
814
- if "log10_background" in nuisance_dict:
815
- nuisance_dict["background"] = 10 ** nuisance_dict.pop("log10_background")
816
-
817
- return self.base_model.reflectivity(q, base_params, **nuisance_dict, **kwargs)
818
-
819
- def init_bounds(self, param_ranges: Dict[str, Tuple[float, float]],
820
- bound_width_ranges: Dict[str, Tuple[float, float]], device=None, dtype=None):
821
- """Initialize bounds for enabled nuisance parameters."""
822
- min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base = self.base_model.init_bounds(
823
- param_ranges, bound_width_ranges, device, dtype)
824
-
825
- ordered_bounds_nuisance = [param_ranges[k] for k in self.enabled_nuisance_params]
826
- delta_bounds_nuisance = [bound_width_ranges[k] for k in self.enabled_nuisance_params]
827
-
828
- if ordered_bounds_nuisance:
829
- min_bounds_nuisance, max_bounds_nuisance = torch.tensor(ordered_bounds_nuisance, device=device, dtype=dtype).T[:, None]
830
- min_deltas_nuisance, max_deltas_nuisance = torch.tensor(delta_bounds_nuisance, device=device, dtype=dtype).T[:, None]
831
-
832
- min_bounds = torch.cat([min_bounds_base, min_bounds_nuisance], dim=-1)
833
- max_bounds = torch.cat([max_bounds_base, max_bounds_nuisance], dim=-1)
834
- min_deltas = torch.cat([min_deltas_base, min_deltas_nuisance], dim=-1)
835
- max_deltas = torch.cat([max_deltas_base, max_deltas_nuisance], dim=-1)
836
- else:
837
- min_bounds, max_bounds, min_deltas, max_deltas = min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base
838
-
839
- return min_bounds, max_bounds, min_deltas, max_deltas
840
-
841
- def get_param_labels(self, **kwargs) -> List[str]:
1
+ from typing import Tuple, Dict, List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.reflectivity import (
7
+ reflectivity,
8
+ abeles_memory_eff,
9
+ kinematical_approximation,
10
+ )
11
+ from reflectorch.data_generation.utils import (
12
+ get_param_labels,
13
+ )
14
+ from reflectorch.data_generation.priors.sampler_strategies import (
15
+ SamplerStrategy,
16
+ BasicSamplerStrategy,
17
+ ConstrainedRoughnessSamplerStrategy,
18
+ ConstrainedRoughnessAndImgSldSamplerStrategy,
19
+ )
20
+
21
+ __all__ = [
22
+ "MULTILAYER_MODELS",
23
+ "ParametricModel",
24
+ ]
25
+
26
+
27
+ class ParametricModel(object):
28
+ """Base class for parameterizations of the SLD profile.
29
+
30
+ Args:
31
+ max_num_layers (int): the number of layers
32
+ """
33
+ NAME: str = ''
34
+ PARAMETER_NAMES: Tuple[str, ...]
35
+
36
+ def __init__(self, max_num_layers: int, **kwargs):
37
+ self.max_num_layers = max_num_layers
38
+ self._sampler_strategy = self._init_sampler_strategy(**kwargs)
39
+
40
+ def _init_sampler_strategy(self, nuisance_params_dim: int = 0, **kwargs):
41
+ return BasicSamplerStrategy(**kwargs)
42
+
43
+ @property
44
+ def param_dim(self) -> int:
45
+ """get the number of parameters
46
+
47
+ Returns:
48
+ int:
49
+ """
50
+ return len(self.PARAMETER_NAMES)
51
+
52
+ @property
53
+ def sampler_strategy(self) -> SamplerStrategy:
54
+ """get the sampler strategy
55
+
56
+ Returns:
57
+ SamplerStrategy:
58
+ """
59
+ return self._sampler_strategy
60
+
61
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
62
+ """computes the reflectivity curves
63
+
64
+ Args:
65
+ q: the reciprocal space (q) positions
66
+ parametrized_model (Tensor): the values of the parameters
67
+
68
+ Returns:
69
+ Tensor: the computed reflectivity curves
70
+ """
71
+ params = self.to_standard_params(parametrized_model)
72
+ return reflectivity(q, **params, **kwargs)
73
+
74
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
75
+ raise NotImplementedError
76
+
77
+ def from_standard_params(self, params: dict) -> Tensor:
78
+ raise NotImplementedError
79
+
80
+ def scale_with_q(self, parametrized_model: Tensor, q_ratio: float) -> Tensor:
81
+ raise NotImplementedError
82
+
83
+ def init_bounds(self,
84
+ param_ranges: Dict[str, Tuple[float, float]],
85
+ bound_width_ranges: Dict[str, Tuple[float, float]],
86
+ device=None,
87
+ dtype=None,
88
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
89
+ """initializes arrays storing individually the upper and lower bounds from the dictionaries of parameter and bound width ranges
90
+
91
+ Args:
92
+ param_ranges (Dict[str, Tuple[float, float]]): parameter ranges
93
+ bound_width_ranges (Dict[str, Tuple[float, float]]): bound width ranges
94
+ device (optional): the Pytorch device. Defaults to None.
95
+ dtype (optional): the Pytorch datatype. Defaults to None.
96
+
97
+ Returns:
98
+ Tuple[Tensor, Tensor, Tensor, Tensor]:
99
+ """
100
+ ordered_bounds = [param_ranges[k] for k in self.PARAMETER_NAMES]
101
+ delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES]
102
+
103
+ min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
104
+ min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
105
+
106
+ return min_bounds, max_bounds, min_deltas, max_deltas
107
+
108
+ def get_param_labels(self, **kwargs) -> List[str]:
109
+ """get the list with the name of the parameters
110
+
111
+ Returns:
112
+ List[str]:
113
+ """
114
+ return list(self.PARAMETER_NAMES)
115
+
116
+ def sample(self, batch_size: int,
117
+ total_min_bounds: Tensor,
118
+ total_max_bounds: Tensor,
119
+ total_min_delta: Tensor,
120
+ total_max_delta: Tensor,
121
+ ):
122
+ """samples the parameter values and their prior bounds
123
+
124
+ Args:
125
+ batch_size (int): the batch size
126
+ total_min_bounds (Tensor): lower bounds of the parameter ranges
127
+ total_max_bounds (Tensor): upper bounds of the parameter ranges
128
+ total_min_delta (Tensor): lower widths of the subprior intervals
129
+ total_max_delta (Tensor): upper widths of the subprior intervals
130
+
131
+ Returns:
132
+ Tensor: sampled parameters
133
+ """
134
+ return self.sampler_strategy.sample(
135
+ batch_size,
136
+ total_min_bounds,
137
+ total_max_bounds,
138
+ total_min_delta,
139
+ total_max_delta,
140
+ )
141
+
142
+
143
+ class StandardModel(ParametricModel):
144
+ """Parameterization for the standard box model. The parameters are the thicknesses, roughnesses and real sld values of the layers."""
145
+ NAME = 'standard_model'
146
+
147
+ PARAMETER_NAMES = (
148
+ "thicknesses",
149
+ "roughnesses",
150
+ "slds",
151
+ )
152
+
153
+ @property
154
+ def param_dim(self) -> int:
155
+ return 3 * self.max_num_layers + 2
156
+
157
+ def _init_sampler_strategy(self,
158
+ constrained_roughness: bool = True,
159
+ max_thickness_share: float = 0.5,
160
+ nuisance_params_dim: int = 0,
161
+ **kwargs):
162
+ if constrained_roughness:
163
+ num_params = self.param_dim + nuisance_params_dim
164
+ thickness_mask = torch.zeros(num_params, dtype=torch.bool)
165
+ roughness_mask = torch.zeros(num_params, dtype=torch.bool)
166
+ thickness_mask[:self.max_num_layers] = True
167
+ roughness_mask[self.max_num_layers:2 * self.max_num_layers + 1] = True
168
+ return ConstrainedRoughnessSamplerStrategy(
169
+ thickness_mask, roughness_mask,
170
+ max_thickness_share=max_thickness_share,
171
+ **kwargs
172
+ )
173
+ else:
174
+ return BasicSamplerStrategy(**kwargs)
175
+
176
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
177
+ return self._params2dict(parametrized_model)
178
+
179
+ def init_bounds(self,
180
+ param_ranges: Dict[str, Tuple[float, float]],
181
+ bound_width_ranges: Dict[str, Tuple[float, float]],
182
+ device=None,
183
+ dtype=None,
184
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
185
+
186
+ other_ranges = [param_ranges[k] for k in self.PARAMETER_NAMES[3:]]
187
+ other_delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES[3:]]
188
+
189
+ ordered_bounds = (
190
+ [param_ranges["thicknesses"]] * self.max_num_layers +
191
+ [param_ranges["roughnesses"]] * (self.max_num_layers + 1) +
192
+ [param_ranges["slds"]] * (self.max_num_layers + 1) +
193
+ other_ranges
194
+ )
195
+ delta_bounds = (
196
+ [bound_width_ranges["thicknesses"]] * self.max_num_layers +
197
+ [bound_width_ranges["roughnesses"]] * (self.max_num_layers + 1) +
198
+ [bound_width_ranges["slds"]] * (self.max_num_layers + 1) +
199
+ other_delta_bounds
200
+ )
201
+
202
+ min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
203
+ min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
204
+
205
+ return min_bounds, max_bounds, min_deltas, max_deltas
206
+
207
+ def get_param_labels(self, **kwargs) -> List[str]:
208
+ return get_param_labels(self.max_num_layers, **kwargs)
209
+
210
+ @staticmethod
211
+ def _params2dict(parametrized_model: Tensor):
212
+ num_params = parametrized_model.shape[-1]
213
+ num_layers = (num_params - 2) // 3
214
+ assert num_layers * 3 + 2 == num_params
215
+
216
+ d, sigma, sld = torch.split(
217
+ parametrized_model, [num_layers, num_layers + 1, num_layers + 1], -1
218
+ )
219
+ params = dict(
220
+ thickness=d,
221
+ roughness=sigma,
222
+ sld=sld
223
+ )
224
+
225
+ return params
226
+
227
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
228
+ return reflectivity(
229
+ q, **self._params2dict(parametrized_model), **kwargs
230
+ )
231
+
232
+
233
+ class ModelWithAbsorption(StandardModel):
234
+ """Parameterization for the box model in which the imaginary sld values of the layers are additional parameters."""
235
+ NAME = 'model_with_absorption'
236
+
237
+ PARAMETER_NAMES = (
238
+ "thicknesses",
239
+ "roughnesses",
240
+ "slds",
241
+ "islds",
242
+ )
243
+
244
+ @property
245
+ def param_dim(self) -> int:
246
+ return 4 * self.max_num_layers + 3
247
+
248
+ def _init_sampler_strategy(self,
249
+ constrained_roughness: bool = True,
250
+ constrained_isld: bool = True,
251
+ max_thickness_share: float = 0.5,
252
+ max_sld_share: float = 0.2,
253
+ nuisance_params_dim: int = 0,
254
+ **kwargs):
255
+ if constrained_roughness:
256
+ num_params = self.param_dim + nuisance_params_dim
257
+ thickness_mask = torch.zeros(num_params, dtype=torch.bool)
258
+ roughness_mask = torch.zeros(num_params, dtype=torch.bool)
259
+ thickness_mask[:self.max_num_layers] = True
260
+ roughness_mask[self.max_num_layers:2 * self.max_num_layers + 1] = True
261
+
262
+ if constrained_isld:
263
+ sld_mask = torch.zeros(num_params, dtype=torch.bool)
264
+ isld_mask = torch.zeros(num_params, dtype=torch.bool)
265
+ sld_mask[2 * self.max_num_layers + 1:3 * self.max_num_layers + 2] = True
266
+ isld_mask[3 * self.max_num_layers + 2:4 * self.max_num_layers + 3] = True
267
+ return ConstrainedRoughnessAndImgSldSamplerStrategy(
268
+ thickness_mask, roughness_mask, sld_mask, isld_mask,
269
+ max_thickness_share=max_thickness_share, max_sld_share=max_sld_share,
270
+ **kwargs
271
+ )
272
+ else:
273
+ return ConstrainedRoughnessSamplerStrategy(
274
+ thickness_mask, roughness_mask,
275
+ max_thickness_share=max_thickness_share,
276
+ **kwargs
277
+ )
278
+ else:
279
+ return BasicSamplerStrategy(**kwargs)
280
+
281
+ def init_bounds(self,
282
+ param_ranges: Dict[str, Tuple[float, float]],
283
+ bound_width_ranges: Dict[str, Tuple[float, float]],
284
+ device=None,
285
+ dtype=None,
286
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
287
+ other_ranges = [param_ranges[k] for k in self.PARAMETER_NAMES[4:]]
288
+ other_delta_bounds = [bound_width_ranges[k] for k in self.PARAMETER_NAMES[4:]]
289
+
290
+ ordered_bounds = (
291
+ [param_ranges["thicknesses"]] * self.max_num_layers +
292
+ [param_ranges["roughnesses"]] * (self.max_num_layers + 1) +
293
+ [param_ranges["slds"]] * (self.max_num_layers + 1) +
294
+ [param_ranges["islds"]] * (self.max_num_layers + 1) +
295
+ other_ranges
296
+ )
297
+ delta_bounds = (
298
+ [bound_width_ranges["thicknesses"]] * self.max_num_layers +
299
+ [bound_width_ranges["roughnesses"]] * (self.max_num_layers + 1) +
300
+ [bound_width_ranges["slds"]] * (self.max_num_layers + 1) +
301
+ [bound_width_ranges["islds"]] * (self.max_num_layers + 1) +
302
+ other_delta_bounds
303
+ )
304
+
305
+ min_bounds, max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
306
+ min_deltas, max_deltas = torch.tensor(delta_bounds, device=device, dtype=dtype).T[:, None]
307
+
308
+ return min_bounds, max_bounds, min_deltas, max_deltas
309
+
310
+ def get_param_labels(self, **kwargs) -> List[str]:
311
+ return get_param_labels(self.max_num_layers, parameterization_type='absorption', **kwargs)
312
+
313
+ @staticmethod
314
+ def _params2dict(parametrized_model: Tensor):
315
+ num_params = parametrized_model.shape[-1]
316
+ num_layers = (num_params - 3) // 4
317
+ assert num_layers * 4 + 3 == num_params
318
+
319
+ d, sigma, sld, isld = torch.split(
320
+ parametrized_model, [num_layers, num_layers + 1, num_layers + 1, num_layers + 1], -1
321
+ )
322
+ params = dict(
323
+ thickness=d,
324
+ roughness=sigma,
325
+ sld=sld + 1j * isld
326
+ )
327
+
328
+ return params
329
+
330
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
331
+ return reflectivity(
332
+ q, **self._params2dict(parametrized_model), **kwargs
333
+ )
334
+
335
+
336
+ class ModelWithShifts(StandardModel):
337
+ """Variant of the standard box model parameterization in which two additional parameters are considered: the shift in the q positions (additive) and the shift in
338
+ intensity (multiplicative, or additive in log domain)."""
339
+ NAME = 'model_with_shifts'
340
+
341
+ PARAMETER_NAMES = (
342
+ "thicknesses",
343
+ "roughnesses",
344
+ "slds",
345
+ "q_shift",
346
+ "norm_shift",
347
+ )
348
+
349
+ @property
350
+ def param_dim(self) -> int:
351
+ return 3 * self.max_num_layers + 4
352
+
353
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
354
+ params = self._params2dict(parametrized_model)
355
+ params.pop('q_shift')
356
+ params.pop('norm_shift')
357
+
358
+ return params
359
+
360
+ def get_param_labels(self, **kwargs) -> List[str]:
361
+ return get_param_labels(self.max_num_layers, **kwargs) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
362
+
363
+
364
+ @staticmethod
365
+ def _params2dict(parametrized_model: Tensor):
366
+ num_params = parametrized_model.shape[-1]
367
+ num_layers = (num_params - 4) // 3
368
+ assert num_layers * 3 + 4 == num_params
369
+
370
+ d, sigma, sld, q_shift, norm_shift = torch.split(
371
+ parametrized_model, [num_layers, num_layers + 1, num_layers + 1, 1, 1], -1
372
+ )
373
+ params = dict(
374
+ thickness=d,
375
+ roughness=sigma,
376
+ sld=sld,
377
+ q_shift=q_shift,
378
+ norm_shift=norm_shift,
379
+ )
380
+
381
+ return params
382
+
383
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
384
+ return reflectivity_with_shifts(
385
+ q, **self._params2dict(parametrized_model), **kwargs
386
+ )
387
+
388
+ def reflectivity_with_shifts(q, thickness, roughness, sld, q_shift, norm_shift, **kwargs):
389
+ q = torch.atleast_2d(q) + q_shift
390
+ return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
391
+
392
+ class NoFresnelModel(StandardModel):
393
+ NAME = 'no_fresnel_model'
394
+
395
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
396
+ return kinematical_approximation(
397
+ q, **self._params2dict(parametrized_model), apply_fresnel=False, **kwargs
398
+ )
399
+
400
+
401
+ class BasicMultilayerModel1(ParametricModel):
402
+ NAME = 'repeating_multilayer_v1'
403
+
404
+ PARAMETER_NAMES = (
405
+ "d_full_rel",
406
+ "rel_sigmas",
407
+ "d_block",
408
+ "s_block_rel",
409
+ "r_block",
410
+ "dr",
411
+ "d3_rel",
412
+ "s3_rel",
413
+ "r3",
414
+ "d_sio2",
415
+ "s_sio2",
416
+ "s_si",
417
+ "r_sio2",
418
+ "r_si",
419
+ )
420
+
421
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
422
+ return multilayer_model1(parametrized_model, self.max_num_layers)
423
+
424
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
425
+ params = self.to_standard_params(parametrized_model)
426
+ return reflectivity(q, abeles_func=abeles_memory_eff, **params, **kwargs)
427
+
428
+
429
+ class BasicMultilayerModel2(BasicMultilayerModel1):
430
+ NAME = 'repeating_multilayer_v2'
431
+
432
+ PARAMETER_NAMES = (
433
+ "d_full_rel",
434
+ "rel_sigmas",
435
+ "dr_sigmoid_rel_pos",
436
+ "dr_sigmoid_rel_width",
437
+ "d_block",
438
+ "s_block_rel",
439
+ "r_block",
440
+ "dr",
441
+ "d3_rel",
442
+ "s3_rel",
443
+ "r3",
444
+ "d_sio2",
445
+ "s_sio2",
446
+ "s_si",
447
+ "r_sio2",
448
+ "r_si",
449
+ )
450
+
451
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
452
+ return multilayer_model2(parametrized_model, self.max_num_layers)
453
+
454
+
455
+ class BasicMultilayerModel3(BasicMultilayerModel1):
456
+ """Parameterization for a thin film composed of repeating identical monolayers, each monolayer consisting of two boxes with distinct SLDs.
457
+ A sigmoid envelope modulating the SLD profile of the monolayers defines the film thickness and the roughness at the top interface.
458
+ A second sigmoid envelope can be used to modulate the amplitude of the monolayer SLDs as a function of the displacement from the position of the first sigmoid.
459
+ These two sigmoids allow one to model a thin film that is coherently ordered up to a certain coherent thickness and gets incoherently ordered or amorphous toward the top of the film.
460
+ In addition, a layer between the substrate and the multilayer (”phase layer”) is introduced to account for the interface structure,
461
+ which does not necessarily have to be identical to the multilayer period.
462
+ """
463
+
464
+ NAME = 'repeating_multilayer_v3'
465
+
466
+ PARAMETER_NAMES = (
467
+ "d_full_rel",
468
+ "rel_sigmas",
469
+ "dr_sigmoid_rel_pos",
470
+ "dr_sigmoid_rel_width",
471
+ "d_block1_rel",
472
+ "d_block",
473
+ "s_block_rel",
474
+ "r_block",
475
+ "dr",
476
+ "d3_rel",
477
+ "s3_rel",
478
+ "r3",
479
+ "d_sio2",
480
+ "s_sio2",
481
+ "s_si",
482
+ "r_sio2",
483
+ "r_si",
484
+ )
485
+
486
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
487
+ return multilayer_model3(parametrized_model, self.max_num_layers)
488
+
489
+
490
+ class MultilayerModel1WithShifts(BasicMultilayerModel1):
491
+ NAME = 'repeating_multilayer_v1_with_shifts'
492
+
493
+ PARAMETER_NAMES = (
494
+ "d_full_rel",
495
+ "rel_sigmas",
496
+ "d_block",
497
+ "s_block_rel",
498
+ "r_block",
499
+ "dr",
500
+ "d3_rel",
501
+ "s3_rel",
502
+ "r3",
503
+ "d_sio2",
504
+ "s_sio2",
505
+ "s_si",
506
+ "r_sio2",
507
+ "r_si",
508
+ "q_shift",
509
+ "norm_shift",
510
+ )
511
+
512
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
513
+ q_shift, norm_shift = parametrized_model[..., -2:].T[..., None]
514
+ return reflectivity_with_shifts(
515
+ q, q_shift=q_shift, norm_shift=norm_shift, abeles_func=abeles_memory_eff,
516
+ **self.to_standard_params(parametrized_model), **kwargs
517
+ )
518
+
519
+
520
+ class MultilayerModel3WithShifts(BasicMultilayerModel3):
521
+ NAME = 'repeating_multilayer_v3_with_shifts'
522
+
523
+ PARAMETER_NAMES = (
524
+ "d_full_rel",
525
+ "rel_sigmas",
526
+ "dr_sigmoid_rel_pos",
527
+ "dr_sigmoid_rel_width",
528
+ "d_block1_rel",
529
+ "d_block",
530
+ "s_block_rel",
531
+ "r_block",
532
+ "dr",
533
+ "d3_rel",
534
+ "s3_rel",
535
+ "r3",
536
+ "d_sio2",
537
+ "s_sio2",
538
+ "s_si",
539
+ "r_sio2",
540
+ "r_si",
541
+ "q_shift",
542
+ "norm_shift",
543
+ )
544
+
545
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
546
+ q_shift, norm_shift = parametrized_model[..., -2:].T[..., None]
547
+ return reflectivity_with_shifts(
548
+ q, q_shift=q_shift, norm_shift=norm_shift, abeles_func=abeles_memory_eff,
549
+ **self.to_standard_params(parametrized_model), **kwargs
550
+ )
551
+
552
+
553
+ MULTILAYER_MODELS = {
554
+ 'standard_model': StandardModel,
555
+ 'model_with_absorption': ModelWithAbsorption,
556
+ 'model_with_shifts': ModelWithShifts,
557
+ 'no_fresnel_model': NoFresnelModel,
558
+ 'repeating_multilayer_v1': BasicMultilayerModel1,
559
+ 'repeating_multilayer_v2': BasicMultilayerModel2,
560
+ 'repeating_multilayer_v3': BasicMultilayerModel3,
561
+ 'repeating_multilayer_v1_with_shifts': MultilayerModel1WithShifts,
562
+ 'repeating_multilayer_v3_with_shifts': MultilayerModel3WithShifts,
563
+ }
564
+
565
+
566
+ def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 30) -> dict:
567
+ n = d_full_rel_max
568
+
569
+ (
570
+ d_full_rel,
571
+ rel_sigmas,
572
+ d_block,
573
+ s_block_rel,
574
+ r_block,
575
+ dr,
576
+ d3_rel,
577
+ s3_rel,
578
+ r3,
579
+ d_sio2,
580
+ s_sio2,
581
+ s_si,
582
+ r_sio2,
583
+ r_si,
584
+ *_,
585
+ ) = parametrized_model.T
586
+
587
+ batch_size = parametrized_model.shape[0]
588
+
589
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
590
+
591
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
592
+
593
+ r_block = r_block[:, None].repeat(1, n)
594
+ dr = dr[:, None].repeat(1, n)
595
+
596
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
597
+
598
+ sld_blocks = r_modulations * sld_blocks
599
+
600
+ d3 = d3_rel * d_block
601
+
602
+ thicknesses = torch.cat(
603
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
604
+ )
605
+
606
+ s_block = s_block_rel * d_block
607
+
608
+ roughnesses = torch.cat(
609
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
610
+ )
611
+
612
+ slds = torch.cat(
613
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
614
+ )
615
+
616
+ params = dict(
617
+ thickness=thicknesses,
618
+ roughness=roughnesses,
619
+ sld=slds
620
+ )
621
+ return params
622
+
623
+
624
+ def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 30) -> dict:
625
+ n = d_full_rel_max
626
+
627
+ (
628
+ d_full_rel,
629
+ rel_sigmas,
630
+ dr_sigmoid_rel_pos,
631
+ dr_sigmoid_rel_width,
632
+ d_block,
633
+ s_block_rel,
634
+ r_block,
635
+ dr,
636
+ d3_rel,
637
+ s3_rel,
638
+ r3,
639
+ d_sio2,
640
+ s_sio2,
641
+ s_si,
642
+ r_sio2,
643
+ r_si,
644
+ *_,
645
+ ) = parametrized_model.T
646
+
647
+ batch_size = parametrized_model.shape[0]
648
+
649
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
650
+
651
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
652
+
653
+ r_block = r_block[:, None].repeat(1, n)
654
+ dr = dr[:, None].repeat(1, n)
655
+
656
+ dr_positions = r_positions[:, ::2]
657
+
658
+ dr_modulations = torch.sigmoid(
659
+ -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
660
+ )
661
+
662
+ dr = dr * dr_modulations
663
+
664
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
665
+
666
+ sld_blocks = r_modulations * sld_blocks
667
+
668
+ d3 = d3_rel * d_block
669
+
670
+ thicknesses = torch.cat(
671
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
672
+ )
673
+
674
+ s_block = s_block_rel * d_block
675
+
676
+ roughnesses = torch.cat(
677
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
678
+ )
679
+
680
+ slds = torch.cat(
681
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
682
+ )
683
+
684
+ params = dict(
685
+ thickness=thicknesses,
686
+ roughness=roughnesses,
687
+ sld=slds
688
+ )
689
+ return params
690
+
691
+
692
+ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
693
+ n = d_full_rel_max
694
+
695
+ (
696
+ d_full_rel,
697
+ rel_sigmas,
698
+ dr_sigmoid_rel_pos,
699
+ dr_sigmoid_rel_width,
700
+ d_block1_rel,
701
+ d_block,
702
+ s_block_rel,
703
+ r_block,
704
+ dr,
705
+ d3_rel,
706
+ s3_rel,
707
+ r3,
708
+ d_sio2,
709
+ s_sio2,
710
+ s_si,
711
+ r_sio2,
712
+ r_si,
713
+ *_,
714
+ ) = parametrized_model.T
715
+
716
+ batch_size = parametrized_model.shape[0]
717
+
718
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
719
+
720
+ r_modulations = torch.sigmoid(
721
+ -(
722
+ r_positions - 2 * d_full_rel[..., None]
723
+ ) / rel_sigmas[..., None]
724
+ )
725
+
726
+ dr_positions = r_positions[:, ::2]
727
+
728
+ dr_modulations = dr[..., None] * (1 - torch.sigmoid(
729
+ -(
730
+ dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
731
+ ) / dr_sigmoid_rel_width[..., None]
732
+ ))
733
+
734
+ r_block = r_block[..., None].repeat(1, n)
735
+ dr = dr[..., None].repeat(1, n)
736
+
737
+ sld_blocks = torch.stack(
738
+ [
739
+ r_block + dr_modulations * (1 - d_block1_rel[..., None]),
740
+ r_block + dr - dr_modulations * d_block1_rel[..., None]
741
+ ], -1).flatten(1)
742
+
743
+ sld_blocks = r_modulations * sld_blocks
744
+
745
+ d3 = d3_rel * d_block
746
+
747
+ d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
748
+
749
+ thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
750
+
751
+ thicknesses = torch.cat(
752
+ [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
753
+ )
754
+
755
+ s_block = s_block_rel * d_block
756
+
757
+ roughnesses = torch.cat(
758
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
759
+ )
760
+
761
+ slds = torch.cat(
762
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
763
+ )
764
+
765
+ params = dict(
766
+ thickness=thicknesses,
767
+ roughness=roughnesses,
768
+ sld=slds
769
+ )
770
+ return params
771
+
772
+
773
+ class NuisanceParamsWrapper(ParametricModel):
774
+ """
775
+ Wraps a base model (e.g. StandardModel) to add nuisance parameters, allowing independent enabling/disabling.
776
+
777
+ Args:
778
+ base_model (ParametricModel): The base parametric model.
779
+ nuisance_params_config (Dict[str, bool]): Dictionary where keys are parameter names
780
+ and values are `True` (enable) or `False` (disable).
781
+ """
782
+
783
+ def __init__(self, base_model: ParametricModel, nuisance_params_config: Dict[str, bool] = None, **kwargs):
784
+ self.base_model = base_model
785
+ self.nuisance_params_config = nuisance_params_config or {}
786
+
787
+ self.enabled_nuisance_params = [name for name, is_enabled in self.nuisance_params_config.items() if is_enabled]
788
+
789
+ self.PARAMETER_NAMES = self.base_model.PARAMETER_NAMES + tuple(self.enabled_nuisance_params)
790
+ self._param_dim = self.base_model.param_dim + len(self.enabled_nuisance_params)
791
+
792
+ super().__init__(base_model.max_num_layers, **kwargs)
793
+
794
+ def _init_sampler_strategy(self, **kwargs):
795
+ return self.base_model._init_sampler_strategy(nuisance_params_dim=len(self.enabled_nuisance_params), **kwargs)
796
+
797
+ @property
798
+ def param_dim(self) -> int:
799
+ return self._param_dim
800
+
801
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
802
+ """Extracts base model parameters only."""
803
+ base_dim = self.base_model.param_dim
804
+ base_part = parametrized_model[..., :base_dim]
805
+ return self.base_model.to_standard_params(base_part)
806
+
807
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
808
+ """Computes reflectivity with optional nuisance parameter shifts."""
809
+ base_dim = self.base_model.param_dim
810
+ base_params = parametrized_model[..., :base_dim]
811
+ nuisance_part = parametrized_model[..., base_dim:]
812
+
813
+ nuisance_dict = {param: nuisance_part[..., i].unsqueeze(-1) for i, param in enumerate(self.enabled_nuisance_params)}
814
+ if "log10_background" in nuisance_dict:
815
+ nuisance_dict["background"] = 10 ** nuisance_dict.pop("log10_background")
816
+
817
+ return self.base_model.reflectivity(q, base_params, **nuisance_dict, **kwargs)
818
+
819
+ def init_bounds(self, param_ranges: Dict[str, Tuple[float, float]],
820
+ bound_width_ranges: Dict[str, Tuple[float, float]], device=None, dtype=None):
821
+ """Initialize bounds for enabled nuisance parameters."""
822
+ min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base = self.base_model.init_bounds(
823
+ param_ranges, bound_width_ranges, device, dtype)
824
+
825
+ ordered_bounds_nuisance = [param_ranges[k] for k in self.enabled_nuisance_params]
826
+ delta_bounds_nuisance = [bound_width_ranges[k] for k in self.enabled_nuisance_params]
827
+
828
+ if ordered_bounds_nuisance:
829
+ min_bounds_nuisance, max_bounds_nuisance = torch.tensor(ordered_bounds_nuisance, device=device, dtype=dtype).T[:, None]
830
+ min_deltas_nuisance, max_deltas_nuisance = torch.tensor(delta_bounds_nuisance, device=device, dtype=dtype).T[:, None]
831
+
832
+ min_bounds = torch.cat([min_bounds_base, min_bounds_nuisance], dim=-1)
833
+ max_bounds = torch.cat([max_bounds_base, max_bounds_nuisance], dim=-1)
834
+ min_deltas = torch.cat([min_deltas_base, min_deltas_nuisance], dim=-1)
835
+ max_deltas = torch.cat([max_deltas_base, max_deltas_nuisance], dim=-1)
836
+ else:
837
+ min_bounds, max_bounds, min_deltas, max_deltas = min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base
838
+
839
+ return min_bounds, max_bounds, min_deltas, max_deltas
840
+
841
+ def get_param_labels(self, **kwargs) -> List[str]:
842
842
  return self.base_model.get_param_labels(**kwargs) + self.enabled_nuisance_params