reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (39) hide show
  1. reflectorch/data_generation/__init__.py +2 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +90 -15
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +31 -11
  8. reflectorch/data_generation/reflectivity/__init__.py +56 -14
  9. reflectorch/data_generation/reflectivity/abeles.py +31 -16
  10. reflectorch/data_generation/reflectivity/kinematical.py +5 -6
  11. reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
  12. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  13. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  14. reflectorch/data_generation/smearing.py +42 -11
  15. reflectorch/data_generation/utils.py +92 -18
  16. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  17. reflectorch/inference/inference_model.py +220 -105
  18. reflectorch/inference/plotting.py +98 -0
  19. reflectorch/inference/scipy_fitter.py +84 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +122 -23
  26. reflectorch/models/__init__.py +1 -1
  27. reflectorch/models/encoders/__init__.py +0 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/networks/__init__.py +2 -0
  31. reflectorch/models/networks/mlp_networks.py +324 -152
  32. reflectorch/models/networks/residual_net.py +31 -5
  33. reflectorch/runs/train.py +0 -1
  34. reflectorch/runs/utils.py +43 -9
  35. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
  36. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
  37. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
  38. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
  39. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,7 @@ from reflectorch.data_generation.noise import (
41
41
  ScalingNoise,
42
42
  BackgroundNoise,
43
43
  BasicExpIntensityNoise,
44
+ GaussianExpIntensityNoise,
44
45
  BasicQNoiseGenerator,
45
46
  )
46
47
  from reflectorch.data_generation.scale_curves import (
@@ -111,6 +112,7 @@ __all__ = [
111
112
  "LogLikelihood",
112
113
  "PoissonLogLikelihood",
113
114
  "BasicExpIntensityNoise",
115
+ "GaussianExpIntensityNoise",
114
116
  "BasicQNoiseGenerator",
115
117
  "ConstantAngle",
116
118
  "SubpriorParametricSampler",
@@ -25,6 +25,7 @@ class BasicDataset(object):
25
25
  curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
26
26
  which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
27
27
  calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
28
+ calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
28
29
  smearing (Smearing, optional): curve smearing generator. Defaults to None.
29
30
  """
30
31
  def __init__(self,
@@ -34,6 +35,7 @@ class BasicDataset(object):
34
35
  q_noise: QNoiseGenerator = None,
35
36
  curves_scaler: CurvesScaler = None,
36
37
  calc_denoised_curves: bool = False,
38
+ calc_nonsmeared_curves: bool = False,
37
39
  smearing: Smearing = None,
38
40
  ):
39
41
  self.q_generator = q_generator
@@ -43,6 +45,7 @@ class BasicDataset(object):
43
45
  self.prior_sampler = prior_sampler
44
46
  self.smearing = smearing
45
47
  self.calc_denoised_curves = calc_denoised_curves
48
+ self.calc_nonsmeared_curves = calc_nonsmeared_curves
46
49
 
47
50
  def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
48
51
  """implement in a subclass to edit batch_data dict inplace"""
@@ -74,7 +77,15 @@ class BasicDataset(object):
74
77
 
75
78
  batch_data['q_values'] = q_values
76
79
 
77
- curves = self._calc_curves(q_values, params)
80
+ refl_kwargs = {}
81
+
82
+ curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
83
+
84
+ if torch.is_tensor(q_resolutions):
85
+ batch_data['q_resolutions'] = q_resolutions
86
+
87
+ if torch.is_tensor(nonsmeared_curves):
88
+ batch_data['nonsmeared_curves'] = nonsmeared_curves
78
89
 
79
90
  if self.calc_denoised_curves:
80
91
  batch_data['curves'] = curves
@@ -88,10 +99,13 @@ class BasicDataset(object):
88
99
  batch_data['scaled_noisy_curves'] = scaled_noisy_curves
89
100
 
90
101
  is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
102
+
91
103
  if not torch.all(is_finite).item():
92
104
  infinite_indices = ~is_finite
93
- warnings.warn(f'Batch with {infinite_indices.sum().item()} curves with infinities skipped.')
94
- return self.get_batch(batch_size = batch_size)
105
+ to_recalculate = infinite_indices.sum().item()
106
+ warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
107
+ recalculated_batch_data = self.get_batch(to_recalculate)
108
+ _insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
95
109
 
96
110
  is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
97
111
  assert torch.all(is_finite).item()
@@ -100,13 +114,19 @@ class BasicDataset(object):
100
114
 
101
115
  return batch_data
102
116
 
103
- def _calc_curves(self, q_values: Tensor, params: BasicParams):
117
+ def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
118
+ nonsmeared_curves = None
119
+
104
120
  if self.smearing:
105
- curves = self.smearing.get_curves(q_values, params)
121
+ if self.calc_nonsmeared_curves:
122
+ nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
123
+ curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
106
124
  else:
107
- curves = params.reflectivity(q_values)
125
+ curves = params.reflectivity(q_values, **refl_kwargs)
126
+ q_resolutions = None
127
+
108
128
  curves = curves.to(q_values)
109
- return curves
129
+ return curves, q_resolutions, nonsmeared_curves
110
130
 
111
131
 
112
132
  def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
@@ -1,11 +1,11 @@
1
- from typing import Union, Tuple
1
+ from typing import List, Union, Tuple
2
2
  from math import log10
3
3
 
4
4
  import torch
5
5
  from torch import Tensor
6
6
 
7
7
  from reflectorch.data_generation.process_data import ProcessData
8
- from reflectorch.data_generation.utils import uniform_sampler
8
+ from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
9
9
 
10
10
  __all__ = [
11
11
  "QNoiseGenerator",
@@ -18,6 +18,7 @@ __all__ = [
18
18
  "ShiftNoise",
19
19
  "BackgroundNoise",
20
20
  "BasicExpIntensityNoise",
21
+ "GaussianExpIntensityNoise",
21
22
  "BasicQNoiseGenerator",
22
23
  ]
23
24
 
@@ -102,18 +103,22 @@ class BasicQNoiseGenerator(QNoiseGenerator):
102
103
  Defaults to (0, 1e-3).
103
104
  """
104
105
  def __init__(self,
106
+ apply_systematic_shifts: bool = True,
105
107
  shift_std: float = 1e-3,
108
+ apply_gaussian_noise: bool = False,
106
109
  noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
107
110
  add_to_context: bool = False,
108
111
  ):
109
- self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context)
110
- self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context)
112
+ self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
113
+ self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
111
114
 
