reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,370 +1,370 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import Tensor
|
|
3
|
-
|
|
4
|
-
from reflectorch.data_generation.utils import (
|
|
5
|
-
uniform_sampler,
|
|
6
|
-
logdist_sampler,
|
|
7
|
-
)
|
|
8
|
-
|
|
9
|
-
from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class SamplerStrategy(object):
|
|
13
|
-
"""Base class for sampler strategies"""
|
|
14
|
-
def sample(self, batch_size: int,
|
|
15
|
-
total_min_bounds: Tensor,
|
|
16
|
-
total_max_bounds: Tensor,
|
|
17
|
-
total_min_delta: Tensor,
|
|
18
|
-
total_max_delta: Tensor,
|
|
19
|
-
):
|
|
20
|
-
raise NotImplementedError
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class BasicSamplerStrategy(SamplerStrategy):
|
|
24
|
-
"""Sampler strategy with no constraints on the values of the parameters
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
28
|
-
"""
|
|
29
|
-
def __init__(self, logdist: bool = False):
|
|
30
|
-
if logdist:
|
|
31
|
-
self.widths_sampler_func = logdist_sampler
|
|
32
|
-
else:
|
|
33
|
-
self.widths_sampler_func = uniform_sampler
|
|
34
|
-
|
|
35
|
-
def sample(self, batch_size: int,
|
|
36
|
-
total_min_bounds: Tensor,
|
|
37
|
-
total_max_bounds: Tensor,
|
|
38
|
-
total_min_delta: Tensor,
|
|
39
|
-
total_max_delta: Tensor,
|
|
40
|
-
):
|
|
41
|
-
"""
|
|
42
|
-
Args:
|
|
43
|
-
batch_size (int): the batch size
|
|
44
|
-
total_min_bounds (Tensor): mimimum values of the parameters
|
|
45
|
-
total_max_bounds (Tensor): maximum values of the parameters
|
|
46
|
-
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
47
|
-
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
|
|
51
|
-
"""
|
|
52
|
-
return basic_sampler(
|
|
53
|
-
batch_size,
|
|
54
|
-
total_min_bounds,
|
|
55
|
-
total_max_bounds,
|
|
56
|
-
total_min_delta,
|
|
57
|
-
total_max_delta,
|
|
58
|
-
self.widths_sampler_func,
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
63
|
-
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
|
|
67
|
-
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
|
|
68
|
-
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
69
|
-
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
|
|
70
|
-
"""
|
|
71
|
-
def __init__(self,
|
|
72
|
-
thickness_mask: Tensor,
|
|
73
|
-
roughness_mask: Tensor,
|
|
74
|
-
logdist: bool = False,
|
|
75
|
-
max_thickness_share: float = 0.5,
|
|
76
|
-
max_total_thickness: float = None,
|
|
77
|
-
):
|
|
78
|
-
super().__init__(logdist=logdist)
|
|
79
|
-
self.thickness_mask = thickness_mask
|
|
80
|
-
self.roughness_mask = roughness_mask
|
|
81
|
-
self.max_thickness_share = max_thickness_share
|
|
82
|
-
self.max_total_thickness = max_total_thickness
|
|
83
|
-
|
|
84
|
-
def sample(self, batch_size: int,
|
|
85
|
-
total_min_bounds: Tensor,
|
|
86
|
-
total_max_bounds: Tensor,
|
|
87
|
-
total_min_delta: Tensor,
|
|
88
|
-
total_max_delta: Tensor,
|
|
89
|
-
):
|
|
90
|
-
"""
|
|
91
|
-
Args:
|
|
92
|
-
batch_size (int): the batch size
|
|
93
|
-
total_min_bounds (Tensor): mimimum values of the parameters
|
|
94
|
-
total_max_bounds (Tensor): maximum values of the parameters
|
|
95
|
-
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
96
|
-
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
|
|
100
|
-
"""
|
|
101
|
-
device = total_min_bounds.device
|
|
102
|
-
return constrained_roughness_sampler(
|
|
103
|
-
batch_size,
|
|
104
|
-
total_min_bounds,
|
|
105
|
-
total_max_bounds,
|
|
106
|
-
total_min_delta,
|
|
107
|
-
total_max_delta,
|
|
108
|
-
thickness_mask=self.thickness_mask.to(device),
|
|
109
|
-
roughness_mask=self.roughness_mask.to(device),
|
|
110
|
-
widths_sampler_func=self.widths_sampler_func,
|
|
111
|
-
coef_roughness=self.max_thickness_share,
|
|
112
|
-
max_total_thickness=self.max_total_thickness,
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
116
|
-
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
|
|
120
|
-
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
|
|
121
|
-
sld_mask (Tensor): indices in the tensors which correspond to real slds
|
|
122
|
-
isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
|
|
123
|
-
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
124
|
-
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
|
|
125
|
-
max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
|
|
126
|
-
"""
|
|
127
|
-
def __init__(self,
|
|
128
|
-
thickness_mask: Tensor,
|
|
129
|
-
roughness_mask: Tensor,
|
|
130
|
-
sld_mask: Tensor,
|
|
131
|
-
isld_mask: Tensor,
|
|
132
|
-
logdist: bool = False,
|
|
133
|
-
max_thickness_share: float = 0.5,
|
|
134
|
-
max_sld_share: float = 0.2,
|
|
135
|
-
max_total_thickness: float = None,
|
|
136
|
-
):
|
|
137
|
-
super().__init__(logdist=logdist)
|
|
138
|
-
self.thickness_mask = thickness_mask
|
|
139
|
-
self.roughness_mask = roughness_mask
|
|
140
|
-
self.sld_mask = sld_mask
|
|
141
|
-
self.isld_mask = isld_mask
|
|
142
|
-
self.max_thickness_share = max_thickness_share
|
|
143
|
-
self.max_sld_share = max_sld_share
|
|
144
|
-
self.max_total_thickness = max_total_thickness
|
|
145
|
-
|
|
146
|
-
def sample(self, batch_size: int,
|
|
147
|
-
total_min_bounds: Tensor,
|
|
148
|
-
total_max_bounds: Tensor,
|
|
149
|
-
total_min_delta: Tensor,
|
|
150
|
-
total_max_delta: Tensor,
|
|
151
|
-
):
|
|
152
|
-
"""
|
|
153
|
-
Args:
|
|
154
|
-
batch_size (int): the batch size
|
|
155
|
-
total_min_bounds (Tensor): mimimum values of the parameters
|
|
156
|
-
total_max_bounds (Tensor): maximum values of the parameters
|
|
157
|
-
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
158
|
-
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
|
|
162
|
-
"""
|
|
163
|
-
device = total_min_bounds.device
|
|
164
|
-
return constrained_roughness_and_isld_sampler(
|
|
165
|
-
batch_size,
|
|
166
|
-
total_min_bounds,
|
|
167
|
-
total_max_bounds,
|
|
168
|
-
total_min_delta,
|
|
169
|
-
total_max_delta,
|
|
170
|
-
thickness_mask=self.thickness_mask.to(device),
|
|
171
|
-
roughness_mask=self.roughness_mask.to(device),
|
|
172
|
-
sld_mask=self.sld_mask.to(device),
|
|
173
|
-
isld_mask=self.isld_mask.to(device),
|
|
174
|
-
widths_sampler_func=self.widths_sampler_func,
|
|
175
|
-
coef_roughness=self.max_thickness_share,
|
|
176
|
-
coef_isld=self.max_sld_share,
|
|
177
|
-
max_total_thickness=self.max_total_thickness,
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
def basic_sampler(
|
|
181
|
-
batch_size: int,
|
|
182
|
-
total_min_bounds: Tensor,
|
|
183
|
-
total_max_bounds: Tensor,
|
|
184
|
-
total_min_delta: Tensor,
|
|
185
|
-
total_max_delta: Tensor,
|
|
186
|
-
widths_sampler_func,
|
|
187
|
-
):
|
|
188
|
-
|
|
189
|
-
delta_vector = total_max_bounds - total_min_bounds
|
|
190
|
-
|
|
191
|
-
prior_widths = widths_sampler_func(
|
|
192
|
-
total_min_delta, total_max_delta,
|
|
193
|
-
batch_size, delta_vector.shape[1],
|
|
194
|
-
device=total_min_bounds.device, dtype=total_min_bounds.dtype
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
prior_centers = uniform_sampler(
|
|
198
|
-
total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
|
|
199
|
-
*prior_widths.shape,
|
|
200
|
-
device=total_min_bounds.device, dtype=total_min_bounds.dtype
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
|
|
204
|
-
|
|
205
|
-
params = torch.rand(
|
|
206
|
-
*min_bounds.shape,
|
|
207
|
-
device=min_bounds.device,
|
|
208
|
-
dtype=min_bounds.dtype
|
|
209
|
-
) * (max_bounds - min_bounds) + min_bounds
|
|
210
|
-
|
|
211
|
-
return params, min_bounds, max_bounds
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def constrained_roughness_sampler(
|
|
215
|
-
batch_size: int,
|
|
216
|
-
total_min_bounds: Tensor,
|
|
217
|
-
total_max_bounds: Tensor,
|
|
218
|
-
total_min_delta: Tensor,
|
|
219
|
-
total_max_delta: Tensor,
|
|
220
|
-
thickness_mask: Tensor,
|
|
221
|
-
roughness_mask: Tensor,
|
|
222
|
-
widths_sampler_func,
|
|
223
|
-
coef_roughness: float = 0.5,
|
|
224
|
-
max_total_thickness: float = None,
|
|
225
|
-
):
|
|
226
|
-
params, min_bounds, max_bounds = basic_sampler(
|
|
227
|
-
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
228
|
-
widths_sampler_func=widths_sampler_func,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
if max_total_thickness is not None:
|
|
232
|
-
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
233
|
-
indices = total_thickness > max_total_thickness
|
|
234
|
-
|
|
235
|
-
if indices.any():
|
|
236
|
-
eps = 0.01
|
|
237
|
-
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
238
|
-
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
239
|
-
scale_coef[~indices] = 1.0
|
|
240
|
-
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
241
|
-
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
242
|
-
params[:, thickness_mask] *= scale_coef[:, None]
|
|
243
|
-
|
|
244
|
-
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
245
|
-
min_bounds[:, thickness_mask],
|
|
246
|
-
total_min_bounds[:, thickness_mask],
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
250
|
-
max_bounds[:, thickness_mask],
|
|
251
|
-
total_min_bounds[:, thickness_mask],
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
params[:, thickness_mask] = torch.clamp_min(
|
|
255
|
-
params[:, thickness_mask],
|
|
256
|
-
total_min_bounds[:, thickness_mask],
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
max_roughness = torch.minimum(
|
|
260
|
-
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
261
|
-
total_max_bounds[..., roughness_mask]
|
|
262
|
-
)
|
|
263
|
-
min_roughness = total_min_bounds[..., roughness_mask]
|
|
264
|
-
|
|
265
|
-
assert torch.all(min_roughness <= max_roughness)
|
|
266
|
-
|
|
267
|
-
min_roughness_delta = total_min_delta[..., roughness_mask]
|
|
268
|
-
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
|
|
269
|
-
|
|
270
|
-
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
|
|
271
|
-
batch_size, min_roughness, max_roughness,
|
|
272
|
-
min_roughness_delta, max_roughness_delta,
|
|
273
|
-
widths_sampler_func=widths_sampler_func
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
|
|
277
|
-
params[..., roughness_mask] = roughnesses
|
|
278
|
-
|
|
279
|
-
return params, min_bounds, max_bounds
|
|
280
|
-
|
|
281
|
-
def constrained_roughness_and_isld_sampler(
|
|
282
|
-
batch_size: int,
|
|
283
|
-
total_min_bounds: Tensor,
|
|
284
|
-
total_max_bounds: Tensor,
|
|
285
|
-
total_min_delta: Tensor,
|
|
286
|
-
total_max_delta: Tensor,
|
|
287
|
-
thickness_mask: Tensor,
|
|
288
|
-
roughness_mask: Tensor,
|
|
289
|
-
sld_mask: Tensor,
|
|
290
|
-
isld_mask: Tensor,
|
|
291
|
-
widths_sampler_func,
|
|
292
|
-
coef_roughness: float = 0.5,
|
|
293
|
-
coef_isld: float = 0.2,
|
|
294
|
-
max_total_thickness: float = None,
|
|
295
|
-
):
|
|
296
|
-
params, min_bounds, max_bounds = basic_sampler(
|
|
297
|
-
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
298
|
-
widths_sampler_func=widths_sampler_func,
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
if max_total_thickness is not None:
|
|
302
|
-
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
303
|
-
indices = total_thickness > max_total_thickness
|
|
304
|
-
|
|
305
|
-
if indices.any():
|
|
306
|
-
eps = 0.01
|
|
307
|
-
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
308
|
-
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
309
|
-
scale_coef[~indices] = 1.0
|
|
310
|
-
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
311
|
-
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
312
|
-
params[:, thickness_mask] *= scale_coef[:, None]
|
|
313
|
-
|
|
314
|
-
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
315
|
-
min_bounds[:, thickness_mask],
|
|
316
|
-
total_min_bounds[:, thickness_mask],
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
320
|
-
max_bounds[:, thickness_mask],
|
|
321
|
-
total_min_bounds[:, thickness_mask],
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
params[:, thickness_mask] = torch.clamp_min(
|
|
325
|
-
params[:, thickness_mask],
|
|
326
|
-
total_min_bounds[:, thickness_mask],
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
max_roughness = torch.minimum(
|
|
330
|
-
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
331
|
-
total_max_bounds[..., roughness_mask]
|
|
332
|
-
)
|
|
333
|
-
min_roughness = total_min_bounds[..., roughness_mask]
|
|
334
|
-
|
|
335
|
-
assert torch.all(min_roughness <= max_roughness)
|
|
336
|
-
|
|
337
|
-
min_roughness_delta = total_min_delta[..., roughness_mask]
|
|
338
|
-
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
|
|
339
|
-
|
|
340
|
-
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
|
|
341
|
-
batch_size, min_roughness, max_roughness,
|
|
342
|
-
min_roughness_delta, max_roughness_delta,
|
|
343
|
-
widths_sampler_func=widths_sampler_func
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
|
|
347
|
-
params[..., roughness_mask] = roughnesses
|
|
348
|
-
|
|
349
|
-
max_isld = torch.minimum(
|
|
350
|
-
torch.abs(params[..., sld_mask]) * coef_isld,
|
|
351
|
-
total_max_bounds[..., isld_mask]
|
|
352
|
-
)
|
|
353
|
-
min_isld = total_min_bounds[..., isld_mask]
|
|
354
|
-
|
|
355
|
-
assert torch.all(min_isld <= max_isld)
|
|
356
|
-
|
|
357
|
-
min_isld_delta = total_min_delta[..., isld_mask]
|
|
358
|
-
max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
|
|
359
|
-
|
|
360
|
-
islds, min_isld_bounds, max_isld_bounds = basic_sampler(
|
|
361
|
-
batch_size, min_isld, max_isld,
|
|
362
|
-
min_isld_delta, max_isld_delta,
|
|
363
|
-
widths_sampler_func=widths_sampler_func
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
|
|
367
|
-
params[..., isld_mask] = islds
|
|
368
|
-
|
|
369
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.utils import (
|
|
5
|
+
uniform_sampler,
|
|
6
|
+
logdist_sampler,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SamplerStrategy(object):
|
|
13
|
+
"""Base class for sampler strategies"""
|
|
14
|
+
def sample(self, batch_size: int,
|
|
15
|
+
total_min_bounds: Tensor,
|
|
16
|
+
total_max_bounds: Tensor,
|
|
17
|
+
total_min_delta: Tensor,
|
|
18
|
+
total_max_delta: Tensor,
|
|
19
|
+
):
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BasicSamplerStrategy(SamplerStrategy):
|
|
24
|
+
"""Sampler strategy with no constraints on the values of the parameters
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, logdist: bool = False):
|
|
30
|
+
if logdist:
|
|
31
|
+
self.widths_sampler_func = logdist_sampler
|
|
32
|
+
else:
|
|
33
|
+
self.widths_sampler_func = uniform_sampler
|
|
34
|
+
|
|
35
|
+
def sample(self, batch_size: int,
|
|
36
|
+
total_min_bounds: Tensor,
|
|
37
|
+
total_max_bounds: Tensor,
|
|
38
|
+
total_min_delta: Tensor,
|
|
39
|
+
total_max_delta: Tensor,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Args:
|
|
43
|
+
batch_size (int): the batch size
|
|
44
|
+
total_min_bounds (Tensor): mimimum values of the parameters
|
|
45
|
+
total_max_bounds (Tensor): maximum values of the parameters
|
|
46
|
+
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
47
|
+
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
|
|
51
|
+
"""
|
|
52
|
+
return basic_sampler(
|
|
53
|
+
batch_size,
|
|
54
|
+
total_min_bounds,
|
|
55
|
+
total_max_bounds,
|
|
56
|
+
total_min_delta,
|
|
57
|
+
total_max_delta,
|
|
58
|
+
self.widths_sampler_func,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
63
|
+
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
|
|
67
|
+
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
|
|
68
|
+
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
69
|
+
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(self,
|
|
72
|
+
thickness_mask: Tensor,
|
|
73
|
+
roughness_mask: Tensor,
|
|
74
|
+
logdist: bool = False,
|
|
75
|
+
max_thickness_share: float = 0.5,
|
|
76
|
+
max_total_thickness: float = None,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(logdist=logdist)
|
|
79
|
+
self.thickness_mask = thickness_mask
|
|
80
|
+
self.roughness_mask = roughness_mask
|
|
81
|
+
self.max_thickness_share = max_thickness_share
|
|
82
|
+
self.max_total_thickness = max_total_thickness
|
|
83
|
+
|
|
84
|
+
def sample(self, batch_size: int,
|
|
85
|
+
total_min_bounds: Tensor,
|
|
86
|
+
total_max_bounds: Tensor,
|
|
87
|
+
total_min_delta: Tensor,
|
|
88
|
+
total_max_delta: Tensor,
|
|
89
|
+
):
|
|
90
|
+
"""
|
|
91
|
+
Args:
|
|
92
|
+
batch_size (int): the batch size
|
|
93
|
+
total_min_bounds (Tensor): mimimum values of the parameters
|
|
94
|
+
total_max_bounds (Tensor): maximum values of the parameters
|
|
95
|
+
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
96
|
+
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
|
|
100
|
+
"""
|
|
101
|
+
device = total_min_bounds.device
|
|
102
|
+
return constrained_roughness_sampler(
|
|
103
|
+
batch_size,
|
|
104
|
+
total_min_bounds,
|
|
105
|
+
total_max_bounds,
|
|
106
|
+
total_min_delta,
|
|
107
|
+
total_max_delta,
|
|
108
|
+
thickness_mask=self.thickness_mask.to(device),
|
|
109
|
+
roughness_mask=self.roughness_mask.to(device),
|
|
110
|
+
widths_sampler_func=self.widths_sampler_func,
|
|
111
|
+
coef_roughness=self.max_thickness_share,
|
|
112
|
+
max_total_thickness=self.max_total_thickness,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
116
|
+
"""Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
|
|
120
|
+
roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
|
|
121
|
+
sld_mask (Tensor): indices in the tensors which correspond to real slds
|
|
122
|
+
isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
|
|
123
|
+
logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
|
|
124
|
+
max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
|
|
125
|
+
max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
|
|
126
|
+
"""
|
|
127
|
+
def __init__(self,
|
|
128
|
+
thickness_mask: Tensor,
|
|
129
|
+
roughness_mask: Tensor,
|
|
130
|
+
sld_mask: Tensor,
|
|
131
|
+
isld_mask: Tensor,
|
|
132
|
+
logdist: bool = False,
|
|
133
|
+
max_thickness_share: float = 0.5,
|
|
134
|
+
max_sld_share: float = 0.2,
|
|
135
|
+
max_total_thickness: float = None,
|
|
136
|
+
):
|
|
137
|
+
super().__init__(logdist=logdist)
|
|
138
|
+
self.thickness_mask = thickness_mask
|
|
139
|
+
self.roughness_mask = roughness_mask
|
|
140
|
+
self.sld_mask = sld_mask
|
|
141
|
+
self.isld_mask = isld_mask
|
|
142
|
+
self.max_thickness_share = max_thickness_share
|
|
143
|
+
self.max_sld_share = max_sld_share
|
|
144
|
+
self.max_total_thickness = max_total_thickness
|
|
145
|
+
|
|
146
|
+
def sample(self, batch_size: int,
|
|
147
|
+
total_min_bounds: Tensor,
|
|
148
|
+
total_max_bounds: Tensor,
|
|
149
|
+
total_min_delta: Tensor,
|
|
150
|
+
total_max_delta: Tensor,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Args:
|
|
154
|
+
batch_size (int): the batch size
|
|
155
|
+
total_min_bounds (Tensor): mimimum values of the parameters
|
|
156
|
+
total_max_bounds (Tensor): maximum values of the parameters
|
|
157
|
+
total_min_delta (Tensor): minimum widths of the subprior intervals
|
|
158
|
+
total_max_delta (Tensor): maximum widths of the subprior intervals
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
|
|
162
|
+
"""
|
|
163
|
+
device = total_min_bounds.device
|
|
164
|
+
return constrained_roughness_and_isld_sampler(
|
|
165
|
+
batch_size,
|
|
166
|
+
total_min_bounds,
|
|
167
|
+
total_max_bounds,
|
|
168
|
+
total_min_delta,
|
|
169
|
+
total_max_delta,
|
|
170
|
+
thickness_mask=self.thickness_mask.to(device),
|
|
171
|
+
roughness_mask=self.roughness_mask.to(device),
|
|
172
|
+
sld_mask=self.sld_mask.to(device),
|
|
173
|
+
isld_mask=self.isld_mask.to(device),
|
|
174
|
+
widths_sampler_func=self.widths_sampler_func,
|
|
175
|
+
coef_roughness=self.max_thickness_share,
|
|
176
|
+
coef_isld=self.max_sld_share,
|
|
177
|
+
max_total_thickness=self.max_total_thickness,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def basic_sampler(
|
|
181
|
+
batch_size: int,
|
|
182
|
+
total_min_bounds: Tensor,
|
|
183
|
+
total_max_bounds: Tensor,
|
|
184
|
+
total_min_delta: Tensor,
|
|
185
|
+
total_max_delta: Tensor,
|
|
186
|
+
widths_sampler_func,
|
|
187
|
+
):
|
|
188
|
+
|
|
189
|
+
delta_vector = total_max_bounds - total_min_bounds
|
|
190
|
+
|
|
191
|
+
prior_widths = widths_sampler_func(
|
|
192
|
+
total_min_delta, total_max_delta,
|
|
193
|
+
batch_size, delta_vector.shape[1],
|
|
194
|
+
device=total_min_bounds.device, dtype=total_min_bounds.dtype
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
prior_centers = uniform_sampler(
|
|
198
|
+
total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
|
|
199
|
+
*prior_widths.shape,
|
|
200
|
+
device=total_min_bounds.device, dtype=total_min_bounds.dtype
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
|
|
204
|
+
|
|
205
|
+
params = torch.rand(
|
|
206
|
+
*min_bounds.shape,
|
|
207
|
+
device=min_bounds.device,
|
|
208
|
+
dtype=min_bounds.dtype
|
|
209
|
+
) * (max_bounds - min_bounds) + min_bounds
|
|
210
|
+
|
|
211
|
+
return params, min_bounds, max_bounds
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def constrained_roughness_sampler(
|
|
215
|
+
batch_size: int,
|
|
216
|
+
total_min_bounds: Tensor,
|
|
217
|
+
total_max_bounds: Tensor,
|
|
218
|
+
total_min_delta: Tensor,
|
|
219
|
+
total_max_delta: Tensor,
|
|
220
|
+
thickness_mask: Tensor,
|
|
221
|
+
roughness_mask: Tensor,
|
|
222
|
+
widths_sampler_func,
|
|
223
|
+
coef_roughness: float = 0.5,
|
|
224
|
+
max_total_thickness: float = None,
|
|
225
|
+
):
|
|
226
|
+
params, min_bounds, max_bounds = basic_sampler(
|
|
227
|
+
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
228
|
+
widths_sampler_func=widths_sampler_func,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if max_total_thickness is not None:
|
|
232
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
233
|
+
indices = total_thickness > max_total_thickness
|
|
234
|
+
|
|
235
|
+
if indices.any():
|
|
236
|
+
eps = 0.01
|
|
237
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
238
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
239
|
+
scale_coef[~indices] = 1.0
|
|
240
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
241
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
242
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
243
|
+
|
|
244
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
245
|
+
min_bounds[:, thickness_mask],
|
|
246
|
+
total_min_bounds[:, thickness_mask],
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
250
|
+
max_bounds[:, thickness_mask],
|
|
251
|
+
total_min_bounds[:, thickness_mask],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
255
|
+
params[:, thickness_mask],
|
|
256
|
+
total_min_bounds[:, thickness_mask],
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
max_roughness = torch.minimum(
|
|
260
|
+
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
261
|
+
total_max_bounds[..., roughness_mask]
|
|
262
|
+
)
|
|
263
|
+
min_roughness = total_min_bounds[..., roughness_mask]
|
|
264
|
+
|
|
265
|
+
assert torch.all(min_roughness <= max_roughness)
|
|
266
|
+
|
|
267
|
+
min_roughness_delta = total_min_delta[..., roughness_mask]
|
|
268
|
+
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
|
|
269
|
+
|
|
270
|
+
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
|
|
271
|
+
batch_size, min_roughness, max_roughness,
|
|
272
|
+
min_roughness_delta, max_roughness_delta,
|
|
273
|
+
widths_sampler_func=widths_sampler_func
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
|
|
277
|
+
params[..., roughness_mask] = roughnesses
|
|
278
|
+
|
|
279
|
+
return params, min_bounds, max_bounds
|
|
280
|
+
|
|
281
|
+
def constrained_roughness_and_isld_sampler(
|
|
282
|
+
batch_size: int,
|
|
283
|
+
total_min_bounds: Tensor,
|
|
284
|
+
total_max_bounds: Tensor,
|
|
285
|
+
total_min_delta: Tensor,
|
|
286
|
+
total_max_delta: Tensor,
|
|
287
|
+
thickness_mask: Tensor,
|
|
288
|
+
roughness_mask: Tensor,
|
|
289
|
+
sld_mask: Tensor,
|
|
290
|
+
isld_mask: Tensor,
|
|
291
|
+
widths_sampler_func,
|
|
292
|
+
coef_roughness: float = 0.5,
|
|
293
|
+
coef_isld: float = 0.2,
|
|
294
|
+
max_total_thickness: float = None,
|
|
295
|
+
):
|
|
296
|
+
params, min_bounds, max_bounds = basic_sampler(
|
|
297
|
+
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
298
|
+
widths_sampler_func=widths_sampler_func,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if max_total_thickness is not None:
|
|
302
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
303
|
+
indices = total_thickness > max_total_thickness
|
|
304
|
+
|
|
305
|
+
if indices.any():
|
|
306
|
+
eps = 0.01
|
|
307
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
308
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
309
|
+
scale_coef[~indices] = 1.0
|
|
310
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
311
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
312
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
313
|
+
|
|
314
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
315
|
+
min_bounds[:, thickness_mask],
|
|
316
|
+
total_min_bounds[:, thickness_mask],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
320
|
+
max_bounds[:, thickness_mask],
|
|
321
|
+
total_min_bounds[:, thickness_mask],
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
325
|
+
params[:, thickness_mask],
|
|
326
|
+
total_min_bounds[:, thickness_mask],
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
max_roughness = torch.minimum(
|
|
330
|
+
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
331
|
+
total_max_bounds[..., roughness_mask]
|
|
332
|
+
)
|
|
333
|
+
min_roughness = total_min_bounds[..., roughness_mask]
|
|
334
|
+
|
|
335
|
+
assert torch.all(min_roughness <= max_roughness)
|
|
336
|
+
|
|
337
|
+
min_roughness_delta = total_min_delta[..., roughness_mask]
|
|
338
|
+
max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
|
|
339
|
+
|
|
340
|
+
roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
|
|
341
|
+
batch_size, min_roughness, max_roughness,
|
|
342
|
+
min_roughness_delta, max_roughness_delta,
|
|
343
|
+
widths_sampler_func=widths_sampler_func
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
|
|
347
|
+
params[..., roughness_mask] = roughnesses
|
|
348
|
+
|
|
349
|
+
max_isld = torch.minimum(
|
|
350
|
+
torch.abs(params[..., sld_mask]) * coef_isld,
|
|
351
|
+
total_max_bounds[..., isld_mask]
|
|
352
|
+
)
|
|
353
|
+
min_isld = total_min_bounds[..., isld_mask]
|
|
354
|
+
|
|
355
|
+
assert torch.all(min_isld <= max_isld)
|
|
356
|
+
|
|
357
|
+
min_isld_delta = total_min_delta[..., isld_mask]
|
|
358
|
+
max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
|
|
359
|
+
|
|
360
|
+
islds, min_isld_bounds, max_isld_bounds = basic_sampler(
|
|
361
|
+
batch_size, min_isld, max_isld,
|
|
362
|
+
min_isld_delta, max_isld_delta,
|
|
363
|
+
widths_sampler_func=widths_sampler_func
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
|
|
367
|
+
params[..., isld_mask] = islds
|
|
368
|
+
|
|
369
|
+
|
|
370
370
|
return params, min_bounds, max_bounds
|