reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,370 +1,370 @@
1
- import torch
2
- from torch import Tensor
3
-
4
- from reflectorch.data_generation.utils import (
5
- uniform_sampler,
6
- logdist_sampler,
7
- )
8
-
9
- from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
10
-
11
-
12
- class SamplerStrategy(object):
13
- """Base class for sampler strategies"""
14
- def sample(self, batch_size: int,
15
- total_min_bounds: Tensor,
16
- total_max_bounds: Tensor,
17
- total_min_delta: Tensor,
18
- total_max_delta: Tensor,
19
- ):
20
- raise NotImplementedError
21
-
22
-
23
- class BasicSamplerStrategy(SamplerStrategy):
24
- """Sampler strategy with no constraints on the values of the parameters
25
-
26
- Args:
27
- logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
28
- """
29
- def __init__(self, logdist: bool = False):
30
- if logdist:
31
- self.widths_sampler_func = logdist_sampler
32
- else:
33
- self.widths_sampler_func = uniform_sampler
34
-
35
- def sample(self, batch_size: int,
36
- total_min_bounds: Tensor,
37
- total_max_bounds: Tensor,
38
- total_min_delta: Tensor,
39
- total_max_delta: Tensor,
40
- ):
41
- """
42
- Args:
43
- batch_size (int): the batch size
44
- total_min_bounds (Tensor): mimimum values of the parameters
45
- total_max_bounds (Tensor): maximum values of the parameters
46
- total_min_delta (Tensor): minimum widths of the subprior intervals
47
- total_max_delta (Tensor): maximum widths of the subprior intervals
48
-
49
- Returns:
50
- tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
51
- """
52
- return basic_sampler(
53
- batch_size,
54
- total_min_bounds,
55
- total_max_bounds,
56
- total_min_delta,
57
- total_max_delta,
58
- self.widths_sampler_func,
59
- )
60
-
61
-
62
- class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
63
- """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
64
-
65
- Args:
66
- thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
67
- roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
68
- logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
69
- max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
70
- """
71
- def __init__(self,
72
- thickness_mask: Tensor,
73
- roughness_mask: Tensor,
74
- logdist: bool = False,
75
- max_thickness_share: float = 0.5,
76
- max_total_thickness: float = None,
77
- ):
78
- super().__init__(logdist=logdist)
79
- self.thickness_mask = thickness_mask
80
- self.roughness_mask = roughness_mask
81
- self.max_thickness_share = max_thickness_share
82
- self.max_total_thickness = max_total_thickness
83
-
84
- def sample(self, batch_size: int,
85
- total_min_bounds: Tensor,
86
- total_max_bounds: Tensor,
87
- total_min_delta: Tensor,
88
- total_max_delta: Tensor,
89
- ):
90
- """
91
- Args:
92
- batch_size (int): the batch size
93
- total_min_bounds (Tensor): mimimum values of the parameters
94
- total_max_bounds (Tensor): maximum values of the parameters
95
- total_min_delta (Tensor): minimum widths of the subprior intervals
96
- total_max_delta (Tensor): maximum widths of the subprior intervals
97
-
98
- Returns:
99
- tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
100
- """
101
- device = total_min_bounds.device
102
- return constrained_roughness_sampler(
103
- batch_size,
104
- total_min_bounds,
105
- total_max_bounds,
106
- total_min_delta,
107
- total_max_delta,
108
- thickness_mask=self.thickness_mask.to(device),
109
- roughness_mask=self.roughness_mask.to(device),
110
- widths_sampler_func=self.widths_sampler_func,
111
- coef_roughness=self.max_thickness_share,
112
- max_total_thickness=self.max_total_thickness,
113
- )
114
-
115
- class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
116
- """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
117
-
118
- Args:
119
- thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
120
- roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
121
- sld_mask (Tensor): indices in the tensors which correspond to real slds
122
- isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
123
- logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
124
- max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
125
- max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
126
- """
127
- def __init__(self,
128
- thickness_mask: Tensor,
129
- roughness_mask: Tensor,
130
- sld_mask: Tensor,
131
- isld_mask: Tensor,
132
- logdist: bool = False,
133
- max_thickness_share: float = 0.5,
134
- max_sld_share: float = 0.2,
135
- max_total_thickness: float = None,
136
- ):
137
- super().__init__(logdist=logdist)
138
- self.thickness_mask = thickness_mask
139
- self.roughness_mask = roughness_mask
140
- self.sld_mask = sld_mask
141
- self.isld_mask = isld_mask
142
- self.max_thickness_share = max_thickness_share
143
- self.max_sld_share = max_sld_share
144
- self.max_total_thickness = max_total_thickness
145
-
146
- def sample(self, batch_size: int,
147
- total_min_bounds: Tensor,
148
- total_max_bounds: Tensor,
149
- total_min_delta: Tensor,
150
- total_max_delta: Tensor,
151
- ):
152
- """
153
- Args:
154
- batch_size (int): the batch size
155
- total_min_bounds (Tensor): mimimum values of the parameters
156
- total_max_bounds (Tensor): maximum values of the parameters
157
- total_min_delta (Tensor): minimum widths of the subprior intervals
158
- total_max_delta (Tensor): maximum widths of the subprior intervals
159
-
160
- Returns:
161
- tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
162
- """
163
- device = total_min_bounds.device
164
- return constrained_roughness_and_isld_sampler(
165
- batch_size,
166
- total_min_bounds,
167
- total_max_bounds,
168
- total_min_delta,
169
- total_max_delta,
170
- thickness_mask=self.thickness_mask.to(device),
171
- roughness_mask=self.roughness_mask.to(device),
172
- sld_mask=self.sld_mask.to(device),
173
- isld_mask=self.isld_mask.to(device),
174
- widths_sampler_func=self.widths_sampler_func,
175
- coef_roughness=self.max_thickness_share,
176
- coef_isld=self.max_sld_share,
177
- max_total_thickness=self.max_total_thickness,
178
- )
179
-
180
- def basic_sampler(
181
- batch_size: int,
182
- total_min_bounds: Tensor,
183
- total_max_bounds: Tensor,
184
- total_min_delta: Tensor,
185
- total_max_delta: Tensor,
186
- widths_sampler_func,
187
- ):
188
-
189
- delta_vector = total_max_bounds - total_min_bounds
190
-
191
- prior_widths = widths_sampler_func(
192
- total_min_delta, total_max_delta,
193
- batch_size, delta_vector.shape[1],
194
- device=total_min_bounds.device, dtype=total_min_bounds.dtype
195
- )
196
-
197
- prior_centers = uniform_sampler(
198
- total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
199
- *prior_widths.shape,
200
- device=total_min_bounds.device, dtype=total_min_bounds.dtype
201
- )
202
-
203
- min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
204
-
205
- params = torch.rand(
206
- *min_bounds.shape,
207
- device=min_bounds.device,
208
- dtype=min_bounds.dtype
209
- ) * (max_bounds - min_bounds) + min_bounds
210
-
211
- return params, min_bounds, max_bounds
212
-
213
-
214
- def constrained_roughness_sampler(
215
- batch_size: int,
216
- total_min_bounds: Tensor,
217
- total_max_bounds: Tensor,
218
- total_min_delta: Tensor,
219
- total_max_delta: Tensor,
220
- thickness_mask: Tensor,
221
- roughness_mask: Tensor,
222
- widths_sampler_func,
223
- coef_roughness: float = 0.5,
224
- max_total_thickness: float = None,
225
- ):
226
- params, min_bounds, max_bounds = basic_sampler(
227
- batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
228
- widths_sampler_func=widths_sampler_func,
229
- )
230
-
231
- if max_total_thickness is not None:
232
- total_thickness = max_bounds[:, thickness_mask].sum(-1)
233
- indices = total_thickness > max_total_thickness
234
-
235
- if indices.any():
236
- eps = 0.01
237
- rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
238
- scale_coef = max_total_thickness / total_thickness * rand_scale
239
- scale_coef[~indices] = 1.0
240
- min_bounds[:, thickness_mask] *= scale_coef[:, None]
241
- max_bounds[:, thickness_mask] *= scale_coef[:, None]
242
- params[:, thickness_mask] *= scale_coef[:, None]
243
-
244
- min_bounds[:, thickness_mask] = torch.clamp_min(
245
- min_bounds[:, thickness_mask],
246
- total_min_bounds[:, thickness_mask],
247
- )
248
-
249
- max_bounds[:, thickness_mask] = torch.clamp_min(
250
- max_bounds[:, thickness_mask],
251
- total_min_bounds[:, thickness_mask],
252
- )
253
-
254
- params[:, thickness_mask] = torch.clamp_min(
255
- params[:, thickness_mask],
256
- total_min_bounds[:, thickness_mask],
257
- )
258
-
259
- max_roughness = torch.minimum(
260
- get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
261
- total_max_bounds[..., roughness_mask]
262
- )
263
- min_roughness = total_min_bounds[..., roughness_mask]
264
-
265
- assert torch.all(min_roughness <= max_roughness)
266
-
267
- min_roughness_delta = total_min_delta[..., roughness_mask]
268
- max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
269
-
270
- roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
271
- batch_size, min_roughness, max_roughness,
272
- min_roughness_delta, max_roughness_delta,
273
- widths_sampler_func=widths_sampler_func
274
- )
275
-
276
- min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
277
- params[..., roughness_mask] = roughnesses
278
-
279
- return params, min_bounds, max_bounds
280
-
281
- def constrained_roughness_and_isld_sampler(
282
- batch_size: int,
283
- total_min_bounds: Tensor,
284
- total_max_bounds: Tensor,
285
- total_min_delta: Tensor,
286
- total_max_delta: Tensor,
287
- thickness_mask: Tensor,
288
- roughness_mask: Tensor,
289
- sld_mask: Tensor,
290
- isld_mask: Tensor,
291
- widths_sampler_func,
292
- coef_roughness: float = 0.5,
293
- coef_isld: float = 0.2,
294
- max_total_thickness: float = None,
295
- ):
296
- params, min_bounds, max_bounds = basic_sampler(
297
- batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
298
- widths_sampler_func=widths_sampler_func,
299
- )
300
-
301
- if max_total_thickness is not None:
302
- total_thickness = max_bounds[:, thickness_mask].sum(-1)
303
- indices = total_thickness > max_total_thickness
304
-
305
- if indices.any():
306
- eps = 0.01
307
- rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
308
- scale_coef = max_total_thickness / total_thickness * rand_scale
309
- scale_coef[~indices] = 1.0
310
- min_bounds[:, thickness_mask] *= scale_coef[:, None]
311
- max_bounds[:, thickness_mask] *= scale_coef[:, None]
312
- params[:, thickness_mask] *= scale_coef[:, None]
313
-
314
- min_bounds[:, thickness_mask] = torch.clamp_min(
315
- min_bounds[:, thickness_mask],
316
- total_min_bounds[:, thickness_mask],
317
- )
318
-
319
- max_bounds[:, thickness_mask] = torch.clamp_min(
320
- max_bounds[:, thickness_mask],
321
- total_min_bounds[:, thickness_mask],
322
- )
323
-
324
- params[:, thickness_mask] = torch.clamp_min(
325
- params[:, thickness_mask],
326
- total_min_bounds[:, thickness_mask],
327
- )
328
-
329
- max_roughness = torch.minimum(
330
- get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
331
- total_max_bounds[..., roughness_mask]
332
- )
333
- min_roughness = total_min_bounds[..., roughness_mask]
334
-
335
- assert torch.all(min_roughness <= max_roughness)
336
-
337
- min_roughness_delta = total_min_delta[..., roughness_mask]
338
- max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
339
-
340
- roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
341
- batch_size, min_roughness, max_roughness,
342
- min_roughness_delta, max_roughness_delta,
343
- widths_sampler_func=widths_sampler_func
344
- )
345
-
346
- min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
347
- params[..., roughness_mask] = roughnesses
348
-
349
- max_isld = torch.minimum(
350
- torch.abs(params[..., sld_mask]) * coef_isld,
351
- total_max_bounds[..., isld_mask]
352
- )
353
- min_isld = total_min_bounds[..., isld_mask]
354
-
355
- assert torch.all(min_isld <= max_isld)
356
-
357
- min_isld_delta = total_min_delta[..., isld_mask]
358
- max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
359
-
360
- islds, min_isld_bounds, max_isld_bounds = basic_sampler(
361
- batch_size, min_isld, max_isld,
362
- min_isld_delta, max_isld_delta,
363
- widths_sampler_func=widths_sampler_func
364
- )
365
-
366
- min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
367
- params[..., isld_mask] = islds
368
-
369
-
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.utils import (
5
+ uniform_sampler,
6
+ logdist_sampler,
7
+ )
8
+
9
+ from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
10
+
11
+
12
+ class SamplerStrategy(object):
13
+ """Base class for sampler strategies"""
14
+ def sample(self, batch_size: int,
15
+ total_min_bounds: Tensor,
16
+ total_max_bounds: Tensor,
17
+ total_min_delta: Tensor,
18
+ total_max_delta: Tensor,
19
+ ):
20
+ raise NotImplementedError
21
+
22
+
23
+ class BasicSamplerStrategy(SamplerStrategy):
24
+ """Sampler strategy with no constraints on the values of the parameters
25
+
26
+ Args:
27
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
28
+ """
29
+ def __init__(self, logdist: bool = False):
30
+ if logdist:
31
+ self.widths_sampler_func = logdist_sampler
32
+ else:
33
+ self.widths_sampler_func = uniform_sampler
34
+
35
+ def sample(self, batch_size: int,
36
+ total_min_bounds: Tensor,
37
+ total_max_bounds: Tensor,
38
+ total_min_delta: Tensor,
39
+ total_max_delta: Tensor,
40
+ ):
41
+ """
42
+ Args:
43
+ batch_size (int): the batch size
44
+ total_min_bounds (Tensor): mimimum values of the parameters
45
+ total_max_bounds (Tensor): maximum values of the parameters
46
+ total_min_delta (Tensor): minimum widths of the subprior intervals
47
+ total_max_delta (Tensor): maximum widths of the subprior intervals
48
+
49
+ Returns:
50
+ tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
51
+ """
52
+ return basic_sampler(
53
+ batch_size,
54
+ total_min_bounds,
55
+ total_max_bounds,
56
+ total_min_delta,
57
+ total_max_delta,
58
+ self.widths_sampler_func,
59
+ )
60
+
61
+
62
+ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
63
+ """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
64
+
65
+ Args:
66
+ thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
67
+ roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
68
+ logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
69
+ max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
70
+ """
71
+ def __init__(self,
72
+ thickness_mask: Tensor,
73
+ roughness_mask: Tensor,
74
+ logdist: bool = False,
75
+ max_thickness_share: float = 0.5,
76
+ max_total_thickness: float = None,
77
+ ):
78
+ super().__init__(logdist=logdist)
79
+ self.thickness_mask = thickness_mask
80
+ self.roughness_mask = roughness_mask
81
+ self.max_thickness_share = max_thickness_share
82
+ self.max_total_thickness = max_total_thickness
83
+
84
+ def sample(self, batch_size: int,
85
+ total_min_bounds: Tensor,
86
+ total_max_bounds: Tensor,
87
+ total_min_delta: Tensor,
88
+ total_max_delta: Tensor,
89
+ ):
90
+ """
91
+ Args:
92
+ batch_size (int): the batch size
93
+ total_min_bounds (Tensor): mimimum values of the parameters
94
+ total_max_bounds (Tensor): maximum values of the parameters
95
+ total_min_delta (Tensor): minimum widths of the subprior intervals
96
+ total_max_delta (Tensor): maximum widths of the subprior intervals
97
+
98
+ Returns:
99
+ tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
100
+ """
101
+ device = total_min_bounds.device
102
+ return constrained_roughness_sampler(
103
+ batch_size,
104
+ total_min_bounds,
105
+ total_max_bounds,
106
+ total_min_delta,
107
+ total_max_delta,
108
+ thickness_mask=self.thickness_mask.to(device),
109
+ roughness_mask=self.roughness_mask.to(device),
110
+ widths_sampler_func=self.widths_sampler_func,
111
+ coef_roughness=self.max_thickness_share,
112
+ max_total_thickness=self.max_total_thickness,
113
+ )
114
+
115
+ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
116
+ """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
117
+
118
+ Args:
119
+ thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
120
+ roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
121
+ sld_mask (Tensor): indices in the tensors which correspond to real slds
122
+ isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
123
+ logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
124
+ max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
125
+ max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
126
+ """
127
+ def __init__(self,
128
+ thickness_mask: Tensor,
129
+ roughness_mask: Tensor,
130
+ sld_mask: Tensor,
131
+ isld_mask: Tensor,
132
+ logdist: bool = False,
133
+ max_thickness_share: float = 0.5,
134
+ max_sld_share: float = 0.2,
135
+ max_total_thickness: float = None,
136
+ ):
137
+ super().__init__(logdist=logdist)
138
+ self.thickness_mask = thickness_mask
139
+ self.roughness_mask = roughness_mask
140
+ self.sld_mask = sld_mask
141
+ self.isld_mask = isld_mask
142
+ self.max_thickness_share = max_thickness_share
143
+ self.max_sld_share = max_sld_share
144
+ self.max_total_thickness = max_total_thickness
145
+
146
+ def sample(self, batch_size: int,
147
+ total_min_bounds: Tensor,
148
+ total_max_bounds: Tensor,
149
+ total_min_delta: Tensor,
150
+ total_max_delta: Tensor,
151
+ ):
152
+ """
153
+ Args:
154
+ batch_size (int): the batch size
155
+ total_min_bounds (Tensor): mimimum values of the parameters
156
+ total_max_bounds (Tensor): maximum values of the parameters
157
+ total_min_delta (Tensor): minimum widths of the subprior intervals
158
+ total_max_delta (Tensor): maximum widths of the subprior intervals
159
+
160
+ Returns:
161
+ tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
162
+ """
163
+ device = total_min_bounds.device
164
+ return constrained_roughness_and_isld_sampler(
165
+ batch_size,
166
+ total_min_bounds,
167
+ total_max_bounds,
168
+ total_min_delta,
169
+ total_max_delta,
170
+ thickness_mask=self.thickness_mask.to(device),
171
+ roughness_mask=self.roughness_mask.to(device),
172
+ sld_mask=self.sld_mask.to(device),
173
+ isld_mask=self.isld_mask.to(device),
174
+ widths_sampler_func=self.widths_sampler_func,
175
+ coef_roughness=self.max_thickness_share,
176
+ coef_isld=self.max_sld_share,
177
+ max_total_thickness=self.max_total_thickness,
178
+ )
179
+
180
+ def basic_sampler(
181
+ batch_size: int,
182
+ total_min_bounds: Tensor,
183
+ total_max_bounds: Tensor,
184
+ total_min_delta: Tensor,
185
+ total_max_delta: Tensor,
186
+ widths_sampler_func,
187
+ ):
188
+
189
+ delta_vector = total_max_bounds - total_min_bounds
190
+
191
+ prior_widths = widths_sampler_func(
192
+ total_min_delta, total_max_delta,
193
+ batch_size, delta_vector.shape[1],
194
+ device=total_min_bounds.device, dtype=total_min_bounds.dtype
195
+ )
196
+
197
+ prior_centers = uniform_sampler(
198
+ total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
199
+ *prior_widths.shape,
200
+ device=total_min_bounds.device, dtype=total_min_bounds.dtype
201
+ )
202
+
203
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
204
+
205
+ params = torch.rand(
206
+ *min_bounds.shape,
207
+ device=min_bounds.device,
208
+ dtype=min_bounds.dtype
209
+ ) * (max_bounds - min_bounds) + min_bounds
210
+
211
+ return params, min_bounds, max_bounds
212
+
213
+
214
+ def constrained_roughness_sampler(
215
+ batch_size: int,
216
+ total_min_bounds: Tensor,
217
+ total_max_bounds: Tensor,
218
+ total_min_delta: Tensor,
219
+ total_max_delta: Tensor,
220
+ thickness_mask: Tensor,
221
+ roughness_mask: Tensor,
222
+ widths_sampler_func,
223
+ coef_roughness: float = 0.5,
224
+ max_total_thickness: float = None,
225
+ ):
226
+ params, min_bounds, max_bounds = basic_sampler(
227
+ batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
228
+ widths_sampler_func=widths_sampler_func,
229
+ )
230
+
231
+ if max_total_thickness is not None:
232
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
233
+ indices = total_thickness > max_total_thickness
234
+
235
+ if indices.any():
236
+ eps = 0.01
237
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
238
+ scale_coef = max_total_thickness / total_thickness * rand_scale
239
+ scale_coef[~indices] = 1.0
240
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
241
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
242
+ params[:, thickness_mask] *= scale_coef[:, None]
243
+
244
+ min_bounds[:, thickness_mask] = torch.clamp_min(
245
+ min_bounds[:, thickness_mask],
246
+ total_min_bounds[:, thickness_mask],
247
+ )
248
+
249
+ max_bounds[:, thickness_mask] = torch.clamp_min(
250
+ max_bounds[:, thickness_mask],
251
+ total_min_bounds[:, thickness_mask],
252
+ )
253
+
254
+ params[:, thickness_mask] = torch.clamp_min(
255
+ params[:, thickness_mask],
256
+ total_min_bounds[:, thickness_mask],
257
+ )
258
+
259
+ max_roughness = torch.minimum(
260
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
261
+ total_max_bounds[..., roughness_mask]
262
+ )
263
+ min_roughness = total_min_bounds[..., roughness_mask]
264
+
265
+ assert torch.all(min_roughness <= max_roughness)
266
+
267
+ min_roughness_delta = total_min_delta[..., roughness_mask]
268
+ max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
269
+
270
+ roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
271
+ batch_size, min_roughness, max_roughness,
272
+ min_roughness_delta, max_roughness_delta,
273
+ widths_sampler_func=widths_sampler_func
274
+ )
275
+
276
+ min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
277
+ params[..., roughness_mask] = roughnesses
278
+
279
+ return params, min_bounds, max_bounds
280
+
281
+ def constrained_roughness_and_isld_sampler(
282
+ batch_size: int,
283
+ total_min_bounds: Tensor,
284
+ total_max_bounds: Tensor,
285
+ total_min_delta: Tensor,
286
+ total_max_delta: Tensor,
287
+ thickness_mask: Tensor,
288
+ roughness_mask: Tensor,
289
+ sld_mask: Tensor,
290
+ isld_mask: Tensor,
291
+ widths_sampler_func,
292
+ coef_roughness: float = 0.5,
293
+ coef_isld: float = 0.2,
294
+ max_total_thickness: float = None,
295
+ ):
296
+ params, min_bounds, max_bounds = basic_sampler(
297
+ batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
298
+ widths_sampler_func=widths_sampler_func,
299
+ )
300
+
301
+ if max_total_thickness is not None:
302
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
303
+ indices = total_thickness > max_total_thickness
304
+
305
+ if indices.any():
306
+ eps = 0.01
307
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
308
+ scale_coef = max_total_thickness / total_thickness * rand_scale
309
+ scale_coef[~indices] = 1.0
310
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
311
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
312
+ params[:, thickness_mask] *= scale_coef[:, None]
313
+
314
+ min_bounds[:, thickness_mask] = torch.clamp_min(
315
+ min_bounds[:, thickness_mask],
316
+ total_min_bounds[:, thickness_mask],
317
+ )
318
+
319
+ max_bounds[:, thickness_mask] = torch.clamp_min(
320
+ max_bounds[:, thickness_mask],
321
+ total_min_bounds[:, thickness_mask],
322
+ )
323
+
324
+ params[:, thickness_mask] = torch.clamp_min(
325
+ params[:, thickness_mask],
326
+ total_min_bounds[:, thickness_mask],
327
+ )
328
+
329
+ max_roughness = torch.minimum(
330
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
331
+ total_max_bounds[..., roughness_mask]
332
+ )
333
+ min_roughness = total_min_bounds[..., roughness_mask]
334
+
335
+ assert torch.all(min_roughness <= max_roughness)
336
+
337
+ min_roughness_delta = total_min_delta[..., roughness_mask]
338
+ max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
339
+
340
+ roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
341
+ batch_size, min_roughness, max_roughness,
342
+ min_roughness_delta, max_roughness_delta,
343
+ widths_sampler_func=widths_sampler_func
344
+ )
345
+
346
+ min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
347
+ params[..., roughness_mask] = roughnesses
348
+
349
+ max_isld = torch.minimum(
350
+ torch.abs(params[..., sld_mask]) * coef_isld,
351
+ total_max_bounds[..., isld_mask]
352
+ )
353
+ min_isld = total_min_bounds[..., isld_mask]
354
+
355
+ assert torch.all(min_isld <= max_isld)
356
+
357
+ min_isld_delta = total_min_delta[..., isld_mask]
358
+ max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
359
+
360
+ islds, min_isld_bounds, max_isld_bounds = basic_sampler(
361
+ batch_size, min_isld, max_isld,
362
+ min_isld_delta, max_isld_delta,
363
+ widths_sampler_func=widths_sampler_func
364
+ )
365
+
366
+ min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
367
+ params[..., isld_mask] = islds
368
+
369
+
370
370
  return params, min_bounds, max_bounds