reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +90 -15
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +31 -11
- reflectorch/data_generation/reflectivity/__init__.py +56 -14
- reflectorch/data_generation/reflectivity/abeles.py +31 -16
- reflectorch/data_generation/reflectivity/kinematical.py +5 -6
- reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +92 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +220 -105
- reflectorch/inference/plotting.py +98 -0
- reflectorch/inference/scipy_fitter.py +84 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +122 -23
- reflectorch/models/__init__.py +1 -1
- reflectorch/models/encoders/__init__.py +0 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +324 -152
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +43 -9
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -41,6 +41,7 @@ from reflectorch.data_generation.noise import (
|
|
|
41
41
|
ScalingNoise,
|
|
42
42
|
BackgroundNoise,
|
|
43
43
|
BasicExpIntensityNoise,
|
|
44
|
+
GaussianExpIntensityNoise,
|
|
44
45
|
BasicQNoiseGenerator,
|
|
45
46
|
)
|
|
46
47
|
from reflectorch.data_generation.scale_curves import (
|
|
@@ -111,6 +112,7 @@ __all__ = [
|
|
|
111
112
|
"LogLikelihood",
|
|
112
113
|
"PoissonLogLikelihood",
|
|
113
114
|
"BasicExpIntensityNoise",
|
|
115
|
+
"GaussianExpIntensityNoise",
|
|
114
116
|
"BasicQNoiseGenerator",
|
|
115
117
|
"ConstantAngle",
|
|
116
118
|
"SubpriorParametricSampler",
|
|
@@ -25,6 +25,7 @@ class BasicDataset(object):
|
|
|
25
25
|
curves_scaler (CurvesScaler, optional): the reflectivity curve scaler. Defaults to an instance of LogAffineCurvesScaler,
|
|
26
26
|
which scales the curves to the range [-1, 1], the minimum considered intensity being 1e-10.
|
|
27
27
|
calc_denoised_curves (bool, optional): whether to add the curves without noise to the dictionary. Defaults to False.
|
|
28
|
+
calc_nonsmeared_curves (bool, optional): whether to add the curves without smearing to the dictionary (only relevant when smearing is applied). Defaults to False.
|
|
28
29
|
smearing (Smearing, optional): curve smearing generator. Defaults to None.
|
|
29
30
|
"""
|
|
30
31
|
def __init__(self,
|
|
@@ -34,6 +35,7 @@ class BasicDataset(object):
|
|
|
34
35
|
q_noise: QNoiseGenerator = None,
|
|
35
36
|
curves_scaler: CurvesScaler = None,
|
|
36
37
|
calc_denoised_curves: bool = False,
|
|
38
|
+
calc_nonsmeared_curves: bool = False,
|
|
37
39
|
smearing: Smearing = None,
|
|
38
40
|
):
|
|
39
41
|
self.q_generator = q_generator
|
|
@@ -43,6 +45,7 @@ class BasicDataset(object):
|
|
|
43
45
|
self.prior_sampler = prior_sampler
|
|
44
46
|
self.smearing = smearing
|
|
45
47
|
self.calc_denoised_curves = calc_denoised_curves
|
|
48
|
+
self.calc_nonsmeared_curves = calc_nonsmeared_curves
|
|
46
49
|
|
|
47
50
|
def update_batch_data(self, batch_data: BATCH_DATA_TYPE) -> None:
|
|
48
51
|
"""implement in a subclass to edit batch_data dict inplace"""
|
|
@@ -74,7 +77,15 @@ class BasicDataset(object):
|
|
|
74
77
|
|
|
75
78
|
batch_data['q_values'] = q_values
|
|
76
79
|
|
|
77
|
-
|
|
80
|
+
refl_kwargs = {}
|
|
81
|
+
|
|
82
|
+
curves, q_resolutions, nonsmeared_curves = self._calc_curves(q_values, params, refl_kwargs)
|
|
83
|
+
|
|
84
|
+
if torch.is_tensor(q_resolutions):
|
|
85
|
+
batch_data['q_resolutions'] = q_resolutions
|
|
86
|
+
|
|
87
|
+
if torch.is_tensor(nonsmeared_curves):
|
|
88
|
+
batch_data['nonsmeared_curves'] = nonsmeared_curves
|
|
78
89
|
|
|
79
90
|
if self.calc_denoised_curves:
|
|
80
91
|
batch_data['curves'] = curves
|
|
@@ -88,10 +99,13 @@ class BasicDataset(object):
|
|
|
88
99
|
batch_data['scaled_noisy_curves'] = scaled_noisy_curves
|
|
89
100
|
|
|
90
101
|
is_finite = torch.all(torch.isfinite(scaled_noisy_curves), -1)
|
|
102
|
+
|
|
91
103
|
if not torch.all(is_finite).item():
|
|
92
104
|
infinite_indices = ~is_finite
|
|
93
|
-
|
|
94
|
-
|
|
105
|
+
to_recalculate = infinite_indices.sum().item()
|
|
106
|
+
warnings.warn(f'Infinite number appeared in the curve simulation! Recalculate {to_recalculate} curves.')
|
|
107
|
+
recalculated_batch_data = self.get_batch(to_recalculate)
|
|
108
|
+
_insert_batch_data(batch_data, recalculated_batch_data, infinite_indices)
|
|
95
109
|
|
|
96
110
|
is_finite = torch.all(torch.isfinite(batch_data['scaled_noisy_curves']), -1)
|
|
97
111
|
assert torch.all(is_finite).item()
|
|
@@ -100,13 +114,19 @@ class BasicDataset(object):
|
|
|
100
114
|
|
|
101
115
|
return batch_data
|
|
102
116
|
|
|
103
|
-
def _calc_curves(self, q_values: Tensor, params: BasicParams):
|
|
117
|
+
def _calc_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs):
|
|
118
|
+
nonsmeared_curves = None
|
|
119
|
+
|
|
104
120
|
if self.smearing:
|
|
105
|
-
|
|
121
|
+
if self.calc_nonsmeared_curves:
|
|
122
|
+
nonsmeared_curves = params.reflectivity(q_values, **refl_kwargs)
|
|
123
|
+
curves, q_resolutions = self.smearing.get_curves(q_values, params, refl_kwargs)
|
|
106
124
|
else:
|
|
107
|
-
curves = params.reflectivity(q_values)
|
|
125
|
+
curves = params.reflectivity(q_values, **refl_kwargs)
|
|
126
|
+
q_resolutions = None
|
|
127
|
+
|
|
108
128
|
curves = curves.to(q_values)
|
|
109
|
-
return curves
|
|
129
|
+
return curves, q_resolutions, nonsmeared_curves
|
|
110
130
|
|
|
111
131
|
|
|
112
132
|
def _insert_batch_data(tgt_batch_data, add_batch_data, indices):
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import Union, Tuple
|
|
1
|
+
from typing import List, Union, Tuple
|
|
2
2
|
from math import log10
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from torch import Tensor
|
|
6
6
|
|
|
7
7
|
from reflectorch.data_generation.process_data import ProcessData
|
|
8
|
-
from reflectorch.data_generation.utils import uniform_sampler
|
|
8
|
+
from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"QNoiseGenerator",
|
|
@@ -18,6 +18,7 @@ __all__ = [
|
|
|
18
18
|
"ShiftNoise",
|
|
19
19
|
"BackgroundNoise",
|
|
20
20
|
"BasicExpIntensityNoise",
|
|
21
|
+
"GaussianExpIntensityNoise",
|
|
21
22
|
"BasicQNoiseGenerator",
|
|
22
23
|
]
|
|
23
24
|
|
|
@@ -102,18 +103,22 @@ class BasicQNoiseGenerator(QNoiseGenerator):
|
|
|
102
103
|
Defaults to (0, 1e-3).
|
|
103
104
|
"""
|
|
104
105
|
def __init__(self,
|
|
106
|
+
apply_systematic_shifts: bool = True,
|
|
105
107
|
shift_std: float = 1e-3,
|
|
108
|
+
apply_gaussian_noise: bool = False,
|
|
106
109
|
noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
107
110
|
add_to_context: bool = False,
|
|
108
111
|
):
|
|
109
|
-
self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context)
|
|
110
|
-
self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context)
|
|
112
|
+
self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
|
|
113
|
+
self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
|
|
111
114
|
|
|
112
115
|
def apply(self, qs: Tensor, context: dict = None):
|
|
113
|
-
"""applies
|
|
116
|
+
"""applies noise to the q values"""
|
|
114
117
|
qs = torch.atleast_2d(qs)
|
|
115
|
-
|
|
116
|
-
|
|
118
|
+
if self.q_shift:
|
|
119
|
+
qs = self.q_shift.apply(qs, context)
|
|
120
|
+
if self.q_noise:
|
|
121
|
+
qs = self.q_noise.apply(qs, context)
|
|
117
122
|
return qs
|
|
118
123
|
|
|
119
124
|
|
|
@@ -154,6 +159,51 @@ class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
|
|
|
154
159
|
|
|
155
160
|
return noise * curves
|
|
156
161
|
|
|
162
|
+
class GaussianNoiseGenerator(IntensityNoiseGenerator):
|
|
163
|
+
"""Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
|
|
167
|
+
consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
|
|
171
|
+
consistent_rel_err: bool = False,
|
|
172
|
+
add_to_context: bool = False):
|
|
173
|
+
self.relative_errors = relative_errors
|
|
174
|
+
self.consistent_rel_err = consistent_rel_err
|
|
175
|
+
self.add_to_context = add_to_context
|
|
176
|
+
|
|
177
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
178
|
+
"""Applies Gaussian noise to the curves."""
|
|
179
|
+
relative_errors = self.relative_errors
|
|
180
|
+
num_channels = curves.shape[1] if curves.dim() == 3 else 1
|
|
181
|
+
|
|
182
|
+
if isinstance(relative_errors, float):
|
|
183
|
+
relative_errors = torch.ones_like(curves) * relative_errors
|
|
184
|
+
|
|
185
|
+
elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
|
|
186
|
+
if self.consistent_rel_err:
|
|
187
|
+
relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
|
|
188
|
+
if num_channels > 1:
|
|
189
|
+
relative_errors = relative_errors.unsqueeze(-1)
|
|
190
|
+
else:
|
|
191
|
+
relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
if self.consistent_rel_err:
|
|
195
|
+
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
196
|
+
else:
|
|
197
|
+
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
198
|
+
|
|
199
|
+
sigmas = relative_errors * curves
|
|
200
|
+
noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
|
|
201
|
+
|
|
202
|
+
if self.add_to_context and context is not None:
|
|
203
|
+
context['relative_errors'] = relative_errors
|
|
204
|
+
context['sigmas'] = sigmas
|
|
205
|
+
|
|
206
|
+
return curves + noise
|
|
157
207
|
|
|
158
208
|
class PoissonNoiseGenerator(IntensityNoiseGenerator):
|
|
159
209
|
"""Noise generator which applies Poisson noise to the reflectivity curves
|
|
@@ -273,6 +323,7 @@ class BackgroundNoise(IntensityNoiseGenerator):
|
|
|
273
323
|
|
|
274
324
|
Args:
|
|
275
325
|
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
326
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
276
327
|
"""
|
|
277
328
|
def __init__(self,
|
|
278
329
|
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
@@ -283,7 +334,7 @@ class BackgroundNoise(IntensityNoiseGenerator):
|
|
|
283
334
|
|
|
284
335
|
def apply(self, curves: Tensor, context: dict = None) -> Tensor:
|
|
285
336
|
"""applies background noise to the curves"""
|
|
286
|
-
backgrounds =
|
|
337
|
+
backgrounds = logdist_sampler(
|
|
287
338
|
*self.background_range, curves.shape[0], 1,
|
|
288
339
|
device=curves.device, dtype=curves.dtype
|
|
289
340
|
)
|
|
@@ -294,6 +345,61 @@ class BackgroundNoise(IntensityNoiseGenerator):
|
|
|
294
345
|
|
|
295
346
|
return curves
|
|
296
347
|
|
|
348
|
+
class GaussianExpIntensityNoise(IntensityNoiseGenerator):
|
|
349
|
+
"""
|
|
350
|
+
A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
|
|
351
|
+
|
|
352
|
+
This class combines three types of noise:
|
|
353
|
+
1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
|
|
354
|
+
2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
355
|
+
3. Background noise: Adds a constant background value to the curves.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
|
|
359
|
+
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
|
|
360
|
+
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
361
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
362
|
+
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
363
|
+
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
364
|
+
same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
|
|
365
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
366
|
+
"""
|
|
367
|
+
def __init__(self,
|
|
368
|
+
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
369
|
+
consistent_rel_err: bool = False,
|
|
370
|
+
apply_shift: bool = False,
|
|
371
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
372
|
+
apply_background: bool = False,
|
|
373
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
374
|
+
same_background_across_channels: bool = False,
|
|
375
|
+
add_to_context: bool = False,
|
|
376
|
+
):
|
|
377
|
+
|
|
378
|
+
self.gaussian_noise = GaussianNoiseGenerator(
|
|
379
|
+
relative_errors=relative_errors,
|
|
380
|
+
consistent_rel_err=consistent_rel_err,
|
|
381
|
+
add_to_context=add_to_context,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
self.shift_noise = ShiftNoise(
|
|
385
|
+
shift_range=shift_range, add_to_context=add_to_context
|
|
386
|
+
) if apply_shift else None
|
|
387
|
+
|
|
388
|
+
self.background_noise = BackgroundNoise(
|
|
389
|
+
background_range=background_range, add_to_context=add_to_context
|
|
390
|
+
) if apply_background else None
|
|
391
|
+
|
|
392
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
393
|
+
"""applies the specified types of noise to the input curves"""
|
|
394
|
+
if self.shift_noise:
|
|
395
|
+
curves = self.shift_noise(curves, context)
|
|
396
|
+
|
|
397
|
+
if self.background_noise:
|
|
398
|
+
curves = self.background_noise.apply(curves, context)
|
|
399
|
+
|
|
400
|
+
curves = self.gaussian_noise(curves, context)
|
|
401
|
+
|
|
402
|
+
return curves
|
|
297
403
|
|
|
298
404
|
class BasicExpIntensityNoise(IntensityNoiseGenerator):
|
|
299
405
|
"""
|
|
@@ -362,4 +468,4 @@ class BasicExpIntensityNoise(IntensityNoiseGenerator):
|
|
|
362
468
|
if self.background_noise:
|
|
363
469
|
curves = self.background_noise.apply(curves, context)
|
|
364
470
|
|
|
365
|
-
return curves
|
|
471
|
+
return curves
|
|
@@ -10,7 +10,6 @@ from reflectorch.data_generation.reflectivity import (
|
|
|
10
10
|
)
|
|
11
11
|
from reflectorch.data_generation.utils import (
|
|
12
12
|
get_param_labels,
|
|
13
|
-
get_param_labels_absorption_model,
|
|
14
13
|
)
|
|
15
14
|
from reflectorch.data_generation.priors.sampler_strategies import (
|
|
16
15
|
SamplerStrategy,
|
|
@@ -44,7 +43,7 @@ class ParametricModel(object):
|
|
|
44
43
|
@property
|
|
45
44
|
def param_dim(self) -> int:
|
|
46
45
|
"""get the number of parameters
|
|
47
|
-
|
|
46
|
+
|
|
48
47
|
Returns:
|
|
49
48
|
int:
|
|
50
49
|
"""
|
|
@@ -106,7 +105,7 @@ class ParametricModel(object):
|
|
|
106
105
|
|
|
107
106
|
return min_bounds, max_bounds, min_deltas, max_deltas
|
|
108
107
|
|
|
109
|
-
def get_param_labels(self) -> List[str]:
|
|
108
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
110
109
|
"""get the list with the name of the parameters
|
|
111
110
|
|
|
112
111
|
Returns:
|
|
@@ -158,9 +157,10 @@ class StandardModel(ParametricModel):
|
|
|
158
157
|
def _init_sampler_strategy(self,
|
|
159
158
|
constrained_roughness: bool = True,
|
|
160
159
|
max_thickness_share: float = 0.5,
|
|
160
|
+
nuisance_params_dim: int = 0,
|
|
161
161
|
**kwargs):
|
|
162
162
|
if constrained_roughness:
|
|
163
|
-
num_params = self.param_dim
|
|
163
|
+
num_params = self.param_dim + nuisance_params_dim
|
|
164
164
|
thickness_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
165
165
|
roughness_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
166
166
|
thickness_mask[:self.max_num_layers] = True
|
|
@@ -204,8 +204,8 @@ class StandardModel(ParametricModel):
|
|
|
204
204
|
|
|
205
205
|
return min_bounds, max_bounds, min_deltas, max_deltas
|
|
206
206
|
|
|
207
|
-
def get_param_labels(self) -> List[str]:
|
|
208
|
-
return get_param_labels(self.max_num_layers)
|
|
207
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
208
|
+
return get_param_labels(self.max_num_layers, **kwargs)
|
|
209
209
|
|
|
210
210
|
@staticmethod
|
|
211
211
|
def _params2dict(parametrized_model: Tensor):
|
|
@@ -250,9 +250,10 @@ class ModelWithAbsorption(StandardModel):
|
|
|
250
250
|
constrained_isld: bool = True,
|
|
251
251
|
max_thickness_share: float = 0.5,
|
|
252
252
|
max_sld_share: float = 0.2,
|
|
253
|
+
nuisance_params_dim: int = 0,
|
|
253
254
|
**kwargs):
|
|
254
255
|
if constrained_roughness:
|
|
255
|
-
num_params = self.param_dim
|
|
256
|
+
num_params = self.param_dim + nuisance_params_dim
|
|
256
257
|
thickness_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
257
258
|
roughness_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
258
259
|
thickness_mask[:self.max_num_layers] = True
|
|
@@ -262,10 +263,11 @@ class ModelWithAbsorption(StandardModel):
|
|
|
262
263
|
sld_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
263
264
|
isld_mask = torch.zeros(num_params, dtype=torch.bool)
|
|
264
265
|
sld_mask[2 * self.max_num_layers + 1:3 * self.max_num_layers + 2] = True
|
|
265
|
-
isld_mask[3 * self.max_num_layers + 2:] = True
|
|
266
|
+
isld_mask[3 * self.max_num_layers + 2:4 * self.max_num_layers + 3] = True
|
|
266
267
|
return ConstrainedRoughnessAndImgSldSamplerStrategy(
|
|
267
268
|
thickness_mask, roughness_mask, sld_mask, isld_mask,
|
|
268
|
-
max_thickness_share=max_thickness_share, max_sld_share=max_sld_share
|
|
269
|
+
max_thickness_share=max_thickness_share, max_sld_share=max_sld_share,
|
|
270
|
+
**kwargs
|
|
269
271
|
)
|
|
270
272
|
else:
|
|
271
273
|
return ConstrainedRoughnessSamplerStrategy(
|
|
@@ -305,9 +307,9 @@ class ModelWithAbsorption(StandardModel):
|
|
|
305
307
|
|
|
306
308
|
return min_bounds, max_bounds, min_deltas, max_deltas
|
|
307
309
|
|
|
308
|
-
def get_param_labels(self) -> List[str]:
|
|
309
|
-
return
|
|
310
|
-
|
|
310
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
311
|
+
return get_param_labels(self.max_num_layers, parameterization_type='absorption', **kwargs)
|
|
312
|
+
|
|
311
313
|
@staticmethod
|
|
312
314
|
def _params2dict(parametrized_model: Tensor):
|
|
313
315
|
num_params = parametrized_model.shape[-1]
|
|
@@ -355,8 +357,9 @@ class ModelWithShifts(StandardModel):
|
|
|
355
357
|
|
|
356
358
|
return params
|
|
357
359
|
|
|
358
|
-
def get_param_labels(self) -> List[str]:
|
|
359
|
-
return get_param_labels(self.max_num_layers) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
|
|
360
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
361
|
+
return get_param_labels(self.max_num_layers, **kwargs) + [r"$\Delta q$ (Å$^{{-1}}$)", r"$\Delta I$"]
|
|
362
|
+
|
|
360
363
|
|
|
361
364
|
@staticmethod
|
|
362
365
|
def _params2dict(parametrized_model: Tensor):
|
|
@@ -384,7 +387,7 @@ class ModelWithShifts(StandardModel):
|
|
|
384
387
|
|
|
385
388
|
def reflectivity_with_shifts(q, thickness, roughness, sld, q_shift, norm_shift, **kwargs):
|
|
386
389
|
q = torch.atleast_2d(q) + q_shift
|
|
387
|
-
return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
|
|
390
|
+
return reflectivity(q, thickness, roughness, sld, **kwargs) * norm_shift
|
|
388
391
|
|
|
389
392
|
class NoFresnelModel(StandardModel):
|
|
390
393
|
NAME = 'no_fresnel_model'
|
|
@@ -765,3 +768,75 @@ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
|
|
|
765
768
|
sld=slds
|
|
766
769
|
)
|
|
767
770
|
return params
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class NuisanceParamsWrapper(ParametricModel):
|
|
774
|
+
"""
|
|
775
|
+
Wraps a base model (e.g. StandardModel) to add nuisance parameters, allowing independent enabling/disabling.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
base_model (ParametricModel): The base parametric model.
|
|
779
|
+
nuisance_params_config (Dict[str, bool]): Dictionary where keys are parameter names
|
|
780
|
+
and values are `True` (enable) or `False` (disable).
|
|
781
|
+
"""
|
|
782
|
+
|
|
783
|
+
def __init__(self, base_model: ParametricModel, nuisance_params_config: Dict[str, bool] = None, **kwargs):
|
|
784
|
+
self.base_model = base_model
|
|
785
|
+
self.nuisance_params_config = nuisance_params_config or {}
|
|
786
|
+
|
|
787
|
+
self.enabled_nuisance_params = [name for name, is_enabled in self.nuisance_params_config.items() if is_enabled]
|
|
788
|
+
|
|
789
|
+
self.PARAMETER_NAMES = self.base_model.PARAMETER_NAMES + tuple(self.enabled_nuisance_params)
|
|
790
|
+
self._param_dim = self.base_model.param_dim + len(self.enabled_nuisance_params)
|
|
791
|
+
|
|
792
|
+
super().__init__(base_model.max_num_layers, **kwargs)
|
|
793
|
+
|
|
794
|
+
def _init_sampler_strategy(self, **kwargs):
|
|
795
|
+
return self.base_model._init_sampler_strategy(nuisance_params_dim=len(self.enabled_nuisance_params), **kwargs)
|
|
796
|
+
|
|
797
|
+
@property
|
|
798
|
+
def param_dim(self) -> int:
|
|
799
|
+
return self._param_dim
|
|
800
|
+
|
|
801
|
+
def to_standard_params(self, parametrized_model: Tensor) -> dict:
|
|
802
|
+
"""Extracts base model parameters only."""
|
|
803
|
+
base_dim = self.base_model.param_dim
|
|
804
|
+
base_part = parametrized_model[..., :base_dim]
|
|
805
|
+
return self.base_model.to_standard_params(base_part)
|
|
806
|
+
|
|
807
|
+
def reflectivity(self, q, parametrized_model: Tensor, **kwargs) -> Tensor:
|
|
808
|
+
"""Computes reflectivity with optional nuisance parameter shifts."""
|
|
809
|
+
base_dim = self.base_model.param_dim
|
|
810
|
+
base_params = parametrized_model[..., :base_dim]
|
|
811
|
+
nuisance_part = parametrized_model[..., base_dim:]
|
|
812
|
+
|
|
813
|
+
nuisance_dict = {param: nuisance_part[..., i].unsqueeze(-1) for i, param in enumerate(self.enabled_nuisance_params)}
|
|
814
|
+
if "log10_background" in nuisance_dict:
|
|
815
|
+
nuisance_dict["background"] = 10 ** nuisance_dict.pop("log10_background")
|
|
816
|
+
|
|
817
|
+
return self.base_model.reflectivity(q, base_params, **nuisance_dict, **kwargs)
|
|
818
|
+
|
|
819
|
+
def init_bounds(self, param_ranges: Dict[str, Tuple[float, float]],
|
|
820
|
+
bound_width_ranges: Dict[str, Tuple[float, float]], device=None, dtype=None):
|
|
821
|
+
"""Initialize bounds for enabled nuisance parameters."""
|
|
822
|
+
min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base = self.base_model.init_bounds(
|
|
823
|
+
param_ranges, bound_width_ranges, device, dtype)
|
|
824
|
+
|
|
825
|
+
ordered_bounds_nuisance = [param_ranges[k] for k in self.enabled_nuisance_params]
|
|
826
|
+
delta_bounds_nuisance = [bound_width_ranges[k] for k in self.enabled_nuisance_params]
|
|
827
|
+
|
|
828
|
+
if ordered_bounds_nuisance:
|
|
829
|
+
min_bounds_nuisance, max_bounds_nuisance = torch.tensor(ordered_bounds_nuisance, device=device, dtype=dtype).T[:, None]
|
|
830
|
+
min_deltas_nuisance, max_deltas_nuisance = torch.tensor(delta_bounds_nuisance, device=device, dtype=dtype).T[:, None]
|
|
831
|
+
|
|
832
|
+
min_bounds = torch.cat([min_bounds_base, min_bounds_nuisance], dim=-1)
|
|
833
|
+
max_bounds = torch.cat([max_bounds_base, max_bounds_nuisance], dim=-1)
|
|
834
|
+
min_deltas = torch.cat([min_deltas_base, min_deltas_nuisance], dim=-1)
|
|
835
|
+
max_deltas = torch.cat([max_deltas_base, max_deltas_nuisance], dim=-1)
|
|
836
|
+
else:
|
|
837
|
+
min_bounds, max_bounds, min_deltas, max_deltas = min_bounds_base, max_bounds_base, min_deltas_base, max_deltas_base
|
|
838
|
+
|
|
839
|
+
return min_bounds, max_bounds, min_deltas, max_deltas
|
|
840
|
+
|
|
841
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
842
|
+
return self.base_model.get_param_labels(**kwargs) + self.enabled_nuisance_params
|
|
@@ -12,6 +12,7 @@ from reflectorch.data_generation.priors.no_constraints import (
|
|
|
12
12
|
|
|
13
13
|
from reflectorch.data_generation.priors.parametric_models import (
|
|
14
14
|
MULTILAYER_MODELS,
|
|
15
|
+
NuisanceParamsWrapper,
|
|
15
16
|
ParametricModel,
|
|
16
17
|
)
|
|
17
18
|
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
@@ -54,9 +55,9 @@ class BasicParams(AbstractParams):
|
|
|
54
55
|
self.min_bounds = min_bounds
|
|
55
56
|
self.max_bounds = max_bounds
|
|
56
57
|
|
|
57
|
-
def get_param_labels(self) -> List[str]:
|
|
58
|
+
def get_param_labels(self, **kwargs) -> List[str]:
|
|
58
59
|
"""gets the parameter labels"""
|
|
59
|
-
return self.param_model.get_param_labels()
|
|
60
|
+
return self.param_model.get_param_labels(**kwargs)
|
|
60
61
|
|
|
61
62
|
def reflectivity(self, q: Tensor, log: bool = False, **kwargs):
|
|
62
63
|
r"""computes the reflectivity curves directly from the parameters
|
|
@@ -97,6 +98,18 @@ class BasicParams(AbstractParams):
|
|
|
97
98
|
"""gets the slds"""
|
|
98
99
|
params = self.param_model.to_standard_params(self.parameters)
|
|
99
100
|
return params['sld']
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def real_slds(self):
|
|
104
|
+
"""gets the real part of the slds"""
|
|
105
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
106
|
+
return params['sld'].real
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def imag_slds(self):
|
|
110
|
+
"""gets the imaginary part of the slds (only for complex dtypes)"""
|
|
111
|
+
params = self.param_model.to_standard_params(self.parameters)
|
|
112
|
+
return params['sld'].imag
|
|
100
113
|
|
|
101
114
|
@staticmethod
|
|
102
115
|
def rearrange_context_from_params(
|
|
@@ -201,11 +214,19 @@ class SubpriorParametricSampler(PriorSampler, ScalerMixin):
|
|
|
201
214
|
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to (-1., 1.)
|
|
202
215
|
"""
|
|
203
216
|
self.scaled_range = scaled_range
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
)
|
|
217
|
+
|
|
218
|
+
self.shift_param_config = kwargs.pop('shift_param_config', {})
|
|
219
|
+
|
|
220
|
+
base_model: ParametricModel = MULTILAYER_MODELS[model_name](max_num_layers, logdist=logdist, **kwargs)
|
|
221
|
+
if any(self.shift_param_config.values()):
|
|
222
|
+
self.param_model = NuisanceParamsWrapper(
|
|
223
|
+
base_model=base_model,
|
|
224
|
+
nuisance_params_config=self.shift_param_config,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
self.param_model = base_model
|
|
229
|
+
|
|
209
230
|
self.device = device
|
|
210
231
|
self.dtype = dtype
|
|
211
232
|
self.num_layers = max_num_layers
|
|
@@ -73,11 +73,13 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
|
73
73
|
roughness_mask: Tensor,
|
|
74
74
|
logdist: bool = False,
|
|
75
75
|
max_thickness_share: float = 0.5,
|
|
76
|
+
max_total_thickness: float = None,
|
|
76
77
|
):
|
|
77
78
|
super().__init__(logdist=logdist)
|
|
78
79
|
self.thickness_mask = thickness_mask
|
|
79
80
|
self.roughness_mask = roughness_mask
|
|
80
81
|
self.max_thickness_share = max_thickness_share
|
|
82
|
+
self.max_total_thickness = max_total_thickness
|
|
81
83
|
|
|
82
84
|
def sample(self, batch_size: int,
|
|
83
85
|
total_min_bounds: Tensor,
|
|
@@ -106,7 +108,8 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
|
106
108
|
thickness_mask=self.thickness_mask.to(device),
|
|
107
109
|
roughness_mask=self.roughness_mask.to(device),
|
|
108
110
|
widths_sampler_func=self.widths_sampler_func,
|
|
109
|
-
|
|
111
|
+
coef_roughness=self.max_thickness_share,
|
|
112
|
+
max_total_thickness=self.max_total_thickness,
|
|
110
113
|
)
|
|
111
114
|
|
|
112
115
|
class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
@@ -129,6 +132,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
129
132
|
logdist: bool = False,
|
|
130
133
|
max_thickness_share: float = 0.5,
|
|
131
134
|
max_sld_share: float = 0.2,
|
|
135
|
+
max_total_thickness: float = None,
|
|
132
136
|
):
|
|
133
137
|
super().__init__(logdist=logdist)
|
|
134
138
|
self.thickness_mask = thickness_mask
|
|
@@ -137,6 +141,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
137
141
|
self.isld_mask = isld_mask
|
|
138
142
|
self.max_thickness_share = max_thickness_share
|
|
139
143
|
self.max_sld_share = max_sld_share
|
|
144
|
+
self.max_total_thickness = max_total_thickness
|
|
140
145
|
|
|
141
146
|
def sample(self, batch_size: int,
|
|
142
147
|
total_min_bounds: Tensor,
|
|
@@ -169,6 +174,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
169
174
|
widths_sampler_func=self.widths_sampler_func,
|
|
170
175
|
coef_roughness=self.max_thickness_share,
|
|
171
176
|
coef_isld=self.max_sld_share,
|
|
177
|
+
max_total_thickness=self.max_total_thickness,
|
|
172
178
|
)
|
|
173
179
|
|
|
174
180
|
def basic_sampler(
|
|
@@ -214,15 +220,44 @@ def constrained_roughness_sampler(
|
|
|
214
220
|
thickness_mask: Tensor,
|
|
215
221
|
roughness_mask: Tensor,
|
|
216
222
|
widths_sampler_func,
|
|
217
|
-
|
|
223
|
+
coef_roughness: float = 0.5,
|
|
224
|
+
max_total_thickness: float = None,
|
|
218
225
|
):
|
|
219
226
|
params, min_bounds, max_bounds = basic_sampler(
|
|
220
227
|
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
221
228
|
widths_sampler_func=widths_sampler_func,
|
|
222
229
|
)
|
|
223
230
|
|
|
231
|
+
if max_total_thickness is not None:
|
|
232
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
233
|
+
indices = total_thickness > max_total_thickness
|
|
234
|
+
|
|
235
|
+
if indices.any():
|
|
236
|
+
eps = 0.01
|
|
237
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
238
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
239
|
+
scale_coef[~indices] = 1.0
|
|
240
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
241
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
242
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
243
|
+
|
|
244
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
245
|
+
min_bounds[:, thickness_mask],
|
|
246
|
+
total_min_bounds[:, thickness_mask],
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
250
|
+
max_bounds[:, thickness_mask],
|
|
251
|
+
total_min_bounds[:, thickness_mask],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
255
|
+
params[:, thickness_mask],
|
|
256
|
+
total_min_bounds[:, thickness_mask],
|
|
257
|
+
)
|
|
258
|
+
|
|
224
259
|
max_roughness = torch.minimum(
|
|
225
|
-
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=
|
|
260
|
+
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
226
261
|
total_max_bounds[..., roughness_mask]
|
|
227
262
|
)
|
|
228
263
|
min_roughness = total_min_bounds[..., roughness_mask]
|
|
@@ -256,12 +291,41 @@ def constrained_roughness_and_isld_sampler(
|
|
|
256
291
|
widths_sampler_func,
|
|
257
292
|
coef_roughness: float = 0.5,
|
|
258
293
|
coef_isld: float = 0.2,
|
|
294
|
+
max_total_thickness: float = None,
|
|
259
295
|
):
|
|
260
296
|
params, min_bounds, max_bounds = basic_sampler(
|
|
261
297
|
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
262
298
|
widths_sampler_func=widths_sampler_func,
|
|
263
299
|
)
|
|
264
300
|
|
|
301
|
+
if max_total_thickness is not None:
|
|
302
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
303
|
+
indices = total_thickness > max_total_thickness
|
|
304
|
+
|
|
305
|
+
if indices.any():
|
|
306
|
+
eps = 0.01
|
|
307
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
308
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
309
|
+
scale_coef[~indices] = 1.0
|
|
310
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
311
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
312
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
313
|
+
|
|
314
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
315
|
+
min_bounds[:, thickness_mask],
|
|
316
|
+
total_min_bounds[:, thickness_mask],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
320
|
+
max_bounds[:, thickness_mask],
|
|
321
|
+
total_min_bounds[:, thickness_mask],
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
325
|
+
params[:, thickness_mask],
|
|
326
|
+
total_min_bounds[:, thickness_mask],
|
|
327
|
+
)
|
|
328
|
+
|
|
265
329
|
max_roughness = torch.minimum(
|
|
266
330
|
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
267
331
|
total_max_bounds[..., roughness_mask]
|