reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,104 +1,104 @@
|
|
|
1
|
-
from typing import Tuple, Dict
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
-
from reflectorch.data_generation.priors.params import Params
|
|
9
|
-
from reflectorch.data_generation.priors.no_constraints import (
|
|
10
|
-
DEFAULT_DEVICE,
|
|
11
|
-
DEFAULT_DTYPE,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
|
|
15
|
-
from reflectorch.utils import to_t
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class MultilayerStructureParams(Params):
|
|
19
|
-
pass
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class SimpleMultilayerSampler(PriorSampler):
|
|
23
|
-
PARAM_CLS = MultilayerStructureParams
|
|
24
|
-
|
|
25
|
-
def __init__(self,
|
|
26
|
-
params: Dict[str, Tuple[float, float]],
|
|
27
|
-
model_name: str,
|
|
28
|
-
device: torch.device = DEFAULT_DEVICE,
|
|
29
|
-
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
30
|
-
max_num_layers: int = 50,
|
|
31
|
-
):
|
|
32
|
-
self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
|
|
33
|
-
self.device = device
|
|
34
|
-
self.dtype = dtype
|
|
35
|
-
self.num_layers = max_num_layers
|
|
36
|
-
ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
|
|
37
|
-
self._np_bounds = np.array(ordered_bounds).T
|
|
38
|
-
self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
|
|
39
|
-
self._param_dim = len(params)
|
|
40
|
-
|
|
41
|
-
@property
|
|
42
|
-
def max_num_layers(self) -> int:
|
|
43
|
-
return self.num_layers
|
|
44
|
-
|
|
45
|
-
@property
|
|
46
|
-
def param_dim(self) -> int:
|
|
47
|
-
return self._param_dim
|
|
48
|
-
|
|
49
|
-
def sample(self, batch_size: int) -> MultilayerStructureParams:
|
|
50
|
-
return self.optimized_sample(batch_size)[0]
|
|
51
|
-
|
|
52
|
-
def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
53
|
-
scaled_params = torch.rand(
|
|
54
|
-
batch_size,
|
|
55
|
-
self.min_bounds.shape[-1],
|
|
56
|
-
device=self.min_bounds.device,
|
|
57
|
-
dtype=self.min_bounds.dtype,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
targets = self.restore_params(scaled_params)
|
|
61
|
-
|
|
62
|
-
return targets, scaled_params
|
|
63
|
-
|
|
64
|
-
def get_np_bounds(self):
|
|
65
|
-
return np.array(self._np_bounds)
|
|
66
|
-
|
|
67
|
-
def restore_np_params(self, params: np.ndarray):
|
|
68
|
-
p = self.multilayer_model.to_standard_params(
|
|
69
|
-
torch.atleast_2d(to_t(params))
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
return {
|
|
73
|
-
'thickness': p['thicknesses'].squeeze().cpu().numpy(),
|
|
74
|
-
'roughness': p['roughnesses'].squeeze().cpu().numpy(),
|
|
75
|
-
'sld': p['slds'].squeeze().cpu().numpy()
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
|
|
79
|
-
return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
|
|
80
|
-
|
|
81
|
-
def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
|
|
82
|
-
return self.to_standard_params(self.restore_params2parametrized(scaled_params))
|
|
83
|
-
|
|
84
|
-
def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
|
|
85
|
-
return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
|
|
86
|
-
|
|
87
|
-
def scale_params(self, params: Params) -> Tensor:
|
|
88
|
-
raise NotImplementedError
|
|
89
|
-
|
|
90
|
-
def log_prob(self, params: Params) -> Tensor:
|
|
91
|
-
raise NotImplementedError
|
|
92
|
-
|
|
93
|
-
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
94
|
-
raise NotImplementedError
|
|
95
|
-
|
|
96
|
-
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
97
|
-
raise NotImplementedError
|
|
98
|
-
|
|
99
|
-
def filter_params(self, params: Params) -> Params:
|
|
100
|
-
indices = self.get_indices_within_domain(params)
|
|
101
|
-
return params[indices]
|
|
102
|
-
|
|
103
|
-
def clamp_params(self, params: Params) -> Params:
|
|
104
|
-
raise NotImplementedError
|
|
1
|
+
from typing import Tuple, Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
+
from reflectorch.data_generation.priors.params import Params
|
|
9
|
+
from reflectorch.data_generation.priors.no_constraints import (
|
|
10
|
+
DEFAULT_DEVICE,
|
|
11
|
+
DEFAULT_DTYPE,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
|
|
15
|
+
from reflectorch.utils import to_t
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MultilayerStructureParams(Params):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SimpleMultilayerSampler(PriorSampler):
|
|
23
|
+
PARAM_CLS = MultilayerStructureParams
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
params: Dict[str, Tuple[float, float]],
|
|
27
|
+
model_name: str,
|
|
28
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
29
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
30
|
+
max_num_layers: int = 50,
|
|
31
|
+
):
|
|
32
|
+
self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
|
|
33
|
+
self.device = device
|
|
34
|
+
self.dtype = dtype
|
|
35
|
+
self.num_layers = max_num_layers
|
|
36
|
+
ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
|
|
37
|
+
self._np_bounds = np.array(ordered_bounds).T
|
|
38
|
+
self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
|
|
39
|
+
self._param_dim = len(params)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def max_num_layers(self) -> int:
|
|
43
|
+
return self.num_layers
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def param_dim(self) -> int:
|
|
47
|
+
return self._param_dim
|
|
48
|
+
|
|
49
|
+
def sample(self, batch_size: int) -> MultilayerStructureParams:
|
|
50
|
+
return self.optimized_sample(batch_size)[0]
|
|
51
|
+
|
|
52
|
+
def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
53
|
+
scaled_params = torch.rand(
|
|
54
|
+
batch_size,
|
|
55
|
+
self.min_bounds.shape[-1],
|
|
56
|
+
device=self.min_bounds.device,
|
|
57
|
+
dtype=self.min_bounds.dtype,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
targets = self.restore_params(scaled_params)
|
|
61
|
+
|
|
62
|
+
return targets, scaled_params
|
|
63
|
+
|
|
64
|
+
def get_np_bounds(self):
|
|
65
|
+
return np.array(self._np_bounds)
|
|
66
|
+
|
|
67
|
+
def restore_np_params(self, params: np.ndarray):
|
|
68
|
+
p = self.multilayer_model.to_standard_params(
|
|
69
|
+
torch.atleast_2d(to_t(params))
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
'thickness': p['thicknesses'].squeeze().cpu().numpy(),
|
|
74
|
+
'roughness': p['roughnesses'].squeeze().cpu().numpy(),
|
|
75
|
+
'sld': p['slds'].squeeze().cpu().numpy()
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
|
|
79
|
+
return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
|
|
80
|
+
|
|
81
|
+
def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
|
|
82
|
+
return self.to_standard_params(self.restore_params2parametrized(scaled_params))
|
|
83
|
+
|
|
84
|
+
def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
|
|
85
|
+
return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
|
|
86
|
+
|
|
87
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
|
|
96
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
def filter_params(self, params: Params) -> Params:
|
|
100
|
+
indices = self.get_indices_within_domain(params)
|
|
101
|
+
return params[indices]
|
|
102
|
+
|
|
103
|
+
def clamp_params(self, params: Params) -> Params:
|
|
104
|
+
raise NotImplementedError
|
|
@@ -1,206 +1,206 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from functools import lru_cache
|
|
3
|
-
from typing import Tuple
|
|
4
|
-
from math import sqrt
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from torch import Tensor
|
|
8
|
-
|
|
9
|
-
from reflectorch.data_generation.utils import (
|
|
10
|
-
get_slds_from_d_rhos,
|
|
11
|
-
uniform_sampler,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from reflectorch.data_generation.priors.utils import (
|
|
15
|
-
get_allowed_roughness_indices,
|
|
16
|
-
generate_roughnesses,
|
|
17
|
-
params_within_bounds,
|
|
18
|
-
)
|
|
19
|
-
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
20
|
-
from reflectorch.data_generation.priors.base import PriorSampler
|
|
21
|
-
from reflectorch.data_generation.priors.params import Params
|
|
22
|
-
|
|
23
|
-
__all__ = [
|
|
24
|
-
"BasicPriorSampler",
|
|
25
|
-
"DEFAULT_ROUGHNESS_RANGE",
|
|
26
|
-
"DEFAULT_THICKNESS_RANGE",
|
|
27
|
-
"DEFAULT_SLD_RANGE",
|
|
28
|
-
"DEFAULT_NUM_LAYERS",
|
|
29
|
-
"DEFAULT_DEVICE",
|
|
30
|
-
"DEFAULT_DTYPE",
|
|
31
|
-
"DEFAULT_SCALED_RANGE",
|
|
32
|
-
"DEFAULT_USE_DRHO",
|
|
33
|
-
]
|
|
34
|
-
|
|
35
|
-
DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
|
|
36
|
-
DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
|
|
37
|
-
DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
|
|
38
|
-
DEFAULT_NUM_LAYERS: int = 5
|
|
39
|
-
DEFAULT_USE_DRHO: bool = False
|
|
40
|
-
DEFAULT_DEVICE: torch.device = torch.device('cuda')
|
|
41
|
-
DEFAULT_DTYPE: torch.dtype = torch.float64
|
|
42
|
-
DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class BasicPriorSampler(PriorSampler, ScalerMixin):
|
|
46
|
-
"""Prior samplers for thicknesses, roughnesses and slds"""
|
|
47
|
-
def __init__(self,
|
|
48
|
-
thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
|
|
49
|
-
roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
|
|
50
|
-
sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
|
|
51
|
-
num_layers: int = DEFAULT_NUM_LAYERS,
|
|
52
|
-
use_drho: bool = DEFAULT_USE_DRHO,
|
|
53
|
-
device: torch.device = DEFAULT_DEVICE,
|
|
54
|
-
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
55
|
-
scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
|
|
56
|
-
restrict_roughnesses: bool = True,
|
|
57
|
-
):
|
|
58
|
-
self.logger = logging.getLogger(__name__)
|
|
59
|
-
self.thickness_range = thickness_range
|
|
60
|
-
self.roughness_range = roughness_range
|
|
61
|
-
self.sld_range = sld_range
|
|
62
|
-
self.num_layers = num_layers
|
|
63
|
-
self.device = device
|
|
64
|
-
self.dtype = dtype
|
|
65
|
-
self.scaled_range = scaled_range
|
|
66
|
-
self.use_drho = use_drho
|
|
67
|
-
self.restrict_roughnesses = restrict_roughnesses
|
|
68
|
-
|
|
69
|
-
@property
|
|
70
|
-
def max_num_layers(self) -> int:
|
|
71
|
-
return self.num_layers
|
|
72
|
-
|
|
73
|
-
@lru_cache()
|
|
74
|
-
def min_vector(self, layers_num, drho: bool = False):
|
|
75
|
-
if drho:
|
|
76
|
-
sld_min = self.sld_range[0] - self.sld_range[1]
|
|
77
|
-
else:
|
|
78
|
-
sld_min = self.sld_range[0]
|
|
79
|
-
|
|
80
|
-
return torch.tensor(
|
|
81
|
-
[self.thickness_range[0]] * layers_num +
|
|
82
|
-
[self.roughness_range[0]] * (layers_num + 1) +
|
|
83
|
-
[sld_min] * (layers_num + 1),
|
|
84
|
-
device=self.device,
|
|
85
|
-
dtype=self.dtype
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
@lru_cache()
|
|
89
|
-
def max_vector(self, layers_num, drho: bool = False):
|
|
90
|
-
if drho:
|
|
91
|
-
sld_max = self.sld_range[1] - self.sld_range[0]
|
|
92
|
-
else:
|
|
93
|
-
sld_max = self.sld_range[1]
|
|
94
|
-
return torch.tensor(
|
|
95
|
-
[self.thickness_range[1]] * layers_num +
|
|
96
|
-
[self.roughness_range[1]] * (layers_num + 1) +
|
|
97
|
-
[sld_max] * (layers_num + 1),
|
|
98
|
-
device=self.device,
|
|
99
|
-
dtype=self.dtype
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
@lru_cache()
|
|
103
|
-
def delta_vector(self, layers_num, drho: bool = False):
|
|
104
|
-
return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
|
|
105
|
-
|
|
106
|
-
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
107
|
-
layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
|
|
108
|
-
|
|
109
|
-
params_t = self._restore(
|
|
110
|
-
scaled_params,
|
|
111
|
-
self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
112
|
-
self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
params = self.PARAM_CLS.from_tensor(params_t)
|
|
116
|
-
|
|
117
|
-
if self.use_drho:
|
|
118
|
-
params.slds = get_slds_from_d_rhos(params.slds)
|
|
119
|
-
|
|
120
|
-
return params
|
|
121
|
-
|
|
122
|
-
def scale_params(self, params: Params) -> Tensor:
|
|
123
|
-
layers_num = params.max_layer_num
|
|
124
|
-
|
|
125
|
-
return self._scale(
|
|
126
|
-
params.as_tensor(use_drho=self.use_drho),
|
|
127
|
-
self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
128
|
-
self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
132
|
-
layer_num = params.max_layer_num
|
|
133
|
-
|
|
134
|
-
return params_within_bounds(
|
|
135
|
-
params.as_tensor(),
|
|
136
|
-
self.min_vector(layer_num),
|
|
137
|
-
self.max_vector(layer_num),
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
141
|
-
if self.restrict_roughnesses:
|
|
142
|
-
indices = (
|
|
143
|
-
self.get_indices_within_bounds(params) &
|
|
144
|
-
self.get_allowed_roughness_indices(params)
|
|
145
|
-
)
|
|
146
|
-
else:
|
|
147
|
-
indices = self.get_indices_within_bounds(params)
|
|
148
|
-
return indices
|
|
149
|
-
|
|
150
|
-
def clamp_params(self, params: Params) -> Params:
|
|
151
|
-
layer_num = params.max_layer_num
|
|
152
|
-
params = params.as_tensor()
|
|
153
|
-
params = torch.clamp(
|
|
154
|
-
params,
|
|
155
|
-
self.min_vector(layer_num),
|
|
156
|
-
self.max_vector(layer_num),
|
|
157
|
-
)
|
|
158
|
-
params = Params.from_tensor(params)
|
|
159
|
-
return params
|
|
160
|
-
|
|
161
|
-
@staticmethod
|
|
162
|
-
def get_allowed_roughness_indices(params: Params) -> Tensor:
|
|
163
|
-
return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
|
|
164
|
-
|
|
165
|
-
def log_prob(self, params: Params) -> Tensor:
|
|
166
|
-
# so far we ignore non-uniform distribution of roughnesses and slds.
|
|
167
|
-
log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
|
|
168
|
-
indices = self.get_indices_within_bounds(params)
|
|
169
|
-
log_prob[~indices] = float('-inf')
|
|
170
|
-
return log_prob
|
|
171
|
-
|
|
172
|
-
def sample(self, batch_size: int) -> Params:
|
|
173
|
-
slds = self.generate_slds(batch_size)
|
|
174
|
-
thicknesses = self.generate_thicknesses(batch_size)
|
|
175
|
-
roughnesses = self.generate_roughnesses(thicknesses)
|
|
176
|
-
|
|
177
|
-
params = Params(thicknesses, roughnesses, slds)
|
|
178
|
-
|
|
179
|
-
return params
|
|
180
|
-
|
|
181
|
-
def generate_slds(self, batch_size: int):
|
|
182
|
-
return uniform_sampler(
|
|
183
|
-
*self.sld_range, batch_size,
|
|
184
|
-
self.num_layers + 1,
|
|
185
|
-
device=self.device,
|
|
186
|
-
dtype=self.dtype
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
def generate_thicknesses(self, batch_size: int):
|
|
190
|
-
return uniform_sampler(
|
|
191
|
-
*self.thickness_range, batch_size,
|
|
192
|
-
self.num_layers,
|
|
193
|
-
device=self.device,
|
|
194
|
-
dtype=self.dtype
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
|
|
198
|
-
if self.restrict_roughnesses:
|
|
199
|
-
return generate_roughnesses(thicknesses, self.roughness_range)
|
|
200
|
-
else:
|
|
201
|
-
return uniform_sampler(
|
|
202
|
-
*self.roughness_range, thicknesses.shape[0],
|
|
203
|
-
self.num_layers + 1,
|
|
204
|
-
device=self.device,
|
|
205
|
-
dtype=self.dtype
|
|
206
|
-
)
|
|
1
|
+
import logging
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from math import sqrt
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.utils import (
|
|
10
|
+
get_slds_from_d_rhos,
|
|
11
|
+
uniform_sampler,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors.utils import (
|
|
15
|
+
get_allowed_roughness_indices,
|
|
16
|
+
generate_roughnesses,
|
|
17
|
+
params_within_bounds,
|
|
18
|
+
)
|
|
19
|
+
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
20
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
21
|
+
from reflectorch.data_generation.priors.params import Params
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"BasicPriorSampler",
|
|
25
|
+
"DEFAULT_ROUGHNESS_RANGE",
|
|
26
|
+
"DEFAULT_THICKNESS_RANGE",
|
|
27
|
+
"DEFAULT_SLD_RANGE",
|
|
28
|
+
"DEFAULT_NUM_LAYERS",
|
|
29
|
+
"DEFAULT_DEVICE",
|
|
30
|
+
"DEFAULT_DTYPE",
|
|
31
|
+
"DEFAULT_SCALED_RANGE",
|
|
32
|
+
"DEFAULT_USE_DRHO",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
|
|
36
|
+
DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
|
|
37
|
+
DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
|
|
38
|
+
DEFAULT_NUM_LAYERS: int = 5
|
|
39
|
+
DEFAULT_USE_DRHO: bool = False
|
|
40
|
+
DEFAULT_DEVICE: torch.device = torch.device('cuda')
|
|
41
|
+
DEFAULT_DTYPE: torch.dtype = torch.float64
|
|
42
|
+
DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BasicPriorSampler(PriorSampler, ScalerMixin):
|
|
46
|
+
"""Prior samplers for thicknesses, roughnesses and slds"""
|
|
47
|
+
def __init__(self,
|
|
48
|
+
thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
|
|
49
|
+
roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
|
|
50
|
+
sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
|
|
51
|
+
num_layers: int = DEFAULT_NUM_LAYERS,
|
|
52
|
+
use_drho: bool = DEFAULT_USE_DRHO,
|
|
53
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
54
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
55
|
+
scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
|
|
56
|
+
restrict_roughnesses: bool = True,
|
|
57
|
+
):
|
|
58
|
+
self.logger = logging.getLogger(__name__)
|
|
59
|
+
self.thickness_range = thickness_range
|
|
60
|
+
self.roughness_range = roughness_range
|
|
61
|
+
self.sld_range = sld_range
|
|
62
|
+
self.num_layers = num_layers
|
|
63
|
+
self.device = device
|
|
64
|
+
self.dtype = dtype
|
|
65
|
+
self.scaled_range = scaled_range
|
|
66
|
+
self.use_drho = use_drho
|
|
67
|
+
self.restrict_roughnesses = restrict_roughnesses
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def max_num_layers(self) -> int:
|
|
71
|
+
return self.num_layers
|
|
72
|
+
|
|
73
|
+
@lru_cache()
|
|
74
|
+
def min_vector(self, layers_num, drho: bool = False):
|
|
75
|
+
if drho:
|
|
76
|
+
sld_min = self.sld_range[0] - self.sld_range[1]
|
|
77
|
+
else:
|
|
78
|
+
sld_min = self.sld_range[0]
|
|
79
|
+
|
|
80
|
+
return torch.tensor(
|
|
81
|
+
[self.thickness_range[0]] * layers_num +
|
|
82
|
+
[self.roughness_range[0]] * (layers_num + 1) +
|
|
83
|
+
[sld_min] * (layers_num + 1),
|
|
84
|
+
device=self.device,
|
|
85
|
+
dtype=self.dtype
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@lru_cache()
|
|
89
|
+
def max_vector(self, layers_num, drho: bool = False):
|
|
90
|
+
if drho:
|
|
91
|
+
sld_max = self.sld_range[1] - self.sld_range[0]
|
|
92
|
+
else:
|
|
93
|
+
sld_max = self.sld_range[1]
|
|
94
|
+
return torch.tensor(
|
|
95
|
+
[self.thickness_range[1]] * layers_num +
|
|
96
|
+
[self.roughness_range[1]] * (layers_num + 1) +
|
|
97
|
+
[sld_max] * (layers_num + 1),
|
|
98
|
+
device=self.device,
|
|
99
|
+
dtype=self.dtype
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@lru_cache()
|
|
103
|
+
def delta_vector(self, layers_num, drho: bool = False):
|
|
104
|
+
return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
|
|
105
|
+
|
|
106
|
+
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
107
|
+
layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
|
|
108
|
+
|
|
109
|
+
params_t = self._restore(
|
|
110
|
+
scaled_params,
|
|
111
|
+
self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
112
|
+
self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
params = self.PARAM_CLS.from_tensor(params_t)
|
|
116
|
+
|
|
117
|
+
if self.use_drho:
|
|
118
|
+
params.slds = get_slds_from_d_rhos(params.slds)
|
|
119
|
+
|
|
120
|
+
return params
|
|
121
|
+
|
|
122
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
123
|
+
layers_num = params.max_layer_num
|
|
124
|
+
|
|
125
|
+
return self._scale(
|
|
126
|
+
params.as_tensor(use_drho=self.use_drho),
|
|
127
|
+
self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
128
|
+
self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
132
|
+
layer_num = params.max_layer_num
|
|
133
|
+
|
|
134
|
+
return params_within_bounds(
|
|
135
|
+
params.as_tensor(),
|
|
136
|
+
self.min_vector(layer_num),
|
|
137
|
+
self.max_vector(layer_num),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
141
|
+
if self.restrict_roughnesses:
|
|
142
|
+
indices = (
|
|
143
|
+
self.get_indices_within_bounds(params) &
|
|
144
|
+
self.get_allowed_roughness_indices(params)
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
indices = self.get_indices_within_bounds(params)
|
|
148
|
+
return indices
|
|
149
|
+
|
|
150
|
+
def clamp_params(self, params: Params) -> Params:
|
|
151
|
+
layer_num = params.max_layer_num
|
|
152
|
+
params = params.as_tensor()
|
|
153
|
+
params = torch.clamp(
|
|
154
|
+
params,
|
|
155
|
+
self.min_vector(layer_num),
|
|
156
|
+
self.max_vector(layer_num),
|
|
157
|
+
)
|
|
158
|
+
params = Params.from_tensor(params)
|
|
159
|
+
return params
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def get_allowed_roughness_indices(params: Params) -> Tensor:
|
|
163
|
+
return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
|
|
164
|
+
|
|
165
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
166
|
+
# so far we ignore non-uniform distribution of roughnesses and slds.
|
|
167
|
+
log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
|
|
168
|
+
indices = self.get_indices_within_bounds(params)
|
|
169
|
+
log_prob[~indices] = float('-inf')
|
|
170
|
+
return log_prob
|
|
171
|
+
|
|
172
|
+
def sample(self, batch_size: int) -> Params:
|
|
173
|
+
slds = self.generate_slds(batch_size)
|
|
174
|
+
thicknesses = self.generate_thicknesses(batch_size)
|
|
175
|
+
roughnesses = self.generate_roughnesses(thicknesses)
|
|
176
|
+
|
|
177
|
+
params = Params(thicknesses, roughnesses, slds)
|
|
178
|
+
|
|
179
|
+
return params
|
|
180
|
+
|
|
181
|
+
def generate_slds(self, batch_size: int):
|
|
182
|
+
return uniform_sampler(
|
|
183
|
+
*self.sld_range, batch_size,
|
|
184
|
+
self.num_layers + 1,
|
|
185
|
+
device=self.device,
|
|
186
|
+
dtype=self.dtype
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def generate_thicknesses(self, batch_size: int):
|
|
190
|
+
return uniform_sampler(
|
|
191
|
+
*self.thickness_range, batch_size,
|
|
192
|
+
self.num_layers,
|
|
193
|
+
device=self.device,
|
|
194
|
+
dtype=self.dtype
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
|
|
198
|
+
if self.restrict_roughnesses:
|
|
199
|
+
return generate_roughnesses(thicknesses, self.roughness_range)
|
|
200
|
+
else:
|
|
201
|
+
return uniform_sampler(
|
|
202
|
+
*self.roughness_range, thicknesses.shape[0],
|
|
203
|
+
self.num_layers + 1,
|
|
204
|
+
device=self.device,
|
|
205
|
+
dtype=self.dtype
|
|
206
|
+
)
|