reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,371 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union, Tuple
8
+ from math import log10
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from reflectorch.data_generation.process_data import ProcessData
14
+ from reflectorch.data_generation.utils import uniform_sampler
15
+
16
+ __all__ = [
17
+ "QNoiseGenerator",
18
+ "IntensityNoiseGenerator",
19
+ "QNormalNoiseGenerator",
20
+ "QSystematicShiftGenerator",
21
+ "PoissonNoiseGenerator",
22
+ "MultiplicativeLogNormalNoiseGenerator",
23
+ "ScalingNoise",
24
+ "ShiftNoise",
25
+ "BackgroundNoise",
26
+ "BasicExpIntensityNoise",
27
+ "BasicQNoiseGenerator",
28
+ ]
29
+
30
+
31
+ class QNoiseGenerator(ProcessData):
32
+ """Base class for q noise generators"""
33
+ def apply(self, qs: Tensor, context: dict = None):
34
+ return qs
35
+
36
+
37
+ class QNormalNoiseGenerator(QNoiseGenerator):
38
+ """Q noise generator which adds to each q value of the reflectivity curve a noise sampled from a normal distribution.
39
+
40
+ Args:
41
+ std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution (the same for all curves in the batch if provided as a float,
42
+ or uniformly sampled for each curve in the batch if provided as a tuple)
43
+ """
44
+ def __init__(self,
45
+ std: Union[float, Tuple[float, float]] = (0, 1e-3),
46
+ add_to_context: bool = False
47
+ ):
48
+ self.std = std
49
+ self.add_to_context = add_to_context
50
+
51
+ def apply(self, qs: Tensor, context: dict = None):
52
+ """applies noise to the q values"""
53
+ std = self.std
54
+
55
+ if isinstance(std, (list, tuple)):
56
+ std = uniform_sampler(*std, qs.shape[0], 1, device=qs.device, dtype=qs.dtype)
57
+ else:
58
+ std = torch.empty_like(qs).fill_(std)
59
+
60
+ noise = torch.normal(mean=0., std=std)
61
+
62
+ if self.add_to_context and context is not None:
63
+ context['q_stds'] = std
64
+
65
+ qs = torch.clamp_min_(qs + noise, 0.)
66
+
67
+ return qs
68
+
69
+
70
+ class QSystematicShiftGenerator(QNoiseGenerator):
71
+ """Q noise generator which samples a q shift (for each curve in the batch) from a normal distribution adds it to all q values of the curve
72
+
73
+ Args:
74
+ std (float): the standard deviation of the normal distribution
75
+ """
76
+ def __init__(self, std: float, add_to_context: bool = True):
77
+ self.std = std
78
+ self.add_to_context = add_to_context
79
+
80
+ def apply(self, qs: Tensor, context: dict = None):
81
+ """applies systematic shifts to the q values"""
82
+ if len(qs.shape) == 1:
83
+ shape = (1,)
84
+ else:
85
+ shape = (qs.shape[0], 1)
86
+
87
+ shifts = torch.normal(
88
+ mean=0., std=self.std * torch.ones(*shape, device=qs.device, dtype=qs.dtype)
89
+ )
90
+
91
+ if self.add_to_context and context is not None:
92
+ context['q_shifts'] = shifts
93
+
94
+ qs = torch.clamp_min_(qs + shifts, 0.)
95
+
96
+ return qs
97
+
98
+
99
+ class BasicQNoiseGenerator(QNoiseGenerator):
100
+ """Q noise generator which applies both systematic shifts (same change for all q points in the curve) and random noise (different changes per q point in the curve)
101
+
102
+ Args:
103
+ shift_std (float, optional): the standard deviation of the normal distribution for systematic q shifts
104
+ (i.e. same change applied to all q points in the curve). Defaults to 1e-3.
105
+ noise_std (Union[float, Tuple[float, float]], optional): the standard deviation of the normal distribution for random q noise
106
+ (i.e. different changes applied to each q point in the curve). The standard deviation is the same
107
+ for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
108
+ Defaults to (0, 1e-3).
109
+ """
110
+ def __init__(self,
111
+ shift_std: float = 1e-3,
112
+ noise_std: Union[float, Tuple[float, float]] = (0, 1e-3),
113
+ add_to_context: bool = False,
114
+ ):
115
+ self.q_shift = QSystematicShiftGenerator(shift_std, add_to_context=add_to_context)
116
+ self.q_noise = QNormalNoiseGenerator(noise_std, add_to_context=add_to_context)
117
+
118
+ def apply(self, qs: Tensor, context: dict = None):
119
+ """applies random noise to the q values"""
120
+ qs = torch.atleast_2d(qs)
121
+ qs = self.q_shift.apply(qs, context)
122
+ qs = self.q_noise.apply(qs, context)
123
+ return qs
124
+
125
+
126
+ class IntensityNoiseGenerator(ProcessData):
127
+ """Base class for intensity noise generators"""
128
+ def apply(self, curves: Tensor, context: dict = None):
129
+ raise NotImplementedError
130
+
131
+
132
+ class MultiplicativeLogNormalNoiseGenerator(IntensityNoiseGenerator):
133
+ """Noise generator which applies noise as :math:`R_n = R * b^{\epsilon}` , where :math:`b` is a base and :math:`\epsilon` is sampled from the normal distribution :math:`\epsilon \sim \mathcal{N}(0, std)` .
134
+ In logarithmic space, this translates to :math:`\log_b(R_n) = \log_b(R) + \epsilon` .
135
+
136
+
137
+ Args:
138
+ std (Union[float, Tuple[float, float]]): the standard deviation of the normal distribution from which the noise is sampled. The standard deviation is the same
139
+ for all curves in the batch if provided as a float, or uniformly sampled for each curve in the batch if provided as a tuple.
140
+ base (float, optional): the base of the logarithm. Defaults to 10.
141
+ """
142
+ def __init__(self, std: Union[float, Tuple[float, float]], base: float = 10, add_to_context: bool = False):
143
+ self.std = std
144
+ self.base = base
145
+ self.add_to_context = add_to_context
146
+
147
+ def apply(self, curves: Tensor, context: dict = None):
148
+ """applies noise to the curves"""
149
+ std = self.std
150
+
151
+ if isinstance(std, (list, tuple)):
152
+ std = uniform_sampler(*std, curves.shape[0], 1, device=curves.device, dtype=curves.dtype)
153
+ else:
154
+ std = torch.ones_like(curves) * std
155
+
156
+ noise = self.base ** torch.normal(mean=0., std=std)
157
+
158
+ if self.add_to_context and context is not None:
159
+ context['std_lognormal'] = std
160
+
161
+ return noise * curves
162
+
163
+
164
+ class PoissonNoiseGenerator(IntensityNoiseGenerator):
165
+ """Noise generator which applies Poisson noise to the reflectivity curves
166
+
167
+ Args:
168
+ relative_errors (Tuple[float, float], optional): the range of relative errors to apply to the intensity curves. Defaults to (0.05, 0.35).
169
+ abs_errors (float, optional): a small constant added to prevent division by zero. Defaults to 1e-8.
170
+ consistent_rel_err (bool, optional): If ``True``, the same relative error is used for all points in a curve.
171
+ logdist (bool, optional): If ``True``, the relative errors in are sampled in logarithmic space. Defaults to False.
172
+ """
173
+ def __init__(self,
174
+ relative_errors: Tuple[float, float] = (0.05, 0.35),
175
+ abs_errors: float = 1e-8,
176
+ add_to_context: bool = False,
177
+ consistent_rel_err: bool = True,
178
+ logdist: bool = False,
179
+ ):
180
+ self.relative_errors = relative_errors
181
+ self.abs_errors = abs_errors
182
+ self.add_to_context = add_to_context
183
+ self.consistent_rel_err = consistent_rel_err
184
+ self.logdist = logdist
185
+
186
+ def apply(self, curves: Tensor, context: dict = None):
187
+ """applies noise to the curves"""
188
+ if self.consistent_rel_err:
189
+ sigmas = self._gen_consistent_sigmas(curves)
190
+ else:
191
+ sigmas = self._gen_sigmas(curves)
192
+
193
+ intensities = curves / sigmas ** 2
194
+ curves = torch.poisson(intensities * curves) / intensities
195
+
196
+ if self.add_to_context and context is not None:
197
+ context['sigmas'] = sigmas
198
+ return curves
199
+
200
+ def _gen_consistent_sigmas(self, curves):
201
+ rel_err = torch.rand(curves.shape[0], device=curves.device, dtype=curves.dtype) * (
202
+ self.relative_errors[1] - self.relative_errors[0]
203
+ ) + self.relative_errors[0]
204
+ sigmas = curves * rel_err[:, None] + self.abs_errors
205
+ return sigmas
206
+
207
+ def _gen_sigmas(self, curves):
208
+ if not self.logdist:
209
+ rel_err = torch.rand_like(curves) * (
210
+ self.relative_errors[1] - self.relative_errors[0]
211
+ ) + self.relative_errors[0]
212
+ else:
213
+ rel_err = torch.rand_like(curves) * (
214
+ log10(self.relative_errors[1]) - log10(self.relative_errors[0])
215
+ ) + log10(self.relative_errors[0])
216
+ rel_err = 10 ** rel_err
217
+
218
+ sigmas = curves * rel_err + self.abs_errors
219
+ return sigmas
220
+
221
+
222
+ class ScalingNoise(IntensityNoiseGenerator):
223
+ """Noise generator which applies scaling noise to reflectivity curves (equivalent to a vertical stretch or compression of the curve in the logarithmic domain).
224
+ The output is R^(1 + scale_factor), which corresponds in logarithmic domain to (1 + scale_factor) * log(R).
225
+
226
+ Args:
227
+ scale_range (tuple, optional): the range of scaling factors (one factor sampled per curve in the batch). Defaults to (-0.2e-2, 0.2e-2).
228
+ """
229
+ def __init__(self,
230
+ scale_range: tuple = (-0.2e-2, 0.2e-2),
231
+ add_to_context: bool = False,
232
+ ):
233
+ self.scale_range = scale_range
234
+ self.add_to_context = add_to_context
235
+
236
+ def apply(self, curves: Tensor, context: dict = None):
237
+ """applies noise to the curves"""
238
+ scales = uniform_sampler(
239
+ *self.scale_range, curves.shape[0], 1,
240
+ device=curves.device, dtype=curves.dtype
241
+ )
242
+ if self.add_to_context and context is not None:
243
+ context['intensity_scales'] = scales
244
+
245
+ curves = curves ** (1 + scales)
246
+
247
+ return curves
248
+
249
+
250
+ class ShiftNoise(IntensityNoiseGenerator):
251
+ def __init__(self,
252
+ shift_range: tuple = (-0.1, 0.2e-2),
253
+ add_to_context: bool = False,
254
+ ):
255
+ """Noise generator which applies shifting noise to reflectivity curves (equivalent to a vertical shift of the entire curve in the logarithmic domain).
256
+ The output is R * (1 + shift_factor), which corresponds in logarithmic domain to log(R) + log(1 + shift_factor).
257
+ Args:
258
+ shift_range (tuple, optional): the range of shift factors (one factor sampled per curve in the batch). Defaults to (-0.1, 0.2e-2).
259
+ """
260
+ self.shift_range = shift_range
261
+ self.add_to_context = add_to_context
262
+
263
+ def apply(self, curves: Tensor, context: dict = None):
264
+ """applies noise to the curves"""
265
+ intensity_shifts = uniform_sampler(
266
+ *self.shift_range, curves.shape[0], 1,
267
+ device=curves.device, dtype=curves.dtype
268
+ )
269
+ if self.add_to_context and context is not None:
270
+ context['intensity_shifts'] = intensity_shifts
271
+
272
+ curves = curves * (1 + intensity_shifts)
273
+
274
+ return curves
275
+
276
+ class BackgroundNoise(IntensityNoiseGenerator):
277
+ """
278
+ Noise generator which adds a constant background to reflectivity curves.
279
+
280
+ Args:
281
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
282
+ """
283
+ def __init__(self,
284
+ background_range: tuple = (1.0e-10, 1.0e-8),
285
+ add_to_context: bool = False,
286
+ ):
287
+ self.background_range = background_range
288
+ self.add_to_context = add_to_context
289
+
290
+ def apply(self, curves: Tensor, context: dict = None) -> Tensor:
291
+ """applies background noise to the curves"""
292
+ backgrounds = uniform_sampler(
293
+ *self.background_range, curves.shape[0], 1,
294
+ device=curves.device, dtype=curves.dtype
295
+ )
296
+ if self.add_to_context and context is not None:
297
+ context['backgrounds'] = backgrounds
298
+
299
+ curves = curves + backgrounds
300
+
301
+ return curves
302
+
303
+
304
+ class BasicExpIntensityNoise(IntensityNoiseGenerator):
305
+ """
306
+ A composite noise generator that applies Poisson, scaling, shift and background noise to reflectivity curves.
307
+
308
+ This class combines four types of noise:
309
+
310
+ 1. **Poisson noise**: Simulates count-based noise common in photon counting experiments.
311
+ 2. **Scaling noise**: Applies a scaling transformation to the curves, equivalent to a vertical stretch or compression in logarithmic space.
312
+ 3. **Shift noise**: Applies a multiplicative shift to the curves, equivalent to a vertical shift in logarithmic space.
313
+ 4. **Background noise**: Adds a constant background value to the curves.
314
+
315
+ Args:
316
+ relative_errors (Tuple[float, float], optional): The range of relative errors for Poisson noise. Defaults to (0.001, 0.15).
317
+ abs_errors (float, optional): A small constant added to prevent division by zero in Poisson noise. Defaults to 1e-8.
318
+ scale_range (tuple, optional): The range of scaling factors for scaling noise. Defaults to (-2e-2, 2e-2).
319
+ shift_range (tuple, optional): The range of shift factors for shift noise. Defaults to (-0.1, 0.2e-2).
320
+ background_range (tuple, optional): The range from which the background value is sampled. Defaults to (1.0e-10, 1.0e-8).
321
+ apply_shift (bool, optional): If True, applies shift noise to the curves. Defaults to False.
322
+ apply_scaling (bool, optional): If True, applies scaling noise to the curves. Defaults to False.
323
+ apply_background (bool, optional): If True, applies background noise to the curves. Defaults to False.
324
+ consistent_rel_err (bool, optional): If True, uses a consistent relative error for Poisson noise across all points in a curve. Defaults to False.
325
+ add_to_context (bool, optional): If True, adds generated noise parameters to the context dictionary. Defaults to False.
326
+ logdist (bool, optional): If True, samples relative errors for Poisson noise in logarithmic space. Defaults to False.
327
+ """
328
+ def __init__(self,
329
+ relative_errors: Tuple[float, float] = (0.001, 0.15),
330
+ abs_errors: float = 1e-8,
331
+ scale_range: tuple = (-2e-2, 2e-2),
332
+ shift_range: tuple = (-0.1, 0.2e-2),
333
+ background_range: tuple = (1.0e-10, 1.0e-8),
334
+ apply_shift: bool = False,
335
+ apply_scaling: bool = False,
336
+ apply_background: bool = False,
337
+ consistent_rel_err: bool = False,
338
+ add_to_context: bool = False,
339
+ logdist: bool = False,
340
+ ):
341
+ self.poisson_noise = PoissonNoiseGenerator(
342
+ relative_errors=relative_errors,
343
+ abs_errors=abs_errors,
344
+ consistent_rel_err=consistent_rel_err,
345
+ add_to_context=add_to_context,
346
+ logdist=logdist,
347
+ )
348
+ self.scaling_noise = ScalingNoise(
349
+ scale_range=scale_range, add_to_context=add_to_context
350
+ ) if apply_scaling else None
351
+
352
+ self.shift_noise = ShiftNoise(
353
+ shift_range=shift_range, add_to_context=add_to_context
354
+ ) if apply_shift else None
355
+
356
+ self.background_noise = BackgroundNoise(
357
+ background_range=background_range, add_to_context=add_to_context
358
+ ) if apply_background else None
359
+
360
+ def apply(self, curves: Tensor, context: dict = None):
361
+ """applies the specified types of noise to the input curves"""
362
+ if self.scaling_noise:
363
+ curves = self.scaling_noise(curves, context)
364
+ if self.shift_noise:
365
+ curves = self.shift_noise(curves, context)
366
+ curves = self.poisson_noise(curves, context)
367
+
368
+ if self.background_noise:
369
+ curves = self.background_noise.apply(curves, context)
370
+
371
+ return curves
@@ -0,0 +1,66 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.data_generation.priors.params import Params
8
+ from reflectorch.data_generation.priors.base import PriorSampler
9
+ from reflectorch.data_generation.priors.no_constraints import BasicPriorSampler
10
+ from reflectorch.data_generation.priors.independent_priors import (
11
+ SingleParamPrior,
12
+ SimplePriorSampler,
13
+ UniformParamPrior,
14
+ GaussianParamPrior,
15
+ TruncatedGaussianParamPrior
16
+ )
17
+
18
+ from reflectorch.data_generation.priors.subprior_sampler import (
19
+ UniformSubPriorParams,
20
+ UniformSubPriorSampler,
21
+ NarrowSldUniformSubPriorSampler,
22
+ )
23
+ from reflectorch.data_generation.priors.exp_subprior_sampler import ExpUniformSubPriorSampler
24
+ from reflectorch.data_generation.priors.multilayer_structures import (
25
+ SimpleMultilayerSampler,
26
+ MultilayerStructureParams,
27
+ )
28
+ from reflectorch.data_generation.priors.parametric_models import (
29
+ ParametricModel,
30
+ MULTILAYER_MODELS,
31
+ )
32
+ from reflectorch.data_generation.priors.parametric_subpriors import (
33
+ SubpriorParametricSampler,
34
+ BasicParams,
35
+ )
36
+ from reflectorch.data_generation.priors.sampler_strategies import (
37
+ SamplerStrategy,
38
+ BasicSamplerStrategy,
39
+ ConstrainedRoughnessSamplerStrategy,
40
+ ConstrainedRoughnessAndImgSldSamplerStrategy,
41
+ )
42
+
43
+ __all__ = [
44
+ "SingleParamPrior",
45
+ "SimplePriorSampler",
46
+ "UniformParamPrior",
47
+ "GaussianParamPrior",
48
+ "TruncatedGaussianParamPrior",
49
+ "Params",
50
+ "PriorSampler",
51
+ "BasicPriorSampler",
52
+ "UniformSubPriorParams",
53
+ "UniformSubPriorSampler",
54
+ "NarrowSldUniformSubPriorSampler",
55
+ "ExpUniformSubPriorSampler",
56
+ "SimpleMultilayerSampler",
57
+ "MultilayerStructureParams",
58
+ "SubpriorParametricSampler",
59
+ "BasicParams",
60
+ "ParametricModel",
61
+ "MULTILAYER_MODELS",
62
+ "SamplerStrategy",
63
+ "BasicSamplerStrategy",
64
+ "ConstrainedRoughnessSamplerStrategy",
65
+ "ConstrainedRoughnessAndImgSldSamplerStrategy",
66
+ ]
@@ -0,0 +1,61 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import Tensor
8
+
9
+ from reflectorch.data_generation.priors.params import Params
10
+
11
+ __all__ = [
12
+ "PriorSampler",
13
+ ]
14
+
15
+
16
+ class PriorSampler(object):
17
+ """Base class for prior samplers"""
18
+
19
+ PARAM_CLS = Params
20
+
21
+ @property
22
+ def param_dim(self) -> int:
23
+ """gets the number of parameters (i.e. the parameter dimensionality)"""
24
+ return self.PARAM_CLS.layers_num2size(self.max_num_layers)
25
+
26
+ @property
27
+ def max_num_layers(self) -> int:
28
+ """gets the number of layers"""
29
+ raise NotImplementedError
30
+
31
+ def sample(self, batch_size: int) -> Params:
32
+ """sample a batch of parameters"""
33
+ raise NotImplementedError
34
+
35
+ def scale_params(self, params: Params) -> Tensor:
36
+ """scale the parameters to a ML-friendly range"""
37
+ raise NotImplementedError
38
+
39
+ def restore_params(self, scaled_params: Tensor) -> Params:
40
+ """restore the parameters to their original range"""
41
+ raise NotImplementedError
42
+
43
+ def log_prob(self, params: Params) -> Tensor:
44
+ raise NotImplementedError
45
+
46
+ def get_indices_within_domain(self, params: Params) -> Tensor:
47
+ raise NotImplementedError
48
+
49
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
50
+ raise NotImplementedError
51
+
52
+ def filter_params(self, params: Params) -> Params:
53
+ indices = self.get_indices_within_domain(params)
54
+ return params[indices]
55
+
56
+ def clamp_params(self, params: Params) -> Params:
57
+ raise NotImplementedError
58
+
59
+ def __repr__(self):
60
+ args = ', '.join(f'{k}={str(v)[:10]}' for k, v in vars(self).items())
61
+ return f'{self.__class__.__name__}({args})'