reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Union, List
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation.utils import (
13
+ uniform_sampler,
14
+ logdist_sampler,
15
+ )
16
+
17
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
18
+ from reflectorch.data_generation.priors.base import PriorSampler
19
+ from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
20
+ from reflectorch.data_generation.priors.params import Params
21
+ from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
22
+ from reflectorch.data_generation.priors.no_constraints import (
23
+ DEFAULT_DEVICE,
24
+ DEFAULT_DTYPE,
25
+ )
26
+
27
+
28
+ class ExpUniformSubPriorSampler(PriorSampler, ScalerMixin):
29
+ PARAM_CLS = UniformSubPriorParams
30
+
31
+ def __init__(self,
32
+ params: List[Union[float, Tuple[float, float], Tuple[float, float, float, float]]],
33
+ device: torch.device = DEFAULT_DEVICE,
34
+ dtype: torch.dtype = DEFAULT_DTYPE,
35
+ scaled_range: Tuple[float, float] = (-1, 1),
36
+ logdist: bool = False,
37
+ relative_min_bound_width: float = 1e-4,
38
+ smaller_roughnesses: bool = True,
39
+ ):
40
+ self.device = device
41
+ self.dtype = dtype
42
+ self.scaled_range = scaled_range
43
+ self.relative_min_bound_width = relative_min_bound_width
44
+ self.logdist = logdist
45
+ self.smaller_roughnesses = smaller_roughnesses
46
+ self._init_params(*params)
47
+
48
+ @property
49
+ def max_num_layers(self) -> int:
50
+ return self.num_layers
51
+
52
+ def _init_params(self, *params: Union[float, Tuple[float, float], Tuple[float, float, float, float]]):
53
+ self.num_layers = (len(params) - 2) // 3
54
+ self._total_num_params = len(params)
55
+
56
+ fixed_mask = []
57
+ bounds = []
58
+ delta_bounds = []
59
+ param_dim = 0
60
+
61
+ for param in params:
62
+ if isinstance(param, (float, int)):
63
+ deltas = (0, 0)
64
+ param = (param, param)
65
+ fixed_mask.append(True)
66
+ else:
67
+ param_dim += 1
68
+ fixed_mask.append(False)
69
+
70
+ if len(param) == 4:
71
+ param, deltas = param[:2], param[2:]
72
+ else:
73
+ max_delta = param[1] - param[0]
74
+ deltas = (max_delta * self.relative_min_bound_width, max_delta)
75
+
76
+ bounds.append(param)
77
+ delta_bounds.append(deltas)
78
+
79
+ self.fixed_mask = torch.tensor(fixed_mask).to(self.device)
80
+ self.fitted_mask = ~self.fixed_mask
81
+ self.roughnesses_mask = torch.zeros_like(self.fitted_mask)
82
+ self.roughnesses_mask[self.num_layers: self.num_layers * 2 + 1] = True
83
+
84
+ self.min_bounds, self.max_bounds = torch.tensor(bounds).to(self.device).to(self.dtype).T
85
+ self.min_deltas, self.max_deltas = map(torch.atleast_2d, torch.tensor(delta_bounds).to(self.min_bounds).T)
86
+ self._param_dim = param_dim
87
+ self._num_fixed = self.fixed_mask.sum()
88
+
89
+ self.fixed_params = self.min_bounds[self.fixed_mask]
90
+
91
+ @property
92
+ def param_dim(self) -> int:
93
+ return self._param_dim
94
+
95
+ def sample(self, batch_size: int) -> UniformSubPriorParams:
96
+ min_bounds, max_bounds = self.sample_bounds(batch_size)
97
+
98
+ params = torch.rand(
99
+ *min_bounds.shape,
100
+ device=self.device,
101
+ dtype=self.dtype
102
+ ) * (max_bounds - min_bounds) + min_bounds
103
+
104
+ thicknesses, roughnesses, slds = torch.split(
105
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
106
+ )
107
+
108
+ if self.smaller_roughnesses:
109
+ fitted_r_mask = self.fitted_mask.clone()
110
+ fitted_r_mask[~self.roughnesses_mask] = False
111
+
112
+ min_roughness = self.min_bounds[fitted_r_mask]
113
+ max_roughness = torch.clamp(
114
+ get_max_allowed_roughness(thicknesses)[..., self.fitted_mask[self.roughnesses_mask]],
115
+ min_roughness,
116
+ self.max_bounds[fitted_r_mask]
117
+ )
118
+
119
+ min_vector = self.min_bounds.clone()[None].repeat(batch_size, 1)
120
+ max_vector = self.max_bounds.clone()[None].repeat(batch_size, 1)
121
+
122
+ max_vector[..., fitted_r_mask] = max_roughness
123
+
124
+ assert torch.all(max_vector[..., fitted_r_mask] == max_roughness)
125
+
126
+ min_deltas = self.min_deltas.clone().repeat(batch_size, 1)
127
+ max_deltas = self.max_deltas.clone().repeat(batch_size, 1)
128
+
129
+ max_deltas[..., fitted_r_mask] = torch.clamp_max(
130
+ max_deltas[..., fitted_r_mask],
131
+ max_roughness - min_roughness,
132
+ )
133
+
134
+ fitted_mask = torch.zeros_like(self.fitted_mask)
135
+ fitted_mask[fitted_r_mask] = True
136
+
137
+ updated_min_bounds, updated_max_bounds = self._sample_bounds(
138
+ batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask
139
+ )
140
+
141
+ min_bounds[..., fitted_mask], max_bounds[..., fitted_mask] = (
142
+ updated_min_bounds[..., fitted_mask], updated_max_bounds[..., fitted_mask]
143
+ )
144
+
145
+ params[..., fitted_mask] = torch.rand(
146
+ batch_size, fitted_mask.sum().item(),
147
+ device=self.device,
148
+ dtype=self.dtype
149
+ ) * (max_bounds[..., fitted_mask] - min_bounds[..., fitted_mask]) + min_bounds[..., fitted_mask]
150
+
151
+ thicknesses, roughnesses, slds = torch.split(
152
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
153
+ )
154
+
155
+ params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
156
+
157
+ return params
158
+
159
+ def scale_params(self, params: UniformSubPriorParams) -> Tensor:
160
+ params_t = params.as_tensor(add_bounds=False)
161
+
162
+ scaled_params = self._scale(params_t, params.min_bounds, params.max_bounds)[..., self.fitted_mask]
163
+
164
+ scaled_min_bounds = self._scale(params.min_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
165
+
166
+ scaled_max_bounds = self._scale(params.max_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
167
+
168
+ scaled_params = torch.cat([scaled_params, scaled_min_bounds, scaled_max_bounds], -1)
169
+
170
+ return scaled_params
171
+
172
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
173
+ return self._scale(bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
174
+
175
+ def restore_params(self, scaled_params: Tensor) -> UniformSubPriorParams:
176
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
177
+ scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
178
+ )
179
+
180
+ min_bounds = self._restore(
181
+ scaled_min_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
182
+ )
183
+ max_bounds = self._restore(
184
+ scaled_max_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
185
+ )
186
+
187
+ restored_params = self._restore(scaled_params, min_bounds, max_bounds)
188
+
189
+ params_t = torch.cat(
190
+ [
191
+ self._cat_restored_with_fixed_vector(restored_params),
192
+ self._cat_restored_with_fixed_vector(min_bounds),
193
+ self._cat_restored_with_fixed_vector(max_bounds),
194
+ ], -1
195
+ )
196
+
197
+ params = UniformSubPriorParams.from_tensor(params_t)
198
+
199
+ return params
200
+
201
+ def _cat_restored_with_fixed_vector(self, restored_t: Tensor) -> Tensor:
202
+ return self._cat_fitted_fixed_t(restored_t, self.fixed_params)
203
+
204
+ def _cat_fitted_fixed_t(self, fitted_t: Tensor, fixed_t: Tensor, fitted_mask: Tensor = None) -> Tensor:
205
+ if fitted_mask is None:
206
+ fitted_mask = self.fitted_mask
207
+ fixed_mask = self.fixed_mask
208
+ else:
209
+ fixed_mask = ~fitted_mask
210
+
211
+ total_num_params = self.fitted_mask.sum().item() + self.fixed_mask.sum().item()
212
+
213
+ batch_size = fitted_t.shape[0]
214
+
215
+ concat_t = torch.empty(
216
+ batch_size, total_num_params, device=fitted_t.device, dtype=fitted_t.dtype
217
+ )
218
+ concat_t[:, fitted_mask] = fitted_t
219
+ concat_t[:, fixed_mask] = fixed_t[None].expand(batch_size, -1)
220
+
221
+ return concat_t
222
+
223
+ def log_prob(self, params: UniformSubPriorParams) -> Tensor:
224
+ log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
225
+ indices = self.get_indices_within_bounds(params)
226
+ log_prob[~indices] = float('-inf')
227
+ return log_prob
228
+
229
+ def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
230
+ t_params = torch.cat([
231
+ params.thicknesses,
232
+ params.roughnesses,
233
+ params.slds
234
+ ], dim=-1)
235
+
236
+ indices = (
237
+ torch.all(t_params >= params.min_bounds, dim=-1) &
238
+ torch.all(t_params <= params.max_bounds, dim=-1)
239
+ )
240
+
241
+ return indices
242
+
243
+ def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
244
+ params = UniformSubPriorParams.from_tensor(
245
+ torch.cat([
246
+ torch.clamp(
247
+ params.as_tensor(add_bounds=False),
248
+ params.min_bounds, params.max_bounds
249
+ ),
250
+ params.min_bounds, params.max_bounds
251
+ ], dim=1)
252
+ )
253
+ return params
254
+
255
+ def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
256
+ return self.get_indices_within_bounds(params)
257
+
258
+ def sample_bounds(self, batch_size: int):
259
+ return self._sample_bounds(
260
+ batch_size,
261
+ self.min_bounds,
262
+ self.max_bounds,
263
+ self.min_deltas,
264
+ self.max_deltas,
265
+ self.fitted_mask,
266
+ )
267
+
268
+ def _sample_bounds(self, batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask):
269
+ if self.logdist:
270
+ widths_sampler_func = logdist_sampler
271
+ else:
272
+ widths_sampler_func = uniform_sampler
273
+
274
+ num_fitted = fitted_mask.sum().item()
275
+ num_fixed = fitted_mask.numel() - num_fitted
276
+
277
+ prior_widths = widths_sampler_func(
278
+ min_deltas[..., fitted_mask], max_deltas[..., fitted_mask],
279
+ batch_size, num_fitted,
280
+ device=self.device, dtype=self.dtype
281
+ )
282
+
283
+ prior_widths = self._cat_fitted_fixed_t(prior_widths, torch.zeros(num_fixed).to(prior_widths), fitted_mask)
284
+
285
+ prior_centers = uniform_sampler(
286
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
287
+ *prior_widths.shape,
288
+ device=self.device, dtype=self.dtype
289
+ )
290
+
291
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
292
+
293
+ return min_bounds, max_bounds
294
+
295
+ @staticmethod
296
+ def scale_bounds_with_q(bounds: Tensor, q_ratio: float) -> Tensor:
297
+ params = Params.from_tensor(torch.atleast_2d(bounds).clone())
298
+ params.scale_with_q(q_ratio)
299
+ return params.as_tensor().squeeze()
300
+
301
+ def clamp_bounds(self, bounds: Tensor) -> Tensor:
302
+ return torch.clamp(
303
+ torch.atleast_2d(bounds), torch.atleast_2d(self.min_bounds), torch.atleast_2d(self.max_bounds)
304
+ )
@@ -0,0 +1,201 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union, Tuple
8
+ from math import log
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from reflectorch.data_generation.priors.base import PriorSampler
14
+ from reflectorch.data_generation.priors.params import Params
15
+ from reflectorch.data_generation.utils import get_param_labels
16
+
17
+ __all__ = [
18
+ 'SingleParamPrior',
19
+ 'UniformParamPrior',
20
+ 'GaussianParamPrior',
21
+ 'TruncatedGaussianParamPrior',
22
+ 'SimplePriorSampler',
23
+
24
+ ]
25
+
26
+
27
+ class SingleParamPrior(object):
28
+ def sample(self, batch_num: int, device=None, dtype=None):
29
+ raise NotImplementedError
30
+
31
+ def log_prob(self, params: Tensor):
32
+ raise NotImplementedError
33
+
34
+ def to_conf(self):
35
+ vars_dict = {k: v for k, v in vars(self).items() if not k.startswith('_')}
36
+ return {
37
+ 'cls': self.__class__.__name__,
38
+ 'kwargs': vars_dict
39
+ }
40
+
41
+ def scale(self, params: Tensor) -> Tensor:
42
+ raise NotImplementedError
43
+
44
+ def restore(self, params: Tensor) -> Tensor:
45
+ raise NotImplementedError
46
+
47
+
48
+ class SimplePriorSampler(PriorSampler):
49
+ def __init__(self,
50
+ *params: Union[Tuple[float, float], SingleParamPrior],
51
+ device=None, dtype=None, param_cls=None, scaling_prior: PriorSampler = None,
52
+ ):
53
+ if param_cls is not None:
54
+ self.PARAM_CLS = param_cls
55
+ elif scaling_prior is not None:
56
+ self.PARAM_CLS = scaling_prior.PARAM_CLS
57
+ self.param_priors: Tuple[SingleParamPrior, ...] = self._init_params(params)
58
+ self._num_layers = self.PARAM_CLS.size2layers_num(len(params))
59
+ self.device = device
60
+ self.dtype = dtype
61
+ self.scaling_prior = scaling_prior
62
+
63
+ @property
64
+ def max_num_layers(self) -> int:
65
+ return self._num_layers
66
+
67
+ def _init_params(self, params: Tuple[Union[Tuple[float, float], SingleParamPrior], ...]):
68
+ assert len(params) == self.PARAM_CLS.layers_num2size(self.PARAM_CLS.size2layers_num(len(params)))
69
+ params = tuple(
70
+ param if isinstance(param, SingleParamPrior) else UniformParamPrior(*param)
71
+ for param in params
72
+ )
73
+ return params
74
+
75
+ def sample(self, batch_size: int) -> Params:
76
+ t_params = torch.stack([
77
+ param.sample(batch_size, device=self.device, dtype=self.dtype) for param in self.param_priors
78
+ ], -1)
79
+ params = self.PARAM_CLS.from_tensor(t_params)
80
+ return params
81
+
82
+ def scale_params(self, params: Params) -> Tensor:
83
+ if self.scaling_prior:
84
+ return self.scaling_prior.scale_params(params)
85
+
86
+ t_params = params.as_tensor()
87
+
88
+ scaled_params = torch.stack(
89
+ [param_prior.scale(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
90
+ )
91
+
92
+ return scaled_params
93
+
94
+ def restore_params(self, scaled_params: Tensor) -> Params:
95
+ if self.scaling_prior:
96
+ return self.scaling_prior.restore_params(scaled_params)
97
+ t_params = torch.stack(
98
+ [param_prior.restore(param) for param, param_prior in zip(scaled_params.T, self.param_priors)], -1
99
+ )
100
+ return self.PARAM_CLS.from_tensor(t_params)
101
+
102
+ def log_prob(self, params: Params) -> Tensor:
103
+ t_params = params.as_tensor()
104
+ log_probs = torch.stack(
105
+ [param_prior.log_prob(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
106
+ )
107
+ return log_probs.sum(1)
108
+
109
+ def get_indices_within_domain(self, params: Params) -> Tensor:
110
+ log_probs = self.log_prob(params)
111
+ return torch.isfinite(log_probs)
112
+
113
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
114
+ return self.get_indices_within_domain(params)
115
+
116
+ def __repr__(self):
117
+ layers_num = self.PARAM_CLS.size2layers_num(len(self.param_priors))
118
+ labels = get_param_labels(layers_num)
119
+ prior_str = '\n\t'.join(f'{label}: {param_prior}' for label, param_prior in zip(labels, self.param_priors))
120
+ return f'SimplePriorSampler(\n\t{prior_str}\n)'
121
+
122
+
123
+ class UniformParamPrior(SingleParamPrior):
124
+ def __init__(self, min_value: float, max_value: float, device=None, dtype=None):
125
+ assert min_value < max_value
126
+ self.min_value, self.max_value, self.delta = min_value, max_value, max_value - min_value
127
+ self._lob_prob_const = - log(self.delta)
128
+ self.device = device
129
+ self.dtype = dtype
130
+
131
+ def sample(self, batch_num: int, device=None, dtype=None):
132
+ params = torch.rand(
133
+ batch_num, device=(device or self.device), dtype=(dtype or self.dtype)
134
+ ) * self.delta + self.min_value
135
+ return params
136
+
137
+ def log_prob(self, params: Tensor):
138
+ log_probs = torch.fill_(torch.ones_like(params), self._lob_prob_const)
139
+ log_probs[(params < self.min_value) | (params > self.max_value)] = - float('inf')
140
+ return log_probs
141
+
142
+ def scale(self, params: Tensor) -> Tensor:
143
+ return (params - self.min_value) / self.delta
144
+
145
+ def restore(self, params: Tensor) -> Tensor:
146
+ return params * self.delta + self.min_value
147
+
148
+ def __repr__(self):
149
+ return f'UniformParamPrior(min={self.min_value}, max={self.max_value})'
150
+
151
+
152
+ class GaussianParamPrior(SingleParamPrior):
153
+ _GAUSS_SCALE_CONST: float = 4.
154
+
155
+ def __init__(self, mean: float, std: float, device=None, dtype=None):
156
+ assert std > 0
157
+ self.mean = mean
158
+ self.std = std
159
+ self.device = device
160
+ self.dtype = dtype
161
+
162
+ def sample(self, batch_num: int, device=None, dtype=None):
163
+ params = torch.normal(
164
+ self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
165
+ )
166
+ return params
167
+
168
+ def log_prob(self, params: Tensor):
169
+ # ignore constant
170
+ log_probs = - ((params - self.mean) / self.std) ** 2 / 2
171
+ return log_probs
172
+
173
+ def scale(self, params: Tensor) -> Tensor:
174
+ return (params - self.mean) / (self.std * self._GAUSS_SCALE_CONST)
175
+
176
+ def restore(self, params: Tensor) -> Tensor:
177
+ return params * self.std * self._GAUSS_SCALE_CONST + self.mean
178
+
179
+ def __repr__(self):
180
+ return f'GaussianParamPrior(mean={self.mean}, std={self.std})'
181
+
182
+
183
+ class TruncatedGaussianParamPrior(GaussianParamPrior):
184
+ def sample(self, batch_num: int, device=None, dtype=None) -> Tensor:
185
+ params = torch.normal(
186
+ self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
187
+ )
188
+ negative_params = params < 0.
189
+ num_negative_params = negative_params.sum().item()
190
+ if num_negative_params:
191
+ params[negative_params] = self.sample(num_negative_params, device=device, dtype=dtype)
192
+ return params
193
+
194
+ def log_prob(self, params: Tensor) -> Tensor:
195
+ # ignore constant
196
+ log_probs: Tensor = - ((params - self.mean) / self.std) ** 2 / 2
197
+ log_probs[params < 0.] = - float('inf')
198
+ return log_probs
199
+
200
+ def __repr__(self):
201
+ return f'TruncatedGaussianParamPrior(mean={self.mean}, std={self.std})'