112
115
  def apply(self, qs: Tensor, context: dict = None):
113
- """applies random noise to the q values"""
116
+ """applies noise to the q values"""
114
117
  qs = torch.atleast_2d(qs)
115
- qs = self.q_shift.apply(qs, context)
116
- qs = self.q_noise.apply(qs, context)
118
+ if self.q_shift:
119
+ qs = self.q_shift.apply(qs, context)
120
+ if self.q_noise:
121
+ qs = self.q_noise.apply(qs, context)
117
122
  return qs
118
123
 
119
124
 
@@ -154,6 +159,51 @@ class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
154
159
 
155
160
  return noise * curves
156
161
 
162
+ class GaussianNoiseGenerator(IntensityNoiseGenerator):
163
+ """Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
164
+
165
+ Args:
166
+ relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
167
+ consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
168
+ """
169
+
170
+ def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
171
+ consistent_rel_err: bool = False,
172
+ add_to_context: bool = False):
173
+ self.relative_errors = relative_errors
174
+ self.consistent_rel_err = consistent_rel_err
175
+ self.add_to_context = add_to_context
176
+
177
+ def apply(self, curves: Tensor, context: dict = None):
178
+ """Applies Gaussian noise to the curves."""
179
+ relative_errors = self.relative_errors
180
+ num_channels = curves.shape[1] if curves.dim() == 3 else 1
181
+
182
+ if isinstance(relative_errors, float):
183
+ relative_errors = torch.ones_like(curves) * relative_errors
184
+
185
+ elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
186
+ if self.consistent_rel_err:
187
+ relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
188
+ if num_channels > 1:
189
+ relative_errors = relative_errors.unsqueeze(-1)
190
+ else:
191
+ relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
192
+
193
+ else:
194
+ if self.consistent_rel_err:
195
+ relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
196
+ else:
197
+ relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
198
+
199
+ sigmas = relative_errors * curves
200
+ noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
201
+
202
+ if self.add_to_context and context is not None:
203
+ context['relative_errors'] = relative_errors
204
+ context['sigmas'] = sigmas
205
+
206
+ return curves + noise
157
207
 
