reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"MULTILAYER_MODELS",
|
|
8
|
+
"MultilayerModel",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultilayerModel(object):
|
|
13
|
+
NAME: str = ''
|
|
14
|
+
PARAMETER_NAMES: Tuple[str, ...]
|
|
15
|
+
|
|
16
|
+
def __init__(self, max_num_layers: int):
|
|
17
|
+
self.max_num_layers = max_num_layers
|
|
18
|
+
|
|
19
|
+
def to_standard_params(self, parametrized_model: Tensor) -> dict:
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
def from_standard_params(self, params: dict) -> Tensor:
|
|
23
|
+
raise NotImplementedError
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BasicMultilayerModel1(MultilayerModel):
|
|
27
|
+
NAME = 'repeating_multilayer_v1'
|
|
28
|
+
|
|
29
|
+
PARAMETER_NAMES = (
|
|
30
|
+
"d_full_rel",
|
|
31
|
+
"rel_sigmas",
|
|
32
|
+
"d_block",
|
|
33
|
+
"s_block_rel",
|
|
34
|
+
"r_block",
|
|
35
|
+
"dr",
|
|
36
|
+
"d3_rel",
|
|
37
|
+
"s3_rel",
|
|
38
|
+
"r3",
|
|
39
|
+
"d_sio2",
|
|
40
|
+
"s_sio2",
|
|
41
|
+
"s_si",
|
|
42
|
+
"r_sio2",
|
|
43
|
+
"r_si",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def to_standard_params(self, parametrized_model: Tensor) -> dict:
|
|
47
|
+
return multilayer_model1(parametrized_model, self.max_num_layers)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BasicMultilayerModel2(MultilayerModel):
|
|
51
|
+
NAME = 'repeating_multilayer_v2'
|
|
52
|
+
|
|
53
|
+
PARAMETER_NAMES = (
|
|
54
|
+
"d_full_rel",
|
|
55
|
+
"rel_sigmas",
|
|
56
|
+
"dr_sigmoid_rel_pos",
|
|
57
|
+
"dr_sigmoid_rel_width",
|
|
58
|
+
"d_block",
|
|
59
|
+
"s_block_rel",
|
|
60
|
+
"r_block",
|
|
61
|
+
"dr",
|
|
62
|
+
"d3_rel",
|
|
63
|
+
"s3_rel",
|
|
64
|
+
"r3",
|
|
65
|
+
"d_sio2",
|
|
66
|
+
"s_sio2",
|
|
67
|
+
"s_si",
|
|
68
|
+
"r_sio2",
|
|
69
|
+
"r_si",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def to_standard_params(self, parametrized_model: Tensor) -> dict:
|
|
73
|
+
return multilayer_model2(parametrized_model, self.max_num_layers)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class BasicMultilayerModel3(MultilayerModel):
|
|
77
|
+
NAME = 'repeating_multilayer_v3'
|
|
78
|
+
|
|
79
|
+
PARAMETER_NAMES = (
|
|
80
|
+
"d_full_rel",
|
|
81
|
+
"rel_sigmas",
|
|
82
|
+
"dr_sigmoid_rel_pos",
|
|
83
|
+
"dr_sigmoid_rel_width",
|
|
84
|
+
"d_block1_rel",
|
|
85
|
+
"d_block",
|
|
86
|
+
"s_block_rel",
|
|
87
|
+
"r_block",
|
|
88
|
+
"dr",
|
|
89
|
+
"d3_rel",
|
|
90
|
+
"s3_rel",
|
|
91
|
+
"r3",
|
|
92
|
+
"d_sio2",
|
|
93
|
+
"s_sio2",
|
|
94
|
+
"s_si",
|
|
95
|
+
"r_sio2",
|
|
96
|
+
"r_si",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def to_standard_params(self, parametrized_model: Tensor) -> dict:
|
|
100
|
+
return multilayer_model3(parametrized_model, self.max_num_layers)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
MULTILAYER_MODELS = {
|
|
104
|
+
'repeating_multilayer_v1': BasicMultilayerModel1,
|
|
105
|
+
'repeating_multilayer_v2': BasicMultilayerModel2,
|
|
106
|
+
'repeating_multilayer_v3': BasicMultilayerModel3,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def multilayer_model1(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
|
|
111
|
+
n = d_full_rel_max
|
|
112
|
+
|
|
113
|
+
(
|
|
114
|
+
d_full_rel,
|
|
115
|
+
rel_sigmas,
|
|
116
|
+
d_block,
|
|
117
|
+
s_block_rel,
|
|
118
|
+
r_block,
|
|
119
|
+
dr,
|
|
120
|
+
d3_rel,
|
|
121
|
+
s3_rel,
|
|
122
|
+
r3,
|
|
123
|
+
d_sio2,
|
|
124
|
+
s_sio2,
|
|
125
|
+
s_si,
|
|
126
|
+
r_sio2,
|
|
127
|
+
r_si,
|
|
128
|
+
) = parametrized_model.T
|
|
129
|
+
|
|
130
|
+
batch_size = parametrized_model.shape[0]
|
|
131
|
+
|
|
132
|
+
r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
|
|
133
|
+
|
|
134
|
+
r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
|
|
135
|
+
|
|
136
|
+
r_block = r_block[:, None].repeat(1, n)
|
|
137
|
+
dr = dr[:, None].repeat(1, n)
|
|
138
|
+
|
|
139
|
+
sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
|
|
140
|
+
|
|
141
|
+
sld_blocks = r_modulations * sld_blocks
|
|
142
|
+
|
|
143
|
+
d3 = d3_rel * d_block
|
|
144
|
+
|
|
145
|
+
thicknesses = torch.cat(
|
|
146
|
+
[(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
s_block = s_block_rel * d_block
|
|
150
|
+
|
|
151
|
+
roughnesses = torch.cat(
|
|
152
|
+
[s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
slds = torch.cat(
|
|
156
|
+
[sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
params = dict(
|
|
160
|
+
thicknesses=thicknesses,
|
|
161
|
+
roughnesses=roughnesses,
|
|
162
|
+
slds=slds
|
|
163
|
+
)
|
|
164
|
+
return params
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def multilayer_model2(parametrized_model: Tensor, d_full_rel_max: int = 50) -> dict:
|
|
168
|
+
n = d_full_rel_max
|
|
169
|
+
|
|
170
|
+
(
|
|
171
|
+
d_full_rel,
|
|
172
|
+
rel_sigmas,
|
|
173
|
+
dr_sigmoid_rel_pos,
|
|
174
|
+
dr_sigmoid_rel_width,
|
|
175
|
+
d_block,
|
|
176
|
+
s_block_rel,
|
|
177
|
+
r_block,
|
|
178
|
+
dr,
|
|
179
|
+
d3_rel,
|
|
180
|
+
s3_rel,
|
|
181
|
+
r3,
|
|
182
|
+
d_sio2,
|
|
183
|
+
s_sio2,
|
|
184
|
+
s_si,
|
|
185
|
+
r_sio2,
|
|
186
|
+
r_si,
|
|
187
|
+
) = parametrized_model.T
|
|
188
|
+
|
|
189
|
+
batch_size = parametrized_model.shape[0]
|
|
190
|
+
|
|
191
|
+
r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
|
|
192
|
+
|
|
193
|
+
r_modulations = torch.sigmoid(-(r_positions - 2 * d_full_rel[..., None]) / rel_sigmas[..., None])
|
|
194
|
+
|
|
195
|
+
r_block = r_block[:, None].repeat(1, n)
|
|
196
|
+
dr = dr[:, None].repeat(1, n)
|
|
197
|
+
|
|
198
|
+
dr_positions = r_positions[:, ::2]
|
|
199
|
+
|
|
200
|
+
dr_modulations = torch.sigmoid(
|
|
201
|
+
-(dr_positions - (2 * d_full_rel * dr_sigmoid_rel_pos)[..., None]) / dr_sigmoid_rel_width[..., None]
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
dr = dr * dr_modulations
|
|
205
|
+
|
|
206
|
+
sld_blocks = torch.stack([r_block, r_block + dr], -1).flatten(1)
|
|
207
|
+
|
|
208
|
+
sld_blocks = r_modulations * sld_blocks
|
|
209
|
+
|
|
210
|
+
d3 = d3_rel * d_block
|
|
211
|
+
|
|
212
|
+
thicknesses = torch.cat(
|
|
213
|
+
[(d_block / 2)[:, None].repeat(1, n * 2), d3[:, None], d_sio2[:, None]], -1
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
s_block = s_block_rel * d_block
|
|
217
|
+
|
|
218
|
+
roughnesses = torch.cat(
|
|
219
|
+
[s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
slds = torch.cat(
|
|
223
|
+
[sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
params = dict(
|
|
227
|
+
thicknesses=thicknesses,
|
|
228
|
+
roughnesses=roughnesses,
|
|
229
|
+
slds=slds
|
|
230
|
+
)
|
|
231
|
+
return params
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def multilayer_model3(parametrized_model: Tensor, d_full_rel_max: int = 30):
|
|
235
|
+
n = d_full_rel_max
|
|
236
|
+
|
|
237
|
+
(
|
|
238
|
+
d_full_rel,
|
|
239
|
+
rel_sigmas,
|
|
240
|
+
dr_sigmoid_rel_pos,
|
|
241
|
+
dr_sigmoid_rel_width,
|
|
242
|
+
d_block1_rel,
|
|
243
|
+
d_block,
|
|
244
|
+
s_block_rel,
|
|
245
|
+
r_block,
|
|
246
|
+
dr,
|
|
247
|
+
d3_rel,
|
|
248
|
+
s3_rel,
|
|
249
|
+
r3,
|
|
250
|
+
d_sio2,
|
|
251
|
+
s_sio2,
|
|
252
|
+
s_si,
|
|
253
|
+
r_sio2,
|
|
254
|
+
r_si,
|
|
255
|
+
) = parametrized_model.T
|
|
256
|
+
|
|
257
|
+
batch_size = parametrized_model.shape[0]
|
|
258
|
+
|
|
259
|
+
r_positions = 2 * n - torch.arange(2 * n, dtype=dr.dtype, device=dr.device)[None].repeat(batch_size, 1)
|
|
260
|
+
|
|
261
|
+
r_modulations = torch.sigmoid(
|
|
262
|
+
-(
|
|
263
|
+
r_positions - 2 * d_full_rel[..., None]
|
|
264
|
+
) / rel_sigmas[..., None]
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
dr_positions = r_positions[:, ::2]
|
|
268
|
+
|
|
269
|
+
dr_modulations = dr[..., None] * (1 - torch.sigmoid(
|
|
270
|
+
-(
|
|
271
|
+
dr_positions - 2 * d_full_rel[..., None] + 2 * dr_sigmoid_rel_pos[..., None]
|
|
272
|
+
) / dr_sigmoid_rel_width[..., None]
|
|
273
|
+
))
|
|
274
|
+
|
|
275
|
+
r_block = r_block[..., None].repeat(1, n)
|
|
276
|
+
dr = dr[..., None].repeat(1, n)
|
|
277
|
+
|
|
278
|
+
sld_blocks = torch.stack(
|
|
279
|
+
[
|
|
280
|
+
r_block + dr_modulations * (1 - d_block1_rel[..., None]),
|
|
281
|
+
r_block + dr - dr_modulations * d_block1_rel[..., None]
|
|
282
|
+
], -1).flatten(1)
|
|
283
|
+
|
|
284
|
+
sld_blocks = r_modulations * sld_blocks
|
|
285
|
+
|
|
286
|
+
d3 = d3_rel * d_block
|
|
287
|
+
|
|
288
|
+
d1, d2 = d_block * d_block1_rel, d_block * (1 - d_block1_rel)
|
|
289
|
+
|
|
290
|
+
thickness_blocks = torch.stack([d1[:, None].repeat(1, n), d2[:, None].repeat(1, n)], -1).flatten(1)
|
|
291
|
+
|
|
292
|
+
thicknesses = torch.cat(
|
|
293
|
+
[thickness_blocks, d3[:, None], d_sio2[:, None]], -1
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
s_block = s_block_rel * d_block
|
|
297
|
+
|
|
298
|
+
roughnesses = torch.cat(
|
|
299
|
+
[s_block[:, None].repeat(1, n * 2), (s3_rel * d3)[:, None], s_sio2[:, None], s_si[:, None]], -1
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
slds = torch.cat(
|
|
303
|
+
[sld_blocks, r3[:, None], r_sio2[:, None], r_si[:, None]], -1
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
params = dict(
|
|
307
|
+
thicknesses=thicknesses,
|
|
308
|
+
roughnesses=roughnesses,
|
|
309
|
+
slds=slds
|
|
310
|
+
)
|
|
311
|
+
return params
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Tuple, Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
8
|
+
from reflectorch.data_generation.priors.params import Params
|
|
9
|
+
from reflectorch.data_generation.priors.no_constraints import (
|
|
10
|
+
DEFAULT_DEVICE,
|
|
11
|
+
DEFAULT_DTYPE,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors.multilayer_models import MULTILAYER_MODELS, MultilayerModel
|
|
15
|
+
from reflectorch.utils import to_t
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MultilayerStructureParams(Params):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SimpleMultilayerSampler(PriorSampler):
|
|
23
|
+
PARAM_CLS = MultilayerStructureParams
|
|
24
|
+
|
|
25
|
+
def __init__(self,
|
|
26
|
+
params: Dict[str, Tuple[float, float]],
|
|
27
|
+
model_name: str,
|
|
28
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
29
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
30
|
+
max_num_layers: int = 50,
|
|
31
|
+
):
|
|
32
|
+
self.multilayer_model: MultilayerModel = MULTILAYER_MODELS[model_name](max_num_layers)
|
|
33
|
+
self.device = device
|
|
34
|
+
self.dtype = dtype
|
|
35
|
+
self.num_layers = max_num_layers
|
|
36
|
+
ordered_bounds = [params[k] for k in self.multilayer_model.PARAMETER_NAMES]
|
|
37
|
+
self._np_bounds = np.array(ordered_bounds).T
|
|
38
|
+
self.min_bounds, self.max_bounds = torch.tensor(ordered_bounds, device=device, dtype=dtype).T[:, None]
|
|
39
|
+
self._param_dim = len(params)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def max_num_layers(self) -> int:
|
|
43
|
+
return self.num_layers
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def param_dim(self) -> int:
|
|
47
|
+
return self._param_dim
|
|
48
|
+
|
|
49
|
+
def sample(self, batch_size: int) -> MultilayerStructureParams:
|
|
50
|
+
return self.optimized_sample(batch_size)[0]
|
|
51
|
+
|
|
52
|
+
def optimized_sample(self, batch_size: int) -> Tuple[MultilayerStructureParams, Tensor]:
|
|
53
|
+
scaled_params = torch.rand(
|
|
54
|
+
batch_size,
|
|
55
|
+
self.min_bounds.shape[-1],
|
|
56
|
+
device=self.min_bounds.device,
|
|
57
|
+
dtype=self.min_bounds.dtype,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
targets = self.restore_params(scaled_params)
|
|
61
|
+
|
|
62
|
+
return targets, scaled_params
|
|
63
|
+
|
|
64
|
+
def get_np_bounds(self):
|
|
65
|
+
return np.array(self._np_bounds)
|
|
66
|
+
|
|
67
|
+
def restore_np_params(self, params: np.ndarray):
|
|
68
|
+
p = self.multilayer_model.to_standard_params(
|
|
69
|
+
torch.atleast_2d(to_t(params))
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
'thickness': p['thicknesses'].squeeze().cpu().numpy(),
|
|
74
|
+
'roughness': p['roughnesses'].squeeze().cpu().numpy(),
|
|
75
|
+
'sld': p['slds'].squeeze().cpu().numpy()
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def restore_params2parametrized(self, scaled_params: Tensor) -> Tensor:
|
|
79
|
+
return scaled_params * (self.max_bounds - self.min_bounds) + self.min_bounds
|
|
80
|
+
|
|
81
|
+
def restore_params(self, scaled_params: Tensor) -> MultilayerStructureParams:
|
|
82
|
+
return self.to_standard_params(self.restore_params2parametrized(scaled_params))
|
|
83
|
+
|
|
84
|
+
def to_standard_params(self, params: Tensor) -> MultilayerStructureParams:
|
|
85
|
+
return MultilayerStructureParams(**self.multilayer_model.to_standard_params(params))
|
|
86
|
+
|
|
87
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
|
|
96
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
def filter_params(self, params: Params) -> Params:
|
|
100
|
+
indices = self.get_indices_within_domain(params)
|
|
101
|
+
return params[indices]
|
|
102
|
+
|
|
103
|
+
def clamp_params(self, params: Params) -> Params:
|
|
104
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from math import sqrt
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.utils import (
|
|
10
|
+
get_slds_from_d_rhos,
|
|
11
|
+
uniform_sampler,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors.utils import (
|
|
15
|
+
get_allowed_roughness_indices,
|
|
16
|
+
generate_roughnesses,
|
|
17
|
+
params_within_bounds,
|
|
18
|
+
)
|
|
19
|
+
from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
|
|
20
|
+
from reflectorch.data_generation.priors.base import PriorSampler
|
|
21
|
+
from reflectorch.data_generation.priors.params import Params
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"BasicPriorSampler",
|
|
25
|
+
"DEFAULT_ROUGHNESS_RANGE",
|
|
26
|
+
"DEFAULT_THICKNESS_RANGE",
|
|
27
|
+
"DEFAULT_SLD_RANGE",
|
|
28
|
+
"DEFAULT_NUM_LAYERS",
|
|
29
|
+
"DEFAULT_DEVICE",
|
|
30
|
+
"DEFAULT_DTYPE",
|
|
31
|
+
"DEFAULT_SCALED_RANGE",
|
|
32
|
+
"DEFAULT_USE_DRHO",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
DEFAULT_THICKNESS_RANGE: Tuple[float, float] = (1., 500.)
|
|
36
|
+
DEFAULT_ROUGHNESS_RANGE: Tuple[float, float] = (0., 50.)
|
|
37
|
+
DEFAULT_SLD_RANGE: Tuple[float, float] = (-10., 30.)
|
|
38
|
+
DEFAULT_NUM_LAYERS: int = 5
|
|
39
|
+
DEFAULT_USE_DRHO: bool = False
|
|
40
|
+
DEFAULT_DEVICE: torch.device = torch.device('cuda')
|
|
41
|
+
DEFAULT_DTYPE: torch.dtype = torch.float64
|
|
42
|
+
DEFAULT_SCALED_RANGE: Tuple[float, float] = (-sqrt(3.), sqrt(3.))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BasicPriorSampler(PriorSampler, ScalerMixin):
|
|
46
|
+
"""Prior samplers for thicknesses, roughnesses and slds"""
|
|
47
|
+
def __init__(self,
|
|
48
|
+
thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
|
|
49
|
+
roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
|
|
50
|
+
sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
|
|
51
|
+
num_layers: int = DEFAULT_NUM_LAYERS,
|
|
52
|
+
use_drho: bool = DEFAULT_USE_DRHO,
|
|
53
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
54
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
55
|
+
scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
|
|
56
|
+
restrict_roughnesses: bool = True,
|
|
57
|
+
):
|
|
58
|
+
self.logger = logging.getLogger(__name__)
|
|
59
|
+
self.thickness_range = thickness_range
|
|
60
|
+
self.roughness_range = roughness_range
|
|
61
|
+
self.sld_range = sld_range
|
|
62
|
+
self.num_layers = num_layers
|
|
63
|
+
self.device = device
|
|
64
|
+
self.dtype = dtype
|
|
65
|
+
self.scaled_range = scaled_range
|
|
66
|
+
self.use_drho = use_drho
|
|
67
|
+
self.restrict_roughnesses = restrict_roughnesses
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def max_num_layers(self) -> int:
|
|
71
|
+
return self.num_layers
|
|
72
|
+
|
|
73
|
+
@lru_cache()
|
|
74
|
+
def min_vector(self, layers_num, drho: bool = False):
|
|
75
|
+
if drho:
|
|
76
|
+
sld_min = self.sld_range[0] - self.sld_range[1]
|
|
77
|
+
else:
|
|
78
|
+
sld_min = self.sld_range[0]
|
|
79
|
+
|
|
80
|
+
return torch.tensor(
|
|
81
|
+
[self.thickness_range[0]] * layers_num +
|
|
82
|
+
[self.roughness_range[0]] * (layers_num + 1) +
|
|
83
|
+
[sld_min] * (layers_num + 1),
|
|
84
|
+
device=self.device,
|
|
85
|
+
dtype=self.dtype
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@lru_cache()
|
|
89
|
+
def max_vector(self, layers_num, drho: bool = False):
|
|
90
|
+
if drho:
|
|
91
|
+
sld_max = self.sld_range[1] - self.sld_range[0]
|
|
92
|
+
else:
|
|
93
|
+
sld_max = self.sld_range[1]
|
|
94
|
+
return torch.tensor(
|
|
95
|
+
[self.thickness_range[1]] * layers_num +
|
|
96
|
+
[self.roughness_range[1]] * (layers_num + 1) +
|
|
97
|
+
[sld_max] * (layers_num + 1),
|
|
98
|
+
device=self.device,
|
|
99
|
+
dtype=self.dtype
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@lru_cache()
|
|
103
|
+
def delta_vector(self, layers_num, drho: bool = False):
|
|
104
|
+
return self._get_delta_vector(self.min_vector(layers_num, drho), self.max_vector(layers_num, drho))
|
|
105
|
+
|
|
106
|
+
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
107
|
+
layers_num = self.PARAM_CLS.size2layers_num(scaled_params.shape[-1])
|
|
108
|
+
|
|
109
|
+
params_t = self._restore(
|
|
110
|
+
scaled_params,
|
|
111
|
+
self.min_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
112
|
+
self.max_vector(layers_num, drho=self.use_drho).to(scaled_params),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
params = self.PARAM_CLS.from_tensor(params_t)
|
|
116
|
+
|
|
117
|
+
if self.use_drho:
|
|
118
|
+
params.slds = get_slds_from_d_rhos(params.slds)
|
|
119
|
+
|
|
120
|
+
return params
|
|
121
|
+
|
|
122
|
+
def scale_params(self, params: Params) -> Tensor:
|
|
123
|
+
layers_num = params.max_layer_num
|
|
124
|
+
|
|
125
|
+
return self._scale(
|
|
126
|
+
params.as_tensor(use_drho=self.use_drho),
|
|
127
|
+
self.min_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
128
|
+
self.max_vector(layers_num, drho=self.use_drho).to(params.thicknesses),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def get_indices_within_bounds(self, params: Params) -> Tensor:
|
|
132
|
+
layer_num = params.max_layer_num
|
|
133
|
+
|
|
134
|
+
return params_within_bounds(
|
|
135
|
+
params.as_tensor(),
|
|
136
|
+
self.min_vector(layer_num),
|
|
137
|
+
self.max_vector(layer_num),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def get_indices_within_domain(self, params: Params) -> Tensor:
|
|
141
|
+
if self.restrict_roughnesses:
|
|
142
|
+
indices = (
|
|
143
|
+
self.get_indices_within_bounds(params) &
|
|
144
|
+
self.get_allowed_roughness_indices(params)
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
indices = self.get_indices_within_bounds(params)
|
|
148
|
+
return indices
|
|
149
|
+
|
|
150
|
+
def clamp_params(self, params: Params) -> Params:
|
|
151
|
+
layer_num = params.max_layer_num
|
|
152
|
+
params = params.as_tensor()
|
|
153
|
+
params = torch.clamp(
|
|
154
|
+
params,
|
|
155
|
+
self.min_vector(layer_num),
|
|
156
|
+
self.max_vector(layer_num),
|
|
157
|
+
)
|
|
158
|
+
params = Params.from_tensor(params)
|
|
159
|
+
return params
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def get_allowed_roughness_indices(params: Params) -> Tensor:
|
|
163
|
+
return get_allowed_roughness_indices(params.thicknesses, params.roughnesses)
|
|
164
|
+
|
|
165
|
+
def log_prob(self, params: Params) -> Tensor:
|
|
166
|
+
# so far we ignore non-uniform distribution of roughnesses and slds.
|
|
167
|
+
log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
|
|
168
|
+
indices = self.get_indices_within_bounds(params)
|
|
169
|
+
log_prob[~indices] = float('-inf')
|
|
170
|
+
return log_prob
|
|
171
|
+
|
|
172
|
+
def sample(self, batch_size: int) -> Params:
|
|
173
|
+
slds = self.generate_slds(batch_size)
|
|
174
|
+
thicknesses = self.generate_thicknesses(batch_size)
|
|
175
|
+
roughnesses = self.generate_roughnesses(thicknesses)
|
|
176
|
+
|
|
177
|
+
params = Params(thicknesses, roughnesses, slds)
|
|
178
|
+
|
|
179
|
+
return params
|
|
180
|
+
|
|
181
|
+
def generate_slds(self, batch_size: int):
|
|
182
|
+
return uniform_sampler(
|
|
183
|
+
*self.sld_range, batch_size,
|
|
184
|
+
self.num_layers + 1,
|
|
185
|
+
device=self.device,
|
|
186
|
+
dtype=self.dtype
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def generate_thicknesses(self, batch_size: int):
|
|
190
|
+
return uniform_sampler(
|
|
191
|
+
*self.thickness_range, batch_size,
|
|
192
|
+
self.num_layers,
|
|
193
|
+
device=self.device,
|
|
194
|
+
dtype=self.dtype
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def generate_roughnesses(self, thicknesses: Tensor) -> Tensor:
|
|
198
|
+
if self.restrict_roughnesses:
|
|
199
|
+
return generate_roughnesses(thicknesses, self.roughness_range)
|
|
200
|
+
else:
|
|
201
|
+
return uniform_sampler(
|
|
202
|
+
*self.roughness_range, thicknesses.shape[0],
|
|
203
|
+
self.num_layers + 1,
|
|
204
|
+
device=self.device,
|
|
205
|
+
dtype=self.dtype
|
|
206
|
+
)
|