reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -128
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -280
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -223
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -1374
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +36 -36
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +523 -516
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -19
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -262
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -200
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -15
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -19
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -434
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -404
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +97 -97
  91. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  94. reflectorch-1.4.0.dist-info/RECORD +0 -88
  95. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
  96. {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,104 +1,104 @@
1
- from typing import Tuple, Dict
2
-
3
- import numpy as np
4
- import torch
5
- from torch import Tensor
6
-
7
- from reflectorch.data_generation.priors.base import PriorSampler
8
- from reflectorch.data_generation.priors.params import Params
9
- from reflectorch.data_generation.priors.no_constraints import (
10
- DEFAULT_DEVICE,
11
- DEFAULT_DTYPE,
12
- )
13
-
14
- from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
15
- from reflectorch.utils import to_t
16
-
17
-
18
- class MultilayerStructureParams(Params):
19
- pass
20
-
21
-
22
- class SimpleMultilayerSampler(PriorSampler):
23
- PARAM_CLS = MultilayerStructureParams
24
-
25
- def __init__(self,
26
- params: Dict[str, Tuple[float, float]],
27
- model_name: str,
28
- device: torch.device = DEFAULT_DEVICE,
29
- dtype: torch.dtype = DEFAULT_DTYPE,
30
- max_num_layers: int = 50,
31
- ):
32
- self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
33
- self.device = device
34
- self.dtype = dtype
35
- self.num_layers = max_num_layers
36
- ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
37
- self._np_bounds = np.array(ordered_bounds).T
38
- self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
39
- self._param_dim = len(params)
40
-
41
- @property
42
- def max_num_layers(self) -> int:
43
- return self.num_layers
44
-
45
- @property
46
- def param_dim(self) -> int:
47
- return self._param_dim
48
-
49
- def sample(self, batch_size: int) -> MultilayerStructureParams:
50
- return self.optimized_sample(batch_size)[0]
51
-
52
- def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
53
- scaled_params = torch.rand(
54
- batch_size,
55
- self.min_bounds.shape[-1],
56
- device=self.min_bounds.device,
57
- dtype=self.min_bounds.dtype,
58
- )
59
-
60
- targets = self.restore_params(scaled_params)
61
-
62
- return targets, scaled_params
63
-
64
- def get_np_bounds(self):
65
- return np.array(self._np_bounds)
66
-
67
- def restore_np_params(self, params: np.ndarray):
68
- p = self.multilayer_model.to_standard_params(
69
- torch.atleast_2d(to_t(params))
70
- )
71
-
72
- return {
73
- 'thickness': p['thicknesses'].squeeze().cpu().numpy(),
74
- 'roughness': p['roughnesses'].squeeze().cpu().numpy(),
75
- 'sld': p['slds'].squeeze().cpu().numpy()
76
- }
77
-
78
- def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
79
- return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
80
-
81
- def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
82
- return self.to_standard_params(self.restore_params2parametrized(scaled_params))
83
-
84
- def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
85
- return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
86
-
87
- def scale_params(self, params: Params) -> Tensor:
88
- raise NotImplementedError
89
-
90
- def log_prob(self, params: Params) -> Tensor:
91
- raise NotImplementedError
92
-
93
- def get_indices_within_domain(self, params: Params) -> Tensor:
94
- raise NotImplementedError
95
-
96
- def get_indices_within_bounds(self, params: Params) -> Tensor:
97
- raise NotImplementedError
98
-
99
- def filter_params(self, params: Params) -> Params:
100
- indices = self.get_indices_within_domain(params)
101
- return params[indices]
102
-
103
- def clamp_params(self, params: Params) -> Params:
104
- raise NotImplementedError
1
+ from typing import Tuple, Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.priors.base import PriorSampler
8
+ from reflectorch.data_generation.priors.params import Params
9
+ from reflectorch.data_generation.priors.no_constraints import (
10
+ DEFAULT_DEVICE,
11
+ DEFAULT_DTYPE,
12
+ )
13
+
14
+ from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
15
+ from reflectorch.utils import to_t
16
+
17
+
18
+ class MultilayerStructureParams(Params):
19
+ pass
20
+
21
+
22
+ class SimpleMultilayerSampler(PriorSampler):
23
+ PARAM_CLS = MultilayerStructureParams
24
+
25
+ def __init__(self,
26
+ params: Dict[str, Tuple[float, float]],
27
+ model_name: str,
28
+ device: torch.device = DEFAULT_DEVICE,
29
+ dtype: torch.dtype = DEFAULT_DTYPE,
30
+ max_num_layers: int = 50,
31
+ ):
32
+ self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
33
+ self.device = device
34
+ self.dtype = dtype
35
+ self.num_layers = max_num_layers
36
+ ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
37
+ self._np_bounds = np.array(ordered_bounds).T
38
+ self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
39
+ self._param_dim = len(params)
40
+
41
+ @property
42
+ def max_num_layers(self) -> int:
43
+ return self.num_layers
44
+
45
+ @property
46
+ def param_dim(self) -> int:
47
+ return self._param_dim
48
+
49
+ def sample(self, batch_size: int) -> MultilayerStructureParams:
50
+ return self.optimized_sample(batch_size)[0]
51
+
52
+ def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
53
+ scaled_params = torch.rand(
54
+ batch_size,
55
+ self.min_bounds.shape[-1],
56
+ device=self.min_bounds.device,
57
+ dtype=self.min_bounds.dtype,
58
+ )
59
+
60
+ targets = self.restore_params(scaled_params)
61
+
62
+ return targets, scaled_params
63
+
64
+ def get_np_bounds(self):
65
+ return np.array(self._np_bounds)
66
+
67
+ def restore_np_params(self, params: np.ndarray):
68
+ p = self.multilayer_model.to_standard_params(
69
+ torch.atleast_2d(to_t(params))
70
+ )
71
+
72
+ return {
73
+ 'thickness': p['thicknesses'].squeeze().cpu().numpy(),
74
+ 'roughness': p['roughnesses'].squeeze().cpu().numpy(),
75
+ 'sld': p['slds'].squeeze().cpu().numpy()
76
+ }
77
+
78
+ def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
79
+ return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
80
+
81
+ def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
82
+ return self.to_standard_params(self.restore_params2parametrized(scaled_params))
83
+
84
+ def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
85
+ return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
86
+
87
+ def scale_params(self, params: Params) -> Tensor:
88
+ raise NotImplementedError
89
+
90
+ def log_prob(self, params: Params) -> Tensor:
91
+ raise NotImplementedError
92
+
93
+ def get_indices_within_domain(self, params: Params) -> Tensor:
94
+ raise NotImplementedError
95
+
96
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
97
+ raise NotImplementedError
98
+
99
+ def filter_params(self, params: Params) -> Params:
100
+ indices = self.get_indices_within_domain(params)
101
+ return params[indices]
102
+
103
+ def clamp_params(self, params: Params) -> Params:
104
+ raise NotImplementedError
@@ -1,206 +1,206 @@
1
- import logging
2
- from functools import lru_cache
3
- from typing import Tuple
4
- from math import sqrt
5
-
6
- import torch
7
- from torch import Tensor
8
-
9
- from reflectorch.data_generation.utils import (
10
- get_slds_from_d_rhos,
11
- uniform_sampler,
12
- )
13
-
14
- from reflectorch.data_generation.priors.utils import (
15
- get_allowed_roughness_indices,
16
- generate_roughnesses,
17
- params_within_bounds,
18
- )
19
- from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
20
- from reflectorch.data_generation.priors.base import PriorSampler
21
- from reflectorch.data_generation.priors.params import Params
22
-
23
- __all__ = [
24
- "BasicPriorSampler",
25
- "DEFAULT_ROUGHNESS_RANGE",
26
- "DEFAULT_THICKNESS_RANGE",
27
- "DEFAULT_SLD_RANGE",
28
- "DEFAULT_NUM_LAYERS",
29
- "DEFAULT_DEVICE",
30
- "DEFAULT_DTYPE",
31
- "DEFAULT_SCALED_RANGE",
32
- "DEFAULT_USE_DRHO",
33
- ]
34
-
35
- DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
36
- DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
37
- DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
38
- DEFAULT_NUM_LAYERS: int = 5
39
- DEFAULT_USE_DRHO: bool = False
40
- DEFAULT_DEVICE: torch.device = torch.device('cuda')
41
- DEFAULT_DTYPE: torch.dtype = torch.float64
42
- DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
43
-
44
-
45
- class BasicPriorSampler(PriorSampler, ScalerMixin):
46
- """Prior samplers for thicknesses, roughnesses and slds"""
47
- def __init__(self,
48
- thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
49
- roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
50
- sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
51
- num_layers: int = DEFAULT_NUM_LAYERS,
52
- use_drho: bool = DEFAULT_USE_DRHO,
53
- device: torch.device = DEFAULT_DEVICE,
54
- dtype: torch.dtype = DEFAULT_DTYPE,
55
- scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
56
- restrict_roughnesses: bool = True,
57
- ):
58
- self.logger = logging.getLogger(__name__)
59
- self.thickness_range = thickness_range
60
- self.roughness_range = roughness_range
61
- self.sld_range = sld_range
62
- self.num_layers = num_layers
63
- self.device = device
64
- self.dtype = dtype
65
- self.scaled_range = scaled_range
66
- self.use_drho = use_drho
67
- self.restrict_roughnesses = restrict_roughnesses
68
-
69
- @property
70
- def max_num_layers(self) -> int:
71
- return self.num_layers
72
-
73
- @lru_cache()
74
- def min_vector(self, layers_num, drho: bool = False):
75
- if drho:
76
- sld_min = self.sld_range[0] - self.sld_range[1]
77
- else:
78
- sld_min = self.sld_range[0]
79
-
80
- return torch.tensor(
81
- [self.thickness_range[0]] * layers_num +
82
- [self.roughness_range[0]] * (layers_num + 1) +
83
- [sld_min] * (layers_num + 1),
84
- device=self.device,
85
- dtype=self.dtype
86
- )
87
-
88
- @lru_cache()
89
- def max_vector(self, layers_num, drho: bool = False):
90
- if drho:
91
- sld_max = self.sld_range[1] - self.sld_range[0]
92
- else:
93
- sld_max = self.sld_range[1]
94
- return torch.tensor(
95
- [self.thickness_range[1]] * layers_num +
96
- [self.roughness_range[1]] * (layers_num + 1) +
97
- [sld_max] * (layers_num + 1),
98
- device=self.device,
99
- dtype=self.dtype
100
- )
101
-
102
- @lru_cache()
103
- def delta_vector(self, layers_num, drho: bool = False):
104
- return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
105
-
106
- def restore_params(self, scaled_params: Tensor) -> Params:
107
- layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
108
-
109
- params_t = self._restore(
110
- scaled_params,
111
- self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
112
- self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
113
- )
114
-
115
- params = self.PARAM_CLS.from_tensor(params_t)
116
-
117
- if self.use_drho:
118
- params.slds = get_slds_from_d_rhos(params.slds)
119
-
120
- return params
121
-
122
- def scale_params(self, params: Params) -> Tensor:
123
- layers_num = params.max_layer_num
124
-
125
- return self._scale(
126
- params.as_tensor(use_drho=self.use_drho),
127
- self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
128
- self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
129
- )
130
-
131
- def get_indices_within_bounds(self, params: Params) -> Tensor:
132
- layer_num = params.max_layer_num
133
-
134
- return params_within_bounds(
135
- params.as_tensor(),
136
- self.min_vector(layer_num),
137
- self.max_vector(layer_num),
138
- )
139
-
140
- def get_indices_within_domain(self, params: Params) -> Tensor:
141
- if self.restrict_roughnesses:
142
- indices = (
143
- self.get_indices_within_bounds(params) &
144
- self.get_allowed_roughness_indices(params)
145
- )
146
- else:
147
- indices = self.get_indices_within_bounds(params)
148
- return indices
149
-
150
- def clamp_params(self, params: Params) -> Params:
151
- layer_num = params.max_layer_num
152
- params = params.as_tensor()
153
- params = torch.clamp(
154
- params,
155
- self.min_vector(layer_num),
156
- self.max_vector(layer_num),
157
- )
158
- params = Params.from_tensor(params)
159
- return params
160
-
161
- @staticmethod
162
- def get_allowed_roughness_indices(params: Params) -> Tensor:
163
- return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
164
-
165
- def log_prob(self, params: Params) -> Tensor:
166
- # so far we ignore non-uniform distribution of roughnesses and slds.
167
- log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
168
- indices = self.get_indices_within_bounds(params)
169
- log_prob[~indices] = float('-inf')
170
- return log_prob
171
-
172
- def sample(self, batch_size: int) -> Params:
173
- slds = self.generate_slds(batch_size)
174
- thicknesses = self.generate_thicknesses(batch_size)
175
- roughnesses = self.generate_roughnesses(thicknesses)
176
-
177
- params = Params(thicknesses, roughnesses, slds)
178
-
179
- return params
180
-
181
- def generate_slds(self, batch_size: int):
182
- return uniform_sampler(
183
- *self.sld_range, batch_size,
184
- self.num_layers + 1,
185
- device=self.device,
186
- dtype=self.dtype
187
- )
188
-
189
- def generate_thicknesses(self, batch_size: int):
190
- return uniform_sampler(
191
- *self.thickness_range, batch_size,
192
- self.num_layers,
193
- device=self.device,
194
- dtype=self.dtype
195
- )
196
-
197
- def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
198
- if self.restrict_roughnesses:
199
- return generate_roughnesses(thicknesses, self.roughness_range)
200
- else:
201
- return uniform_sampler(
202
- *self.roughness_range, thicknesses.shape[0],
203
- self.num_layers + 1,
204
- device=self.device,
205
- dtype=self.dtype
206
- )
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Tuple
4
+ from math import sqrt
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from reflectorch.data_generation.utils import (
10
+ get_slds_from_d_rhos,
11
+ uniform_sampler,
12
+ )
13
+
14
+ from reflectorch.data_generation.priors.utils import (
15
+ get_allowed_roughness_indices,
16
+ generate_roughnesses,
17
+ params_within_bounds,
18
+ )
19
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
20
+ from reflectorch.data_generation.priors.base import PriorSampler
21
+ from reflectorch.data_generation.priors.params import Params
22
+
23
+ __all__ = [
24
+ "BasicPriorSampler",
25
+ "DEFAULT_ROUGHNESS_RANGE",
26
+ "DEFAULT_THICKNESS_RANGE",
27
+ "DEFAULT_SLD_RANGE",
28
+ "DEFAULT_NUM_LAYERS",
29
+ "DEFAULT_DEVICE",
30
+ "DEFAULT_DTYPE",
31
+ "DEFAULT_SCALED_RANGE",
32
+ "DEFAULT_USE_DRHO",
33
+ ]
34
+
35
+ DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
36
+ DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
37
+ DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
38
+ DEFAULT_NUM_LAYERS: int = 5
39
+ DEFAULT_USE_DRHO: bool = False
40
+ DEFAULT_DEVICE: torch.device = torch.device('cuda')
41
+ DEFAULT_DTYPE: torch.dtype = torch.float64
42
+ DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
43
+
44
+
45
+ class BasicPriorSampler(PriorSampler, ScalerMixin):
46
+ """Prior samplers for thicknesses, roughnesses and slds"""
47
+ def __init__(self,
48
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
49
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
50
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
51
+ num_layers: int = DEFAULT_NUM_LAYERS,
52
+ use_drho: bool = DEFAULT_USE_DRHO,
53
+ device: torch.device = DEFAULT_DEVICE,
54
+ dtype: torch.dtype = DEFAULT_DTYPE,
55
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
56
+ restrict_roughnesses: bool = True,
57
+ ):
58
+ self.logger = logging.getLogger(__name__)
59
+ self.thickness_range = thickness_range
60
+ self.roughness_range = roughness_range
61
+ self.sld_range = sld_range
62
+ self.num_layers = num_layers
63
+ self.device = device
64
+ self.dtype = dtype
65
+ self.scaled_range = scaled_range
66
+ self.use_drho = use_drho
67
+ self.restrict_roughnesses = restrict_roughnesses
68
+
69
+ @property
70
+ def max_num_layers(self) -> int:
71
+ return self.num_layers
72
+
73
+ @lru_cache()
74
+ def min_vector(self, layers_num, drho: bool = False):
75
+ if drho:
76
+ sld_min = self.sld_range[0] - self.sld_range[1]
77
+ else:
78
+ sld_min = self.sld_range[0]
79
+
80
+ return torch.tensor(
81
+ [self.thickness_range[0]] * layers_num +
82
+ [self.roughness_range[0]] * (layers_num + 1) +
83
+ [sld_min] * (layers_num + 1),
84
+ device=self.device,
85
+ dtype=self.dtype
86
+ )
87
+
88
+ @lru_cache()
89
+ def max_vector(self, layers_num, drho: bool = False):
90
+ if drho:
91
+ sld_max = self.sld_range[1] - self.sld_range[0]
92
+ else:
93
+ sld_max = self.sld_range[1]
94
+ return torch.tensor(
95
+ [self.thickness_range[1]] * layers_num +
96
+ [self.roughness_range[1]] * (layers_num + 1) +
97
+ [sld_max] * (layers_num + 1),
98
+ device=self.device,
99
+ dtype=self.dtype
100
+ )
101
+
102
+ @lru_cache()
103
+ def delta_vector(self, layers_num, drho: bool = False):
104
+ return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
105
+
106
+ def restore_params(self, scaled_params: Tensor) -> Params:
107
+ layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
108
+
109
+ params_t = self._restore(
110
+ scaled_params,
111
+ self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
112
+ self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
113
+ )
114
+
115
+ params = self.PARAM_CLS.from_tensor(params_t)
116
+
117
+ if self.use_drho:
118
+ params.slds = get_slds_from_d_rhos(params.slds)
119
+
120
+ return params
121
+
122
+ def scale_params(self, params: Params) -> Tensor:
123
+ layers_num = params.max_layer_num
124
+
125
+ return self._scale(
126
+ params.as_tensor(use_drho=self.use_drho),
127
+ self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
128
+ self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
129
+ )
130
+
131
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
132
+ layer_num = params.max_layer_num
133
+
134
+ return params_within_bounds(
135
+ params.as_tensor(),
136
+ self.min_vector(layer_num),
137
+ self.max_vector(layer_num),
138
+ )
139
+
140
+ def get_indices_within_domain(self, params: Params) -> Tensor:
141
+ if self.restrict_roughnesses:
142
+ indices = (
143
+ self.get_indices_within_bounds(params) &
144
+ self.get_allowed_roughness_indices(params)
145
+ )
146
+ else:
147
+ indices = self.get_indices_within_bounds(params)
148
+ return indices
149
+
150
+ def clamp_params(self, params: Params) -> Params:
151
+ layer_num = params.max_layer_num
152
+ params = params.as_tensor()
153
+ params = torch.clamp(
154
+ params,
155
+ self.min_vector(layer_num),
156
+ self.max_vector(layer_num),
157
+ )
158
+ params = Params.from_tensor(params)
159
+ return params
160
+
161
+ @staticmethod
162
+ def get_allowed_roughness_indices(params: Params) -> Tensor:
163
+ return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
164
+
165
+ def log_prob(self, params: Params) -> Tensor:
166
+ # so far we ignore non-uniform distribution of roughnesses and slds.
167
+ log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
168
+ indices = self.get_indices_within_bounds(params)
169
+ log_prob[~indices] = float('-inf')
170
+ return log_prob
171
+
172
+ def sample(self, batch_size: int) -> Params:
173
+ slds = self.generate_slds(batch_size)
174
+ thicknesses = self.generate_thicknesses(batch_size)
175
+ roughnesses = self.generate_roughnesses(thicknesses)
176
+
177
+ params = Params(thicknesses, roughnesses, slds)
178
+
179
+ return params
180
+
181
+ def generate_slds(self, batch_size: int):
182
+ return uniform_sampler(
183
+ *self.sld_range, batch_size,
184
+ self.num_layers + 1,
185
+ device=self.device,
186
+ dtype=self.dtype
187
+ )
188
+
189
+ def generate_thicknesses(self, batch_size: int):
190
+ return uniform_sampler(
191
+ *self.thickness_range, batch_size,
192
+ self.num_layers,
193
+ device=self.device,
194
+ dtype=self.dtype
195
+ )
196
+
197
+ def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
198
+ if self.restrict_roughnesses:
199
+ return generate_roughnesses(thicknesses, self.roughness_range)
200
+ else:
201
+ return uniform_sampler(
202
+ *self.roughness_range, thicknesses.shape[0],
203
+ self.num_layers + 1,
204
+ device=self.device,
205
+ dtype=self.dtype
206
+ )