158
208
  class PoissonNoiseGenerator(IntensityNoiseGenerator):
159
209
  """Noise generator which applies Poisson noise to the reflectivity curves
@@ -273,6 +323,7 @@ class BackgroundNoise(IntensityNoiseGenerator):
273
323
 
274
324
  Args:
275
325
  background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
326
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
276
327
  """
277
328
  def __init__(self,
278
329
  background_range: tuple = (1.0e-10, 1.0e-8),
@@ -283,7 +334,7 @@ class BackgroundNoise(IntensityNoiseGenerator):
283
334
 
284
335
  def apply(self, curves: Tensor, context: dict = None) -> Tensor:
285
336
  """applies background noise to the curves"""
286
- backgrounds = uniform_sampler(
337
+ backgrounds = logdist_sampler(
287
338
  *self.background_range, curves.shape[0], 1,
288
339
  device=curves.device, dtype=curves.dtype
289
340
  )
@@ -294,6 +345,61 @@ class BackgroundNoise(IntensityNoiseGenerator):
294
345
 
295
346
  return curves
296
347
 
348
+ class GaussianExpIntensityNoise(IntensityNoiseGenerator):
349
+ """
350
+ A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
351
+
352
+ This class combines three types of noise:
353
+ 1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
354
+ 2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
355
+ 3. Background noise: Adds a constant background value to the curves.
356
+
357
+ Args:
358
+ relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
359
+ consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
360
+ shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
361
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
362
+ apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
363
+ apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
364
+ same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
365
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
366
+ """
367
+ def __init__(self,
368
+ relative_errors: Tuple[float, float] = (0.001, 0.15),
369
+ consistent_rel_err: bool = False,
370
+ apply_shift: bool = False,
371
+ shift_range: tuple = (-0.1, 0.2e-2),
372
+ apply_background: bool = False,
373
+ background_range: tuple = (1.0e-10, 1.0e-8),
374
+ same_background_across_channels: bool = False,
375
+ add_to_context: bool = False,
376
+ ):
377
+
378
+ self.gaussian_noise = GaussianNoiseGenerator(
379
+ relative_errors=relative_errors,
380
+ consistent_rel_err=consistent_rel_err,
381
+ add_to_context=add_to_context,
382
+ )
383
+
384
+ self.shift_noise = ShiftNoise(
385
+ shift_range=shift_range, add_to_context=add_to_context
386
+ ) if apply_shift else None
387
+
388
+ self.background_noise = BackgroundNoise(
389
+ background_range=background_range, add_to_context=add_to_context
390
+ ) if apply_background else None
391
+
392
+ def apply(self, curves: Tensor, context: dict = None):
393
+ """applies the specified types of noise to the input curves"""
394
+ if self.shift_noise:
395
+ curves = self.shift_noise(curves, context)
396
+
397
+ if self.background_noise:
398
+ curves = self.background_noise.apply(curves, context)
399
+
400
+ curves = self.gaussian_noise(curves, context)
401
+
402
+ return curves
297
403
 
298
404
  class BasicExpIntensityNoise(IntensityNoiseGenerator):
299
405
  """
@@ -362,4 +468,4 @@ class BasicExpIntensityNoise(IntensityNoiseGenerator):
362
468
  if self.background_noise:
363
469
  curves = self.background_noise.apply(curves, context)
364
470
 
365
- return curves
471
+ return curves
@@ -10,7 +10,6 @@ from reflectorch.data_generation.reflectivity import (
10
10
  )
11
11
  from reflectorch.data_generation.utils import (
12
12
  get_param_labels,
13
- get_param_labels_absorption_model,
14
13
  )
15
14
  from reflectorch.data_generation.priors.sampler_strategies import (
16
15
  SamplerStrategy,
@@ -44,7 +43,7 @@ class ParametricModel(object):
44
43
  @property
45
44
  def param_dim(self) -> int:
46
45
  """get the number of parameters
47
-
46
+
48
47
  Returns:
49
48
  int:
50
49
  """
@@ -106,7 +105,7 @@ class ParametricModel(object):
106
105
 
107
106
  return min_bounds, max_bounds, min_deltas, max_deltas
108
107
 
109
- def get_param_labels(self) -> List[str]:
108
+ def get_param_labels(self, **kwargs) -> List[str]:
110
109
  """get the list with the name of the parameters
111
110
 
112
111
  Returns:
@@ -158,9 +157,10 @@ class StandardModel(ParametricModel):
158
157
  def _init_sampler_strategy(self,
159
158
  constrained_roughness: bool = True,
160
159
  max_thickness_share: float = 0.5,
160
+ nuisance_params_dim: int = 0,
161
161
  **kwargs):
162
162
  if constrained_roughness:
163
- num_params = self.param_dim
163
+ num_params = self.param_dim + nuisance_params_dim
164
164
  thickness_mask = torch.zeros(num_params, dtype=torch.bool)
165
165
  roughness_mask = torch.zeros(num_params, dtype=torch.bool)
166
166
  thickness_mask[:self.max_num_layers] = True
@@ -204,8 +204,8 @@ class StandardModel(ParametricModel):
204
204
 
205
205
  return min_bounds, max_bounds, min_deltas, max_deltas
206
206
 
207
- def get_param_labels(self) -> List[str]:
208
- return get_param_labels(self.max_num_layers)
207
+ def get_param_labels(self, **kwargs) -> List[str]:
208
+ return get_param_labels(self.max_num_layers, **kwargs)
209
209
 
210
210
  @staticmethod
211
211
  def _params2dict(parametrized_model: Tensor):
@@ -250,9 +250,10 @@ class ModelWithAbsorption(StandardModel):
250
250
  constrained_isld: bool = True,
251
251
  max_thickness_share: float = 0.5,
252
252
  max_sld_share: float = 0.2,
253
+ nuisance_params_dim: int = 0,
253
254
  **kwargs):
254
255
  if constrained_roughness:
255
- num_params = self.param_dim
256
+ num_params = self.param_dim + nuisance_params_dim
256
257
  thickness_mask = torch.zeros(num_params, dtype=torch.bool)
257
258
  roughness_mask = torch.zeros(num_params, dtype=torch.bool)
258
259
  thickness_mask[:self.max_num_layers] = True
@@ -262,10 +263,11 @@ class ModelWithAbsorption(StandardModel):
262
263
  sld_mask = torch.zeros(num_params, dtype=torch.bool)
263
264
  isld_mask = torch.zeros(num_params, dtype=torch.bool)
264
265
  sld_mask[2 * self.max_num_layers + 1:3 * self.max_num_layers + 2] = True
265
- isld_mask[3 * self.max_num_layers + 2:] = True
266
+ isld_mask[3 * self.max_num_layers + 2:4 * self.max_num_layers + 3] = True
266
267
  return ConstrainedRoughnessAndImgSldSamplerStrategy(
267
268
  thickness_mask, roughness_mask, sld_mask, isld_mask,
268
- max_thickness_share=max_thickness_share, max_sld_share=max_sld_share
269
+ max_thickness_share=max_thickness_share, max_sld_share=max_sld_share,
270
+ **kwargs
269
271
  )
270
272
  else:
271
273
  return ConstrainedRoughnessSamplerStrategy(
@@ -305,9 +307,9 @@ class ModelWithAbsorption(StandardModel):
305
307
 
306
308
  return min_bounds, max_bounds, min_deltas, max_deltas
307
309
 
308
- def get_param_labels(self) -> List[str]:
309
- return get_param_labels_absorption_model(self.max_num_layers)
310
-
310
+ def get_param_labels(self, **kwargs) -> List[str]:
311
+ return get_param_labels(self.max_num_layers, parameterization_type='absorption', **kwargs)
312
+
311
313
  @staticmethod
312
314
  def _params2dict(parametrized_model: Tensor):
313
315
  num_params = parametrized_model.shape[-1]
@@ -355,8 +357,9 @@ class ModelWithShifts(StandardModel):
355
357
 
356
358
  return params
357
359
 
358
- def get_param_labels(self) -> List[str]:
359
- return get_param_labels(self.max_num_layers) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
360
+ def get_param_labels(self, **kwargs) -> List[str]:
361
+ return get_param_labels(self.max_num_layers, **kwargs) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
362
+
360
363
 
361
364
  @staticmethod
362
365
  def _params2dict(parametrized_model: Tensor):
@@ -384,7 +387,7 @@ class ModelWithShifts(StandardModel):
384
387
 
385
388
  def reflectivity_with_shifts(q, thickness, roughness, sld, q_shift, norm_shift, **kwargs):
386
389
  q = torch.atleast_2d(q) + q_shift
387
- return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
390
+ return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
388
391
 
389
392
  class NoFresnelModel(StandardModel):
390
393
  NAME = 'no_fresnel_model'
@@ -765,3 +768,75 @@ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
765
768
  sld=slds
766
769
  )
767
770
  return params
771
+
772
+
773
+ class NuisanceParamsWrapper(ParametricModel):
774
+ """
775
+ Wraps a base model (e.g. StandardModel) to add nuisance parameters, allowing independent enabling/disabling.
776
+
777
+ Args:
778
+ base_model (ParametricModel): The base parametric model.
779
+ nuisance_params_config (Dict[str, bool]): Dictionary where keys are parameter names
780
+ and values are `True` (enable) or `False` (disable).
781
+ """
782
+
783
+ def __init__(self, base_model: ParametricModel, nuisance_params_config: Dict[str, bool] = None, **kwargs):
784
+ self.base_model = base_model
785
+ self.nuisance_params_config = nuisance_params_config or {}
786
+
787
+ self.enabled_nuisance_params = [name for name, is_enabled in self.nuisance_params_config.items() if is_enabled]
788
+
789
+ self.PARAMETER_NAMES = self.base_model.PARAMETER_NAMES + tuple(self.enabled_nuisance_params)
790
+ self._param_dim = self.base_model.param_dim + len(self.enabled_nuisance_params)
791
+
792
+ super().__init__(base_model.max_num_layers, **kwargs)
793
+
794
+ def _init_sampler_strategy(self, **kwargs):
795
+ return self.base_model._init_sampler_strategy(nuisance_params_dim=len(self.enabled_nuisance_params), **kwargs)
796
+
797
+ @property
798
+ def param_dim(self) -> int:
799
+ return self._param_dim
800
+
801
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
802
+ """Extracts base model parameters only."""
803
+ base_dim = self.base_model.param_dim
804
+ base_part = parametrized_model[..., :base_dim]
805
+ return self.base_model.to_standard_params(base_part)
806
+
807
+ def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
808
+ """Computes reflectivity with optional nuisance parameter shifts."""
809
+ base_dim = self.base_model.param_dim
810
+ base_params = parametrized_model[..., :base_dim]
811
+ nuisance_part = parametrized_model[..., base_dim:]
812
+
813
+ nuisance_dict = {param: nuisance_part[..., i].unsqueeze(-1) for i, param in enumerate(self.enabled_nuisance_params)}
814
+ if "log10_background" in nuisance_dict:
815
+ nuisance_dict["background"] = 10 ** nuisance_dict.pop("log10_background")
816
+
817
+ return self.base_model.reflectivity(q, base_params, **nuisance_dict, **kwargs)
818
+
819
+ def init_bounds(self, param_ranges: Dict[str, Tuple[float, float]],
820
+ bound_width_ranges: Dict[str, Tuple[float, float]], device=None, dtype=None):
821
+ """Initialize bounds for enabled nuisance parameters."""
822
+ min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base = self.base_model.init_bounds(
823
+ param_ranges, bound_width_ranges, device, dtype)
824
+
825
+ ordered_bounds_nuisance = [param_ranges[k] for k in self.enabled_nuisance_params]
826
+ delta_bounds_nuisance = [bound_width_ranges[k] for k in self.enabled_nuisance_params]
827
+
828
+ if ordered_bounds_nuisance:
829
+ min_bounds_nuisance, max_bounds_nuisance = torch.tensor(ordered_bounds_nuisance, device=device, dtype=dtype).T[:, None]
830
+ min_deltas_nuisance, max_deltas_nuisance = torch.tensor(delta_bounds_nuisance, device=device, dtype=dtype).T[:, None]
831
+
832
+ min_bounds = torch.cat([min_bounds_base, min_bounds_nuisance], dim=-1)
833
+ max_bounds = torch.cat([max_bounds_base, max_bounds_nuisance], dim=-1)
834
+ min_deltas = torch.cat([min_deltas_base, min_deltas_nuisance], dim=-1)
835
+ max_deltas = torch.cat([max_deltas_base, max_deltas_nuisance], dim=-1)
836
+ else:
837
+ min_bounds, max_bounds, min_deltas, max_deltas = min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base
838
+
839
+ return min_bounds, max_bounds, min_deltas, max_deltas
840
+
841
+ def get_param_labels(self, **kwargs) -> List[str]:
842
+ return self.base_model.get_param_labels(**kwargs) + self.enabled_nuisance_params
@@ -12,6 +12,7 @@ from reflectorch.data_generation.priors.no_constraints import (
12
12
 
13
13
  from reflectorch.data_generation.priors.parametric_models import (
14
14
  MULTILAYER_MODELS,
15
+ NuisanceParamsWrapper,
15
16
  ParametricModel,
16
17
  )
17
18
  from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
@@ -54,9 +55,9 @@ class BasicParams(AbstractParams):
54
55
  self.min_bounds = min_bounds
55
56
  self.max_bounds = max_bounds
56
57
 
57
- def get_param_labels(self) -> List[str]:
58
+ def get_param_labels(self, **kwargs) -> List[str]:
58
59
  """gets the parameter labels"""
59
- return self.param_model.get_param_labels()
60
+ return self.param_model.get_param_labels(**kwargs)
60
61
 
61
62
  def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
62
63
  r"""computes the reflectivity curves directly from the parameters
@@ -97,6 +98,18 @@ class BasicParams(AbstractParams):
97
98
  """gets the slds"""
98
99
  params = self.param_model.to_standard_params(self.parameters)
99
100
  return params['sld']
101
+
102
+ @property
103
+ def real_slds(self):
104
+ """gets the real part of the slds"""
105
+ params = self.param_model.to_standard_params(self.parameters)
106
+ return params['sld'].real
107
+
108
+ @property
109
+ def imag_slds(self):
110
+ """gets the imaginary part of the slds (only for complex dtypes)"""
111
+ params = self.param_model.to_standard_params(self.parameters)
112
+ return params['sld'].imag
100
113
 
101
114
  @staticmethod
102
115
  def rearrange_context_from_params(
@@ -201,11 +214,19 @@ class SubpriorParametricSampler(PriorSampler, ScalerMixin):
201
214
  scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
202
215
  """
203
216
  self.scaled_range = scaled_range
204
- self.param_model: ParametricModel = MULTILAYER_MODELS[model_name](
205
- max_num_layers,
206
- logdist=logdist,
207
- **kwargs
208
- )
217
+
218
+ self.shift_param_config = kwargs.pop('shift_param_config', {})
219
+
220
+ base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
221
+ if any(self.shift_param_config.values()):
222
+ self.param_model = NuisanceParamsWrapper(
223
+ base_model=base_model,
224
+ nuisance_params_config=self.shift_param_config,
225
+ **kwargs,
226
+ )
227
+ else:
228
+ self.param_model = base_model
229
+
209
230
  self.device = device
210
231
  self.dtype = dtype
211
232
  self.num_layers = max_num_layers
@@ -73,11 +73,13 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
73
73
  roughness_mask: Tensor,
74
74
  logdist: bool = False,
75
75
  max_thickness_share: float = 0.5,
76
+ max_total_thickness: float = None,
76
77
  ):
77
78
  super().__init__(logdist=logdist)
78
79
  self.thickness_mask = thickness_mask
79
80
  self.roughness_mask = roughness_mask
80
81
  self.max_thickness_share = max_thickness_share
82
+ self.max_total_thickness = max_total_thickness
81
83
 
82
84
  def sample(self, batch_size: int,
83
85
  total_min_bounds: Tensor,
@@ -106,7 +108,8 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
106
108
  thickness_mask=self.thickness_mask.to(device),
107
109
  roughness_mask=self.roughness_mask.to(device),
108
110
  widths_sampler_func=self.widths_sampler_func,
109
- coef=self.max_thickness_share,
111
+ coef_roughness=self.max_thickness_share,
112
+ max_total_thickness=self.max_total_thickness,
110
113
  )
111
114
 
112
115
  class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
@@ -129,6 +132,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
129
132
  logdist: bool = False,
130
133
  max_thickness_share: float = 0.5,
131
134
  max_sld_share: float = 0.2,
135
+ max_total_thickness: float = None,
132
136
  ):
