reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ __all__ = [
7
+ "MULTILAYER_MODELS",
8
+ "MultilayerModel",
9
+ ]
10
+
11
+
12
+ class MultilayerModel(object):
13
+ NAME: str = ''
14
+ PARAMETER_NAMES: Tuple[str, ...]
15
+
16
+ def __init__(self, max_num_layers: int):
17
+ self.max_num_layers = max_num_layers
18
+
19
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
20
+ raise NotImplementedError
21
+
22
+ def from_standard_params(self, params: dict) -> Tensor:
23
+ raise NotImplementedError
24
+
25
+
26
+ class BasicMultilayerModel1(MultilayerModel):
27
+ NAME = 'repeating_multilayer_v1'
28
+
29
+ PARAMETER_NAMES = (
30
+ "d_full_rel",
31
+ "rel_sigmas",
32
+ "d_block",
33
+ "s_block_rel",
34
+ "r_block",
35
+ "dr",
36
+ "d3_rel",
37
+ "s3_rel",
38
+ "r3",
39
+ "d_sio2",
40
+ "s_sio2",
41
+ "s_si",
42
+ "r_sio2",
43
+ "r_si",
44
+ )
45
+
46
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
47
+ return multilayer_model1(parametrized_model, self.max_num_layers)
48
+
49
+
50
+ class BasicMultilayerModel2(MultilayerModel):
51
+ NAME = 'repeating_multilayer_v2'
52
+
53
+ PARAMETER_NAMES = (
54
+ "d_full_rel",
55
+ "rel_sigmas",
56
+ "dr_sigmoid_rel_pos",
57
+ "dr_sigmoid_rel_width",
58
+ "d_block",
59
+ "s_block_rel",
60
+ "r_block",
61
+ "dr",
62
+ "d3_rel",
63
+ "s3_rel",
64
+ "r3",
65
+ "d_sio2",
66
+ "s_sio2",
67
+ "s_si",
68
+ "r_sio2",
69
+ "r_si",
70
+ )
71
+
72
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
73
+ return multilayer_model2(parametrized_model, self.max_num_layers)
74
+
75
+
76
+ class BasicMultilayerModel3(MultilayerModel):
77
+ NAME = 'repeating_multilayer_v3'
78
+
79
+ PARAMETER_NAMES = (
80
+ "d_full_rel",
81
+ "rel_sigmas",
82
+ "dr_sigmoid_rel_pos",
83
+ "dr_sigmoid_rel_width",
84
+ "d_block1_rel",
85
+ "d_block",
86
+ "s_block_rel",
87
+ "r_block",
88
+ "dr",
89
+ "d3_rel",
90
+ "s3_rel",
91
+ "r3",
92
+ "d_sio2",
93
+ "s_sio2",
94
+ "s_si",
95
+ "r_sio2",
96
+ "r_si",
97
+ )
98
+
99
+ def to_standard_params(self, parametrized_model: Tensor) -> dict:
100
+ return multilayer_model3(parametrized_model, self.max_num_layers)
101
+
102
+
103
+ MULTILAYER_MODELS = {
104
+ 'repeating_multilayer_v1': BasicMultilayerModel1,
105
+ 'repeating_multilayer_v2': BasicMultilayerModel2,
106
+ 'repeating_multilayer_v3': BasicMultilayerModel3,
107
+ }
108
+
109
+
110
+ def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
111
+ n = d_full_rel_max
112
+
113
+ (
114
+ d_full_rel,
115
+ rel_sigmas,
116
+ d_block,
117
+ s_block_rel,
118
+ r_block,
119
+ dr,
120
+ d3_rel,
121
+ s3_rel,
122
+ r3,
123
+ d_sio2,
124
+ s_sio2,
125
+ s_si,
126
+ r_sio2,
127
+ r_si,
128
+ ) = parametrized_model.T
129
+
130
+ batch_size = parametrized_model.shape[0]
131
+
132
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
133
+
134
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
135
+
136
+ r_block = r_block[:, None].repeat(1, n)
137
+ dr = dr[:, None].repeat(1, n)
138
+
139
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
140
+
141
+ sld_blocks = r_modulations * sld_blocks
142
+
143
+ d3 = d3_rel * d_block
144
+
145
+ thicknesses = torch.cat(
146
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
147
+ )
148
+
149
+ s_block = s_block_rel * d_block
150
+
151
+ roughnesses = torch.cat(
152
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
153
+ )
154
+
155
+ slds = torch.cat(
156
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
157
+ )
158
+
159
+ params = dict(
160
+ thicknesses=thicknesses,
161
+ roughnesses=roughnesses,
162
+ slds=slds
163
+ )
164
+ return params
165
+
166
+
167
+ def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
168
+ n = d_full_rel_max
169
+
170
+ (
171
+ d_full_rel,
172
+ rel_sigmas,
173
+ dr_sigmoid_rel_pos,
174
+ dr_sigmoid_rel_width,
175
+ d_block,
176
+ s_block_rel,
177
+ r_block,
178
+ dr,
179
+ d3_rel,
180
+ s3_rel,
181
+ r3,
182
+ d_sio2,
183
+ s_sio2,
184
+ s_si,
185
+ r_sio2,
186
+ r_si,
187
+ ) = parametrized_model.T
188
+
189
+ batch_size = parametrized_model.shape[0]
190
+
191
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
192
+
193
+ r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
194
+
195
+ r_block = r_block[:, None].repeat(1, n)
196
+ dr = dr[:, None].repeat(1, n)
197
+
198
+ dr_positions = r_positions[:, ::2]
199
+
200
+ dr_modulations = torch.sigmoid(
201
+ -(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
202
+ )
203
+
204
+ dr = dr * dr_modulations
205
+
206
+ sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
207
+
208
+ sld_blocks = r_modulations * sld_blocks
209
+
210
+ d3 = d3_rel * d_block
211
+
212
+ thicknesses = torch.cat(
213
+ [(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
214
+ )
215
+
216
+ s_block = s_block_rel * d_block
217
+
218
+ roughnesses = torch.cat(
219
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
220
+ )
221
+
222
+ slds = torch.cat(
223
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
224
+ )
225
+
226
+ params = dict(
227
+ thicknesses=thicknesses,
228
+ roughnesses=roughnesses,
229
+ slds=slds
230
+ )
231
+ return params
232
+
233
+
234
+ def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
235
+ n = d_full_rel_max
236
+
237
+ (
238
+ d_full_rel,
239
+ rel_sigmas,
240
+ dr_sigmoid_rel_pos,
241
+ dr_sigmoid_rel_width,
242
+ d_block1_rel,
243
+ d_block,
244
+ s_block_rel,
245
+ r_block,
246
+ dr,
247
+ d3_rel,
248
+ s3_rel,
249
+ r3,
250
+ d_sio2,
251
+ s_sio2,
252
+ s_si,
253
+ r_sio2,
254
+ r_si,
255
+ ) = parametrized_model.T
256
+
257
+ batch_size = parametrized_model.shape[0]
258
+
259
+ r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
260
+
261
+ r_modulations = torch.sigmoid(
262
+ -(
263
+ r_positions - 2 * d_full_rel[..., None]
264
+ ) / rel_sigmas[..., None]
265
+ )
266
+
267
+ dr_positions = r_positions[:, ::2]
268
+
269
+ dr_modulations = dr[..., None] * (1 - torch.sigmoid(
270
+ -(
271
+ dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
272
+ ) / dr_sigmoid_rel_width[..., None]
273
+ ))
274
+
275
+ r_block = r_block[..., None].repeat(1, n)
276
+ dr = dr[..., None].repeat(1, n)
277
+
278
+ sld_blocks = torch.stack(
279
+ [
280
+ r_block + dr_modulations * (1 - d_block1_rel[..., None]),
281
+ r_block + dr - dr_modulations * d_block1_rel[..., None]
282
+ ], -1).flatten(1)
283
+
284
+ sld_blocks = r_modulations * sld_blocks
285
+
286
+ d3 = d3_rel * d_block
287
+
288
+ d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
289
+
290
+ thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
291
+
292
+ thicknesses = torch.cat(
293
+ [thickness_blocks, d3[:, None], d_sio2[:, None]], -1
294
+ )
295
+
296
+ s_block = s_block_rel * d_block
297
+
298
+ roughnesses = torch.cat(
299
+ [s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
300
+ )
301
+
302
+ slds = torch.cat(
303
+ [sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
304
+ )
305
+
306
+ params = dict(
307
+ thicknesses=thicknesses,
308
+ roughnesses=roughnesses,
309
+ slds=slds
310
+ )
311
+ return params
@@ -0,0 +1,110 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Dict
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from reflectorch.data_generation.priors.base import PriorSampler
14
+ from reflectorch.data_generation.priors.params import Params
15
+ from reflectorch.data_generation.priors.no_constraints import (
16
+ DEFAULT_DEVICE,
17
+ DEFAULT_DTYPE,
18
+ )
19
+
20
+ from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
21
+ from reflectorch.utils import to_t
22
+
23
+
24
+ class MultilayerStructureParams(Params):
25
+ pass
26
+
27
+
28
+ class SimpleMultilayerSampler(PriorSampler):
29
+ PARAM_CLS = MultilayerStructureParams
30
+
31
+ def __init__(self,
32
+ params: Dict[str, Tuple[float, float]],
33
+ model_name: str,
34
+ device: torch.device = DEFAULT_DEVICE,
35
+ dtype: torch.dtype = DEFAULT_DTYPE,
36
+ max_num_layers: int = 50,
37
+ ):
38
+ self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
39
+ self.device = device
40
+ self.dtype = dtype
41
+ self.num_layers = max_num_layers
42
+ ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
43
+ self._np_bounds = np.array(ordered_bounds).T
44
+ self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
45
+ self._param_dim = len(params)
46
+
47
+ @property
48
+ def max_num_layers(self) -> int:
49
+ return self.num_layers
50
+
51
+ @property
52
+ def param_dim(self) -> int:
53
+ return self._param_dim
54
+
55
+ def sample(self, batch_size: int) -> MultilayerStructureParams:
56
+ return self.optimized_sample(batch_size)[0]
57
+
58
+ def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
59
+ scaled_params = torch.rand(
60
+ batch_size,
61
+ self.min_bounds.shape[-1],
62
+ device=self.min_bounds.device,
63
+ dtype=self.min_bounds.dtype,
64
+ )
65
+
66
+ targets = self.restore_params(scaled_params)
67
+
68
+ return targets, scaled_params
69
+
70
+ def get_np_bounds(self):
71
+ return np.array(self._np_bounds)
72
+
73
+ def restore_np_params(self, params: np.ndarray):
74
+ p = self.multilayer_model.to_standard_params(
75
+ torch.atleast_2d(to_t(params))
76
+ )
77
+
78
+ return {
79
+ 'thickness': p['thicknesses'].squeeze().cpu().numpy(),
80
+ 'roughness': p['roughnesses'].squeeze().cpu().numpy(),
81
+ 'sld': p['slds'].squeeze().cpu().numpy()
82
+ }
83
+
84
+ def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
85
+ return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
86
+
87
+ def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
88
+ return self.to_standard_params(self.restore_params2parametrized(scaled_params))
89
+
90
+ def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
91
+ return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
92
+
93
+ def scale_params(self, params: Params) -> Tensor:
94
+ raise NotImplementedError
95
+
96
+ def log_prob(self, params: Params) -> Tensor:
97
+ raise NotImplementedError
98
+
99
+ def get_indices_within_domain(self, params: Params) -> Tensor:
100
+ raise NotImplementedError
101
+
102
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
103
+ raise NotImplementedError
104
+
105
+ def filter_params(self, params: Params) -> Params:
106
+ indices = self.get_indices_within_domain(params)
107
+ return params[indices]
108
+
109
+ def clamp_params(self, params: Params) -> Params:
110
+ raise NotImplementedError
@@ -0,0 +1,212 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from functools import lru_cache
9
+ from typing import Tuple
10
+ from math import sqrt
11
+
12
+ import torch
13
+ from torch import Tensor
14
+
15
+ from reflectorch.data_generation.utils import (
16
+ get_slds_from_d_rhos,
17
+ uniform_sampler,
18
+ )
19
+
20
+ from reflectorch.data_generation.priors.utils import (
21
+ get_allowed_roughness_indices,
22
+ generate_roughnesses,
23
+ params_within_bounds,
24
+ )
25
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
26
+ from reflectorch.data_generation.priors.base import PriorSampler
27
+ from reflectorch.data_generation.priors.params import Params
28
+
29
+ __all__ = [
30
+ "BasicPriorSampler",
31
+ "DEFAULT_ROUGHNESS_RANGE",
32
+ "DEFAULT_THICKNESS_RANGE",
33
+ "DEFAULT_SLD_RANGE",
34
+ "DEFAULT_NUM_LAYERS",
35
+ "DEFAULT_DEVICE",
36
+ "DEFAULT_DTYPE",
37
+ "DEFAULT_SCALED_RANGE",
38
+ "DEFAULT_USE_DRHO",
39
+ ]
40
+
41
+ DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
42
+ DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
43
+ DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
44
+ DEFAULT_NUM_LAYERS: int = 5
45
+ DEFAULT_USE_DRHO: bool = False
46
+ DEFAULT_DEVICE: torch.device = torch.device('cuda')
47
+ DEFAULT_DTYPE: torch.dtype = torch.float64
48
+ DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
49
+
50
+
51
+ class BasicPriorSampler(PriorSampler, ScalerMixin):
52
+ """Prior samplers for thicknesses, roughnesses and slds"""
53
+ def __init__(self,
54
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
55
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
56
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
57
+ num_layers: int = DEFAULT_NUM_LAYERS,
58
+ use_drho: bool = DEFAULT_USE_DRHO,
59
+ device: torch.device = DEFAULT_DEVICE,
60
+ dtype: torch.dtype = DEFAULT_DTYPE,
61
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
62
+ restrict_roughnesses: bool = True,
63
+ ):
64
+ self.logger = logging.getLogger(__name__)
65
+ self.thickness_range = thickness_range
66
+ self.roughness_range = roughness_range
67
+ self.sld_range = sld_range
68
+ self.num_layers = num_layers
69
+ self.device = device
70
+ self.dtype = dtype
71
+ self.scaled_range = scaled_range
72
+ self.use_drho = use_drho
73
+ self.restrict_roughnesses = restrict_roughnesses
74
+
75
+ @property
76
+ def max_num_layers(self) -> int:
77
+ return self.num_layers
78
+
79
+ @lru_cache()
80
+ def min_vector(self, layers_num, drho: bool = False):
81
+ if drho:
82
+ sld_min = self.sld_range[0] - self.sld_range[1]
83
+ else:
84
+ sld_min = self.sld_range[0]
85
+
86
+ return torch.tensor(
87
+ [self.thickness_range[0]] * layers_num +
88
+ [self.roughness_range[0]] * (layers_num + 1) +
89
+ [sld_min] * (layers_num + 1),
90
+ device=self.device,
91
+ dtype=self.dtype
92
+ )
93
+
94
+ @lru_cache()
95
+ def max_vector(self, layers_num, drho: bool = False):
96
+ if drho:
97
+ sld_max = self.sld_range[1] - self.sld_range[0]
98
+ else:
99
+ sld_max = self.sld_range[1]
100
+ return torch.tensor(
101
+ [self.thickness_range[1]] * layers_num +
102
+ [self.roughness_range[1]] * (layers_num + 1) +
103
+ [sld_max] * (layers_num + 1),
104
+ device=self.device,
105
+ dtype=self.dtype
106
+ )
107
+
108
+ @lru_cache()
109
+ def delta_vector(self, layers_num, drho: bool = False):
110
+ return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
111
+
112
+ def restore_params(self, scaled_params: Tensor) -> Params:
113
+ layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
114
+
115
+ params_t = self._restore(
116
+ scaled_params,
117
+ self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
118
+ self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
119
+ )
120
+
121
+ params = self.PARAM_CLS.from_tensor(params_t)
122
+
123
+ if self.use_drho:
124
+ params.slds = get_slds_from_d_rhos(params.slds)
125
+
126
+ return params
127
+
128
+ def scale_params(self, params: Params) -> Tensor:
129
+ layers_num = params.max_layer_num
130
+
131
+ return self._scale(
132
+ params.as_tensor(use_drho=self.use_drho),
133
+ self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
134
+ self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
135
+ )
136
+
137
+ def get_indices_within_bounds(self, params: Params) -> Tensor:
138
+ layer_num = params.max_layer_num
139
+
140
+ return params_within_bounds(
141
+ params.as_tensor(),
142
+ self.min_vector(layer_num),
143
+ self.max_vector(layer_num),
144
+ )
145
+
146
+ def get_indices_within_domain(self, params: Params) -> Tensor:
147
+ if self.restrict_roughnesses:
148
+ indices = (
149
+ self.get_indices_within_bounds(params) &
150
+ self.get_allowed_roughness_indices(params)
151
+ )
152
+ else:
153
+ indices = self.get_indices_within_bounds(params)
154
+ return indices
155
+
156
+ def clamp_params(self, params: Params) -> Params:
157
+ layer_num = params.max_layer_num
158
+ params = params.as_tensor()
159
+ params = torch.clamp(
160
+ params,
161
+ self.min_vector(layer_num),
162
+ self.max_vector(layer_num),
163
+ )
164
+ params = Params.from_tensor(params)
165
+ return params
166
+
167
+ @staticmethod
168
+ def get_allowed_roughness_indices(params: Params) -> Tensor:
169
+ return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
170
+
171
+ def log_prob(self, params: Params) -> Tensor:
172
+ # so far we ignore non-uniform distribution of roughnesses and slds.
173
+ log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
174
+ indices = self.get_indices_within_bounds(params)
175
+ log_prob[~indices] = float('-inf')
176
+ return log_prob
177
+
178
+ def sample(self, batch_size: int) -> Params:
179
+ slds = self.generate_slds(batch_size)
180
+ thicknesses = self.generate_thicknesses(batch_size)
181
+ roughnesses = self.generate_roughnesses(thicknesses)
182
+
183
+ params = Params(thicknesses, roughnesses, slds)
184
+
185
+ return params
186
+
187
+ def generate_slds(self, batch_size: int):
188
+ return uniform_sampler(
189
+ *self.sld_range, batch_size,
190
+ self.num_layers + 1,
191
+ device=self.device,
192
+ dtype=self.dtype
193
+ )
194
+
195
+ def generate_thicknesses(self, batch_size: int):
196
+ return uniform_sampler(
197
+ *self.thickness_range, batch_size,
198
+ self.num_layers,
199
+ device=self.device,
200
+ dtype=self.dtype
201
+ )
202
+
203
+ def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
204
+ if self.restrict_roughnesses:
205
+ return generate_roughnesses(thicknesses, self.roughness_range)
206
+ else:
207
+ return uniform_sampler(
208
+ *self.roughness_range, thicknesses.shape[0],
209
+ self.num_layers + 1,
210
+ device=self.device,
211
+ dtype=self.dtype
212
+ )