reflectorch 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +23 -0
- reflectorch/data_generation/__init__.py +130 -0
- reflectorch/data_generation/dataset.py +196 -0
- reflectorch/data_generation/likelihoods.py +86 -0
- reflectorch/data_generation/noise.py +371 -0
- reflectorch/data_generation/priors/__init__.py +66 -0
- reflectorch/data_generation/priors/base.py +61 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
- reflectorch/data_generation/priors/independent_priors.py +201 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +110 -0
- reflectorch/data_generation/priors/no_constraints.py +212 -0
- reflectorch/data_generation/priors/parametric_models.py +767 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
- reflectorch/data_generation/priors/params.py +258 -0
- reflectorch/data_generation/priors/sampler_strategies.py +306 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +377 -0
- reflectorch/data_generation/priors/utils.py +124 -0
- reflectorch/data_generation/process_data.py +47 -0
- reflectorch/data_generation/q_generator.py +232 -0
- reflectorch/data_generation/reflectivity/__init__.py +56 -0
- reflectorch/data_generation/reflectivity/abeles.py +81 -0
- reflectorch/data_generation/reflectivity/kinematical.py +58 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +123 -0
- reflectorch/data_generation/scale_curves.py +118 -0
- reflectorch/data_generation/smearing.py +67 -0
- reflectorch/data_generation/utils.py +154 -0
- reflectorch/extensions/__init__.py +6 -0
- reflectorch/extensions/jupyter/__init__.py +12 -0
- reflectorch/extensions/jupyter/callbacks.py +40 -0
- reflectorch/extensions/matplotlib/__init__.py +11 -0
- reflectorch/extensions/matplotlib/losses.py +38 -0
- reflectorch/inference/__init__.py +22 -0
- reflectorch/inference/inference_model.py +734 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +16 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +171 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +37 -0
- reflectorch/ml/basic_trainer.py +286 -0
- reflectorch/ml/callbacks.py +86 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +38 -0
- reflectorch/ml/schedulers.py +246 -0
- reflectorch/ml/trainers.py +126 -0
- reflectorch/ml/utils.py +9 -0
- reflectorch/models/__init__.py +22 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +27 -0
- reflectorch/models/encoders/conv_encoder.py +211 -0
- reflectorch/models/encoders/conv_res_net.py +119 -0
- reflectorch/models/encoders/fno.py +127 -0
- reflectorch/models/encoders/transformers.py +56 -0
- reflectorch/models/networks/__init__.py +18 -0
- reflectorch/models/networks/mlp_networks.py +256 -0
- reflectorch/models/networks/residual_net.py +131 -0
- reflectorch/paths.py +33 -0
- reflectorch/runs/__init__.py +35 -0
- reflectorch/runs/config.py +31 -0
- reflectorch/runs/slurm_utils.py +99 -0
- reflectorch/runs/train.py +85 -0
- reflectorch/runs/utils.py +300 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +74 -0
- reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
- reflectorch-1.0.0.dist-info/METADATA +115 -0
- reflectorch-1.0.0.dist-info/RECORD +83 -0
- reflectorch-1.0.0.dist-info/WHEEL +5 -0
- reflectorch-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Union, Tuple
|
|
8
|
+
from math import log10
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from reflectorch.data_generation.process_data import ProcessData
|
|
14
|
+
from reflectorch.data_generation.utils import uniform_sampler
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"QNoiseGenerator",
|
|
18
|
+
"IntensityNoiseGenerator",
|
|
19
|
+
"QNormalNoiseGenerator",
|
|
20
|
+
"QSystematicShiftGenerator",
|
|
21
|
+
"PoissonNoiseGenerator",
|
|
22
|
+
"MultiplicativeLogNormalNoiseGenerator",
|
|
23
|
+
"ScalingNoise",
|
|
24
|
+
"ShiftNoise",
|
|
25
|
+
"BackgroundNoise",
|
|
26
|
+
"BasicExpIntensityNoise",
|
|
27
|
+
"BasicQNoiseGenerator",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class QNoiseGenerator(ProcessData):
|
|
32
|
+
"""Base class for q noise generators"""
|
|
33
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
34
|
+
return qs
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class QNormalNoiseGenerator(QNoiseGenerator):
|
|
38
|
+
"""Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
|
|
42
|
+
or uniformly sampled for each curve in the batch if provided as a tuple)
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self,
|
|
45
|
+
std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
46
|
+
add_to_context: bool = False
|
|
47
|
+
):
|
|
48
|
+
self.std = std
|
|
49
|
+
self.add_to_context = add_to_context
|
|
50
|
+
|
|
51
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
52
|
+
"""applies noise to the q values"""
|
|
53
|
+
std = self.std
|
|
54
|
+
|
|
55
|
+
if isinstance(std, (list, tuple)):
|
|
56
|
+
std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
|
|
57
|
+
else:
|
|
58
|
+
std = torch.empty_like(qs).fill_(std)
|
|
59
|
+
|
|
60
|
+
noise = torch.normal(mean=0., std=std)
|
|
61
|
+
|
|
62
|
+
if self.add_to_context and context is not None:
|
|
63
|
+
context['q_stds'] = std
|
|
64
|
+
|
|
65
|
+
qs = torch.clamp_min_(qs + noise, 0.)
|
|
66
|
+
|
|
67
|
+
return qs
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class QSystematicShiftGenerator(QNoiseGenerator):
|
|
71
|
+
"""Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
std (float): the standard deviation of the normal distribution
|
|
75
|
+
"""
|
|
76
|
+
def __init__(self, std: float, add_to_context: bool = True):
|
|
77
|
+
self.std = std
|
|
78
|
+
self.add_to_context = add_to_context
|
|
79
|
+
|
|
80
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
81
|
+
"""applies systematic shifts to the q values"""
|
|
82
|
+
if len(qs.shape) == 1:
|
|
83
|
+
shape = (1,)
|
|
84
|
+
else:
|
|
85
|
+
shape = (qs.shape[0], 1)
|
|
86
|
+
|
|
87
|
+
shifts = torch.normal(
|
|
88
|
+
mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if self.add_to_context and context is not None:
|
|
92
|
+
context['q_shifts'] = shifts
|
|
93
|
+
|
|
94
|
+
qs = torch.clamp_min_(qs + shifts, 0.)
|
|
95
|
+
|
|
96
|
+
return qs
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class BasicQNoiseGenerator(QNoiseGenerator):
|
|
100
|
+
"""Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
|
|
104
|
+
(i.e. same change applied to all q points in the curve). Defaults to 1e-3.
|
|
105
|
+
noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
|
|
106
|
+
(i.e. different changes applied to each q point in the curve). The standard deviation is the same
|
|
107
|
+
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
108
|
+
Defaults to (0, 1e-3).
|
|
109
|
+
"""
|
|
110
|
+
def __init__(self,
|
|
111
|
+
shift_std: float = 1e-3,
|
|
112
|
+
noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
|
|
113
|
+
add_to_context: bool = False,
|
|
114
|
+
):
|
|
115
|
+
self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context)
|
|
116
|
+
self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context)
|
|
117
|
+
|
|
118
|
+
def apply(self, qs: Tensor, context: dict = None):
|
|
119
|
+
"""applies random noise to the q values"""
|
|
120
|
+
qs = torch.atleast_2d(qs)
|
|
121
|
+
qs = self.q_shift.apply(qs, context)
|
|
122
|
+
qs = self.q_noise.apply(qs, context)
|
|
123
|
+
return qs
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class IntensityNoiseGenerator(ProcessData):
|
|
127
|
+
"""Base class for intensity noise generators"""
|
|
128
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
|
|
133
|
+
"""Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
|
|
134
|
+
In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
|
|
139
|
+
for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
|
|
140
|
+
base (float, optional): the base of the logarithm. Defaults to 10.
|
|
141
|
+
"""
|
|
142
|
+
def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
|
|
143
|
+
self.std = std
|
|
144
|
+
self.base = base
|
|
145
|
+
self.add_to_context = add_to_context
|
|
146
|
+
|
|
147
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
148
|
+
"""applies noise to the curves"""
|
|
149
|
+
std = self.std
|
|
150
|
+
|
|
151
|
+
if isinstance(std, (list, tuple)):
|
|
152
|
+
std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
|
|
153
|
+
else:
|
|
154
|
+
std = torch.ones_like(curves) * std
|
|
155
|
+
|
|
156
|
+
noise = self.base ** torch.normal(mean=0., std=std)
|
|
157
|
+
|
|
158
|
+
if self.add_to_context and context is not None:
|
|
159
|
+
context['std_lognormal'] = std
|
|
160
|
+
|
|
161
|
+
return noise * curves
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class PoissonNoiseGenerator(IntensityNoiseGenerator):
|
|
165
|
+
"""Noise generator which applies Poisson noise to the reflectivity curves
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
|
|
169
|
+
abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
|
|
170
|
+
consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
|
|
171
|
+
logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
|
|
172
|
+
"""
|
|
173
|
+
def __init__(self,
|
|
174
|
+
relative_errors: Tuple[float, float] = (0.05, 0.35),
|
|
175
|
+
abs_errors: float = 1e-8,
|
|
176
|
+
add_to_context: bool = False,
|
|
177
|
+
consistent_rel_err: bool = True,
|
|
178
|
+
logdist: bool = False,
|
|
179
|
+
):
|
|
180
|
+
self.relative_errors = relative_errors
|
|
181
|
+
self.abs_errors = abs_errors
|
|
182
|
+
self.add_to_context = add_to_context
|
|
183
|
+
self.consistent_rel_err = consistent_rel_err
|
|
184
|
+
self.logdist = logdist
|
|
185
|
+
|
|
186
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
187
|
+
"""applies noise to the curves"""
|
|
188
|
+
if self.consistent_rel_err:
|
|
189
|
+
sigmas = self._gen_consistent_sigmas(curves)
|
|
190
|
+
else:
|
|
191
|
+
sigmas = self._gen_sigmas(curves)
|
|
192
|
+
|
|
193
|
+
intensities = curves / sigmas ** 2
|
|
194
|
+
curves = torch.poisson(intensities * curves) / intensities
|
|
195
|
+
|
|
196
|
+
if self.add_to_context and context is not None:
|
|
197
|
+
context['sigmas'] = sigmas
|
|
198
|
+
return curves
|
|
199
|
+
|
|
200
|
+
def _gen_consistent_sigmas(self, curves):
|
|
201
|
+
rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
|
|
202
|
+
self.relative_errors[1] - self.relative_errors[0]
|
|
203
|
+
) + self.relative_errors[0]
|
|
204
|
+
sigmas = curves * rel_err[:, None] + self.abs_errors
|
|
205
|
+
return sigmas
|
|
206
|
+
|
|
207
|
+
def _gen_sigmas(self, curves):
|
|
208
|
+
if not self.logdist:
|
|
209
|
+
rel_err = torch.rand_like(curves) * (
|
|
210
|
+
self.relative_errors[1] - self.relative_errors[0]
|
|
211
|
+
) + self.relative_errors[0]
|
|
212
|
+
else:
|
|
213
|
+
rel_err = torch.rand_like(curves) * (
|
|
214
|
+
log10(self.relative_errors[1]) - log10(self.relative_errors[0])
|
|
215
|
+
) + log10(self.relative_errors[0])
|
|
216
|
+
rel_err = 10 ** rel_err
|
|
217
|
+
|
|
218
|
+
sigmas = curves * rel_err + self.abs_errors
|
|
219
|
+
return sigmas
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class ScalingNoise(IntensityNoiseGenerator):
|
|
223
|
+
"""Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
|
|
224
|
+
The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
|
|
228
|
+
"""
|
|
229
|
+
def __init__(self,
|
|
230
|
+
scale_range: tuple = (-0.2e-2, 0.2e-2),
|
|
231
|
+
add_to_context: bool = False,
|
|
232
|
+
):
|
|
233
|
+
self.scale_range = scale_range
|
|
234
|
+
self.add_to_context = add_to_context
|
|
235
|
+
|
|
236
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
237
|
+
"""applies noise to the curves"""
|
|
238
|
+
scales = uniform_sampler(
|
|
239
|
+
*self.scale_range, curves.shape[0], 1,
|
|
240
|
+
device=curves.device, dtype=curves.dtype
|
|
241
|
+
)
|
|
242
|
+
if self.add_to_context and context is not None:
|
|
243
|
+
context['intensity_scales'] = scales
|
|
244
|
+
|
|
245
|
+
curves = curves ** (1 + scales)
|
|
246
|
+
|
|
247
|
+
return curves
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class ShiftNoise(IntensityNoiseGenerator):
|
|
251
|
+
def __init__(self,
|
|
252
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
253
|
+
add_to_context: bool = False,
|
|
254
|
+
):
|
|
255
|
+
"""Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
|
|
256
|
+
The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
|
|
257
|
+
Args:
|
|
258
|
+
shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
|
|
259
|
+
"""
|
|
260
|
+
self.shift_range = shift_range
|
|
261
|
+
self.add_to_context = add_to_context
|
|
262
|
+
|
|
263
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
264
|
+
"""applies noise to the curves"""
|
|
265
|
+
intensity_shifts = uniform_sampler(
|
|
266
|
+
*self.shift_range, curves.shape[0], 1,
|
|
267
|
+
device=curves.device, dtype=curves.dtype
|
|
268
|
+
)
|
|
269
|
+
if self.add_to_context and context is not None:
|
|
270
|
+
context['intensity_shifts'] = intensity_shifts
|
|
271
|
+
|
|
272
|
+
curves = curves * (1 + intensity_shifts)
|
|
273
|
+
|
|
274
|
+
return curves
|
|
275
|
+
|
|
276
|
+
class BackgroundNoise(IntensityNoiseGenerator):
|
|
277
|
+
"""
|
|
278
|
+
Noise generator which adds a constant background to reflectivity curves.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
282
|
+
"""
|
|
283
|
+
def __init__(self,
|
|
284
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
285
|
+
add_to_context: bool = False,
|
|
286
|
+
):
|
|
287
|
+
self.background_range = background_range
|
|
288
|
+
self.add_to_context = add_to_context
|
|
289
|
+
|
|
290
|
+
def apply(self, curves: Tensor, context: dict = None) -> Tensor:
|
|
291
|
+
"""applies background noise to the curves"""
|
|
292
|
+
backgrounds = uniform_sampler(
|
|
293
|
+
*self.background_range, curves.shape[0], 1,
|
|
294
|
+
device=curves.device, dtype=curves.dtype
|
|
295
|
+
)
|
|
296
|
+
if self.add_to_context and context is not None:
|
|
297
|
+
context['backgrounds'] = backgrounds
|
|
298
|
+
|
|
299
|
+
curves = curves + backgrounds
|
|
300
|
+
|
|
301
|
+
return curves
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class BasicExpIntensityNoise(IntensityNoiseGenerator):
|
|
305
|
+
"""
|
|
306
|
+
A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
|
|
307
|
+
|
|
308
|
+
This class combines four types of noise:
|
|
309
|
+
|
|
310
|
+
1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
|
|
311
|
+
2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
|
|
312
|
+
3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
|
|
313
|
+
4. **Background noise**: Adds a constant background value to the curves.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
|
|
317
|
+
abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
|
|
318
|
+
scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
|
|
319
|
+
shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
|
|
320
|
+
background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
|
|
321
|
+
apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
|
|
322
|
+
apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
|
|
323
|
+
apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
|
|
324
|
+
consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
|
|
325
|
+
add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
|
|
326
|
+
logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
|
|
327
|
+
"""
|
|
328
|
+
def __init__(self,
|
|
329
|
+
relative_errors: Tuple[float, float] = (0.001, 0.15),
|
|
330
|
+
abs_errors: float = 1e-8,
|
|
331
|
+
scale_range: tuple = (-2e-2, 2e-2),
|
|
332
|
+
shift_range: tuple = (-0.1, 0.2e-2),
|
|
333
|
+
background_range: tuple = (1.0e-10, 1.0e-8),
|
|
334
|
+
apply_shift: bool = False,
|
|
335
|
+
apply_scaling: bool = False,
|
|
336
|
+
apply_background: bool = False,
|
|
337
|
+
consistent_rel_err: bool = False,
|
|
338
|
+
add_to_context: bool = False,
|
|
339
|
+
logdist: bool = False,
|
|
340
|
+
):
|
|
341
|
+
self.poisson_noise = PoissonNoiseGenerator(
|
|
342
|
+
relative_errors=relative_errors,
|
|
343
|
+
abs_errors=abs_errors,
|
|
344
|
+
consistent_rel_err=consistent_rel_err,
|
|
345
|
+
add_to_context=add_to_context,
|
|
346
|
+
logdist=logdist,
|
|
347
|
+
)
|
|
348
|
+
self.scaling_noise = ScalingNoise(
|
|
349
|
+
scale_range=scale_range, add_to_context=add_to_context
|
|
350
|
+
) if apply_scaling else None
|
|
351
|
+
|
|
352
|
+
self.shift_noise = ShiftNoise(
|
|
353
|
+
shift_range=shift_range, add_to_context=add_to_context
|
|
354
|
+
) if apply_shift else None
|
|
355
|
+
|
|
356
|
+
self.background_noise = BackgroundNoise(
|
|
357
|
+
background_range=background_range, add_to_context=add_to_context
|
|
358
|
+
) if apply_background else None
|
|
359
|
+
|
|
360
|
+
def apply(self, curves: Tensor, context: dict = None):
|
|
361
|
+
"""applies the specified types of noise to the input curves"""
|
|
362
|
+
if self.scaling_noise:
|
|
363
|
+
curves = self.scaling_noise(curves, context)
|
|
364
|
+
if self.shift_noise:
|
|
365
|
+
curves = self.shift_noise(curves, context)
|
|
366
|
+
curves = self.poisson_noise(curves, context)
|
|
367
|
+
|
|
368
|
+
if self.background_noise:
|
|
369
|
+
curves = self.background_noise.apply(curves, context)
|
|
370
|
+
|
|
371
|
+
return curves
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors.params import Params
|
|
8
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
9
|
+
from reflectorch.data_generation.priors.no_constraints import BasicPriorSampler
|
|
10
|
+
from reflectorch.data_generation.priors.independent_priors import (
|
|
11
|
+
SingleParamPrior,
|
|
12
|
+
SimplePriorSampler,
|
|
13
|
+
UniformParamPrior,
|
|
14
|
+
GaussianParamPrior,
|
|
15
|
+
TruncatedGaussianParamPrior
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from reflectorch.data_generation.priors.subprior_sampler import (
|
|
19
|
+
UniformSubPriorParams,
|
|
20
|
+
UniformSubPriorSampler,
|
|
21
|
+
NarrowSldUniformSubPriorSampler,
|
|
22
|
+
)
|
|
23
|
+
from reflectorch.data_generation.priors.exp_subprior_sampler import ExpUniformSubPriorSampler
|
|
24
|
+
from reflectorch.data_generation.priors.multilayer_structures import (
|
|
25
|
+
SimpleMultilayerSampler,
|
|
26
|
+
MultilayerStructureParams,
|
|
27
|
+
)
|
|
28
|
+
from reflectorch.data_generation.priors.parametric_models import (
|
|
29
|
+
ParametricModel,
|
|
30
|
+
MULTILAYER_MODELS,
|
|
31
|
+
)
|
|
32
|
+
from reflectorch.data_generation.priors.parametric_subpriors import (
|
|
33
|
+
SubpriorParametricSampler,
|
|
34
|
+
BasicParams,
|
|
35
|
+
)
|
|
36
|
+
from reflectorch.data_generation.priors.sampler_strategies import (
|
|
37
|
+
SamplerStrategy,
|
|
38
|
+
BasicSamplerStrategy,
|
|
39
|
+
ConstrainedRoughnessSamplerStrategy,
|
|
40
|
+
ConstrainedRoughnessAndImgSldSamplerStrategy,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"SingleParamPrior",
|
|
45
|
+
"SimplePriorSampler",
|
|
46
|
+
"UniformParamPrior",
|
|
47
|
+
"GaussianParamPrior",
|
|
48
|
+
"TruncatedGaussianParamPrior",
|
|
49
|
+
"Params",
|
|
50
|
+
"PriorSampler",
|
|
51
|
+
"BasicPriorSampler",
|
|
52
|
+
"UniformSubPriorParams",
|
|
53
|
+
"UniformSubPriorSampler",
|
|
54
|
+
"NarrowSldUniformSubPriorSampler",
|
|
55
|
+
"ExpUniformSubPriorSampler",
|
|
56
|
+
"SimpleMultilayerSampler",
|
|
57
|
+
"MultilayerStructureParams",
|
|
58
|
+
"SubpriorParametricSampler",
|
|
59
|
+
"BasicParams",
|
|
60
|
+
"ParametricModel",
|
|
61
|
+
"MULTILAYER_MODELS",
|
|
62
|
+
"SamplerStrategy",
|
|
63
|
+
"BasicSamplerStrategy",
|
|
64
|
+
"ConstrainedRoughnessSamplerStrategy",
|
|
65
|
+
"ConstrainedRoughnessAndImgSldSamplerStrategy",
|
|
66
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
#
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the GPL license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.priors.params import Params
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PriorSampler",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PriorSampler(object):
|
|
17
|
+
"""Base class for prior samplers"""
|
|
18
|
+
|
|
19
|
+
PARAM_CLS = Params
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def param_dim(self) -> int:
|
|
23
|
+
"""gets the number of parameters (i.e. the parameter dimensionality)"""
|
|
24
|
+
return self.PARAM_CLS.layers_num2size(self.max_num_layers)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def max_num_layers(self) -> int:
|
|
28
|
+
"""gets the number of layers"""
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
def sample(self, batch_size: int) -> Params:
|
|
32
|
+
"""sample a batch of parameters"""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
36
|
+
"""scale the parameters to a ML-friendly range"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
40
|
+
"""restore the parameters to their original range"""
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
def filter_params(self, params: Params) -> Params:
|
|
53
|
+
indices = self.get_indices_within_domain(params)
|
|
54
|
+
return params[indices]
|
|
55
|
+
|
|
56
|
+
def clamp_params(self, params: Params) -> Params:
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
def __repr__(self):
|
|
60
|
+
args = ', '.join(f'{k}={str(v)[:10]}' for k, v in vars(self).items())
|
|
61
|
+
return f'{self.__class__.__name__}({args})'
|