reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,371 @@
1
+ from functools import lru_cache
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.utils import (
8
+ uniform_sampler,
9
+ logdist_sampler,
10
+ triangular_sampler,
11
+ get_slds_from_d_rhos,
12
+ )
13
+
14
+ from reflectorch.data_generation.priors.params import Params
15
+ from reflectorch.data_generation.priors.no_constraints import (
16
+ BasicPriorSampler,
17
+ DEFAULT_ROUGHNESS_RANGE,
18
+ DEFAULT_THICKNESS_RANGE,
19
+ DEFAULT_SLD_RANGE,
20
+ DEFAULT_NUM_LAYERS,
21
+ DEFAULT_DEVICE,
22
+ DEFAULT_DTYPE,
23
+ DEFAULT_SCALED_RANGE,
24
+ DEFAULT_USE_DRHO,
25
+ )
26
+
27
+
28
+ class UniformSubPriorParams(Params):
29
+ """Parameters class for thicknesses, roughnesses and slds, together with their subprior bounds."""
30
+ __slots__ = ('thicknesses', 'roughnesses', 'slds', 'min_bounds', 'max_bounds')
31
+ PARAM_NAMES = __slots__
32
+
33
+ def __init__(self,
34
+ thicknesses: Tensor,
35
+ roughnesses: Tensor,
36
+ slds: Tensor,
37
+ min_bounds: Tensor,
38
+ max_bounds: Tensor,
39
+ ):
40
+ super().__init__(thicknesses, roughnesses, slds)
41
+ self.min_bounds = min_bounds
42
+ self.max_bounds = max_bounds
43
+
44
+ @staticmethod
45
+ def rearrange_context_from_params(
46
+ scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
47
+ ):
48
+ if inference:
49
+ if from_params:
50
+ num_params = scaled_params.shape[1] // 3
51
+ scaled_params = scaled_params[:, num_params:]
52
+ context = torch.cat([context, scaled_params], dim=-1)
53
+ return context
54
+
55
+ num_params = scaled_params.shape[1] // 3
56
+ assert num_params * 3 == scaled_params.shape[1]
57
+ scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
58
+ context = torch.cat([context, bound_context], dim=-1)
59
+ return scaled_params, context
60
+
61
+ @staticmethod
62
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
63
+ num_params = scaled_params.shape[-1]
64
+ scaled_bounds = context[:, -2 * num_params:]
65
+ scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
66
+ return scaled_params
67
+
68
+ @staticmethod
69
+ def input_context_split(t_params):
70
+ num_params = t_params.shape[1] // 3
71
+ return torch.split(t_params, [num_params, 2 * num_params])
72
+
73
+ def as_tensor(self, use_drho: bool = False, add_bounds: bool = True) -> Tensor:
74
+ t_list = [self.thicknesses, self.roughnesses]
75
+ if use_drho:
76
+ t_list.append(self.d_rhos)
77
+ else:
78
+ t_list.append(self.slds)
79
+ if add_bounds:
80
+ t_list += [self.min_bounds, self.max_bounds]
81
+ return torch.cat(t_list, -1)
82
+
83
+ @classmethod
84
+ def from_tensor(cls, params: Tensor):
85
+ layers_num = (params.shape[-1] - 6) // 9
86
+ num_params = 3 * layers_num + 2
87
+
88
+ thicknesses, roughnesses, slds, min_bounds, max_bounds = torch.split(
89
+ params,
90
+ [layers_num, layers_num + 1, layers_num + 1, num_params, num_params],
91
+ dim=-1
92
+ )
93
+
94
+ return cls(thicknesses, roughnesses, slds, min_bounds, max_bounds)
95
+
96
+ @property
97
+ def num_params(self) -> int:
98
+ return self.layers_num2size(self.max_layer_num)
99
+
100
+ @staticmethod
101
+ def size2layers_num(size: int) -> int:
102
+ return (size - 6) // 9
103
+
104
+ @staticmethod
105
+ def layers_num2size(layers_num: int) -> int:
106
+ return layers_num * 9 + 6
107
+
108
+ def scale_with_q(self, q_ratio: float):
109
+ super().scale_with_q(q_ratio)
110
+
111
+ layer_num = self.max_layer_num
112
+ scales = torch.tensor(
113
+ [1 / q_ratio] * (2 * layer_num + 1) + [q_ratio ** 2] * (layer_num + 1),
114
+ device=self.device, dtype=self.dtype
115
+ )
116
+
117
+ self.min_bounds *= scales
118
+ self.max_bounds *= scales
119
+
120
+
121
+ class UniformSubPriorSampler(BasicPriorSampler):
122
+ """Prior sampler for thicknesses, roughnesses, slds and their subprior bounds
123
+
124
+ Args:
125
+ thickness_range (Tuple[float, float], optional): the range of the layer thicknesses. Defaults to DEFAULT_THICKNESS_RANGE.
126
+ roughness_range (Tuple[float, float], optional): the range of the interlayer roughnesses. Defaults to DEFAULT_ROUGHNESS_RANGE.
127
+ sld_range (Tuple[float, float], optional): the range of the layer SLDs. Defaults to DEFAULT_SLD_RANGE.
128
+ num_layers (int, optional): the number of layers. Defaults to DEFAULT_NUM_LAYERS.
129
+ use_drho (bool, optional): whether to use differences in SLD values between neighboring layers instead of the actual SLD values. Defaults to DEFAULT_USE_DRHO.
130
+ device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
131
+ dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
132
+ scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to DEFAULT_SCALED_RANGE.
133
+ scale_by_subpriors (bool, optional): if True the film parameters are scaled with respect to their subprior bounds. Defaults to False.
134
+ smaller_roughnesses (bool, optional): if True the sampled roughnesses are biased towards smaller values. Defaults to False.
135
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
136
+ relative_min_bound_width (float, optional): defines the interval [relative_min_bound_width, 1.0] from which the relative bound widths for each parameter are sampled. Defaults to 1e-2.
137
+ """
138
+ PARAM_CLS = UniformSubPriorParams
139
+
140
+ def __init__(self,
141
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
142
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
143
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
144
+ num_layers: int = DEFAULT_NUM_LAYERS,
145
+ use_drho: bool = DEFAULT_USE_DRHO,
146
+ device: torch.device = DEFAULT_DEVICE,
147
+ dtype: torch.dtype = DEFAULT_DTYPE,
148
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
149
+ scale_by_subpriors: bool = False,
150
+ smaller_roughnesses: bool = False,
151
+ logdist: bool = False,
152
+ relative_min_bound_width: float = 1e-2,
153
+ ):
154
+ super().__init__(
155
+ thickness_range,
156
+ roughness_range,
157
+ sld_range,
158
+ num_layers,
159
+ use_drho,
160
+ device,
161
+ dtype,
162
+ scaled_range,
163
+ )
164
+
165
+ self.scale_by_subpriors = scale_by_subpriors
166
+ self.smaller_roughnesses = smaller_roughnesses
167
+ self.logdist = logdist
168
+ self.relative_min_bound_width = relative_min_bound_width
169
+
170
+ @property
171
+ def max_num_layers(self) -> int:
172
+ return self.num_layers
173
+
174
+ @property
175
+ def param_dim(self) -> int:
176
+ return self.max_num_layers * 3 + 2
177
+
178
+ @lru_cache()
179
+ def min_vector(self, layers_num, drho: bool = False):
180
+ min_vector = super().min_vector(layers_num, drho)
181
+ min_vector = torch.cat([min_vector, min_vector, min_vector], dim=0)
182
+ return min_vector
183
+
184
+ def scale_params(self, params: UniformSubPriorParams) -> Tensor:
185
+ scaled_params = super().scale_params(params)
186
+
187
+ if self.scale_by_subpriors:
188
+ params_t = params.as_tensor(use_drho=self.use_drho, add_bounds=False)
189
+ scaled_params[:, :self.param_dim] = self._scale(params_t, params.min_bounds, params.max_bounds)
190
+
191
+ return scaled_params
192
+
193
+ def restore_params(self, scaled_params: Tensor) -> Params:
194
+ if not self.scale_by_subpriors:
195
+ return super().restore_params(scaled_params)
196
+
197
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
198
+ scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
199
+ )
200
+
201
+ min_vector = super().min_vector(self.max_num_layers, self.use_drho)
202
+ max_vector = super().max_vector(self.max_num_layers, self.use_drho)
203
+
204
+ min_bounds = self._restore(scaled_min_bounds, min_vector, max_vector)
205
+ max_bounds = self._restore(scaled_max_bounds, min_vector, max_vector)
206
+
207
+ param_t = self._restore(scaled_params, min_bounds, max_bounds)
208
+ param_t = torch.cat([param_t, min_bounds, max_bounds], dim=-1)
209
+
210
+ params = UniformSubPriorParams.from_tensor(param_t)
211
+
212
+ if self.use_drho:
213
+ params.slds = get_slds_from_d_rhos(params.slds)
214
+ return params
215
+
216
+ @lru_cache()
217
+ def max_vector(self, layers_num, drho: bool = False):
218
+ max_vector = super().max_vector(layers_num, drho)
219
+ max_vector = torch.cat([max_vector, max_vector, max_vector], dim=0)
220
+ return max_vector
221
+
222
+ @lru_cache()
223
+ def delta_vector(self, layers_num, drho: bool = False):
224
+ delta_vector = self.max_vector(layers_num, drho) - self.min_vector(layers_num, drho)
225
+ delta_vector[delta_vector == 0.] = 1.
226
+ return delta_vector
227
+
228
+ def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
229
+ t_params = torch.cat([
230
+ params.thicknesses,
231
+ params.roughnesses,
232
+ params.slds
233
+ ], dim=-1)
234
+
235
+ indices = (
236
+ torch.all(t_params >= params.min_bounds, dim=-1) &
237
+ torch.all(t_params <= params.max_bounds, dim=-1)
238
+ )
239
+
240
+ return indices
241
+
242
+ def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
243
+ params = UniformSubPriorParams.from_tensor(
244
+ torch.cat([
245
+ torch.clamp(
246
+ params.as_tensor(add_bounds=False),
247
+ params.min_bounds, params.max_bounds
248
+ ),
249
+ params.min_bounds, params.max_bounds
250
+ ], dim=1)
251
+ )
252
+ return params
253
+
254
+ def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
255
+ return self.get_indices_within_bounds(params)
256
+
257
+ def sample(self, batch_size: int) -> UniformSubPriorParams:
258
+ min_bounds, max_bounds = self.sample_bounds(batch_size)
259
+
260
+ params = torch.rand(
261
+ *min_bounds.shape,
262
+ device=self.device,
263
+ dtype=self.dtype
264
+ ) * (max_bounds - min_bounds) + min_bounds
265
+
266
+ thicknesses, roughnesses, slds = torch.split(
267
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
268
+ )
269
+
270
+ params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
271
+
272
+ return params
273
+
274
+ def sample_bounds(self, batch_size: int):
275
+ min_vector, max_vector = (
276
+ super().min_vector(self.num_layers)[None],
277
+ super().max_vector(self.num_layers)[None]
278
+ )
279
+
280
+ delta_vector = max_vector - min_vector
281
+
282
+ if self.logdist:
283
+ widths_sampler_func = logdist_sampler
284
+ else:
285
+ widths_sampler_func = uniform_sampler
286
+
287
+ prior_widths = widths_sampler_func(
288
+ self.relative_min_bound_width, 1.,
289
+ batch_size, delta_vector.shape[1],
290
+ device=self.device, dtype=self.dtype
291
+ ) * delta_vector
292
+
293
+ prior_centers = uniform_sampler(
294
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
295
+ *prior_widths.shape,
296
+ device=self.device, dtype=self.dtype
297
+ )
298
+
299
+ if self.smaller_roughnesses:
300
+ idx_min, idx_max = self.num_layers, self.num_layers * 2 + 1
301
+ prior_centers[:, idx_min:idx_max] = triangular_sampler(
302
+ min_vector[:, idx_min:idx_max] + prior_widths[:, idx_min:idx_max] / 2,
303
+ max_vector[:, idx_min:idx_max] - prior_widths[:, idx_min:idx_max] / 2,
304
+ batch_size, self.num_layers + 1,
305
+ device=self.device, dtype=self.dtype
306
+ )
307
+
308
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
309
+
310
+ return min_bounds, max_bounds
311
+
312
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
313
+ layers_num = bounds.shape[-1] // 2
314
+
315
+ return self._scale(
316
+ bounds,
317
+ self.min_vector(layers_num, drho=self.use_drho).to(bounds),
318
+ self.max_vector(layers_num, drho=self.use_drho).to(bounds),
319
+ )
320
+
321
+
322
+ class NarrowSldUniformSubPriorSampler(UniformSubPriorSampler):
323
+ """Prior sampler for thicknesses, roughnesses, slds and their subprior bounds. The subprior bound widths for SLDs are restricted to be lower than a specified value. """
324
+ def __init__(self,
325
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
326
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
327
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
328
+ num_layers: int = DEFAULT_NUM_LAYERS,
329
+ use_drho: bool = DEFAULT_USE_DRHO,
330
+ device: torch.device = DEFAULT_DEVICE,
331
+ dtype: torch.dtype = DEFAULT_DTYPE,
332
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
333
+ scale_by_subpriors: bool = False,
334
+ max_sld_prior_width: float = 10.,
335
+ ):
336
+ super().__init__(
337
+ thickness_range,
338
+ roughness_range,
339
+ sld_range,
340
+ num_layers,
341
+ use_drho,
342
+ device,
343
+ dtype,
344
+ scaled_range,
345
+ scale_by_subpriors,
346
+ )
347
+
348
+ self.max_sld_prior_width = max_sld_prior_width
349
+
350
+ def sample_bounds(self, batch_size: int):
351
+ min_vector, max_vector = (
352
+ BasicPriorSampler.min_vector(self, self.num_layers),
353
+ BasicPriorSampler.max_vector(self, self.num_layers),
354
+ )
355
+
356
+ delta_vector = max_vector - min_vector
357
+ delta_vector[-self.num_layers:] = self.max_sld_prior_width
358
+
359
+ prior_widths = uniform_sampler(
360
+ delta_vector * self.relative_min_bound_width, delta_vector,
361
+ batch_size, min_vector.shape[0],
362
+ device=self.device, dtype=self.dtype
363
+ )
364
+
365
+ prior_centers = uniform_sampler(
366
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
367
+ *prior_widths.shape,
368
+ device=self.device, dtype=self.dtype
369
+ )
370
+
371
+ return prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
@@ -0,0 +1,118 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.utils import (
7
+ get_d_rhos,
8
+ uniform_sampler,
9
+ )
10
+
11
+
12
+ def get_max_allowed_roughness(thicknesses: Tensor, mask: Tensor = None, coef: float = 0.5):
13
+ """gets the maximum allowed interlayer roughnesses such that they do not exceed a fraction of the thickness of either layers meeting at that interface"""
14
+ batch_size, layers_num = thicknesses.shape
15
+ max_roughness = torch.ones(
16
+ batch_size, layers_num + 1, device=thicknesses.device, dtype=thicknesses.dtype
17
+ ) * float('inf')
18
+
19
+ boundary = thicknesses * coef
20
+ if mask is not None:
21
+ boundary[get_thickness_mask_from_sld_mask(mask)] = float('inf')
22
+
23
+ max_roughness[:, :-1] = boundary
24
+ max_roughness[:, 1:] = torch.minimum(max_roughness[:, 1:], boundary)
25
+ return max_roughness
26
+
27
+
28
+ def get_allowed_contrast_indices(slds: Tensor, min_contrast: float, mask: Tensor = None) -> Tensor:
29
+ d_rhos = get_d_rhos(slds)
30
+ indices = d_rhos.abs() >= min_contrast
31
+ if mask is not None:
32
+ indices = indices | mask
33
+ indices = torch.all(indices, -1)
34
+ return indices
35
+
36
+
37
+ def params_within_bounds(params_t: Tensor, min_t: Tensor, max_t: Tensor, mask: Tensor = None) -> Tensor:
38
+ indices = (params_t >= min_t[None]) & (params_t <= max_t[None])
39
+ if mask is not None:
40
+ indices = indices | mask
41
+ indices = torch.all(indices, -1)
42
+ return indices
43
+
44
+
45
+ def get_allowed_roughness_indices(thicknesses: Tensor, roughnesses: Tensor, mask: Tensor = None) -> Tensor:
46
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
47
+ indices = roughnesses <= max_roughness
48
+ if mask is not None:
49
+ indices = indices | mask
50
+ indices = torch.all(indices, -1)
51
+ return indices
52
+
53
+
54
+ def get_thickness_mask_from_sld_mask(mask: Tensor):
55
+ return mask[:, :-1]
56
+
57
+
58
+ def generate_roughnesses(thicknesses: Tensor, roughness_range: Tuple[float, float], mask: Tensor = None):
59
+ batch_size, layers_num = thicknesses.shape
60
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
61
+ max_roughness = torch.clamp_(max_roughness, max=roughness_range[1])
62
+
63
+ roughnesses = uniform_sampler(
64
+ roughness_range[0], max_roughness, batch_size, layers_num + 1,
65
+ device=thicknesses.device, dtype=thicknesses.dtype
66
+ )
67
+
68
+ if mask is not None:
69
+ roughnesses[mask] = 0.
70
+
71
+ return roughnesses
72
+
73
+
74
+ def generate_thicknesses(
75
+ thickness_range: Tuple[float, float],
76
+ batch_size: int,
77
+ layers_num: int,
78
+ device: torch.device,
79
+ dtype: torch.dtype,
80
+ mask: Tensor = None
81
+ ):
82
+ thicknesses = uniform_sampler(
83
+ *thickness_range, batch_size, layers_num, device=device, dtype=dtype
84
+ )
85
+ if mask is not None:
86
+ thicknesses[get_thickness_mask_from_sld_mask(mask)] = 0.
87
+ return thicknesses
88
+
89
+
90
+ def generate_slds_with_min_contrast(
91
+ sld_range: Tuple[float, float],
92
+ batch_size: int,
93
+ layers_num: int,
94
+ min_contrast: float,
95
+ device: torch.device,
96
+ dtype: torch.dtype,
97
+ mask: Tensor = None,
98
+ *,
99
+ _depth: int = 0
100
+ ):
101
+ # rejection sampling
102
+ slds = uniform_sampler(
103
+ *sld_range, batch_size, layers_num + 1, device=device, dtype=dtype
104
+ )
105
+
106
+ if mask is not None:
107
+ slds[mask] = 0.
108
+
109
+ rejected_indices = ~get_allowed_contrast_indices(slds, min_contrast, mask)
110
+ rejected_num = rejected_indices.sum(0).item()
111
+
112
+ if rejected_num:
113
+ if mask is not None:
114
+ mask = mask[rejected_indices]
115
+ slds[rejected_indices] = generate_slds_with_min_contrast(
116
+ sld_range, rejected_num, layers_num, min_contrast, device, dtype, mask, _depth=_depth+1
117
+ )
118
+ return slds
@@ -0,0 +1,41 @@
1
+ from typing import Any
2
+
3
+ __all__ = [
4
+ "ProcessData",
5
+ "ProcessPipeline",
6
+ ]
7
+
8
+
9
+ class ProcessData(object):
10
+ def __add__(self, other):
11
+ if isinstance(other, ProcessData):
12
+ return ProcessPipeline(self, other)
13
+
14
+ def apply(self, args: Any, context: dict = None):
15
+ return args
16
+
17
+ def __call__(self, args: Any, context: dict = None):
18
+ return self.apply(args, context)
19
+
20
+ def __repr__(self):
21
+ return f'{self.__class__.__name__}()'
22
+
23
+
24
+ class ProcessPipeline(ProcessData):
25
+ def __init__(self, *processes):
26
+ self._processes = list(processes)
27
+
28
+ def apply(self, args: Any, context: dict = None):
29
+ for process in self._processes:
30
+ args = process(args, context)
31
+ return args
32
+
33
+ def __add__(self, other):
34
+ if isinstance(other, ProcessPipeline):
35
+ return ProcessPipeline(*self._processes, *other._processes)
36
+ elif isinstance(other, ProcessData):
37
+ return ProcessPipeline(*self._processes, other)
38
+
39
+ def __repr__(self):
40
+ processes = ", ".join(repr(p) for p in self._processes)
41
+ return f'ProcessPipeline({processes})'