reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,195 +1,195 @@
|
|
|
1
|
-
from typing import Union, Tuple
|
|
2
|
-
from math import log
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
-
from reflectorch.data_generation.priors.params import Params
|
|
9
|
-
from reflectorch.data_generation.utils import get_param_labels
|
|
10
|
-
|
|
11
|
-
__all__ = [
|
|
12
|
-
'SingleParamPrior',
|
|
13
|
-
'UniformParamPrior',
|
|
14
|
-
'GaussianParamPrior',
|
|
15
|
-
'TruncatedGaussianParamPrior',
|
|
16
|
-
'SimplePriorSampler',
|
|
17
|
-
|
|
18
|
-
]
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class SingleParamPrior(object):
|
|
22
|
-
def sample(self, batch_num: int, device=None, dtype=None):
|
|
23
|
-
raise NotImplementedError
|
|
24
|
-
|
|
25
|
-
def log_prob(self, params: Tensor):
|
|
26
|
-
raise NotImplementedError
|
|
27
|
-
|
|
28
|
-
def to_conf(self):
|
|
29
|
-
vars_dict = {k: v for k, v in vars(self).items() if not k.startswith('_')}
|
|
30
|
-
return {
|
|
31
|
-
'cls': self.__class__.__name__,
|
|
32
|
-
'kwargs': vars_dict
|
|
33
|
-
}
|
|
34
|
-
|
|
35
|
-
def scale(self, params: Tensor) -> Tensor:
|
|
36
|
-
raise NotImplementedError
|
|
37
|
-
|
|
38
|
-
def restore(self, params: Tensor) -> Tensor:
|
|
39
|
-
raise NotImplementedError
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class SimplePriorSampler(PriorSampler):
|
|
43
|
-
def __init__(self,
|
|
44
|
-
*params: Union[Tuple[float, float], SingleParamPrior],
|
|
45
|
-
device=None, dtype=None, param_cls=None, scaling_prior: PriorSampler = None,
|
|
46
|
-
):
|
|
47
|
-
if param_cls is not None:
|
|
48
|
-
self.PARAM_CLS = param_cls
|
|
49
|
-
elif scaling_prior is not None:
|
|
50
|
-
self.PARAM_CLS = scaling_prior.PARAM_CLS
|
|
51
|
-
self.param_priors: Tuple[SingleParamPrior, ...] = self._init_params(params)
|
|
52
|
-
self._num_layers = self.PARAM_CLS.size2layers_num(len(params))
|
|
53
|
-
self.device = device
|
|
54
|
-
self.dtype = dtype
|
|
55
|
-
self.scaling_prior = scaling_prior
|
|
56
|
-
|
|
57
|
-
@property
|
|
58
|
-
def max_num_layers(self) -> int:
|
|
59
|
-
return self._num_layers
|
|
60
|
-
|
|
61
|
-
def _init_params(self, params: Tuple[Union[Tuple[float, float], SingleParamPrior], ...]):
|
|
62
|
-
assert len(params) == self.PARAM_CLS.layers_num2size(self.PARAM_CLS.size2layers_num(len(params)))
|
|
63
|
-
params = tuple(
|
|
64
|
-
param if isinstance(param, SingleParamPrior) else UniformParamPrior(*param)
|
|
65
|
-
for param in params
|
|
66
|
-
)
|
|
67
|
-
return params
|
|
68
|
-
|
|
69
|
-
def sample(self, batch_size: int) -> Params:
|
|
70
|
-
t_params = torch.stack([
|
|
71
|
-
param.sample(batch_size, device=self.device, dtype=self.dtype) for param in self.param_priors
|
|
72
|
-
], -1)
|
|
73
|
-
params = self.PARAM_CLS.from_tensor(t_params)
|
|
74
|
-
return params
|
|
75
|
-
|
|
76
|
-
def scale_params(self, params: Params) -> Tensor:
|
|
77
|
-
if self.scaling_prior:
|
|
78
|
-
return self.scaling_prior.scale_params(params)
|
|
79
|
-
|
|
80
|
-
t_params = params.as_tensor()
|
|
81
|
-
|
|
82
|
-
scaled_params = torch.stack(
|
|
83
|
-
[param_prior.scale(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
return scaled_params
|
|
87
|
-
|
|
88
|
-
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
89
|
-
if self.scaling_prior:
|
|
90
|
-
return self.scaling_prior.restore_params(scaled_params)
|
|
91
|
-
t_params = torch.stack(
|
|
92
|
-
[param_prior.restore(param) for param, param_prior in zip(scaled_params.T, self.param_priors)], -1
|
|
93
|
-
)
|
|
94
|
-
return self.PARAM_CLS.from_tensor(t_params)
|
|
95
|
-
|
|
96
|
-
def log_prob(self, params: Params) -> Tensor:
|
|
97
|
-
t_params = params.as_tensor()
|
|
98
|
-
log_probs = torch.stack(
|
|
99
|
-
[param_prior.log_prob(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
|
|
100
|
-
)
|
|
101
|
-
return log_probs.sum(1)
|
|
102
|
-
|
|
103
|
-
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
104
|
-
log_probs = self.log_prob(params)
|
|
105
|
-
return torch.isfinite(log_probs)
|
|
106
|
-
|
|
107
|
-
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
108
|
-
return self.get_indices_within_domain(params)
|
|
109
|
-
|
|
110
|
-
def __repr__(self):
|
|
111
|
-
layers_num = self.PARAM_CLS.size2layers_num(len(self.param_priors))
|
|
112
|
-
labels = get_param_labels(layers_num)
|
|
113
|
-
prior_str = '\n\t'.join(f'{label}: {param_prior}' for label, param_prior in zip(labels, self.param_priors))
|
|
114
|
-
return f'SimplePriorSampler(\n\t{prior_str}\n)'
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class UniformParamPrior(SingleParamPrior):
|
|
118
|
-
def __init__(self, min_value: float, max_value: float, device=None, dtype=None):
|
|
119
|
-
assert min_value < max_value
|
|
120
|
-
self.min_value, self.max_value, self.delta = min_value, max_value, max_value - min_value
|
|
121
|
-
self._lob_prob_const = - log(self.delta)
|
|
122
|
-
self.device = device
|
|
123
|
-
self.dtype = dtype
|
|
124
|
-
|
|
125
|
-
def sample(self, batch_num: int, device=None, dtype=None):
|
|
126
|
-
params = torch.rand(
|
|
127
|
-
batch_num, device=(device or self.device), dtype=(dtype or self.dtype)
|
|
128
|
-
) * self.delta + self.min_value
|
|
129
|
-
return params
|
|
130
|
-
|
|
131
|
-
def log_prob(self, params: Tensor):
|
|
132
|
-
log_probs = torch.fill_(torch.ones_like(params), self._lob_prob_const)
|
|
133
|
-
log_probs[(params < self.min_value) | (params > self.max_value)] = - float('inf')
|
|
134
|
-
return log_probs
|
|
135
|
-
|
|
136
|
-
def scale(self, params: Tensor) -> Tensor:
|
|
137
|
-
return (params - self.min_value) / self.delta
|
|
138
|
-
|
|
139
|
-
def restore(self, params: Tensor) -> Tensor:
|
|
140
|
-
return params * self.delta + self.min_value
|
|
141
|
-
|
|
142
|
-
def __repr__(self):
|
|
143
|
-
return f'UniformParamPrior(min={self.min_value}, max={self.max_value})'
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class GaussianParamPrior(SingleParamPrior):
|
|
147
|
-
_GAUSS_SCALE_CONST: float = 4.
|
|
148
|
-
|
|
149
|
-
def __init__(self, mean: float, std: float, device=None, dtype=None):
|
|
150
|
-
assert std > 0
|
|
151
|
-
self.mean = mean
|
|
152
|
-
self.std = std
|
|
153
|
-
self.device = device
|
|
154
|
-
self.dtype = dtype
|
|
155
|
-
|
|
156
|
-
def sample(self, batch_num: int, device=None, dtype=None):
|
|
157
|
-
params = torch.normal(
|
|
158
|
-
self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
|
|
159
|
-
)
|
|
160
|
-
return params
|
|
161
|
-
|
|
162
|
-
def log_prob(self, params: Tensor):
|
|
163
|
-
# ignore constant
|
|
164
|
-
log_probs = - ((params - self.mean) / self.std) ** 2 / 2
|
|
165
|
-
return log_probs
|
|
166
|
-
|
|
167
|
-
def scale(self, params: Tensor) -> Tensor:
|
|
168
|
-
return (params - self.mean) / (self.std * self._GAUSS_SCALE_CONST)
|
|
169
|
-
|
|
170
|
-
def restore(self, params: Tensor) -> Tensor:
|
|
171
|
-
return params * self.std * self._GAUSS_SCALE_CONST + self.mean
|
|
172
|
-
|
|
173
|
-
def __repr__(self):
|
|
174
|
-
return f'GaussianParamPrior(mean={self.mean}, std={self.std})'
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
class TruncatedGaussianParamPrior(GaussianParamPrior):
|
|
178
|
-
def sample(self, batch_num: int, device=None, dtype=None) -> Tensor:
|
|
179
|
-
params = torch.normal(
|
|
180
|
-
self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
|
|
181
|
-
)
|
|
182
|
-
negative_params = params < 0.
|
|
183
|
-
num_negative_params = negative_params.sum().item()
|
|
184
|
-
if num_negative_params:
|
|
185
|
-
params[negative_params] = self.sample(num_negative_params, device=device, dtype=dtype)
|
|
186
|
-
return params
|
|
187
|
-
|
|
188
|
-
def log_prob(self, params: Tensor) -> Tensor:
|
|
189
|
-
# ignore constant
|
|
190
|
-
log_probs: Tensor = - ((params - self.mean) / self.std) ** 2 / 2
|
|
191
|
-
log_probs[params < 0.] = - float('inf')
|
|
192
|
-
return log_probs
|
|
193
|
-
|
|
194
|
-
def __repr__(self):
|
|
195
|
-
return f'TruncatedGaussianParamPrior(mean={self.mean}, std={self.std})'
|
|
1
|
+
from typing import Union, Tuple
|
|
2
|
+
from math import log
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
+
from reflectorch.data_generation.priors.params import Params
|
|
9
|
+
from reflectorch.data_generation.utils import get_param_labels
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'SingleParamPrior',
|
|
13
|
+
'UniformParamPrior',
|
|
14
|
+
'GaussianParamPrior',
|
|
15
|
+
'TruncatedGaussianParamPrior',
|
|
16
|
+
'SimplePriorSampler',
|
|
17
|
+
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SingleParamPrior(object):
|
|
22
|
+
def sample(self, batch_num: int, device=None, dtype=None):
|
|
23
|
+
raise NotImplementedError
|
|
24
|
+
|
|
25
|
+
def log_prob(self, params: Tensor):
|
|
26
|
+
raise NotImplementedError
|
|
27
|
+
|
|
28
|
+
def to_conf(self):
|
|
29
|
+
vars_dict = {k: v for k, v in vars(self).items() if not k.startswith('_')}
|
|
30
|
+
return {
|
|
31
|
+
'cls': self.__class__.__name__,
|
|
32
|
+
'kwargs': vars_dict
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def scale(self, params: Tensor) -> Tensor:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
def restore(self, params: Tensor) -> Tensor:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SimplePriorSampler(PriorSampler):
|
|
43
|
+
def __init__(self,
|
|
44
|
+
*params: Union[Tuple[float, float], SingleParamPrior],
|
|
45
|
+
device=None, dtype=None, param_cls=None, scaling_prior: PriorSampler = None,
|
|
46
|
+
):
|
|
47
|
+
if param_cls is not None:
|
|
48
|
+
self.PARAM_CLS = param_cls
|
|
49
|
+
elif scaling_prior is not None:
|
|
50
|
+
self.PARAM_CLS = scaling_prior.PARAM_CLS
|
|
51
|
+
self.param_priors: Tuple[SingleParamPrior, ...] = self._init_params(params)
|
|
52
|
+
self._num_layers = self.PARAM_CLS.size2layers_num(len(params))
|
|
53
|
+
self.device = device
|
|
54
|
+
self.dtype = dtype
|
|
55
|
+
self.scaling_prior = scaling_prior
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def max_num_layers(self) -> int:
|
|
59
|
+
return self._num_layers
|
|
60
|
+
|
|
61
|
+
def _init_params(self, params: Tuple[Union[Tuple[float, float], SingleParamPrior], ...]):
|
|
62
|
+
assert len(params) == self.PARAM_CLS.layers_num2size(self.PARAM_CLS.size2layers_num(len(params)))
|
|
63
|
+
params = tuple(
|
|
64
|
+
param if isinstance(param, SingleParamPrior) else UniformParamPrior(*param)
|
|
65
|
+
for param in params
|
|
66
|
+
)
|
|
67
|
+
return params
|
|
68
|
+
|
|
69
|
+
def sample(self, batch_size: int) -> Params:
|
|
70
|
+
t_params = torch.stack([
|
|
71
|
+
param.sample(batch_size, device=self.device, dtype=self.dtype) for param in self.param_priors
|
|
72
|
+
], -1)
|
|
73
|
+
params = self.PARAM_CLS.from_tensor(t_params)
|
|
74
|
+
return params
|
|
75
|
+
|
|
76
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
77
|
+
if self.scaling_prior:
|
|
78
|
+
return self.scaling_prior.scale_params(params)
|
|
79
|
+
|
|
80
|
+
t_params = params.as_tensor()
|
|
81
|
+
|
|
82
|
+
scaled_params = torch.stack(
|
|
83
|
+
[param_prior.scale(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return scaled_params
|
|
87
|
+
|
|
88
|
+
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
89
|
+
if self.scaling_prior:
|
|
90
|
+
return self.scaling_prior.restore_params(scaled_params)
|
|
91
|
+
t_params = torch.stack(
|
|
92
|
+
[param_prior.restore(param) for param, param_prior in zip(scaled_params.T, self.param_priors)], -1
|
|
93
|
+
)
|
|
94
|
+
return self.PARAM_CLS.from_tensor(t_params)
|
|
95
|
+
|
|
96
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
97
|
+
t_params = params.as_tensor()
|
|
98
|
+
log_probs = torch.stack(
|
|
99
|
+
[param_prior.log_prob(param) for param, param_prior in zip(t_params.T, self.param_priors)], -1
|
|
100
|
+
)
|
|
101
|
+
return log_probs.sum(1)
|
|
102
|
+
|
|
103
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
104
|
+
log_probs = self.log_prob(params)
|
|
105
|
+
return torch.isfinite(log_probs)
|
|
106
|
+
|
|
107
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
108
|
+
return self.get_indices_within_domain(params)
|
|
109
|
+
|
|
110
|
+
def __repr__(self):
|
|
111
|
+
layers_num = self.PARAM_CLS.size2layers_num(len(self.param_priors))
|
|
112
|
+
labels = get_param_labels(layers_num)
|
|
113
|
+
prior_str = '\n\t'.join(f'{label}: {param_prior}' for label, param_prior in zip(labels, self.param_priors))
|
|
114
|
+
return f'SimplePriorSampler(\n\t{prior_str}\n)'
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class UniformParamPrior(SingleParamPrior):
|
|
118
|
+
def __init__(self, min_value: float, max_value: float, device=None, dtype=None):
|
|
119
|
+
assert min_value < max_value
|
|
120
|
+
self.min_value, self.max_value, self.delta = min_value, max_value, max_value - min_value
|
|
121
|
+
self._lob_prob_const = - log(self.delta)
|
|
122
|
+
self.device = device
|
|
123
|
+
self.dtype = dtype
|
|
124
|
+
|
|
125
|
+
def sample(self, batch_num: int, device=None, dtype=None):
|
|
126
|
+
params = torch.rand(
|
|
127
|
+
batch_num, device=(device or self.device), dtype=(dtype or self.dtype)
|
|
128
|
+
) * self.delta + self.min_value
|
|
129
|
+
return params
|
|
130
|
+
|
|
131
|
+
def log_prob(self, params: Tensor):
|
|
132
|
+
log_probs = torch.fill_(torch.ones_like(params), self._lob_prob_const)
|
|
133
|
+
log_probs[(params < self.min_value) | (params > self.max_value)] = - float('inf')
|
|
134
|
+
return log_probs
|
|
135
|
+
|
|
136
|
+
def scale(self, params: Tensor) -> Tensor:
|
|
137
|
+
return (params - self.min_value) / self.delta
|
|
138
|
+
|
|
139
|
+
def restore(self, params: Tensor) -> Tensor:
|
|
140
|
+
return params * self.delta + self.min_value
|
|
141
|
+
|
|
142
|
+
def __repr__(self):
|
|
143
|
+
return f'UniformParamPrior(min={self.min_value}, max={self.max_value})'
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class GaussianParamPrior(SingleParamPrior):
|
|
147
|
+
_GAUSS_SCALE_CONST: float = 4.
|
|
148
|
+
|
|
149
|
+
def __init__(self, mean: float, std: float, device=None, dtype=None):
|
|
150
|
+
assert std > 0
|
|
151
|
+
self.mean = mean
|
|
152
|
+
self.std = std
|
|
153
|
+
self.device = device
|
|
154
|
+
self.dtype = dtype
|
|
155
|
+
|
|
156
|
+
def sample(self, batch_num: int, device=None, dtype=None):
|
|
157
|
+
params = torch.normal(
|
|
158
|
+
self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
|
|
159
|
+
)
|
|
160
|
+
return params
|
|
161
|
+
|
|
162
|
+
def log_prob(self, params: Tensor):
|
|
163
|
+
# ignore constant
|
|
164
|
+
log_probs = - ((params - self.mean) / self.std) ** 2 / 2
|
|
165
|
+
return log_probs
|
|
166
|
+
|
|
167
|
+
def scale(self, params: Tensor) -> Tensor:
|
|
168
|
+
return (params - self.mean) / (self.std * self._GAUSS_SCALE_CONST)
|
|
169
|
+
|
|
170
|
+
def restore(self, params: Tensor) -> Tensor:
|
|
171
|
+
return params * self.std * self._GAUSS_SCALE_CONST + self.mean
|
|
172
|
+
|
|
173
|
+
def __repr__(self):
|
|
174
|
+
return f'GaussianParamPrior(mean={self.mean}, std={self.std})'
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class TruncatedGaussianParamPrior(GaussianParamPrior):
|
|
178
|
+
def sample(self, batch_num: int, device=None, dtype=None) -> Tensor:
|
|
179
|
+
params = torch.normal(
|
|
180
|
+
self.mean, self.std, (batch_num,), device=(device or self.device), dtype=(dtype or self.dtype)
|
|
181
|
+
)
|
|
182
|
+
negative_params = params < 0.
|
|
183
|
+
num_negative_params = negative_params.sum().item()
|
|
184
|
+
if num_negative_params:
|
|
185
|
+
params[negative_params] = self.sample(num_negative_params, device=device, dtype=dtype)
|
|
186
|
+
return params
|
|
187
|
+
|
|
188
|
+
def log_prob(self, params: Tensor) -> Tensor:
|
|
189
|
+
# ignore constant
|
|
190
|
+
log_probs: Tensor = - ((params - self.mean) / self.std) ** 2 / 2
|
|
191
|
+
log_probs[params < 0.] = - float('inf')
|
|
192
|
+
return log_probs
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
return f'TruncatedGaussianParamPrior(mean={self.mean}, std={self.std})'
|