133
137
  super().__init__(logdist=logdist)
134
138
  self.thickness_mask = thickness_mask
@@ -137,6 +141,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
137
141
  self.isld_mask = isld_mask
138
142
  self.max_thickness_share = max_thickness_share
139
143
  self.max_sld_share = max_sld_share
144
+ self.max_total_thickness = max_total_thickness
140
145
 
141
146
  def sample(self, batch_size: int,
142
147
  total_min_bounds: Tensor,
@@ -169,6 +174,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
169
174
  widths_sampler_func=self.widths_sampler_func,
170
175
  coef_roughness=self.max_thickness_share,
171
176
  coef_isld=self.max_sld_share,
177
+ max_total_thickness=self.max_total_thickness,
172
178
  )
173
179
 
174
180
  def basic_sampler(
@@ -214,15 +220,44 @@ def constrained_roughness_sampler(
214
220
  thickness_mask: Tensor,
215
221
  roughness_mask: Tensor,
216
222
  widths_sampler_func,
217
- coef: float = 0.5,
223
+ coef_roughness: float = 0.5,
224
+ max_total_thickness: float = None,
218
225
  ):
219
226
  params, min_bounds, max_bounds = basic_sampler(
220
227
  batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
221
228
  widths_sampler_func=widths_sampler_func,
222
229
  )
223
230
 
231
+ if max_total_thickness is not None:
232
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
233
+ indices = total_thickness > max_total_thickness
234
+
235
+ if indices.any():
236
+ eps = 0.01
237
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
238
+ scale_coef = max_total_thickness / total_thickness * rand_scale
239
+ scale_coef[~indices] = 1.0
240
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
241
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
242
+ params[:, thickness_mask] *= scale_coef[:, None]
243
+
244
+ min_bounds[:, thickness_mask] = torch.clamp_min(
245
+ min_bounds[:, thickness_mask],
246
+ total_min_bounds[:, thickness_mask],
247
+ )
248
+
249
+ max_bounds[:, thickness_mask] = torch.clamp_min(
250
+ max_bounds[:, thickness_mask],
251
+ total_min_bounds[:, thickness_mask],
252
+ )
253
+
254
+ params[:, thickness_mask] = torch.clamp_min(
255
+ params[:, thickness_mask],
256
+ total_min_bounds[:, thickness_mask],
257
+ )
258
+
224
259
  max_roughness = torch.minimum(
225
- get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef),
260
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
226
261
  total_max_bounds[..., roughness_mask]
227
262
  )
