reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,471 +1,471 @@
1
- from typing import List, Union, Tuple
2
- from math import log10
3
-
4
- import torch
5
- from torch import Tensor
6
-
7
- from reflectorch.data_generation.process_data import ProcessData
8
- from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
9
-
10
- __all__ = [
11
- "QNoiseGenerator",
12
- "IntensityNoiseGenerator",
13
- "QNormalNoiseGenerator",
14
- "QSystematicShiftGenerator",
15
- "PoissonNoiseGenerator",
16
- "MultiplicativeLogNormalNoiseGenerator",
17
- "ScalingNoise",
18
- "ShiftNoise",
19
- "BackgroundNoise",
20
- "BasicExpIntensityNoise",
21
- "GaussianExpIntensityNoise",
22
- "BasicQNoiseGenerator",
23
- ]
24
-
25
-
26
- class QNoiseGenerator(ProcessData):
27
- """Base class for q noise generators"""
28
- def apply(self, qs: Tensor, context: dict = None):
29
- return qs
30
-
31
-
32
- class QNormalNoiseGenerator(QNoiseGenerator):
33
- """Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
34
-
35
- Args:
36
- std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
37
- or uniformly sampled for each curve in the batch if provided as a tuple)
38
- """
39
- def __init__(self,
40
- std: Union[float, Tuple[float, float]] = (0, 1e-3),
41
- add_to_context: bool = False
42
- ):
43
- self.std = std
44
- self.add_to_context = add_to_context
45
-
46
- def apply(self, qs: Tensor, context: dict = None):
47
- """applies noise to the q values"""
48
- std = self.std
49
-
50
- if isinstance(std, (list, tuple)):
51
- std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
52
- else:
53
- std = torch.empty_like(qs).fill_(std)
54
-
55
- noise = torch.normal(mean=0., std=std)
56
-
57
- if self.add_to_context and context is not None:
58
- context['q_stds'] = std
59
-
60
- qs = torch.clamp_min_(qs + noise, 0.)
61
-
62
- return qs
63
-
64
-
65
- class QSystematicShiftGenerator(QNoiseGenerator):
66
- """Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
67
-
68
- Args:
69
- std (float): the standard deviation of the normal distribution
70
- """
71
- def __init__(self, std: float, add_to_context: bool = True):
72
- self.std = std
73
- self.add_to_context = add_to_context
74
-
75
- def apply(self, qs: Tensor, context: dict = None):
76
- """applies systematic shifts to the q values"""
77
- if len(qs.shape) == 1:
78
- shape = (1,)
79
- else:
80
- shape = (qs.shape[0], 1)
81
-
82
- shifts = torch.normal(
83
- mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
84
- )
85
-
86
- if self.add_to_context and context is not None:
87
- context['q_shifts'] = shifts
88
-
89
- qs = torch.clamp_min_(qs + shifts, 0.)
90
-
91
- return qs
92
-
93
-
94
- class BasicQNoiseGenerator(QNoiseGenerator):
95
- """Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
96
-
97
- Args:
98
- shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
99
- (i.e. same change applied to all q points in the curve). Defaults to 1e-3.
100
- noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
101
- (i.e. different changes applied to each q point in the curve). The standard deviation is the same
102
- for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
103
- Defaults to (0, 1e-3).
104
- """
105
- def __init__(self,
106
- apply_systematic_shifts: bool = True,
107
- shift_std: float = 1e-3,
108
- apply_gaussian_noise: bool = False,
109
- noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
110
- add_to_context: bool = False,
111
- ):
112
- self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
113
- self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
114
-
115
- def apply(self, qs: Tensor, context: dict = None):
116
- """applies noise to the q values"""
117
- qs = torch.atleast_2d(qs)
118
- if self.q_shift:
119
- qs = self.q_shift.apply(qs, context)
120
- if self.q_noise:
121
- qs = self.q_noise.apply(qs, context)
122
- return qs
123
-
124
-
125
- class IntensityNoiseGenerator(ProcessData):
126
- """Base class for intensity noise generators"""
127
- def apply(self, curves: Tensor, context: dict = None):
128
- raise NotImplementedError
129
-
130
-
131
- class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
132
- """Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
133
- In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
134
-
135
-
136
- Args:
137
- std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
138
- for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
139
- base (float, optional): the base of the logarithm. Defaults to 10.
140
- """
141
- def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
142
- self.std = std
143
- self.base = base
144
- self.add_to_context = add_to_context
145
-
146
- def apply(self, curves: Tensor, context: dict = None):
147
- """applies noise to the curves"""
148
- std = self.std
149
-
150
- if isinstance(std, (list, tuple)):
151
- std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
152
- else:
153
- std = torch.ones_like(curves) * std
154
-
155
- noise = self.base ** torch.normal(mean=0., std=std)
156
-
157
- if self.add_to_context and context is not None:
158
- context['std_lognormal'] = std
159
-
160
- return noise * curves
161
-
162
- class GaussianNoiseGenerator(IntensityNoiseGenerator):
163
- """Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
164
-
165
- Args:
166
- relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
167
- consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
168
- """
169
-
170
- def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
171
- consistent_rel_err: bool = False,
172
- add_to_context: bool = False):
173
- self.relative_errors = relative_errors
174
- self.consistent_rel_err = consistent_rel_err
175
- self.add_to_context = add_to_context
176
-
177
- def apply(self, curves: Tensor, context: dict = None):
178
- """Applies Gaussian noise to the curves."""
179
- relative_errors = self.relative_errors
180
- num_channels = curves.shape[1] if curves.dim() == 3 else 1
181
-
182
- if isinstance(relative_errors, float):
183
- relative_errors = torch.ones_like(curves) * relative_errors
184
-
185
- elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
186
- if self.consistent_rel_err:
187
- relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
188
- if num_channels > 1:
189
- relative_errors = relative_errors.unsqueeze(-1)
190
- else:
191
- relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
192
-
193
- else:
194
- if self.consistent_rel_err:
195
- relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
196
- else:
197
- relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
198
-
199
- sigmas = relative_errors * curves
200
- noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
201
-
202
- if self.add_to_context and context is not None:
203
- context['relative_errors'] = relative_errors
204
- context['sigmas'] = sigmas
205
-
206
- return curves + noise
207
-
208
- class PoissonNoiseGenerator(IntensityNoiseGenerator):
209
- """Noise generator which applies Poisson noise to the reflectivity curves
210
-
211
- Args:
212
- relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
213
- abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
214
- consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
215
- logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
216
- """
217
- def __init__(self,
218
- relative_errors: Tuple[float, float] = (0.05, 0.35),
219
- abs_errors: float = 1e-8,
220
- add_to_context: bool = False,
221
- consistent_rel_err: bool = True,
222
- logdist: bool = False,
223
- ):
224
- self.relative_errors = relative_errors
225
- self.abs_errors = abs_errors
226
- self.add_to_context = add_to_context
227
- self.consistent_rel_err = consistent_rel_err
228
- self.logdist = logdist
229
-
230
- def apply(self, curves: Tensor, context: dict = None):
231
- """applies noise to the curves"""
232
- if self.consistent_rel_err:
233
- sigmas = self._gen_consistent_sigmas(curves)
234
- else:
235
- sigmas = self._gen_sigmas(curves)
236
-
237
- intensities = curves / sigmas ** 2
238
- curves = torch.poisson(intensities * curves) / intensities
239
-
240
- if self.add_to_context and context is not None:
241
- context['sigmas'] = sigmas
242
- return curves
243
-
244
- def _gen_consistent_sigmas(self, curves):
245
- rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
246
- self.relative_errors[1] - self.relative_errors[0]
247
- ) + self.relative_errors[0]
248
- sigmas = curves * rel_err[:, None] + self.abs_errors
249
- return sigmas
250
-
251
- def _gen_sigmas(self, curves):
252
- if not self.logdist:
253
- rel_err = torch.rand_like(curves) * (
254
- self.relative_errors[1] - self.relative_errors[0]
255
- ) + self.relative_errors[0]
256
- else:
257
- rel_err = torch.rand_like(curves) * (
258
- log10(self.relative_errors[1]) - log10(self.relative_errors[0])
259
- ) + log10(self.relative_errors[0])
260
- rel_err = 10 ** rel_err
261
-
262
- sigmas = curves * rel_err + self.abs_errors
263
- return sigmas
264
-
265
-
266
- class ScalingNoise(IntensityNoiseGenerator):
267
- """Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
268
- The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
269
-
270
- Args:
271
- scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
272
- """
273
- def __init__(self,
274
- scale_range: tuple = (-0.2e-2, 0.2e-2),
275
- add_to_context: bool = False,
276
- ):
277
- self.scale_range = scale_range
278
- self.add_to_context = add_to_context
279
-
280
- def apply(self, curves: Tensor, context: dict = None):
281
- """applies noise to the curves"""
282
- scales = uniform_sampler(
283
- *self.scale_range, curves.shape[0], 1,
284
- device=curves.device, dtype=curves.dtype
285
- )
286
- if self.add_to_context and context is not None:
287
- context['intensity_scales'] = scales
288
-
289
- curves = curves ** (1 + scales)
290
-
291
- return curves
292
-
293
-
294
- class ShiftNoise(IntensityNoiseGenerator):
295
- def __init__(self,
296
- shift_range: tuple = (-0.1, 0.2e-2),
297
- add_to_context: bool = False,
298
- ):
299
- """Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
300
- The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
301
- Args:
302
- shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
303
- """
304
- self.shift_range = shift_range
305
- self.add_to_context = add_to_context
306
-
307
- def apply(self, curves: Tensor, context: dict = None):
308
- """applies noise to the curves"""
309
- intensity_shifts = uniform_sampler(
310
- *self.shift_range, curves.shape[0], 1,
311
- device=curves.device, dtype=curves.dtype
312
- )
313
- if self.add_to_context and context is not None:
314
- context['intensity_shifts'] = intensity_shifts
315
-
316
- curves = curves * (1 + intensity_shifts)
317
-
318
- return curves
319
-
320
- class BackgroundNoise(IntensityNoiseGenerator):
321
- """
322
- Noise generator which adds a constant background to reflectivity curves.
323
-
324
- Args:
325
- background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
326
- add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
327
- """
328
- def __init__(self,
329
- background_range: tuple = (1.0e-10, 1.0e-8),
330
- add_to_context: bool = False,
331
- ):
332
- self.background_range = background_range
333
- self.add_to_context = add_to_context
334
-
335
- def apply(self, curves: Tensor, context: dict = None) -> Tensor:
336
- """applies background noise to the curves"""
337
- backgrounds = logdist_sampler(
338
- *self.background_range, curves.shape[0], 1,
339
- device=curves.device, dtype=curves.dtype
340
- )
341
- if self.add_to_context and context is not None:
342
- context['backgrounds'] = backgrounds
343
-
344
- curves = curves + backgrounds
345
-
346
- return curves
347
-
348
- class GaussianExpIntensityNoise(IntensityNoiseGenerator):
349
- """
350
- A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
351
-
352
- This class combines three types of noise:
353
- 1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
354
- 2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
355
- 3. Background noise: Adds a constant background value to the curves.
356
-
357
- Args:
358
- relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
359
- consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
360
- shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
361
- background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
362
- apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
363
- apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
364
- same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
365
- add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
366
- """
367
- def __init__(self,
368
- relative_errors: Tuple[float, float] = (0.001, 0.15),
369
- consistent_rel_err: bool = False,
370
- apply_shift: bool = False,
371
- shift_range: tuple = (-0.1, 0.2e-2),
372
- apply_background: bool = False,
373
- background_range: tuple = (1.0e-10, 1.0e-8),
374
- same_background_across_channels: bool = False,
375
- add_to_context: bool = False,
376
- ):
377
-
378
- self.gaussian_noise = GaussianNoiseGenerator(
379
- relative_errors=relative_errors,
380
- consistent_rel_err=consistent_rel_err,
381
- add_to_context=add_to_context,
382
- )
383
-
384
- self.shift_noise = ShiftNoise(
385
- shift_range=shift_range, add_to_context=add_to_context
386
- ) if apply_shift else None
387
-
388
- self.background_noise = BackgroundNoise(
389
- background_range=background_range, add_to_context=add_to_context
390
- ) if apply_background else None
391
-
392
- def apply(self, curves: Tensor, context: dict = None):
393
- """applies the specified types of noise to the input curves"""
394
- if self.shift_noise:
395
- curves = self.shift_noise(curves, context)
396
-
397
- if self.background_noise:
398
- curves = self.background_noise.apply(curves, context)
399
-
400
- curves = self.gaussian_noise(curves, context)
401
-
402
- return curves
403
-
404
- class BasicExpIntensityNoise(IntensityNoiseGenerator):
405
- """
406
- A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
407
-
408
- This class combines four types of noise:
409
-
410
- 1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
411
- 2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
412
- 3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
413
- 4. **Background noise**: Adds a constant background value to the curves.
414
-
415
- Args:
416
- relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
417
- abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
418
- scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
419
- shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
420
- background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
421
- apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
422
- apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
423
- apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
424
- consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
425
- add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
426
- logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
427
- """
428
- def __init__(self,
429
- relative_errors: Tuple[float, float] = (0.001, 0.15),
430
- abs_errors: float = 1e-8,
431
- scale_range: tuple = (-2e-2, 2e-2),
432
- shift_range: tuple = (-0.1, 0.2e-2),
433
- background_range: tuple = (1.0e-10, 1.0e-8),
434
- apply_shift: bool = False,
435
- apply_scaling: bool = False,
436
- apply_background: bool = False,
437
- consistent_rel_err: bool = False,
438
- add_to_context: bool = False,
439
- logdist: bool = False,
440
- ):
441
- self.poisson_noise = PoissonNoiseGenerator(
442
- relative_errors=relative_errors,
443
- abs_errors=abs_errors,
444
- consistent_rel_err=consistent_rel_err,
445
- add_to_context=add_to_context,
446
- logdist=logdist,
447
- )
448
- self.scaling_noise = ScalingNoise(
449
- scale_range=scale_range, add_to_context=add_to_context
450
- ) if apply_scaling else None
451
-
452
- self.shift_noise = ShiftNoise(
453
- shift_range=shift_range, add_to_context=add_to_context
454
- ) if apply_shift else None
455
-
456
- self.background_noise = BackgroundNoise(
457
- background_range=background_range, add_to_context=add_to_context
458
- ) if apply_background else None
459
-
460
- def apply(self, curves: Tensor, context: dict = None):
461
- """applies the specified types of noise to the input curves"""
462
- if self.scaling_noise:
463
- curves = self.scaling_noise(curves, context)
464
- if self.shift_noise:
465
- curves = self.shift_noise(curves, context)
466
- curves = self.poisson_noise(curves, context)
467
-
468
- if self.background_noise:
469
- curves = self.background_noise.apply(curves, context)
470
-
1
+ from typing import List, Union, Tuple
2
+ from math import log10
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.process_data import ProcessData
8
+ from reflectorch.data_generation.utils import logdist_sampler, uniform_sampler
9
+
10
+ __all__ = [
11
+ "QNoiseGenerator",
12
+ "IntensityNoiseGenerator",
13
+ "QNormalNoiseGenerator",
14
+ "QSystematicShiftGenerator",
15
+ "PoissonNoiseGenerator",
16
+ "MultiplicativeLogNormalNoiseGenerator",
17
+ "ScalingNoise",
18
+ "ShiftNoise",
19
+ "BackgroundNoise",
20
+ "BasicExpIntensityNoise",
21
+ "GaussianExpIntensityNoise",
22
+ "BasicQNoiseGenerator",
23
+ ]
24
+
25
+
26
+ class QNoiseGenerator(ProcessData):
27
+ """Base class for q noise generators"""
28
+ def apply(self, qs: Tensor, context: dict = None):
29
+ return qs
30
+
31
+
32
+ class QNormalNoiseGenerator(QNoiseGenerator):
33
+ """Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
34
+
35
+ Args:
36
+ std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
37
+ or uniformly sampled for each curve in the batch if provided as a tuple)
38
+ """
39
+ def __init__(self,
40
+ std: Union[float, Tuple[float, float]] = (0, 1e-3),
41
+ add_to_context: bool = False
42
+ ):
43
+ self.std = std
44
+ self.add_to_context = add_to_context
45
+
46
+ def apply(self, qs: Tensor, context: dict = None):
47
+ """applies noise to the q values"""
48
+ std = self.std
49
+
50
+ if isinstance(std, (list, tuple)):
51
+ std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
52
+ else:
53
+ std = torch.empty_like(qs).fill_(std)
54
+
55
+ noise = torch.normal(mean=0., std=std)
56
+
57
+ if self.add_to_context and context is not None:
58
+ context['q_stds'] = std
59
+
60
+ qs = torch.clamp_min_(qs + noise, 0.)
61
+
62
+ return qs
63
+
64
+
65
+ class QSystematicShiftGenerator(QNoiseGenerator):
66
+ """Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
67
+
68
+ Args:
69
+ std (float): the standard deviation of the normal distribution
70
+ """
71
+ def __init__(self, std: float, add_to_context: bool = True):
72
+ self.std = std
73
+ self.add_to_context = add_to_context
74
+
75
+ def apply(self, qs: Tensor, context: dict = None):
76
+ """applies systematic shifts to the q values"""
77
+ if len(qs.shape) == 1:
78
+ shape = (1,)
79
+ else:
80
+ shape = (qs.shape[0], 1)
81
+
82
+ shifts = torch.normal(
83
+ mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
84
+ )
85
+
86
+ if self.add_to_context and context is not None:
87
+ context['q_shifts'] = shifts
88
+
89
+ qs = torch.clamp_min_(qs + shifts, 0.)
90
+
91
+ return qs
92
+
93
+
94
+ class BasicQNoiseGenerator(QNoiseGenerator):
95
+ """Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
96
+
97
+ Args:
98
+ shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
99
+ (i.e. same change applied to all q points in the curve). Defaults to 1e-3.
100
+ noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
101
+ (i.e. different changes applied to each q point in the curve). The standard deviation is the same
102
+ for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
103
+ Defaults to (0, 1e-3).
104
+ """
105
+ def __init__(self,
106
+ apply_systematic_shifts: bool = True,
107
+ shift_std: float = 1e-3,
108
+ apply_gaussian_noise: bool = False,
109
+ noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
110
+ add_to_context: bool = False,
111
+ ):
112
+ self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context) if apply_systematic_shifts else None
113
+ self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context) if apply_gaussian_noise else None
114
+
115
+ def apply(self, qs: Tensor, context: dict = None):
116
+ """applies noise to the q values"""
117
+ qs = torch.atleast_2d(qs)
118
+ if self.q_shift:
119
+ qs = self.q_shift.apply(qs, context)
120
+ if self.q_noise:
121
+ qs = self.q_noise.apply(qs, context)
122
+ return qs
123
+
124
+
125
+ class IntensityNoiseGenerator(ProcessData):
126
+ """Base class for intensity noise generators"""
127
+ def apply(self, curves: Tensor, context: dict = None):
128
+ raise NotImplementedError
129
+
130
+
131
+ class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
132
+ """Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
133
+ In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
134
+
135
+
136
+ Args:
137
+ std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
138
+ for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
139
+ base (float, optional): the base of the logarithm. Defaults to 10.
140
+ """
141
+ def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
142
+ self.std = std
143
+ self.base = base
144
+ self.add_to_context = add_to_context
145
+
146
+ def apply(self, curves: Tensor, context: dict = None):
147
+ """applies noise to the curves"""
148
+ std = self.std
149
+
150
+ if isinstance(std, (list, tuple)):
151
+ std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
152
+ else:
153
+ std = torch.ones_like(curves) * std
154
+
155
+ noise = self.base ** torch.normal(mean=0., std=std)
156
+
157
+ if self.add_to_context and context is not None:
158
+ context['std_lognormal'] = std
159
+
160
+ return noise * curves
161
+
162
+ class GaussianNoiseGenerator(IntensityNoiseGenerator):
163
+ """Noise generator which applies noise as R_n = R + eps, with eps~N(0, sigmas) and sigmas = relative_errors * R
164
+
165
+ Args:
166
+ relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]])
167
+ consistent_relative_errors (bool): If True the relative_error is the same for all point of a curve, otherwise it is sampled uniformly.
168
+ """
169
+
170
+ def __init__(self, relative_errors: Union[float, Tuple[float, float], List[float], List[Tuple[float, float]]],
171
+ consistent_rel_err: bool = False,
172
+ add_to_context: bool = False):
173
+ self.relative_errors = relative_errors
174
+ self.consistent_rel_err = consistent_rel_err
175
+ self.add_to_context = add_to_context
176
+
177
+ def apply(self, curves: Tensor, context: dict = None):
178
+ """Applies Gaussian noise to the curves."""
179
+ relative_errors = self.relative_errors
180
+ num_channels = curves.shape[1] if curves.dim() == 3 else 1
181
+
182
+ if isinstance(relative_errors, float):
183
+ relative_errors = torch.ones_like(curves) * relative_errors
184
+
185
+ elif isinstance(relative_errors, (list, tuple)) and isinstance(relative_errors[0], float):
186
+ if self.consistent_rel_err:
187
+ relative_errors = uniform_sampler(*relative_errors, curves.shape[0], num_channels, device=curves.device, dtype=curves.dtype)
188
+ if num_channels > 1:
189
+ relative_errors = relative_errors.unsqueeze(-1)
190
+ else:
191
+ relative_errors = uniform_sampler(*relative_errors, *curves.shape, device=curves.device, dtype=curves.dtype)
192
+
193
+ else:
194
+ if self.consistent_rel_err:
195
+ relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], 1, device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
196
+ else:
197
+ relative_errors = torch.stack([uniform_sampler(*item, curves.shape[0], curves.shape[-1], device=curves.device, dtype=curves.dtype) for item in relative_errors], dim=1)
198
+
199
+ sigmas = relative_errors * curves
200
+ noise = torch.normal(mean=0., std=sigmas).clamp_min_(0.0)
201
+
202
+ if self.add_to_context and context is not None:
203
+ context['relative_errors'] = relative_errors
204
+ context['sigmas'] = sigmas
205
+
206
+ return curves + noise
207
+
208
+ class PoissonNoiseGenerator(IntensityNoiseGenerator):
209
+ """Noise generator which applies Poisson noise to the reflectivity curves
210
+
211
+ Args:
212
+ relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
213
+ abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
214
+ consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
215
+ logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
216
+ """
217
+ def __init__(self,
218
+ relative_errors: Tuple[float, float] = (0.05, 0.35),
219
+ abs_errors: float = 1e-8,
220
+ add_to_context: bool = False,
221
+ consistent_rel_err: bool = True,
222
+ logdist: bool = False,
223
+ ):
224
+ self.relative_errors = relative_errors
225
+ self.abs_errors = abs_errors
226
+ self.add_to_context = add_to_context
227
+ self.consistent_rel_err = consistent_rel_err
228
+ self.logdist = logdist
229
+
230
+ def apply(self, curves: Tensor, context: dict = None):
231
+ """applies noise to the curves"""
232
+ if self.consistent_rel_err:
233
+ sigmas = self._gen_consistent_sigmas(curves)
234
+ else:
235
+ sigmas = self._gen_sigmas(curves)
236
+
237
+ intensities = curves / sigmas ** 2
238
+ curves = torch.poisson(intensities * curves) / intensities
239
+
240
+ if self.add_to_context and context is not None:
241
+ context['sigmas'] = sigmas
242
+ return curves
243
+
244
+ def _gen_consistent_sigmas(self, curves):
245
+ rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
246
+ self.relative_errors[1] - self.relative_errors[0]
247
+ ) + self.relative_errors[0]
248
+ sigmas = curves * rel_err[:, None] + self.abs_errors
249
+ return sigmas
250
+
251
+ def _gen_sigmas(self, curves):
252
+ if not self.logdist:
253
+ rel_err = torch.rand_like(curves) * (
254
+ self.relative_errors[1] - self.relative_errors[0]
255
+ ) + self.relative_errors[0]
256
+ else:
257
+ rel_err = torch.rand_like(curves) * (
258
+ log10(self.relative_errors[1]) - log10(self.relative_errors[0])
259
+ ) + log10(self.relative_errors[0])
260
+ rel_err = 10 ** rel_err
261
+
262
+ sigmas = curves * rel_err + self.abs_errors
263
+ return sigmas
264
+
265
+
266
+ class ScalingNoise(IntensityNoiseGenerator):
267
+ """Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
268
+ The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
269
+
270
+ Args:
271
+ scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
272
+ """
273
+ def __init__(self,
274
+ scale_range: tuple = (-0.2e-2, 0.2e-2),
275
+ add_to_context: bool = False,
276
+ ):
277
+ self.scale_range = scale_range
278
+ self.add_to_context = add_to_context
279
+
280
+ def apply(self, curves: Tensor, context: dict = None):
281
+ """applies noise to the curves"""
282
+ scales = uniform_sampler(
283
+ *self.scale_range, curves.shape[0], 1,
284
+ device=curves.device, dtype=curves.dtype
285
+ )
286
+ if self.add_to_context and context is not None:
287
+ context['intensity_scales'] = scales
288
+
289
+ curves = curves ** (1 + scales)
290
+
291
+ return curves
292
+
293
+
294
+ class ShiftNoise(IntensityNoiseGenerator):
295
+ def __init__(self,
296
+ shift_range: tuple = (-0.1, 0.2e-2),
297
+ add_to_context: bool = False,
298
+ ):
299
+ """Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
300
+ The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
301
+ Args:
302
+ shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
303
+ """
304
+ self.shift_range = shift_range
305
+ self.add_to_context = add_to_context
306
+
307
+ def apply(self, curves: Tensor, context: dict = None):
308
+ """applies noise to the curves"""
309
+ intensity_shifts = uniform_sampler(
310
+ *self.shift_range, curves.shape[0], 1,
311
+ device=curves.device, dtype=curves.dtype
312
+ )
313
+ if self.add_to_context and context is not None:
314
+ context['intensity_shifts'] = intensity_shifts
315
+
316
+ curves = curves * (1 + intensity_shifts)
317
+
318
+ return curves
319
+
320
+ class BackgroundNoise(IntensityNoiseGenerator):
321
+ """
322
+ Noise generator which adds a constant background to reflectivity curves.
323
+
324
+ Args:
325
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
326
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
327
+ """
328
+ def __init__(self,
329
+ background_range: tuple = (1.0e-10, 1.0e-8),
330
+ add_to_context: bool = False,
331
+ ):
332
+ self.background_range = background_range
333
+ self.add_to_context = add_to_context
334
+
335
+ def apply(self, curves: Tensor, context: dict = None) -> Tensor:
336
+ """applies background noise to the curves"""
337
+ backgrounds = logdist_sampler(
338
+ *self.background_range, curves.shape[0], 1,
339
+ device=curves.device, dtype=curves.dtype
340
+ )
341
+ if self.add_to_context and context is not None:
342
+ context['backgrounds'] = backgrounds
343
+
344
+ curves = curves + backgrounds
345
+
346
+ return curves
347
+
348
+ class GaussianExpIntensityNoise(IntensityNoiseGenerator):
349
+ """
350
+ A composite noise generator that applies Gaussian, shift and background noise to reflectivity curves.
351
+
352
+ This class combines three types of noise:
353
+ 1. Gaussian noise: Applies Gaussian noise (to account for count-based Poisson noise as well as other sources of error)
354
+ 2. Shift noise: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
355
+ 3. Background noise: Adds a constant background value to the curves.
356
+
357
+ Args:
358
+ relative_errors (Union[float, Tuple[float, float], List[Tuple[float, float]]]): The range of relative errors for Gaussian noise. Defaults to (0.001, 0.15).
359
+ consistent_rel_err (bool, optional): If True, uses a consistent relative error for Gaussian noise across all points in a curve. Defaults to False.
360
+ shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
361
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
362
+ apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
363
+ apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
364
+ same_background_across_channels(bool, optional): If True, the same background is applied to all channels of a multi-channel curve. Defaults to False.
365
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
366
+ """
367
+ def __init__(self,
368
+ relative_errors: Tuple[float, float] = (0.001, 0.15),
369
+ consistent_rel_err: bool = False,
370
+ apply_shift: bool = False,
371
+ shift_range: tuple = (-0.1, 0.2e-2),
372
+ apply_background: bool = False,
373
+ background_range: tuple = (1.0e-10, 1.0e-8),
374
+ same_background_across_channels: bool = False,
375
+ add_to_context: bool = False,
376
+ ):
377
+
378
+ self.gaussian_noise = GaussianNoiseGenerator(
379
+ relative_errors=relative_errors,
380
+ consistent_rel_err=consistent_rel_err,
381
+ add_to_context=add_to_context,
382
+ )
383
+
384
+ self.shift_noise = ShiftNoise(
385
+ shift_range=shift_range, add_to_context=add_to_context
386
+ ) if apply_shift else None
387
+
388
+ self.background_noise = BackgroundNoise(
389
+ background_range=background_range, add_to_context=add_to_context
390
+ ) if apply_background else None
391
+
392
+ def apply(self, curves: Tensor, context: dict = None):
393
+ """applies the specified types of noise to the input curves"""
394
+ if self.shift_noise:
395
+ curves = self.shift_noise(curves, context)
396
+
397
+ if self.background_noise:
398
+ curves = self.background_noise.apply(curves, context)
399
+
400
+ curves = self.gaussian_noise(curves, context)
401
+
402
+ return curves
403
+
404
+ class BasicExpIntensityNoise(IntensityNoiseGenerator):
405
+ """
406
+ A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
407
+
408
+ This class combines four types of noise:
409
+
410
+ 1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
411
+ 2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
412
+ 3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
413
+ 4. **Background noise**: Adds a constant background value to the curves.
414
+
415
+ Args:
416
+ relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
417
+ abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
418
+ scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
419
+ shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
420
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
421
+ apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
422
+ apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
423
+ apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
424
+ consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
425
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
426
+ logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
427
+ """
428
+ def __init__(self,
429
+ relative_errors: Tuple[float, float] = (0.001, 0.15),
430
+ abs_errors: float = 1e-8,
431
+ scale_range: tuple = (-2e-2, 2e-2),
432
+ shift_range: tuple = (-0.1, 0.2e-2),
433
+ background_range: tuple = (1.0e-10, 1.0e-8),
434
+ apply_shift: bool = False,
435
+ apply_scaling: bool = False,
436
+ apply_background: bool = False,
437
+ consistent_rel_err: bool = False,
438
+ add_to_context: bool = False,
439
+ logdist: bool = False,
440
+ ):
441
+ self.poisson_noise = PoissonNoiseGenerator(
442
+ relative_errors=relative_errors,
443
+ abs_errors=abs_errors,
444
+ consistent_rel_err=consistent_rel_err,
445
+ add_to_context=add_to_context,
446
+ logdist=logdist,
447
+ )
448
+ self.scaling_noise = ScalingNoise(
449
+ scale_range=scale_range, add_to_context=add_to_context
450
+ ) if apply_scaling else None
451
+
452
+ self.shift_noise = ShiftNoise(
453
+ shift_range=shift_range, add_to_context=add_to_context
454
+ ) if apply_shift else None
455
+
456
+ self.background_noise = BackgroundNoise(
457
+ background_range=background_range, add_to_context=add_to_context
458
+ ) if apply_background else None
459
+
460
+ def apply(self, curves: Tensor, context: dict = None):
461
+ """applies the specified types of noise to the input curves"""
462
+ if self.scaling_noise:
463
+ curves = self.scaling_noise(curves, context)
464
+ if self.shift_noise:
465
+ curves = self.shift_noise(curves, context)
466
+ curves = self.poisson_noise(curves, context)
467
+
468
+ if self.background_noise:
469
+ curves = self.background_noise.apply(curves, context)
470
+
471
471
  return curves