reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ __all__ = [
7
+ "MULTILAYER_MODELS",
8
+ "MultilayerModel",
9
+ ]
10
+
11
+
12
+ class MultilayerModel(object):
13
+ NAME: str = ''
14
+ PARAMETER_NAMES: Tuple[str, ...]
15
+
16
+ def __init__(self, max_num_layers: int):
17
+ self.max_num_layers = max_num_layers
18
+
19
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
20
+ raise NotImplementedError
21
+
22
+ def from_standard_params(self, params: dict) -> Tensor:
23
+ raise NotImplementedError
24
+
25
+
26
+ class BasicMultilayerModel1(MultilayerModel):
27
+ NAME = 'repeating_multilayer_v1'
28
+
29
+ PARAMETER_NAMES = (
30
+ "d_full_rel",
31
+ "rel_sigmas",
32
+ "d_block",
33
+ "s_block_rel",
34
+ "r_block",
35
+ "dr",
36
+ "d3_rel",
37
+ "s3_rel",
38
+ "r3",
39
+ "d_sio2",
40
+ "s_sio2",
41
+ "s_si",
42
+ "r_sio2",
43
+ "r_si",
44
+ )
45
+
46
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
47
+ return multilayer_model1(parametrized_model, self.max_num_layers)
48
+
49
+
50
+ class BasicMultilayerModel2(MultilayerModel):
51
+ NAME = 'repeating_multilayer_v2'
52
+
53
+ PARAMETER_NAMES = (
54
+ "d_full_rel",
55
+ "rel_sigmas",
56
+ "dr_sigmoid_rel_pos",
57
+ "dr_sigmoid_rel_width",
58
+ "d_block",
59
+ "s_block_rel",
60
+ "r_block",
61
+ "dr",
62
+ "d3_rel",
63
+ "s3_rel",
64
+ "r3",
65
+ "d_sio2",
66
+ "s_sio2",
67
+ "s_si",
68
+ "r_sio2",
69
+ "r_si",
70
+ )
71
+
72
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
73
+ return multilayer_model2(parametrized_model, self.max_num_layers)
74
+
75
+
76
+ class BasicMultilayerModel3(MultilayerModel):
77
+ NAME = 'repeating_multilayer_v3'
78
+
79
+ PARAMETER_NAMES = (
80
+ "d_full_rel",
81
+ "rel_sigmas",
82
+ "dr_sigmoid_rel_pos",
83
+ "dr_sigmoid_rel_width",
84
+ "d_block1_rel",
85
+ "d_block",
86
+ "s_block_rel",
87
+ "r_block",
88
+ "dr",
89
+ "d3_rel",
90
+ "s3_rel",
91
+ "r3",
92
+ "d_sio2",
93
+ "s_sio2",
94
+ "s_si",
95
+ "r_sio2",
96
+ "r_si",
97
+ )
98
+
99
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
100
+ return multilayer_model3(parametrized_model, self.max_num_layers)
101
+
102
+
103
+ MULTILAYER_MODELS = {
104
+ 'repeating_multilayer_v1': BasicMultilayerModel1,
105
+ 'repeating_multilayer_v2': BasicMultilayerModel2,
106
+ 'repeating_multilayer_v3': BasicMultilayerModel3,
107
+ }
108
+
109
+
110
+ def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
111
+ n = d_full_rel_max
112
+
113
+ (
114
+ d_full_rel,
115
+ rel_sigmas,
116
+ d_block,
117
+ s_block_rel,
118
+ r_block,
119
+ dr,
120
+ d3_rel,
121
+ s3_rel,
122
+ r3,
123
+ d_sio2,
124
+ s_sio2,
125
+ s_si,
126
+ r_sio2,
127
+ r_si,
128
+ ) = parametrized_model.T
129
+
130
+ batch_size = parametrized_model.shape[0]
131
+
132
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
133
+
134
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
135
+
136
+ r_block = r_block[:, None].repeat(1, n)
137
+ dr = dr[:, None].repeat(1, n)
138
+
139
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
140
+
141
+ sld_blocks = r_modulations * sld_blocks
142
+
143
+ d3 = d3_rel * d_block
144
+
145
+ thicknesses = torch.cat(
146
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
147
+ )
148
+
149
+ s_block = s_block_rel * d_block
150
+
151
+ roughnesses = torch.cat(
152
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
153
+ )
154
+
155
+ slds = torch.cat(
156
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
157
+ )
158
+
159
+ params = dict(
160
+ thicknesses=thicknesses,
161
+ roughnesses=roughnesses,
162
+ slds=slds
163
+ )
164
+ return params
165
+
166
+
167
+ def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
168
+ n = d_full_rel_max
169
+
170
+ (
171
+ d_full_rel,
172
+ rel_sigmas,
173
+ dr_sigmoid_rel_pos,
174
+ dr_sigmoid_rel_width,
175
+ d_block,
176
+ s_block_rel,
177
+ r_block,
178
+ dr,
179
+ d3_rel,
180
+ s3_rel,
181
+ r3,
182
+ d_sio2,
183
+ s_sio2,
184
+ s_si,
185
+ r_sio2,
186
+ r_si,
187
+ ) = parametrized_model.T
188
+
189
+ batch_size = parametrized_model.shape[0]
190
+
191
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
192
+
193
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
194
+
195
+ r_block = r_block[:, None].repeat(1, n)
196
+ dr = dr[:, None].repeat(1, n)
197
+
198
+ dr_positions = r_positions[:, ::2]
199
+
200
+ dr_modulations = torch.sigmoid(
201
+ -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
202
+ )
203
+
204
+ dr = dr * dr_modulations
205
+
206
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
207
+
208
+ sld_blocks = r_modulations * sld_blocks
209
+
210
+ d3 = d3_rel * d_block
211
+
212
+ thicknesses = torch.cat(
213
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
214
+ )
215
+
216
+ s_block = s_block_rel * d_block
217
+
218
+ roughnesses = torch.cat(
219
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
220
+ )
221
+
222
+ slds = torch.cat(
223
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
224
+ )
225
+
226
+ params = dict(
227
+ thicknesses=thicknesses,
228
+ roughnesses=roughnesses,
229
+ slds=slds
230
+ )
231
+ return params
232
+
233
+
234
+ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
235
+ n = d_full_rel_max
236
+
237
+ (
238
+ d_full_rel,
239
+ rel_sigmas,
240
+ dr_sigmoid_rel_pos,
241
+ dr_sigmoid_rel_width,
242
+ d_block1_rel,
243
+ d_block,
244
+ s_block_rel,
245
+ r_block,
246
+ dr,
247
+ d3_rel,
248
+ s3_rel,
249
+ r3,
250
+ d_sio2,
251
+ s_sio2,
252
+ s_si,
253
+ r_sio2,
254
+ r_si,
255
+ ) = parametrized_model.T
256
+
257
+ batch_size = parametrized_model.shape[0]
258
+
259
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
260
+
261
+ r_modulations = torch.sigmoid(
262
+ -(
263
+ r_positions - 2 * d_full_rel[..., None]
264
+ ) / rel_sigmas[..., None]
265
+ )
266
+
267
+ dr_positions = r_positions[:, ::2]
268
+
269
+ dr_modulations = dr[..., None] * (1 - torch.sigmoid(
270
+ -(
271
+ dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
272
+ ) / dr_sigmoid_rel_width[..., None]
273
+ ))
274
+
275
+ r_block = r_block[..., None].repeat(1, n)
276
+ dr = dr[..., None].repeat(1, n)
277
+
278
+ sld_blocks = torch.stack(
279
+ [
280
+ r_block + dr_modulations * (1 - d_block1_rel[..., None]),
281
+ r_block + dr - dr_modulations * d_block1_rel[..., None]
282
+ ], -1).flatten(1)
283
+
284
+ sld_blocks = r_modulations * sld_blocks
285
+
286
+ d3 = d3_rel * d_block
287
+
288
+ d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
289
+
290
+ thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
291
+
292
+ thicknesses = torch.cat(
293
+ [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
294
+ )
295
+
296
+ s_block = s_block_rel * d_block
297
+
298
+ roughnesses = torch.cat(
299
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
300
+ )
301
+
302
+ slds = torch.cat(
303
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
304
+ )
305
+
306
+ params = dict(
307
+ thicknesses=thicknesses,
308
+ roughnesses=roughnesses,
309
+ slds=slds
310
+ )
311
+ return params
@@ -0,0 +1,104 @@
1
+ from typing import Tuple, Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.priors.base import PriorSampler
8
+ from reflectorch.data_generation.priors.params import Params
9
+ from reflectorch.data_generation.priors.no_constraints import (
10
+ DEFAULT_DEVICE,
11
+ DEFAULT_DTYPE,
12
+ )
13
+
14
+ from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
15
+ from reflectorch.utils import to_t
16
+
17
+
18
+ class MultilayerStructureParams(Params):
19
+ pass
20
+
21
+
22
+ class SimpleMultilayerSampler(PriorSampler):
23
+ PARAM_CLS = MultilayerStructureParams
24
+
25
+ def __init__(self,
26
+ params: Dict[str, Tuple[float, float]],
27
+ model_name: str,
28
+ device: torch.device = DEFAULT_DEVICE,
29
+ dtype: torch.dtype = DEFAULT_DTYPE,
30
+ max_num_layers: int = 50,
31
+ ):
32
+ self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
33
+ self.device = device
34
+ self.dtype = dtype
35
+ self.num_layers = max_num_layers
36
+ ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
37
+ self._np_bounds = np.array(ordered_bounds).T
38
+ self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
39
+ self._param_dim = len(params)
40
+
41
+ @property
42
+ def max_num_layers(self) -> int:
43
+ return self.num_layers
44
+
45
+ @property
46
+ def param_dim(self) -> int:
47
+ return self._param_dim
48
+
49
+ def sample(self, batch_size: int) -> MultilayerStructureParams:
50
+ return self.optimized_sample(batch_size)[0]
51
+
52
+ def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
53
+ scaled_params = torch.rand(
54
+ batch_size,
55
+ self.min_bounds.shape[-1],
56
+ device=self.min_bounds.device,
57
+ dtype=self.min_bounds.dtype,
58
+ )
59
+
60
+ targets = self.restore_params(scaled_params)
61
+
62
+ return targets, scaled_params
63
+
64
+ def get_np_bounds(self):
65
+ return np.array(self._np_bounds)
66
+
67
+ def restore_np_params(self, params: np.ndarray):
68
+ p = self.multilayer_model.to_standard_params(
69
+ torch.atleast_2d(to_t(params))
70
+ )
71
+
72
+ return {
73
+ 'thickness': p['thicknesses'].squeeze().cpu().numpy(),
74
+ 'roughness': p['roughnesses'].squeeze().cpu().numpy(),
75
+ 'sld': p['slds'].squeeze().cpu().numpy()
76
+ }
77
+
78
+ def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
79
+ return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
80
+
81
+ def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
82
+ return self.to_standard_params(self.restore_params2parametrized(scaled_params))
83
+
84
+ def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
85
+ return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
86
+
87
+ def scale_params(self, params: Params) -> Tensor:
88
+ raise NotImplementedError
89
+
90
+ def log_prob(self, params: Params) -> Tensor:
91
+ raise NotImplementedError
92
+
93
+ def get_indices_within_domain(self, params: Params) -> Tensor:
94
+ raise NotImplementedError
95
+
96
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
97
+ raise NotImplementedError
98
+
99
+ def filter_params(self, params: Params) -> Params:
100
+ indices = self.get_indices_within_domain(params)
101
+ return params[indices]
102
+
103
+ def clamp_params(self, params: Params) -> Params:
104
+ raise NotImplementedError
@@ -0,0 +1,206 @@
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Tuple
4
+ from math import sqrt
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ from reflectorch.data_generation.utils import (
10
+ get_slds_from_d_rhos,
11
+ uniform_sampler,
12
+ )
13
+
14
+ from reflectorch.data_generation.priors.utils import (
15
+ get_allowed_roughness_indices,
16
+ generate_roughnesses,
17
+ params_within_bounds,
18
+ )
19
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
20
+ from reflectorch.data_generation.priors.base import PriorSampler
21
+ from reflectorch.data_generation.priors.params import Params
22
+
23
+ __all__ = [
24
+ "BasicPriorSampler",
25
+ "DEFAULT_ROUGHNESS_RANGE",
26
+ "DEFAULT_THICKNESS_RANGE",
27
+ "DEFAULT_SLD_RANGE",
28
+ "DEFAULT_NUM_LAYERS",
29
+ "DEFAULT_DEVICE",
30
+ "DEFAULT_DTYPE",
31
+ "DEFAULT_SCALED_RANGE",
32
+ "DEFAULT_USE_DRHO",
33
+ ]
34
+
35
+ DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
36
+ DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
37
+ DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
38
+ DEFAULT_NUM_LAYERS: int = 5
39
+ DEFAULT_USE_DRHO: bool = False
40
+ DEFAULT_DEVICE: torch.device = torch.device('cuda')
41
+ DEFAULT_DTYPE: torch.dtype = torch.float64
42
+ DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
43
+
44
+
45
+ class BasicPriorSampler(PriorSampler, ScalerMixin):
46
+ """Prior samplers for thicknesses, roughnesses and slds"""
47
+ def __init__(self,
48
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
49
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
50
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
51
+ num_layers: int = DEFAULT_NUM_LAYERS,
52
+ use_drho: bool = DEFAULT_USE_DRHO,
53
+ device: torch.device = DEFAULT_DEVICE,
54
+ dtype: torch.dtype = DEFAULT_DTYPE,
55
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
56
+ restrict_roughnesses: bool = True,
57
+ ):
58
+ self.logger = logging.getLogger(__name__)
59
+ self.thickness_range = thickness_range
60
+ self.roughness_range = roughness_range
61
+ self.sld_range = sld_range
62
+ self.num_layers = num_layers
63
+ self.device = device
64
+ self.dtype = dtype
65
+ self.scaled_range = scaled_range
66
+ self.use_drho = use_drho
67
+ self.restrict_roughnesses = restrict_roughnesses
68
+
69
+ @property
70
+ def max_num_layers(self) -> int:
71
+ return self.num_layers
72
+
73
+ @lru_cache()
74
+ def min_vector(self, layers_num, drho: bool = False):
75
+ if drho:
76
+ sld_min = self.sld_range[0] - self.sld_range[1]
77
+ else:
78
+ sld_min = self.sld_range[0]
79
+
80
+ return torch.tensor(
81
+ [self.thickness_range[0]] * layers_num +
82
+ [self.roughness_range[0]] * (layers_num + 1) +
83
+ [sld_min] * (layers_num + 1),
84
+ device=self.device,
85
+ dtype=self.dtype
86
+ )
87
+
88
+ @lru_cache()
89
+ def max_vector(self, layers_num, drho: bool = False):
90
+ if drho:
91
+ sld_max = self.sld_range[1] - self.sld_range[0]
92
+ else:
93
+ sld_max = self.sld_range[1]
94
+ return torch.tensor(
95
+ [self.thickness_range[1]] * layers_num +
96
+ [self.roughness_range[1]] * (layers_num + 1) +
97
+ [sld_max] * (layers_num + 1),
98
+ device=self.device,
99
+ dtype=self.dtype
100
+ )
101
+
102
+ @lru_cache()
103
+ def delta_vector(self, layers_num, drho: bool = False):
104
+ return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
105
+
106
+ def restore_params(self, scaled_params: Tensor) -> Params:
107
+ layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
108
+
109
+ params_t = self._restore(
110
+ scaled_params,
111
+ self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
112
+ self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
113
+ )
114
+
115
+ params = self.PARAM_CLS.from_tensor(params_t)
116
+
117
+ if self.use_drho:
118
+ params.slds = get_slds_from_d_rhos(params.slds)
119
+
120
+ return params
121
+
122
+ def scale_params(self, params: Params) -> Tensor:
123
+ layers_num = params.max_layer_num
124
+
125
+ return self._scale(
126
+ params.as_tensor(use_drho=self.use_drho),
127
+ self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
128
+ self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
129
+ )
130
+
131
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
132
+ layer_num = params.max_layer_num
133
+
134
+ return params_within_bounds(
135
+ params.as_tensor(),
136
+ self.min_vector(layer_num),
137
+ self.max_vector(layer_num),
138
+ )
139
+
140
+ def get_indices_within_domain(self, params: Params) -> Tensor:
141
+ if self.restrict_roughnesses:
142
+ indices = (
143
+ self.get_indices_within_bounds(params) &
144
+ self.get_allowed_roughness_indices(params)
145
+ )
146
+ else:
147
+ indices = self.get_indices_within_bounds(params)
148
+ return indices
149
+
150
+ def clamp_params(self, params: Params) -> Params:
151
+ layer_num = params.max_layer_num
152
+ params = params.as_tensor()
153
+ params = torch.clamp(
154
+ params,
155
+ self.min_vector(layer_num),
156
+ self.max_vector(layer_num),
157
+ )
158
+ params = Params.from_tensor(params)
159
+ return params
160
+
161
+ @staticmethod
162
+ def get_allowed_roughness_indices(params: Params) -> Tensor:
163
+ return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
164
+
165
+ def log_prob(self, params: Params) -> Tensor:
166
+ # so far we ignore non-uniform distribution of roughnesses and slds.
167
+ log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
168
+ indices = self.get_indices_within_bounds(params)
169
+ log_prob[~indices] = float('-inf')
170
+ return log_prob
171
+
172
+ def sample(self, batch_size: int) -> Params:
173
+ slds = self.generate_slds(batch_size)
174
+ thicknesses = self.generate_thicknesses(batch_size)
175
+ roughnesses = self.generate_roughnesses(thicknesses)
176
+
177
+ params = Params(thicknesses, roughnesses, slds)
178
+
179
+ return params
180
+
181
+ def generate_slds(self, batch_size: int):
182
+ return uniform_sampler(
183
+ *self.sld_range, batch_size,
184
+ self.num_layers + 1,
185
+ device=self.device,
186
+ dtype=self.dtype
187
+ )
188
+
189
+ def generate_thicknesses(self, batch_size: int):
190
+ return uniform_sampler(
191
+ *self.thickness_range, batch_size,
192
+ self.num_layers,
193
+ device=self.device,
194
+ dtype=self.dtype
195
+ )
196
+
197
+ def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
198
+ if self.restrict_roughnesses:
199
+ return generate_roughnesses(thicknesses, self.roughness_range)
200
+ else:
201
+ return uniform_sampler(
202
+ *self.roughness_range, thicknesses.shape[0],
203
+ self.num_layers + 1,
204
+ device=self.device,
205
+ dtype=self.dtype
206
+ )