228
263
  min_roughness = total_min_bounds[..., roughness_mask]
@@ -256,12 +291,41 @@ def constrained_roughness_and_isld_sampler(
256
291
  widths_sampler_func,
257
292
  coef_roughness: float = 0.5,
258
293
  coef_isld: float = 0.2,
294
+ max_total_thickness: float = None,
259
295
  ):
260
296
  params, min_bounds, max_bounds = basic_sampler(
261
297
  batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
262
298
  widths_sampler_func=widths_sampler_func,
263
299
  )
264
300
 
301
+ if max_total_thickness is not None:
302
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
303
+ indices = total_thickness > max_total_thickness
304
+
305
+ if indices.any():
306
+ eps = 0.01
307
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
308
+ scale_coef = max_total_thickness / total_thickness * rand_scale
309
+ scale_coef[~indices] = 1.0
310
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
311
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
312
+ params[:, thickness_mask] *= scale_coef[:, None]
313
+
314
+ min_bounds[:, thickness_mask] = torch.clamp_min(
315
+ min_bounds[:, thickness_mask],
316
+ total_min_bounds[:, thickness_mask],
317
+ )
318
+
319
+ max_bounds[:, thickness_mask] = torch.clamp_min(
320
+ max_bounds[:, thickness_mask],
321
+ total_min_bounds[:, thickness_mask],
322
+ )
323
+
324
+ params[:, thickness_mask] = torch.clamp_min(
325
+ params[:, thickness_mask],
326
+ total_min_bounds[:, thickness_mask],
327
+ )
328
+
265
329
  max_roughness = torch.minimum(
266
330
  get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
267
331
  total_max_bounds[..., roughness_mask]