reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,471 +1,471 @@
|
|
|
1
|
-
from typing import List, Union, Tuple
|
|
2
|
-
from math import log10
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.process_data import ProcessData
|
|
8
|
-
from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
|
|
9
|
-
|
|
10
|
-
__all__ = [
|
|
11
|
-
"QNoiseGenerator",
|
|
12
|
-
"IntensityNoiseGenerator",
|
|
13
|
-
"QNormalNoiseGenerator",
|
|
14
|
-
"QSystematicShiftGenerator",
|
|
15
|
-
"PoissonNoiseGenerator",
|
|
16
|
-
"MultiplicativeLogNormalNoiseGenerator",
|
|
17
|
-
"ScalingNoise",
|
|
18
|
-
"ShiftNoise",
|
|
19
|
-
"BackgroundNoise",
|
|
20
|
-
"BasicExpIntensityNoise",
|
|
21
|
-
"GaussianExpIntensityNoise",
|
|
22
|
-
"BasicQNoiseGenerator",
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class QNoiseGenerator(ProcessData):
|
|
27
|
-
"""Base class for q noise generators"""
|
|
28
|
-
def apply(self, qs: Tensor, context: dict = None):
|
|
29
|
-
return qs
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class QNormalNoiseGenerator(QNoiseGenerator):
|
|
33
|
-
"""Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
|
|
37
|
-
or uniformly sampled for each curve in the batch if provided as a tuple)
|
|
38
|
-
"""
|
|
39
|
-
def __init__(self,
|
|
40
|
-
std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
41
|
-
add_to_context: bool = False
|
|
42
|
-
):
|
|
43
|
-
self.std = std
|
|
44
|
-
self.add_to_context = add_to_context
|
|
45
|
-
|
|
46
|
-
def apply(self, qs: Tensor, context: dict = None):
|
|
47
|
-
"""applies noise to the q values"""
|
|
48
|
-
std = self.std
|
|
49
|
-
|
|
50
|
-
if isinstance(std, (list, tuple)):
|
|
51
|
-
std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
|
|
52
|
-
else:
|
|
53
|
-
std = torch.empty_like(qs).fill_(std)
|
|
54
|
-
|
|
55
|
-
noise = torch.normal(mean=0., std=std)
|
|
56
|
-
|
|
57
|
-
if self.add_to_context and context is not None:
|
|
58
|
-
context['q_stds'] = std
|
|
59
|
-
|
|
60
|
-
qs = torch.clamp_min_(qs + noise, 0.)
|
|
61
|
-
|
|
62
|
-
return qs
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
class QSystematicShiftGenerator(QNoiseGenerator):
|
|
66
|
-
"""Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
std (float): the standard deviation of the normal distribution
|
|
70
|
-
"""
|
|
71
|
-
def __init__(self, std: float, add_to_context: bool = True):
|
|
72
|
-
self.std = std
|
|
73
|
-
self.add_to_context = add_to_context
|
|
74
|
-
|
|
75
|
-
def apply(self, qs: Tensor, context: dict = None):
|
|
76
|
-
"""applies systematic shifts to the q values"""
|
|
77
|
-
if len(qs.shape) == 1:
|
|
78
|
-
shape = (1,)
|
|
79
|
-
else:
|
|
80
|
-
shape = (qs.shape[0], 1)
|
|
81
|
-
|
|
82
|
-
shifts = torch.normal(
|
|
83
|
-
mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
if self.add_to_context and context is not None:
|
|
87
|
-
context['q_shifts'] = shifts
|
|
88
|
-
|
|
89
|
-
qs = torch.clamp_min_(qs + shifts, 0.)
|
|
90
|
-
|
|
91
|
-
return qs
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
class BasicQNoiseGenerator(QNoiseGenerator):
|
|
95
|
-
"""Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
|
|
99
|
-
(i.e. same change applied to all q points in the curve). Defaults to 1e-3.
|
|
100
|
-
noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
|
|
101
|
-
(i.e. different changes applied to each q point in the curve). The standard deviation is the same
|
|
102
|
-
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
103
|
-
Defaults to (0, 1e-3).
|
|
104
|
-
"""
|
|
105
|
-
def __init__(self,
|
|
106
|
-
apply_systematic_shifts: bool = True,
|
|
107
|
-
shift_std: float = 1e-3,
|
|
108
|
-
apply_gaussian_noise: bool = False,
|
|
109
|
-
noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
110
|
-
add_to_context: bool = False,
|
|
111
|
-
):
|
|
112
|
-
self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
|
|
113
|
-
self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
|
|
114
|
-
|
|
115
|
-
def apply(self, qs: Tensor, context: dict = None):
|
|
116
|
-
"""applies noise to the q values"""
|
|
117
|
-
qs = torch.atleast_2d(qs)
|
|
118
|
-
if self.q_shift:
|
|
119
|
-
qs = self.q_shift.apply(qs, context)
|
|
120
|
-
if self.q_noise:
|
|
121
|
-
qs = self.q_noise.apply(qs, context)
|
|
122
|
-
return qs
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
class IntensityNoiseGenerator(ProcessData):
|
|
126
|
-
"""Base class for intensity noise generators"""
|
|
127
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
128
|
-
raise NotImplementedError
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
|
|
132
|
-
"""Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
|
|
133
|
-
In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
Args:
|
|
137
|
-
std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
|
|
138
|
-
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
139
|
-
base (float, optional): the base of the logarithm. Defaults to 10.
|
|
140
|
-
"""
|
|
141
|
-
def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
|
|
142
|
-
self.std = std
|
|
143
|
-
self.base = base
|
|
144
|
-
self.add_to_context = add_to_context
|
|
145
|
-
|
|
146
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
147
|
-
"""applies noise to the curves"""
|
|
148
|
-
std = self.std
|
|
149
|
-
|
|
150
|
-
if isinstance(std, (list, tuple)):
|
|
151
|
-
std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
|
|
152
|
-
else:
|
|
153
|
-
std = torch.ones_like(curves) * std
|
|
154
|
-
|
|
155
|
-
noise = self.base ** torch.normal(mean=0., std=std)
|
|
156
|
-
|
|
157
|
-
if self.add_to_context and context is not None:
|
|
158
|
-
context['std_lognormal'] = std
|
|
159
|
-
|
|
160
|
-
return noise * curves
|
|
161
|
-
|
|
162
|
-
class GaussianNoiseGenerator(IntensityNoiseGenerator):
|
|
163
|
-
"""Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
|
|
167
|
-
consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
|
|
168
|
-
"""
|
|
169
|
-
|
|
170
|
-
def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
|
|
171
|
-
consistent_rel_err: bool = False,
|
|
172
|
-
add_to_context: bool = False):
|
|
173
|
-
self.relative_errors = relative_errors
|
|
174
|
-
self.consistent_rel_err = consistent_rel_err
|
|
175
|
-
self.add_to_context = add_to_context
|
|
176
|
-
|
|
177
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
178
|
-
"""Applies Gaussian noise to the curves."""
|
|
179
|
-
relative_errors = self.relative_errors
|
|
180
|
-
num_channels = curves.shape[1] if curves.dim() == 3 else 1
|
|
181
|
-
|
|
182
|
-
if isinstance(relative_errors, float):
|
|
183
|
-
relative_errors = torch.ones_like(curves) * relative_errors
|
|
184
|
-
|
|
185
|
-
elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
|
|
186
|
-
if self.consistent_rel_err:
|
|
187
|
-
relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
|
|
188
|
-
if num_channels > 1:
|
|
189
|
-
relative_errors = relative_errors.unsqueeze(-1)
|
|
190
|
-
else:
|
|
191
|
-
relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
|
|
192
|
-
|
|
193
|
-
else:
|
|
194
|
-
if self.consistent_rel_err:
|
|
195
|
-
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
196
|
-
else:
|
|
197
|
-
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
198
|
-
|
|
199
|
-
sigmas = relative_errors * curves
|
|
200
|
-
noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
|
|
201
|
-
|
|
202
|
-
if self.add_to_context and context is not None:
|
|
203
|
-
context['relative_errors'] = relative_errors
|
|
204
|
-
context['sigmas'] = sigmas
|
|
205
|
-
|
|
206
|
-
return curves + noise
|
|
207
|
-
|
|
208
|
-
class PoissonNoiseGenerator(IntensityNoiseGenerator):
|
|
209
|
-
"""Noise generator which applies Poisson noise to the reflectivity curves
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
|
|
213
|
-
abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
|
|
214
|
-
consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
|
|
215
|
-
logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
|
|
216
|
-
"""
|
|
217
|
-
def __init__(self,
|
|
218
|
-
relative_errors: Tuple[float, float] = (0.05, 0.35),
|
|
219
|
-
abs_errors: float = 1e-8,
|
|
220
|
-
add_to_context: bool = False,
|
|
221
|
-
consistent_rel_err: bool = True,
|
|
222
|
-
logdist: bool = False,
|
|
223
|
-
):
|
|
224
|
-
self.relative_errors = relative_errors
|
|
225
|
-
self.abs_errors = abs_errors
|
|
226
|
-
self.add_to_context = add_to_context
|
|
227
|
-
self.consistent_rel_err = consistent_rel_err
|
|
228
|
-
self.logdist = logdist
|
|
229
|
-
|
|
230
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
231
|
-
"""applies noise to the curves"""
|
|
232
|
-
if self.consistent_rel_err:
|
|
233
|
-
sigmas = self._gen_consistent_sigmas(curves)
|
|
234
|
-
else:
|
|
235
|
-
sigmas = self._gen_sigmas(curves)
|
|
236
|
-
|
|
237
|
-
intensities = curves / sigmas ** 2
|
|
238
|
-
curves = torch.poisson(intensities * curves) / intensities
|
|
239
|
-
|
|
240
|
-
if self.add_to_context and context is not None:
|
|
241
|
-
context['sigmas'] = sigmas
|
|
242
|
-
return curves
|
|
243
|
-
|
|
244
|
-
def _gen_consistent_sigmas(self, curves):
|
|
245
|
-
rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
|
|
246
|
-
self.relative_errors[1] - self.relative_errors[0]
|
|
247
|
-
) + self.relative_errors[0]
|
|
248
|
-
sigmas = curves * rel_err[:, None] + self.abs_errors
|
|
249
|
-
return sigmas
|
|
250
|
-
|
|
251
|
-
def _gen_sigmas(self, curves):
|
|
252
|
-
if not self.logdist:
|
|
253
|
-
rel_err = torch.rand_like(curves) * (
|
|
254
|
-
self.relative_errors[1] - self.relative_errors[0]
|
|
255
|
-
) + self.relative_errors[0]
|
|
256
|
-
else:
|
|
257
|
-
rel_err = torch.rand_like(curves) * (
|
|
258
|
-
log10(self.relative_errors[1]) - log10(self.relative_errors[0])
|
|
259
|
-
) + log10(self.relative_errors[0])
|
|
260
|
-
rel_err = 10 ** rel_err
|
|
261
|
-
|
|
262
|
-
sigmas = curves * rel_err + self.abs_errors
|
|
263
|
-
return sigmas
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
class ScalingNoise(IntensityNoiseGenerator):
|
|
267
|
-
"""Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
|
|
268
|
-
The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
|
|
269
|
-
|
|
270
|
-
Args:
|
|
271
|
-
scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
|
|
272
|
-
"""
|
|
273
|
-
def __init__(self,
|
|
274
|
-
scale_range: tuple = (-0.2e-2, 0.2e-2),
|
|
275
|
-
add_to_context: bool = False,
|
|
276
|
-
):
|
|
277
|
-
self.scale_range = scale_range
|
|
278
|
-
self.add_to_context = add_to_context
|
|
279
|
-
|
|
280
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
281
|
-
"""applies noise to the curves"""
|
|
282
|
-
scales = uniform_sampler(
|
|
283
|
-
*self.scale_range, curves.shape[0], 1,
|
|
284
|
-
device=curves.device, dtype=curves.dtype
|
|
285
|
-
)
|
|
286
|
-
if self.add_to_context and context is not None:
|
|
287
|
-
context['intensity_scales'] = scales
|
|
288
|
-
|
|
289
|
-
curves = curves ** (1 + scales)
|
|
290
|
-
|
|
291
|
-
return curves
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
class ShiftNoise(IntensityNoiseGenerator):
|
|
295
|
-
def __init__(self,
|
|
296
|
-
shift_range: tuple = (-0.1, 0.2e-2),
|
|
297
|
-
add_to_context: bool = False,
|
|
298
|
-
):
|
|
299
|
-
"""Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
|
|
300
|
-
The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
|
|
301
|
-
Args:
|
|
302
|
-
shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
|
|
303
|
-
"""
|
|
304
|
-
self.shift_range = shift_range
|
|
305
|
-
self.add_to_context = add_to_context
|
|
306
|
-
|
|
307
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
308
|
-
"""applies noise to the curves"""
|
|
309
|
-
intensity_shifts = uniform_sampler(
|
|
310
|
-
*self.shift_range, curves.shape[0], 1,
|
|
311
|
-
device=curves.device, dtype=curves.dtype
|
|
312
|
-
)
|
|
313
|
-
if self.add_to_context and context is not None:
|
|
314
|
-
context['intensity_shifts'] = intensity_shifts
|
|
315
|
-
|
|
316
|
-
curves = curves * (1 + intensity_shifts)
|
|
317
|
-
|
|
318
|
-
return curves
|
|
319
|
-
|
|
320
|
-
class BackgroundNoise(IntensityNoiseGenerator):
|
|
321
|
-
"""
|
|
322
|
-
Noise generator which adds a constant background to reflectivity curves.
|
|
323
|
-
|
|
324
|
-
Args:
|
|
325
|
-
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
326
|
-
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
327
|
-
"""
|
|
328
|
-
def __init__(self,
|
|
329
|
-
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
330
|
-
add_to_context: bool = False,
|
|
331
|
-
):
|
|
332
|
-
self.background_range = background_range
|
|
333
|
-
self.add_to_context = add_to_context
|
|
334
|
-
|
|
335
|
-
def apply(self, curves: Tensor, context: dict = None) -> Tensor:
|
|
336
|
-
"""applies background noise to the curves"""
|
|
337
|
-
backgrounds = logdist_sampler(
|
|
338
|
-
*self.background_range, curves.shape[0], 1,
|
|
339
|
-
device=curves.device, dtype=curves.dtype
|
|
340
|
-
)
|
|
341
|
-
if self.add_to_context and context is not None:
|
|
342
|
-
context['backgrounds'] = backgrounds
|
|
343
|
-
|
|
344
|
-
curves = curves + backgrounds
|
|
345
|
-
|
|
346
|
-
return curves
|
|
347
|
-
|
|
348
|
-
class GaussianExpIntensityNoise(IntensityNoiseGenerator):
|
|
349
|
-
"""
|
|
350
|
-
A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
|
|
351
|
-
|
|
352
|
-
This class combines three types of noise:
|
|
353
|
-
1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
|
|
354
|
-
2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
355
|
-
3. Background noise: Adds a constant background value to the curves.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
|
|
359
|
-
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
|
|
360
|
-
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
361
|
-
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
362
|
-
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
363
|
-
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
364
|
-
same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
|
|
365
|
-
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
366
|
-
"""
|
|
367
|
-
def __init__(self,
|
|
368
|
-
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
369
|
-
consistent_rel_err: bool = False,
|
|
370
|
-
apply_shift: bool = False,
|
|
371
|
-
shift_range: tuple = (-0.1, 0.2e-2),
|
|
372
|
-
apply_background: bool = False,
|
|
373
|
-
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
374
|
-
same_background_across_channels: bool = False,
|
|
375
|
-
add_to_context: bool = False,
|
|
376
|
-
):
|
|
377
|
-
|
|
378
|
-
self.gaussian_noise = GaussianNoiseGenerator(
|
|
379
|
-
relative_errors=relative_errors,
|
|
380
|
-
consistent_rel_err=consistent_rel_err,
|
|
381
|
-
add_to_context=add_to_context,
|
|
382
|
-
)
|
|
383
|
-
|
|
384
|
-
self.shift_noise = ShiftNoise(
|
|
385
|
-
shift_range=shift_range, add_to_context=add_to_context
|
|
386
|
-
) if apply_shift else None
|
|
387
|
-
|
|
388
|
-
self.background_noise = BackgroundNoise(
|
|
389
|
-
background_range=background_range, add_to_context=add_to_context
|
|
390
|
-
) if apply_background else None
|
|
391
|
-
|
|
392
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
393
|
-
"""applies the specified types of noise to the input curves"""
|
|
394
|
-
if self.shift_noise:
|
|
395
|
-
curves = self.shift_noise(curves, context)
|
|
396
|
-
|
|
397
|
-
if self.background_noise:
|
|
398
|
-
curves = self.background_noise.apply(curves, context)
|
|
399
|
-
|
|
400
|
-
curves = self.gaussian_noise(curves, context)
|
|
401
|
-
|
|
402
|
-
return curves
|
|
403
|
-
|
|
404
|
-
class BasicExpIntensityNoise(IntensityNoiseGenerator):
|
|
405
|
-
"""
|
|
406
|
-
A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
|
|
407
|
-
|
|
408
|
-
This class combines four types of noise:
|
|
409
|
-
|
|
410
|
-
1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
|
|
411
|
-
2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
|
|
412
|
-
3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
413
|
-
4. **Background noise**: Adds a constant background value to the curves.
|
|
414
|
-
|
|
415
|
-
Args:
|
|
416
|
-
relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
|
|
417
|
-
abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
|
|
418
|
-
scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
|
|
419
|
-
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
420
|
-
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
421
|
-
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
422
|
-
apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
|
|
423
|
-
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
424
|
-
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
|
|
425
|
-
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
426
|
-
logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
|
|
427
|
-
"""
|
|
428
|
-
def __init__(self,
|
|
429
|
-
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
430
|
-
abs_errors: float = 1e-8,
|
|
431
|
-
scale_range: tuple = (-2e-2, 2e-2),
|
|
432
|
-
shift_range: tuple = (-0.1, 0.2e-2),
|
|
433
|
-
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
434
|
-
apply_shift: bool = False,
|
|
435
|
-
apply_scaling: bool = False,
|
|
436
|
-
apply_background: bool = False,
|
|
437
|
-
consistent_rel_err: bool = False,
|
|
438
|
-
add_to_context: bool = False,
|
|
439
|
-
logdist: bool = False,
|
|
440
|
-
):
|
|
441
|
-
self.poisson_noise = PoissonNoiseGenerator(
|
|
442
|
-
relative_errors=relative_errors,
|
|
443
|
-
abs_errors=abs_errors,
|
|
444
|
-
consistent_rel_err=consistent_rel_err,
|
|
445
|
-
add_to_context=add_to_context,
|
|
446
|
-
logdist=logdist,
|
|
447
|
-
)
|
|
448
|
-
self.scaling_noise = ScalingNoise(
|
|
449
|
-
scale_range=scale_range, add_to_context=add_to_context
|
|
450
|
-
) if apply_scaling else None
|
|
451
|
-
|
|
452
|
-
self.shift_noise = ShiftNoise(
|
|
453
|
-
shift_range=shift_range, add_to_context=add_to_context
|
|
454
|
-
) if apply_shift else None
|
|
455
|
-
|
|
456
|
-
self.background_noise = BackgroundNoise(
|
|
457
|
-
background_range=background_range, add_to_context=add_to_context
|
|
458
|
-
) if apply_background else None
|
|
459
|
-
|
|
460
|
-
def apply(self, curves: Tensor, context: dict = None):
|
|
461
|
-
"""applies the specified types of noise to the input curves"""
|
|
462
|
-
if self.scaling_noise:
|
|
463
|
-
curves = self.scaling_noise(curves, context)
|
|
464
|
-
if self.shift_noise:
|
|
465
|
-
curves = self.shift_noise(curves, context)
|
|
466
|
-
curves = self.poisson_noise(curves, context)
|
|
467
|
-
|
|
468
|
-
if self.background_noise:
|
|
469
|
-
curves = self.background_noise.apply(curves, context)
|
|
470
|
-
|
|
1
|
+
from typing import List, Union, Tuple
|
|
2
|
+
from math import log10
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.process_data import ProcessData
|
|
8
|
+
from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"QNoiseGenerator",
|
|
12
|
+
"IntensityNoiseGenerator",
|
|
13
|
+
"QNormalNoiseGenerator",
|
|
14
|
+
"QSystematicShiftGenerator",
|
|
15
|
+
"PoissonNoiseGenerator",
|
|
16
|
+
"MultiplicativeLogNormalNoiseGenerator",
|
|
17
|
+
"ScalingNoise",
|
|
18
|
+
"ShiftNoise",
|
|
19
|
+
"BackgroundNoise",
|
|
20
|
+
"BasicExpIntensityNoise",
|
|
21
|
+
"GaussianExpIntensityNoise",
|
|
22
|
+
"BasicQNoiseGenerator",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class QNoiseGenerator(ProcessData):
|
|
27
|
+
"""Base class for q noise generators"""
|
|
28
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
29
|
+
return qs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QNormalNoiseGenerator(QNoiseGenerator):
|
|
33
|
+
"""Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
|
|
37
|
+
or uniformly sampled for each curve in the batch if provided as a tuple)
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self,
|
|
40
|
+
std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
41
|
+
add_to_context: bool = False
|
|
42
|
+
):
|
|
43
|
+
self.std = std
|
|
44
|
+
self.add_to_context = add_to_context
|
|
45
|
+
|
|
46
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
47
|
+
"""applies noise to the q values"""
|
|
48
|
+
std = self.std
|
|
49
|
+
|
|
50
|
+
if isinstance(std, (list, tuple)):
|
|
51
|
+
std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
|
|
52
|
+
else:
|
|
53
|
+
std = torch.empty_like(qs).fill_(std)
|
|
54
|
+
|
|
55
|
+
noise = torch.normal(mean=0., std=std)
|
|
56
|
+
|
|
57
|
+
if self.add_to_context and context is not None:
|
|
58
|
+
context['q_stds'] = std
|
|
59
|
+
|
|
60
|
+
qs = torch.clamp_min_(qs + noise, 0.)
|
|
61
|
+
|
|
62
|
+
return qs
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class QSystematicShiftGenerator(QNoiseGenerator):
|
|
66
|
+
"""Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
std (float): the standard deviation of the normal distribution
|
|
70
|
+
"""
|
|
71
|
+
def __init__(self, std: float, add_to_context: bool = True):
|
|
72
|
+
self.std = std
|
|
73
|
+
self.add_to_context = add_to_context
|
|
74
|
+
|
|
75
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
76
|
+
"""applies systematic shifts to the q values"""
|
|
77
|
+
if len(qs.shape) == 1:
|
|
78
|
+
shape = (1,)
|
|
79
|
+
else:
|
|
80
|
+
shape = (qs.shape[0], 1)
|
|
81
|
+
|
|
82
|
+
shifts = torch.normal(
|
|
83
|
+
mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.add_to_context and context is not None:
|
|
87
|
+
context['q_shifts'] = shifts
|
|
88
|
+
|
|
89
|
+
qs = torch.clamp_min_(qs + shifts, 0.)
|
|
90
|
+
|
|
91
|
+
return qs
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class BasicQNoiseGenerator(QNoiseGenerator):
|
|
95
|
+
"""Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
|
|
99
|
+
(i.e. same change applied to all q points in the curve). Defaults to 1e-3.
|
|
100
|
+
noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
|
|
101
|
+
(i.e. different changes applied to each q point in the curve). The standard deviation is the same
|
|
102
|
+
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
103
|
+
Defaults to (0, 1e-3).
|
|
104
|
+
"""
|
|
105
|
+
def __init__(self,
|
|
106
|
+
apply_systematic_shifts: bool = True,
|
|
107
|
+
shift_std: float = 1e-3,
|
|
108
|
+
apply_gaussian_noise: bool = False,
|
|
109
|
+
noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
110
|
+
add_to_context: bool = False,
|
|
111
|
+
):
|
|
112
|
+
self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
|
|
113
|
+
self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
|
|
114
|
+
|
|
115
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
116
|
+
"""applies noise to the q values"""
|
|
117
|
+
qs = torch.atleast_2d(qs)
|
|
118
|
+
if self.q_shift:
|
|
119
|
+
qs = self.q_shift.apply(qs, context)
|
|
120
|
+
if self.q_noise:
|
|
121
|
+
qs = self.q_noise.apply(qs, context)
|
|
122
|
+
return qs
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class IntensityNoiseGenerator(ProcessData):
|
|
126
|
+
"""Base class for intensity noise generators"""
|
|
127
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
128
|
+
raise NotImplementedError
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
|
|
132
|
+
"""Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
|
|
133
|
+
In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
|
|
138
|
+
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
139
|
+
base (float, optional): the base of the logarithm. Defaults to 10.
|
|
140
|
+
"""
|
|
141
|
+
def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
|
|
142
|
+
self.std = std
|
|
143
|
+
self.base = base
|
|
144
|
+
self.add_to_context = add_to_context
|
|
145
|
+
|
|
146
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
147
|
+
"""applies noise to the curves"""
|
|
148
|
+
std = self.std
|
|
149
|
+
|
|
150
|
+
if isinstance(std, (list, tuple)):
|
|
151
|
+
std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
|
|
152
|
+
else:
|
|
153
|
+
std = torch.ones_like(curves) * std
|
|
154
|
+
|
|
155
|
+
noise = self.base ** torch.normal(mean=0., std=std)
|
|
156
|
+
|
|
157
|
+
if self.add_to_context and context is not None:
|
|
158
|
+
context['std_lognormal'] = std
|
|
159
|
+
|
|
160
|
+
return noise * curves
|
|
161
|
+
|
|
162
|
+
class GaussianNoiseGenerator(IntensityNoiseGenerator):
|
|
163
|
+
"""Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
|
|
167
|
+
consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
|
|
171
|
+
consistent_rel_err: bool = False,
|
|
172
|
+
add_to_context: bool = False):
|
|
173
|
+
self.relative_errors = relative_errors
|
|
174
|
+
self.consistent_rel_err = consistent_rel_err
|
|
175
|
+
self.add_to_context = add_to_context
|
|
176
|
+
|
|
177
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
178
|
+
"""Applies Gaussian noise to the curves."""
|
|
179
|
+
relative_errors = self.relative_errors
|
|
180
|
+
num_channels = curves.shape[1] if curves.dim() == 3 else 1
|
|
181
|
+
|
|
182
|
+
if isinstance(relative_errors, float):
|
|
183
|
+
relative_errors = torch.ones_like(curves) * relative_errors
|
|
184
|
+
|
|
185
|
+
elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
|
|
186
|
+
if self.consistent_rel_err:
|
|
187
|
+
relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
|
|
188
|
+
if num_channels > 1:
|
|
189
|
+
relative_errors = relative_errors.unsqueeze(-1)
|
|
190
|
+
else:
|
|
191
|
+
relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
if self.consistent_rel_err:
|
|
195
|
+
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
196
|
+
else:
|
|
197
|
+
relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
|
|
198
|
+
|
|
199
|
+
sigmas = relative_errors * curves
|
|
200
|
+
noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
|
|
201
|
+
|
|
202
|
+
if self.add_to_context and context is not None:
|
|
203
|
+
context['relative_errors'] = relative_errors
|
|
204
|
+
context['sigmas'] = sigmas
|
|
205
|
+
|
|
206
|
+
return curves + noise
|
|
207
|
+
|
|
208
|
+
class PoissonNoiseGenerator(IntensityNoiseGenerator):
|
|
209
|
+
"""Noise generator which applies Poisson noise to the reflectivity curves
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
|
|
213
|
+
abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
|
|
214
|
+
consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
|
|
215
|
+
logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
|
|
216
|
+
"""
|
|
217
|
+
def __init__(self,
|
|
218
|
+
relative_errors: Tuple[float, float] = (0.05, 0.35),
|
|
219
|
+
abs_errors: float = 1e-8,
|
|
220
|
+
add_to_context: bool = False,
|
|
221
|
+
consistent_rel_err: bool = True,
|
|
222
|
+
logdist: bool = False,
|
|
223
|
+
):
|
|
224
|
+
self.relative_errors = relative_errors
|
|
225
|
+
self.abs_errors = abs_errors
|
|
226
|
+
self.add_to_context = add_to_context
|
|
227
|
+
self.consistent_rel_err = consistent_rel_err
|
|
228
|
+
self.logdist = logdist
|
|
229
|
+
|
|
230
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
231
|
+
"""applies noise to the curves"""
|
|
232
|
+
if self.consistent_rel_err:
|
|
233
|
+
sigmas = self._gen_consistent_sigmas(curves)
|
|
234
|
+
else:
|
|
235
|
+
sigmas = self._gen_sigmas(curves)
|
|
236
|
+
|
|
237
|
+
intensities = curves / sigmas ** 2
|
|
238
|
+
curves = torch.poisson(intensities * curves) / intensities
|
|
239
|
+
|
|
240
|
+
if self.add_to_context and context is not None:
|
|
241
|
+
context['sigmas'] = sigmas
|
|
242
|
+
return curves
|
|
243
|
+
|
|
244
|
+
def _gen_consistent_sigmas(self, curves):
|
|
245
|
+
rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
|
|
246
|
+
self.relative_errors[1] - self.relative_errors[0]
|
|
247
|
+
) + self.relative_errors[0]
|
|
248
|
+
sigmas = curves * rel_err[:, None] + self.abs_errors
|
|
249
|
+
return sigmas
|
|
250
|
+
|
|
251
|
+
def _gen_sigmas(self, curves):
|
|
252
|
+
if not self.logdist:
|
|
253
|
+
rel_err = torch.rand_like(curves) * (
|
|
254
|
+
self.relative_errors[1] - self.relative_errors[0]
|
|
255
|
+
) + self.relative_errors[0]
|
|
256
|
+
else:
|
|
257
|
+
rel_err = torch.rand_like(curves) * (
|
|
258
|
+
log10(self.relative_errors[1]) - log10(self.relative_errors[0])
|
|
259
|
+
) + log10(self.relative_errors[0])
|
|
260
|
+
rel_err = 10 ** rel_err
|
|
261
|
+
|
|
262
|
+
sigmas = curves * rel_err + self.abs_errors
|
|
263
|
+
return sigmas
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class ScalingNoise(IntensityNoiseGenerator):
|
|
267
|
+
"""Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
|
|
268
|
+
The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
|
|
272
|
+
"""
|
|
273
|
+
def __init__(self,
|
|
274
|
+
scale_range: tuple = (-0.2e-2, 0.2e-2),
|
|
275
|
+
add_to_context: bool = False,
|
|
276
|
+
):
|
|
277
|
+
self.scale_range = scale_range
|
|
278
|
+
self.add_to_context = add_to_context
|
|
279
|
+
|
|
280
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
281
|
+
"""applies noise to the curves"""
|
|
282
|
+
scales = uniform_sampler(
|
|
283
|
+
*self.scale_range, curves.shape[0], 1,
|
|
284
|
+
device=curves.device, dtype=curves.dtype
|
|
285
|
+
)
|
|
286
|
+
if self.add_to_context and context is not None:
|
|
287
|
+
context['intensity_scales'] = scales
|
|
288
|
+
|
|
289
|
+
curves = curves ** (1 + scales)
|
|
290
|
+
|
|
291
|
+
return curves
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class ShiftNoise(IntensityNoiseGenerator):
|
|
295
|
+
def __init__(self,
|
|
296
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
297
|
+
add_to_context: bool = False,
|
|
298
|
+
):
|
|
299
|
+
"""Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
|
|
300
|
+
The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
|
|
301
|
+
Args:
|
|
302
|
+
shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
|
|
303
|
+
"""
|
|
304
|
+
self.shift_range = shift_range
|
|
305
|
+
self.add_to_context = add_to_context
|
|
306
|
+
|
|
307
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
308
|
+
"""applies noise to the curves"""
|
|
309
|
+
intensity_shifts = uniform_sampler(
|
|
310
|
+
*self.shift_range, curves.shape[0], 1,
|
|
311
|
+
device=curves.device, dtype=curves.dtype
|
|
312
|
+
)
|
|
313
|
+
if self.add_to_context and context is not None:
|
|
314
|
+
context['intensity_shifts'] = intensity_shifts
|
|
315
|
+
|
|
316
|
+
curves = curves * (1 + intensity_shifts)
|
|
317
|
+
|
|
318
|
+
return curves
|
|
319
|
+
|
|
320
|
+
class BackgroundNoise(IntensityNoiseGenerator):
|
|
321
|
+
"""
|
|
322
|
+
Noise generator which adds a constant background to reflectivity curves.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
326
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
327
|
+
"""
|
|
328
|
+
def __init__(self,
|
|
329
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
330
|
+
add_to_context: bool = False,
|
|
331
|
+
):
|
|
332
|
+
self.background_range = background_range
|
|
333
|
+
self.add_to_context = add_to_context
|
|
334
|
+
|
|
335
|
+
def apply(self, curves: Tensor, context: dict = None) -> Tensor:
|
|
336
|
+
"""applies background noise to the curves"""
|
|
337
|
+
backgrounds = logdist_sampler(
|
|
338
|
+
*self.background_range, curves.shape[0], 1,
|
|
339
|
+
device=curves.device, dtype=curves.dtype
|
|
340
|
+
)
|
|
341
|
+
if self.add_to_context and context is not None:
|
|
342
|
+
context['backgrounds'] = backgrounds
|
|
343
|
+
|
|
344
|
+
curves = curves + backgrounds
|
|
345
|
+
|
|
346
|
+
return curves
|
|
347
|
+
|
|
348
|
+
class GaussianExpIntensityNoise(IntensityNoiseGenerator):
|
|
349
|
+
"""
|
|
350
|
+
A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
|
|
351
|
+
|
|
352
|
+
This class combines three types of noise:
|
|
353
|
+
1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
|
|
354
|
+
2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
355
|
+
3. Background noise: Adds a constant background value to the curves.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
|
|
359
|
+
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
|
|
360
|
+
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
361
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
362
|
+
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
363
|
+
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
364
|
+
same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
|
|
365
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
366
|
+
"""
|
|
367
|
+
def __init__(self,
|
|
368
|
+
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
369
|
+
consistent_rel_err: bool = False,
|
|
370
|
+
apply_shift: bool = False,
|
|
371
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
372
|
+
apply_background: bool = False,
|
|
373
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
374
|
+
same_background_across_channels: bool = False,
|
|
375
|
+
add_to_context: bool = False,
|
|
376
|
+
):
|
|
377
|
+
|
|
378
|
+
self.gaussian_noise = GaussianNoiseGenerator(
|
|
379
|
+
relative_errors=relative_errors,
|
|
380
|
+
consistent_rel_err=consistent_rel_err,
|
|
381
|
+
add_to_context=add_to_context,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
self.shift_noise = ShiftNoise(
|
|
385
|
+
shift_range=shift_range, add_to_context=add_to_context
|
|
386
|
+
) if apply_shift else None
|
|
387
|
+
|
|
388
|
+
self.background_noise = BackgroundNoise(
|
|
389
|
+
background_range=background_range, add_to_context=add_to_context
|
|
390
|
+
) if apply_background else None
|
|
391
|
+
|
|
392
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
393
|
+
"""applies the specified types of noise to the input curves"""
|
|
394
|
+
if self.shift_noise:
|
|
395
|
+
curves = self.shift_noise(curves, context)
|
|
396
|
+
|
|
397
|
+
if self.background_noise:
|
|
398
|
+
curves = self.background_noise.apply(curves, context)
|
|
399
|
+
|
|
400
|
+
curves = self.gaussian_noise(curves, context)
|
|
401
|
+
|
|
402
|
+
return curves
|
|
403
|
+
|
|
404
|
+
class BasicExpIntensityNoise(IntensityNoiseGenerator):
|
|
405
|
+
"""
|
|
406
|
+
A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
|
|
407
|
+
|
|
408
|
+
This class combines four types of noise:
|
|
409
|
+
|
|
410
|
+
1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
|
|
411
|
+
2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
|
|
412
|
+
3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
413
|
+
4. **Background noise**: Adds a constant background value to the curves.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
|
|
417
|
+
abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
|
|
418
|
+
scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
|
|
419
|
+
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
420
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
421
|
+
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
422
|
+
apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
|
|
423
|
+
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
424
|
+
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
|
|
425
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
426
|
+
logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
|
|
427
|
+
"""
|
|
428
|
+
def __init__(self,
|
|
429
|
+
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
430
|
+
abs_errors: float = 1e-8,
|
|
431
|
+
scale_range: tuple = (-2e-2, 2e-2),
|
|
432
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
433
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
434
|
+
apply_shift: bool = False,
|
|
435
|
+
apply_scaling: bool = False,
|
|
436
|
+
apply_background: bool = False,
|
|
437
|
+
consistent_rel_err: bool = False,
|
|
438
|
+
add_to_context: bool = False,
|
|
439
|
+
logdist: bool = False,
|
|
440
|
+
):
|
|
441
|
+
self.poisson_noise = PoissonNoiseGenerator(
|
|
442
|
+
relative_errors=relative_errors,
|
|
443
|
+
abs_errors=abs_errors,
|
|
444
|
+
consistent_rel_err=consistent_rel_err,
|
|
445
|
+
add_to_context=add_to_context,
|
|
446
|
+
logdist=logdist,
|
|
447
|
+
)
|
|
448
|
+
self.scaling_noise = ScalingNoise(
|
|
449
|
+
scale_range=scale_range, add_to_context=add_to_context
|
|
450
|
+
) if apply_scaling else None
|
|
451
|
+
|
|
452
|
+
self.shift_noise = ShiftNoise(
|
|
453
|
+
shift_range=shift_range, add_to_context=add_to_context
|
|
454
|
+
) if apply_shift else None
|
|
455
|
+
|
|
456
|
+
self.background_noise = BackgroundNoise(
|
|
457
|
+
background_range=background_range, add_to_context=add_to_context
|
|
458
|
+
) if apply_background else None
|
|
459
|
+
|
|
460
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
461
|
+
"""applies the specified types of noise to the input curves"""
|
|
462
|
+
if self.scaling_noise:
|
|
463
|
+
curves = self.scaling_noise(curves, context)
|
|
464
|
+
if self.shift_noise:
|
|
465
|
+
curves = self.shift_noise(curves, context)
|
|
466
|
+
curves = self.poisson_noise(curves, context)
|
|
467
|
+
|
|
468
|
+
if self.background_noise:
|
|
469
|
+
curves = self.background_noise.apply(curves, context)
|
|
470
|
+
|
|
471
471
|
return curves
|