reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.utils import (
|
|
8
|
+
uniform_sampler,
|
|
9
|
+
logdist_sampler,
|
|
10
|
+
triangular_sampler,
|
|
11
|
+
get_slds_from_d_rhos,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from reflectorch.data_generation.priors.params import Params
|
|
15
|
+
from reflectorch.data_generation.priors.no_constraints import (
|
|
16
|
+
BasicPriorSampler,
|
|
17
|
+
DEFAULT_ROUGHNESS_RANGE,
|
|
18
|
+
DEFAULT_THICKNESS_RANGE,
|
|
19
|
+
DEFAULT_SLD_RANGE,
|
|
20
|
+
DEFAULT_NUM_LAYERS,
|
|
21
|
+
DEFAULT_DEVICE,
|
|
22
|
+
DEFAULT_DTYPE,
|
|
23
|
+
DEFAULT_SCALED_RANGE,
|
|
24
|
+
DEFAULT_USE_DRHO,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class UniformSubPriorParams(Params):
|
|
29
|
+
"""Parameters class for thicknesses, roughnesses and slds, together with their subprior bounds."""
|
|
30
|
+
__slots__ = ('thicknesses', 'roughnesses', 'slds', 'min_bounds', 'max_bounds')
|
|
31
|
+
PARAM_NAMES = __slots__
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
thicknesses: Tensor,
|
|
35
|
+
roughnesses: Tensor,
|
|
36
|
+
slds: Tensor,
|
|
37
|
+
min_bounds: Tensor,
|
|
38
|
+
max_bounds: Tensor,
|
|
39
|
+
):
|
|
40
|
+
super().__init__(thicknesses, roughnesses, slds)
|
|
41
|
+
self.min_bounds = min_bounds
|
|
42
|
+
self.max_bounds = max_bounds
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def rearrange_context_from_params(
|
|
46
|
+
scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
|
|
47
|
+
):
|
|
48
|
+
if inference:
|
|
49
|
+
if from_params:
|
|
50
|
+
num_params = scaled_params.shape[1] // 3
|
|
51
|
+
scaled_params = scaled_params[:, num_params:]
|
|
52
|
+
context = torch.cat([context, scaled_params], dim=-1)
|
|
53
|
+
return context
|
|
54
|
+
|
|
55
|
+
num_params = scaled_params.shape[1] // 3
|
|
56
|
+
assert num_params * 3 == scaled_params.shape[1]
|
|
57
|
+
scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
|
|
58
|
+
context = torch.cat([context, bound_context], dim=-1)
|
|
59
|
+
return scaled_params, context
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def restore_params_from_context(scaled_params: Tensor, context: Tensor):
|
|
63
|
+
num_params = scaled_params.shape[-1]
|
|
64
|
+
scaled_bounds = context[:, -2 * num_params:]
|
|
65
|
+
scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
|
|
66
|
+
return scaled_params
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def input_context_split(t_params):
|
|
70
|
+
num_params = t_params.shape[1] // 3
|
|
71
|
+
return torch.split(t_params, [num_params, 2 * num_params])
|
|
72
|
+
|
|
73
|
+
def as_tensor(self, use_drho: bool = False, add_bounds: bool = True) -> Tensor:
|
|
74
|
+
t_list = [self.thicknesses, self.roughnesses]
|
|
75
|
+
if use_drho:
|
|
76
|
+
t_list.append(self.d_rhos)
|
|
77
|
+
else:
|
|
78
|
+
t_list.append(self.slds)
|
|
79
|
+
if add_bounds:
|
|
80
|
+
t_list += [self.min_bounds, self.max_bounds]
|
|
81
|
+
return torch.cat(t_list, -1)
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def from_tensor(cls, params: Tensor):
|
|
85
|
+
layers_num = (params.shape[-1] - 6) // 9
|
|
86
|
+
num_params = 3 * layers_num + 2
|
|
87
|
+
|
|
88
|
+
thicknesses, roughnesses, slds, min_bounds, max_bounds = torch.split(
|
|
89
|
+
params,
|
|
90
|
+
[layers_num, layers_num + 1, layers_num + 1, num_params, num_params],
|
|
91
|
+
dim=-1
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return cls(thicknesses, roughnesses, slds, min_bounds, max_bounds)
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def num_params(self) -> int:
|
|
98
|
+
return self.layers_num2size(self.max_layer_num)
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def size2layers_num(size: int) -> int:
|
|
102
|
+
return (size - 6) // 9
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def layers_num2size(layers_num: int) -> int:
|
|
106
|
+
return layers_num * 9 + 6
|
|
107
|
+
|
|
108
|
+
def scale_with_q(self, q_ratio: float):
|
|
109
|
+
super().scale_with_q(q_ratio)
|
|
110
|
+
|
|
111
|
+
layer_num = self.max_layer_num
|
|
112
|
+
scales = torch.tensor(
|
|
113
|
+
[1 / q_ratio] * (2 * layer_num + 1) + [q_ratio ** 2] * (layer_num + 1),
|
|
114
|
+
device=self.device, dtype=self.dtype
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.min_bounds *= scales
|
|
118
|
+
self.max_bounds *= scales
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class UniformSubPriorSampler(BasicPriorSampler):
|
|
122
|
+
"""Prior sampler for thicknesses, roughnesses, slds and their subprior bounds
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
thickness_range (Tuple[float, float], optional): the range of the layer thicknesses. Defaults to DEFAULT_THICKNESS_RANGE.
|
|
126
|
+
roughness_range (Tuple[float, float], optional): the range of the interlayer roughnesses. Defaults to DEFAULT_ROUGHNESS_RANGE.
|
|
127
|
+
sld_range (Tuple[float, float], optional): the range of the layer SLDs. Defaults to DEFAULT_SLD_RANGE.
|
|
128
|
+
num_layers (int, optional): the number of layers. Defaults to DEFAULT_NUM_LAYERS.
|
|
129
|
+
use_drho (bool, optional): whether to use differences in SLD values between neighboring layers instead of the actual SLD values. Defaults to DEFAULT_USE_DRHO.
|
|
130
|
+
device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
131
|
+
dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
132
|
+
scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to DEFAULT_SCALED_RANGE.
|
|
133
|
+
scale_by_subpriors (bool, optional): if True the film parameters are scaled with respect to their subprior bounds. Defaults to False.
|
|
134
|
+
smaller_roughnesses (bool, optional): if True the sampled roughnesses are biased towards smaller values. Defaults to False.
|
|
135
|
+
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
136
|
+
relative_min_bound_width (float, optional): defines the interval [relative_min_bound_width, 1.0] from which the relative bound widths for each parameter are sampled. Defaults to 1e-2.
|
|
137
|
+
"""
|
|
138
|
+
PARAM_CLS = UniformSubPriorParams
|
|
139
|
+
|
|
140
|
+
def __init__(self,
|
|
141
|
+
thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
|
|
142
|
+
roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
|
|
143
|
+
sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
|
|
144
|
+
num_layers: int = DEFAULT_NUM_LAYERS,
|
|
145
|
+
use_drho: bool = DEFAULT_USE_DRHO,
|
|
146
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
147
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
148
|
+
scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
|
|
149
|
+
scale_by_subpriors: bool = False,
|
|
150
|
+
smaller_roughnesses: bool = False,
|
|
151
|
+
logdist: bool = False,
|
|
152
|
+
relative_min_bound_width: float = 1e-2,
|
|
153
|
+
):
|
|
154
|
+
super().__init__(
|
|
155
|
+
thickness_range,
|
|
156
|
+
roughness_range,
|
|
157
|
+
sld_range,
|
|
158
|
+
num_layers,
|
|
159
|
+
use_drho,
|
|
160
|
+
device,
|
|
161
|
+
dtype,
|
|
162
|
+
scaled_range,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
self.scale_by_subpriors = scale_by_subpriors
|
|
166
|
+
self.smaller_roughnesses = smaller_roughnesses
|
|
167
|
+
self.logdist = logdist
|
|
168
|
+
self.relative_min_bound_width = relative_min_bound_width
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def max_num_layers(self) -> int:
|
|
172
|
+
return self.num_layers
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def param_dim(self) -> int:
|
|
176
|
+
return self.max_num_layers * 3 + 2
|
|
177
|
+
|
|
178
|
+
@lru_cache()
|
|
179
|
+
def min_vector(self, layers_num, drho: bool = False):
|
|
180
|
+
min_vector = super().min_vector(layers_num, drho)
|
|
181
|
+
min_vector = torch.cat([min_vector, min_vector, min_vector], dim=0)
|
|
182
|
+
return min_vector
|
|
183
|
+
|
|
184
|
+
def scale_params(self, params: UniformSubPriorParams) -> Tensor:
|
|
185
|
+
scaled_params = super().scale_params(params)
|
|
186
|
+
|
|
187
|
+
if self.scale_by_subpriors:
|
|
188
|
+
params_t = params.as_tensor(use_drho=self.use_drho, add_bounds=False)
|
|
189
|
+
scaled_params[:, :self.param_dim] = self._scale(params_t, params.min_bounds, params.max_bounds)
|
|
190
|
+
|
|
191
|
+
return scaled_params
|
|
192
|
+
|
|
193
|
+
def restore_params(self, scaled_params: Tensor) -> Params:
|
|
194
|
+
if not self.scale_by_subpriors:
|
|
195
|
+
return super().restore_params(scaled_params)
|
|
196
|
+
|
|
197
|
+
scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
|
|
198
|
+
scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
min_vector = super().min_vector(self.max_num_layers, self.use_drho)
|
|
202
|
+
max_vector = super().max_vector(self.max_num_layers, self.use_drho)
|
|
203
|
+
|
|
204
|
+
min_bounds = self._restore(scaled_min_bounds, min_vector, max_vector)
|
|
205
|
+
max_bounds = self._restore(scaled_max_bounds, min_vector, max_vector)
|
|
206
|
+
|
|
207
|
+
param_t = self._restore(scaled_params, min_bounds, max_bounds)
|
|
208
|
+
param_t = torch.cat([param_t, min_bounds, max_bounds], dim=-1)
|
|
209
|
+
|
|
210
|
+
params = UniformSubPriorParams.from_tensor(param_t)
|
|
211
|
+
|
|
212
|
+
if self.use_drho:
|
|
213
|
+
params.slds = get_slds_from_d_rhos(params.slds)
|
|
214
|
+
return params
|
|
215
|
+
|
|
216
|
+
@lru_cache()
|
|
217
|
+
def max_vector(self, layers_num, drho: bool = False):
|
|
218
|
+
max_vector = super().max_vector(layers_num, drho)
|
|
219
|
+
max_vector = torch.cat([max_vector, max_vector, max_vector], dim=0)
|
|
220
|
+
return max_vector
|
|
221
|
+
|
|
222
|
+
@lru_cache()
|
|
223
|
+
def delta_vector(self, layers_num, drho: bool = False):
|
|
224
|
+
delta_vector = self.max_vector(layers_num, drho) - self.min_vector(layers_num, drho)
|
|
225
|
+
delta_vector[delta_vector == 0.] = 1.
|
|
226
|
+
return delta_vector
|
|
227
|
+
|
|
228
|
+
def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
|
|
229
|
+
t_params = torch.cat([
|
|
230
|
+
params.thicknesses,
|
|
231
|
+
params.roughnesses,
|
|
232
|
+
params.slds
|
|
233
|
+
], dim=-1)
|
|
234
|
+
|
|
235
|
+
indices = (
|
|
236
|
+
torch.all(t_params >= params.min_bounds, dim=-1) &
|
|
237
|
+
torch.all(t_params <= params.max_bounds, dim=-1)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return indices
|
|
241
|
+
|
|
242
|
+
def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
|
|
243
|
+
params = UniformSubPriorParams.from_tensor(
|
|
244
|
+
torch.cat([
|
|
245
|
+
torch.clamp(
|
|
246
|
+
params.as_tensor(add_bounds=False),
|
|
247
|
+
params.min_bounds, params.max_bounds
|
|
248
|
+
),
|
|
249
|
+
params.min_bounds, params.max_bounds
|
|
250
|
+
], dim=1)
|
|
251
|
+
)
|
|
252
|
+
return params
|
|
253
|
+
|
|
254
|
+
def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
|
|
255
|
+
return self.get_indices_within_bounds(params)
|
|
256
|
+
|
|
257
|
+
def sample(self, batch_size: int) -> UniformSubPriorParams:
|
|
258
|
+
min_bounds, max_bounds = self.sample_bounds(batch_size)
|
|
259
|
+
|
|
260
|
+
params = torch.rand(
|
|
261
|
+
*min_bounds.shape,
|
|
262
|
+
device=self.device,
|
|
263
|
+
dtype=self.dtype
|
|
264
|
+
) * (max_bounds - min_bounds) + min_bounds
|
|
265
|
+
|
|
266
|
+
thicknesses, roughnesses, slds = torch.split(
|
|
267
|
+
params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
|
|
271
|
+
|
|
272
|
+
return params
|
|
273
|
+
|
|
274
|
+
def sample_bounds(self, batch_size: int):
|
|
275
|
+
min_vector, max_vector = (
|
|
276
|
+
super().min_vector(self.num_layers)[None],
|
|
277
|
+
super().max_vector(self.num_layers)[None]
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
delta_vector = max_vector - min_vector
|
|
281
|
+
|
|
282
|
+
if self.logdist:
|
|
283
|
+
widths_sampler_func = logdist_sampler
|
|
284
|
+
else:
|
|
285
|
+
widths_sampler_func = uniform_sampler
|
|
286
|
+
|
|
287
|
+
prior_widths = widths_sampler_func(
|
|
288
|
+
self.relative_min_bound_width, 1.,
|
|
289
|
+
batch_size, delta_vector.shape[1],
|
|
290
|
+
device=self.device, dtype=self.dtype
|
|
291
|
+
) * delta_vector
|
|
292
|
+
|
|
293
|
+
prior_centers = uniform_sampler(
|
|
294
|
+
min_vector + prior_widths / 2, max_vector - prior_widths / 2,
|
|
295
|
+
*prior_widths.shape,
|
|
296
|
+
device=self.device, dtype=self.dtype
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if self.smaller_roughnesses:
|
|
300
|
+
idx_min, idx_max = self.num_layers, self.num_layers * 2 + 1
|
|
301
|
+
prior_centers[:, idx_min:idx_max] = triangular_sampler(
|
|
302
|
+
min_vector[:, idx_min:idx_max] + prior_widths[:, idx_min:idx_max] / 2,
|
|
303
|
+
max_vector[:, idx_min:idx_max] - prior_widths[:, idx_min:idx_max] / 2,
|
|
304
|
+
batch_size, self.num_layers + 1,
|
|
305
|
+
device=self.device, dtype=self.dtype
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
|
|
309
|
+
|
|
310
|
+
return min_bounds, max_bounds
|
|
311
|
+
|
|
312
|
+
def scale_bounds(self, bounds: Tensor) -> Tensor:
|
|
313
|
+
layers_num = bounds.shape[-1] // 2
|
|
314
|
+
|
|
315
|
+
return self._scale(
|
|
316
|
+
bounds,
|
|
317
|
+
self.min_vector(layers_num, drho=self.use_drho).to(bounds),
|
|
318
|
+
self.max_vector(layers_num, drho=self.use_drho).to(bounds),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class NarrowSldUniformSubPriorSampler(UniformSubPriorSampler):
|
|
323
|
+
"""Prior sampler for thicknesses, roughnesses, slds and their subprior bounds. The subprior bound widths for SLDs are restricted to be lower than a specified value. """
|
|
324
|
+
def __init__(self,
|
|
325
|
+
thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
|
|
326
|
+
roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
|
|
327
|
+
sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
|
|
328
|
+
num_layers: int = DEFAULT_NUM_LAYERS,
|
|
329
|
+
use_drho: bool = DEFAULT_USE_DRHO,
|
|
330
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
331
|
+
dtype: torch.dtype = DEFAULT_DTYPE,
|
|
332
|
+
scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
|
|
333
|
+
scale_by_subpriors: bool = False,
|
|
334
|
+
max_sld_prior_width: float = 10.,
|
|
335
|
+
):
|
|
336
|
+
super().__init__(
|
|
337
|
+
thickness_range,
|
|
338
|
+
roughness_range,
|
|
339
|
+
sld_range,
|
|
340
|
+
num_layers,
|
|
341
|
+
use_drho,
|
|
342
|
+
device,
|
|
343
|
+
dtype,
|
|
344
|
+
scaled_range,
|
|
345
|
+
scale_by_subpriors,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
self.max_sld_prior_width = max_sld_prior_width
|
|
349
|
+
|
|
350
|
+
def sample_bounds(self, batch_size: int):
|
|
351
|
+
min_vector, max_vector = (
|
|
352
|
+
BasicPriorSampler.min_vector(self, self.num_layers),
|
|
353
|
+
BasicPriorSampler.max_vector(self, self.num_layers),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
delta_vector = max_vector - min_vector
|
|
357
|
+
delta_vector[-self.num_layers:] = self.max_sld_prior_width
|
|
358
|
+
|
|
359
|
+
prior_widths = uniform_sampler(
|
|
360
|
+
delta_vector * self.relative_min_bound_width, delta_vector,
|
|
361
|
+
batch_size, min_vector.shape[0],
|
|
362
|
+
device=self.device, dtype=self.dtype
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
prior_centers = uniform_sampler(
|
|
366
|
+
min_vector + prior_widths / 2, max_vector - prior_widths / 2,
|
|
367
|
+
*prior_widths.shape,
|
|
368
|
+
device=self.device, dtype=self.dtype
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
return prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.utils import (
|
|
7
|
+
get_d_rhos,
|
|
8
|
+
uniform_sampler,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_max_allowed_roughness(thicknesses: Tensor, mask: Tensor = None, coef: float = 0.5):
|
|
13
|
+
"""gets the maximum allowed interlayer roughnesses such that they do not exceed a fraction of the thickness of either layers meeting at that interface"""
|
|
14
|
+
batch_size, layers_num = thicknesses.shape
|
|
15
|
+
max_roughness = torch.ones(
|
|
16
|
+
batch_size, layers_num + 1, device=thicknesses.device, dtype=thicknesses.dtype
|
|
17
|
+
) * float('inf')
|
|
18
|
+
|
|
19
|
+
boundary = thicknesses * coef
|
|
20
|
+
if mask is not None:
|
|
21
|
+
boundary[get_thickness_mask_from_sld_mask(mask)] = float('inf')
|
|
22
|
+
|
|
23
|
+
max_roughness[:, :-1] = boundary
|
|
24
|
+
max_roughness[:, 1:] = torch.minimum(max_roughness[:, 1:], boundary)
|
|
25
|
+
return max_roughness
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_allowed_contrast_indices(slds: Tensor, min_contrast: float, mask: Tensor = None) -> Tensor:
|
|
29
|
+
d_rhos = get_d_rhos(slds)
|
|
30
|
+
indices = d_rhos.abs() >= min_contrast
|
|
31
|
+
if mask is not None:
|
|
32
|
+
indices = indices | mask
|
|
33
|
+
indices = torch.all(indices, -1)
|
|
34
|
+
return indices
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def params_within_bounds(params_t: Tensor, min_t: Tensor, max_t: Tensor, mask: Tensor = None) -> Tensor:
|
|
38
|
+
indices = (params_t >= min_t[None]) & (params_t <= max_t[None])
|
|
39
|
+
if mask is not None:
|
|
40
|
+
indices = indices | mask
|
|
41
|
+
indices = torch.all(indices, -1)
|
|
42
|
+
return indices
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_allowed_roughness_indices(thicknesses: Tensor, roughnesses: Tensor, mask: Tensor = None) -> Tensor:
|
|
46
|
+
max_roughness = get_max_allowed_roughness(thicknesses, mask)
|
|
47
|
+
indices = roughnesses <= max_roughness
|
|
48
|
+
if mask is not None:
|
|
49
|
+
indices = indices | mask
|
|
50
|
+
indices = torch.all(indices, -1)
|
|
51
|
+
return indices
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_thickness_mask_from_sld_mask(mask: Tensor):
|
|
55
|
+
return mask[:, :-1]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def generate_roughnesses(thicknesses: Tensor, roughness_range: Tuple[float, float], mask: Tensor = None):
|
|
59
|
+
batch_size, layers_num = thicknesses.shape
|
|
60
|
+
max_roughness = get_max_allowed_roughness(thicknesses, mask)
|
|
61
|
+
max_roughness = torch.clamp_(max_roughness, max=roughness_range[1])
|
|
62
|
+
|
|
63
|
+
roughnesses = uniform_sampler(
|
|
64
|
+
roughness_range[0], max_roughness, batch_size, layers_num + 1,
|
|
65
|
+
device=thicknesses.device, dtype=thicknesses.dtype
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if mask is not None:
|
|
69
|
+
roughnesses[mask] = 0.
|
|
70
|
+
|
|
71
|
+
return roughnesses
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def generate_thicknesses(
|
|
75
|
+
thickness_range: Tuple[float, float],
|
|
76
|
+
batch_size: int,
|
|
77
|
+
layers_num: int,
|
|
78
|
+
device: torch.device,
|
|
79
|
+
dtype: torch.dtype,
|
|
80
|
+
mask: Tensor = None
|
|
81
|
+
):
|
|
82
|
+
thicknesses = uniform_sampler(
|
|
83
|
+
*thickness_range, batch_size, layers_num, device=device, dtype=dtype
|
|
84
|
+
)
|
|
85
|
+
if mask is not None:
|
|
86
|
+
thicknesses[get_thickness_mask_from_sld_mask(mask)] = 0.
|
|
87
|
+
return thicknesses
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def generate_slds_with_min_contrast(
|
|
91
|
+
sld_range: Tuple[float, float],
|
|
92
|
+
batch_size: int,
|
|
93
|
+
layers_num: int,
|
|
94
|
+
min_contrast: float,
|
|
95
|
+
device: torch.device,
|
|
96
|
+
dtype: torch.dtype,
|
|
97
|
+
mask: Tensor = None,
|
|
98
|
+
*,
|
|
99
|
+
_depth: int = 0
|
|
100
|
+
):
|
|
101
|
+
# rejection sampling
|
|
102
|
+
slds = uniform_sampler(
|
|
103
|
+
*sld_range, batch_size, layers_num + 1, device=device, dtype=dtype
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if mask is not None:
|
|
107
|
+
slds[mask] = 0.
|
|
108
|
+
|
|
109
|
+
rejected_indices = ~get_allowed_contrast_indices(slds, min_contrast, mask)
|
|
110
|
+
rejected_num = rejected_indices.sum(0).item()
|
|
111
|
+
|
|
112
|
+
if rejected_num:
|
|
113
|
+
if mask is not None:
|
|
114
|
+
mask = mask[rejected_indices]
|
|
115
|
+
slds[rejected_indices] = generate_slds_with_min_contrast(
|
|
116
|
+
sld_range, rejected_num, layers_num, min_contrast, device, dtype, mask, _depth=_depth+1
|
|
117
|
+
)
|
|
118
|
+
return slds
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"ProcessData",
|
|
5
|
+
"ProcessPipeline",
|
|
6
|
+
]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ProcessData(object):
|
|
10
|
+
def __add__(self, other):
|
|
11
|
+
if isinstance(other, ProcessData):
|
|
12
|
+
return ProcessPipeline(self, other)
|
|
13
|
+
|
|
14
|
+
def apply(self, args: Any, context: dict = None):
|
|
15
|
+
return args
|
|
16
|
+
|
|
17
|
+
def __call__(self, args: Any, context: dict = None):
|
|
18
|
+
return self.apply(args, context)
|
|
19
|
+
|
|
20
|
+
def __repr__(self):
|
|
21
|
+
return f'{self.__class__.__name__}()'
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ProcessPipeline(ProcessData):
|
|
25
|
+
def __init__(self, *processes):
|
|
26
|
+
self._processes = list(processes)
|
|
27
|
+
|
|
28
|
+
def apply(self, args: Any, context: dict = None):
|
|
29
|
+
for process in self._processes:
|
|
30
|
+
args = process(args, context)
|
|
31
|
+
return args
|
|
32
|
+
|
|
33
|
+
def __add__(self, other):
|
|
34
|
+
if isinstance(other, ProcessPipeline):
|
|
35
|
+
return ProcessPipeline(*self._processes, *other._processes)
|
|
36
|
+
elif isinstance(other, ProcessData):
|
|
37
|
+
return ProcessPipeline(*self._processes, other)
|
|
38
|
+
|
|
39
|
+
def __repr__(self):
|
|
40
|
+
processes = ", ".join(repr(p) for p in self._processes)
|
|
41
|
+
return f'ProcessPipeline({processes})'
|