reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,377 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import lru_cache
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from reflectorch.data_generation.utils import (
14
+ uniform_sampler,
15
+ logdist_sampler,
16
+ triangular_sampler,
17
+ get_slds_from_d_rhos,
18
+ )
19
+
20
+ from reflectorch.data_generation.priors.params import Params
21
+ from reflectorch.data_generation.priors.no_constraints import (
22
+ BasicPriorSampler,
23
+ DEFAULT_ROUGHNESS_RANGE,
24
+ DEFAULT_THICKNESS_RANGE,
25
+ DEFAULT_SLD_RANGE,
26
+ DEFAULT_NUM_LAYERS,
27
+ DEFAULT_DEVICE,
28
+ DEFAULT_DTYPE,
29
+ DEFAULT_SCALED_RANGE,
30
+ DEFAULT_USE_DRHO,
31
+ )
32
+
33
+
34
+ class UniformSubPriorParams(Params):
35
+ """Parameters class for thicknesses, roughnesses and slds, together with their subprior bounds."""
36
+ __slots__ = ('thicknesses', 'roughnesses', 'slds', 'min_bounds', 'max_bounds')
37
+ PARAM_NAMES = __slots__
38
+
39
+ def __init__(self,
40
+ thicknesses: Tensor,
41
+ roughnesses: Tensor,
42
+ slds: Tensor,
43
+ min_bounds: Tensor,
44
+ max_bounds: Tensor,
45
+ ):
46
+ super().__init__(thicknesses, roughnesses, slds)
47
+ self.min_bounds = min_bounds
48
+ self.max_bounds = max_bounds
49
+
50
+ @staticmethod
51
+ def rearrange_context_from_params(
52
+ scaled_params: Tensor, context: Tensor, inference: bool = False, from_params: bool = False
53
+ ):
54
+ if inference:
55
+ if from_params:
56
+ num_params = scaled_params.shape[1] // 3
57
+ scaled_params = scaled_params[:, num_params:]
58
+ context = torch.cat([context, scaled_params], dim=-1)
59
+ return context
60
+
61
+ num_params = scaled_params.shape[1] // 3
62
+ assert num_params * 3 == scaled_params.shape[1]
63
+ scaled_params, bound_context = torch.split(scaled_params, [num_params, 2 * num_params], dim=-1)
64
+ context = torch.cat([context, bound_context], dim=-1)
65
+ return scaled_params, context
66
+
67
+ @staticmethod
68
+ def restore_params_from_context(scaled_params: Tensor, context: Tensor):
69
+ num_params = scaled_params.shape[-1]
70
+ scaled_bounds = context[:, -2 * num_params:]
71
+ scaled_params = torch.cat([scaled_params, scaled_bounds], dim=-1)
72
+ return scaled_params
73
+
74
+ @staticmethod
75
+ def input_context_split(t_params):
76
+ num_params = t_params.shape[1] // 3
77
+ return torch.split(t_params, [num_params, 2 * num_params])
78
+
79
+ def as_tensor(self, use_drho: bool = False, add_bounds: bool = True) -> Tensor:
80
+ t_list = [self.thicknesses, self.roughnesses]
81
+ if use_drho:
82
+ t_list.append(self.d_rhos)
83
+ else:
84
+ t_list.append(self.slds)
85
+ if add_bounds:
86
+ t_list += [self.min_bounds, self.max_bounds]
87
+ return torch.cat(t_list, -1)
88
+
89
+ @classmethod
90
+ def from_tensor(cls, params: Tensor):
91
+ layers_num = (params.shape[-1] - 6) // 9
92
+ num_params = 3 * layers_num + 2
93
+
94
+ thicknesses, roughnesses, slds, min_bounds, max_bounds = torch.split(
95
+ params,
96
+ [layers_num, layers_num + 1, layers_num + 1, num_params, num_params],
97
+ dim=-1
98
+ )
99
+
100
+ return cls(thicknesses, roughnesses, slds, min_bounds, max_bounds)
101
+
102
+ @property
103
+ def num_params(self) -> int:
104
+ return self.layers_num2size(self.max_layer_num)
105
+
106
+ @staticmethod
107
+ def size2layers_num(size: int) -> int:
108
+ return (size - 6) // 9
109
+
110
+ @staticmethod
111
+ def layers_num2size(layers_num: int) -> int:
112
+ return layers_num * 9 + 6
113
+
114
+ def scale_with_q(self, q_ratio: float):
115
+ super().scale_with_q(q_ratio)
116
+
117
+ layer_num = self.max_layer_num
118
+ scales = torch.tensor(
119
+ [1 / q_ratio] * (2 * layer_num + 1) + [q_ratio ** 2] * (layer_num + 1),
120
+ device=self.device, dtype=self.dtype
121
+ )
122
+
123
+ self.min_bounds *= scales
124
+ self.max_bounds *= scales
125
+
126
+
127
+ class UniformSubPriorSampler(BasicPriorSampler):
128
+ """Prior sampler for thicknesses, roughnesses, slds and their subprior bounds
129
+
130
+ Args:
131
+ thickness_range (Tuple[float, float], optional): the range of the layer thicknesses. Defaults to DEFAULT_THICKNESS_RANGE.
132
+ roughness_range (Tuple[float, float], optional): the range of the interlayer roughnesses. Defaults to DEFAULT_ROUGHNESS_RANGE.
133
+ sld_range (Tuple[float, float], optional): the range of the layer SLDs. Defaults to DEFAULT_SLD_RANGE.
134
+ num_layers (int, optional): the number of layers. Defaults to DEFAULT_NUM_LAYERS.
135
+ use_drho (bool, optional): whether to use differences in SLD values between neighboring layers instead of the actual SLD values. Defaults to DEFAULT_USE_DRHO.
136
+ device (torch.device, optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
137
+ dtype (torch.dtype, optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
138
+ scaled_range (Tuple[float, float], optional): the range for scaling the parameters. Defaults to DEFAULT_SCALED_RANGE.
139
+ scale_by_subpriors (bool, optional): if True the film parameters are scaled with respect to their subprior bounds. Defaults to False.
140
+ smaller_roughnesses (bool, optional): if True the sampled roughnesses are biased towards smaller values. Defaults to False.
141
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
142
+ relative_min_bound_width (float, optional): defines the interval [relative_min_bound_width, 1.0] from which the relative bound widths for each parameter are sampled. Defaults to 1e-2.
143
+ """
144
+ PARAM_CLS = UniformSubPriorParams
145
+
146
+ def __init__(self,
147
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
148
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
149
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
150
+ num_layers: int = DEFAULT_NUM_LAYERS,
151
+ use_drho: bool = DEFAULT_USE_DRHO,
152
+ device: torch.device = DEFAULT_DEVICE,
153
+ dtype: torch.dtype = DEFAULT_DTYPE,
154
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
155
+ scale_by_subpriors: bool = False,
156
+ smaller_roughnesses: bool = False,
157
+ logdist: bool = False,
158
+ relative_min_bound_width: float = 1e-2,
159
+ ):
160
+ super().__init__(
161
+ thickness_range,
162
+ roughness_range,
163
+ sld_range,
164
+ num_layers,
165
+ use_drho,
166
+ device,
167
+ dtype,
168
+ scaled_range,
169
+ )
170
+
171
+ self.scale_by_subpriors = scale_by_subpriors
172
+ self.smaller_roughnesses = smaller_roughnesses
173
+ self.logdist = logdist
174
+ self.relative_min_bound_width = relative_min_bound_width
175
+
176
+ @property
177
+ def max_num_layers(self) -> int:
178
+ return self.num_layers
179
+
180
+ @property
181
+ def param_dim(self) -> int:
182
+ return self.max_num_layers * 3 + 2
183
+
184
+ @lru_cache()
185
+ def min_vector(self, layers_num, drho: bool = False):
186
+ min_vector = super().min_vector(layers_num, drho)
187
+ min_vector = torch.cat([min_vector, min_vector, min_vector], dim=0)
188
+ return min_vector
189
+
190
+ def scale_params(self, params: UniformSubPriorParams) -> Tensor:
191
+ scaled_params = super().scale_params(params)
192
+
193
+ if self.scale_by_subpriors:
194
+ params_t = params.as_tensor(use_drho=self.use_drho, add_bounds=False)
195
+ scaled_params[:, :self.param_dim] = self._scale(params_t, params.min_bounds, params.max_bounds)
196
+
197
+ return scaled_params
198
+
199
+ def restore_params(self, scaled_params: Tensor) -> Params:
200
+ if not self.scale_by_subpriors:
201
+ return super().restore_params(scaled_params)
202
+
203
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
204
+ scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
205
+ )
206
+
207
+ min_vector = super().min_vector(self.max_num_layers, self.use_drho)
208
+ max_vector = super().max_vector(self.max_num_layers, self.use_drho)
209
+
210
+ min_bounds = self._restore(scaled_min_bounds, min_vector, max_vector)
211
+ max_bounds = self._restore(scaled_max_bounds, min_vector, max_vector)
212
+
213
+ param_t = self._restore(scaled_params, min_bounds, max_bounds)
214
+ param_t = torch.cat([param_t, min_bounds, max_bounds], dim=-1)
215
+
216
+ params = UniformSubPriorParams.from_tensor(param_t)
217
+
218
+ if self.use_drho:
219
+ params.slds = get_slds_from_d_rhos(params.slds)
220
+ return params
221
+
222
+ @lru_cache()
223
+ def max_vector(self, layers_num, drho: bool = False):
224
+ max_vector = super().max_vector(layers_num, drho)
225
+ max_vector = torch.cat([max_vector, max_vector, max_vector], dim=0)
226
+ return max_vector
227
+
228
+ @lru_cache()
229
+ def delta_vector(self, layers_num, drho: bool = False):
230
+ delta_vector = self.max_vector(layers_num, drho) - self.min_vector(layers_num, drho)
231
+ delta_vector[delta_vector == 0.] = 1.
232
+ return delta_vector
233
+
234
+ def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
235
+ t_params = torch.cat([
236
+ params.thicknesses,
237
+ params.roughnesses,
238
+ params.slds
239
+ ], dim=-1)
240
+
241
+ indices = (
242
+ torch.all(t_params >= params.min_bounds, dim=-1) &
243
+ torch.all(t_params <= params.max_bounds, dim=-1)
244
+ )
245
+
246
+ return indices
247
+
248
+ def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
249
+ params = UniformSubPriorParams.from_tensor(
250
+ torch.cat([
251
+ torch.clamp(
252
+ params.as_tensor(add_bounds=False),
253
+ params.min_bounds, params.max_bounds
254
+ ),
255
+ params.min_bounds, params.max_bounds
256
+ ], dim=1)
257
+ )
258
+ return params
259
+
260
+ def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
261
+ return self.get_indices_within_bounds(params)
262
+
263
+ def sample(self, batch_size: int) -> UniformSubPriorParams:
264
+ min_bounds, max_bounds = self.sample_bounds(batch_size)
265
+
266
+ params = torch.rand(
267
+ *min_bounds.shape,
268
+ device=self.device,
269
+ dtype=self.dtype
270
+ ) * (max_bounds - min_bounds) + min_bounds
271
+
272
+ thicknesses, roughnesses, slds = torch.split(
273
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
274
+ )
275
+
276
+ params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
277
+
278
+ return params
279
+
280
+ def sample_bounds(self, batch_size: int):
281
+ min_vector, max_vector = (
282
+ super().min_vector(self.num_layers)[None],
283
+ super().max_vector(self.num_layers)[None]
284
+ )
285
+
286
+ delta_vector = max_vector - min_vector
287
+
288
+ if self.logdist:
289
+ widths_sampler_func = logdist_sampler
290
+ else:
291
+ widths_sampler_func = uniform_sampler
292
+
293
+ prior_widths = widths_sampler_func(
294
+ self.relative_min_bound_width, 1.,
295
+ batch_size, delta_vector.shape[1],
296
+ device=self.device, dtype=self.dtype
297
+ ) * delta_vector
298
+
299
+ prior_centers = uniform_sampler(
300
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
301
+ *prior_widths.shape,
302
+ device=self.device, dtype=self.dtype
303
+ )
304
+
305
+ if self.smaller_roughnesses:
306
+ idx_min, idx_max = self.num_layers, self.num_layers * 2 + 1
307
+ prior_centers[:, idx_min:idx_max] = triangular_sampler(
308
+ min_vector[:, idx_min:idx_max] + prior_widths[:, idx_min:idx_max] / 2,
309
+ max_vector[:, idx_min:idx_max] - prior_widths[:, idx_min:idx_max] / 2,
310
+ batch_size, self.num_layers + 1,
311
+ device=self.device, dtype=self.dtype
312
+ )
313
+
314
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
315
+
316
+ return min_bounds, max_bounds
317
+
318
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
319
+ layers_num = bounds.shape[-1] // 2
320
+
321
+ return self._scale(
322
+ bounds,
323
+ self.min_vector(layers_num, drho=self.use_drho).to(bounds),
324
+ self.max_vector(layers_num, drho=self.use_drho).to(bounds),
325
+ )
326
+
327
+
328
+ class NarrowSldUniformSubPriorSampler(UniformSubPriorSampler):
329
+ """Prior sampler for thicknesses, roughnesses, slds and their subprior bounds. The subprior bound widths for SLDs are restricted to be lower than a specified value. """
330
+ def __init__(self,
331
+ thickness_range: Tuple[float, float] = DEFAULT_THICKNESS_RANGE,
332
+ roughness_range: Tuple[float, float] = DEFAULT_ROUGHNESS_RANGE,
333
+ sld_range: Tuple[float, float] = DEFAULT_SLD_RANGE,
334
+ num_layers: int = DEFAULT_NUM_LAYERS,
335
+ use_drho: bool = DEFAULT_USE_DRHO,
336
+ device: torch.device = DEFAULT_DEVICE,
337
+ dtype: torch.dtype = DEFAULT_DTYPE,
338
+ scaled_range: Tuple[float, float] = DEFAULT_SCALED_RANGE,
339
+ scale_by_subpriors: bool = False,
340
+ max_sld_prior_width: float = 10.,
341
+ ):
342
+ super().__init__(
343
+ thickness_range,
344
+ roughness_range,
345
+ sld_range,
346
+ num_layers,
347
+ use_drho,
348
+ device,
349
+ dtype,
350
+ scaled_range,
351
+ scale_by_subpriors,
352
+ )
353
+
354
+ self.max_sld_prior_width = max_sld_prior_width
355
+
356
+ def sample_bounds(self, batch_size: int):
357
+ min_vector, max_vector = (
358
+ BasicPriorSampler.min_vector(self, self.num_layers),
359
+ BasicPriorSampler.max_vector(self, self.num_layers),
360
+ )
361
+
362
+ delta_vector = max_vector - min_vector
363
+ delta_vector[-self.num_layers:] = self.max_sld_prior_width
364
+
365
+ prior_widths = uniform_sampler(
366
+ delta_vector * self.relative_min_bound_width, delta_vector,
367
+ batch_size, min_vector.shape[0],
368
+ device=self.device, dtype=self.dtype
369
+ )
370
+
371
+ prior_centers = uniform_sampler(
372
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
373
+ *prior_widths.shape,
374
+ device=self.device, dtype=self.dtype
375
+ )
376
+
377
+ return prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
@@ -0,0 +1,124 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from reflectorch.data_generation.utils import (
13
+ get_d_rhos,
14
+ uniform_sampler,
15
+ )
16
+
17
+
18
+ def get_max_allowed_roughness(thicknesses: Tensor, mask: Tensor = None, coef: float = 0.5):
19
+ """gets the maximum allowed interlayer roughnesses such that they do not exceed a fraction of the thickness of either layers meeting at that interface"""
20
+ batch_size, layers_num = thicknesses.shape
21
+ max_roughness = torch.ones(
22
+ batch_size, layers_num + 1, device=thicknesses.device, dtype=thicknesses.dtype
23
+ ) * float('inf')
24
+
25
+ boundary = thicknesses * coef
26
+ if mask is not None:
27
+ boundary[get_thickness_mask_from_sld_mask(mask)] = float('inf')
28
+
29
+ max_roughness[:, :-1] = boundary
30
+ max_roughness[:, 1:] = torch.minimum(max_roughness[:, 1:], boundary)
31
+ return max_roughness
32
+
33
+
34
+ def get_allowed_contrast_indices(slds: Tensor, min_contrast: float, mask: Tensor = None) -> Tensor:
35
+ d_rhos = get_d_rhos(slds)
36
+ indices = d_rhos.abs() >= min_contrast
37
+ if mask is not None:
38
+ indices = indices | mask
39
+ indices = torch.all(indices, -1)
40
+ return indices
41
+
42
+
43
+ def params_within_bounds(params_t: Tensor, min_t: Tensor, max_t: Tensor, mask: Tensor = None) -> Tensor:
44
+ indices = (params_t >= min_t[None]) & (params_t <= max_t[None])
45
+ if mask is not None:
46
+ indices = indices | mask
47
+ indices = torch.all(indices, -1)
48
+ return indices
49
+
50
+
51
+ def get_allowed_roughness_indices(thicknesses: Tensor, roughnesses: Tensor, mask: Tensor = None) -> Tensor:
52
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
53
+ indices = roughnesses <= max_roughness
54
+ if mask is not None:
55
+ indices = indices | mask
56
+ indices = torch.all(indices, -1)
57
+ return indices
58
+
59
+
60
+ def get_thickness_mask_from_sld_mask(mask: Tensor):
61
+ return mask[:, :-1]
62
+
63
+
64
+ def generate_roughnesses(thicknesses: Tensor, roughness_range: Tuple[float, float], mask: Tensor = None):
65
+ batch_size, layers_num = thicknesses.shape
66
+ max_roughness = get_max_allowed_roughness(thicknesses, mask)
67
+ max_roughness = torch.clamp_(max_roughness, max=roughness_range[1])
68
+
69
+ roughnesses = uniform_sampler(
70
+ roughness_range[0], max_roughness, batch_size, layers_num + 1,
71
+ device=thicknesses.device, dtype=thicknesses.dtype
72
+ )
73
+
74
+ if mask is not None:
75
+ roughnesses[mask] = 0.
76
+
77
+ return roughnesses
78
+
79
+
80
+ def generate_thicknesses(
81
+ thickness_range: Tuple[float, float],
82
+ batch_size: int,
83
+ layers_num: int,
84
+ device: torch.device,
85
+ dtype: torch.dtype,
86
+ mask: Tensor = None
87
+ ):
88
+ thicknesses = uniform_sampler(
89
+ *thickness_range, batch_size, layers_num, device=device, dtype=dtype
90
+ )
91
+ if mask is not None:
92
+ thicknesses[get_thickness_mask_from_sld_mask(mask)] = 0.
93
+ return thicknesses
94
+
95
+
96
+ def generate_slds_with_min_contrast(
97
+ sld_range: Tuple[float, float],
98
+ batch_size: int,
99
+ layers_num: int,
100
+ min_contrast: float,
101
+ device: torch.device,
102
+ dtype: torch.dtype,
103
+ mask: Tensor = None,
104
+ *,
105
+ _depth: int = 0
106
+ ):
107
+ # rejection sampling
108
+ slds = uniform_sampler(
109
+ *sld_range, batch_size, layers_num + 1, device=device, dtype=dtype
110
+ )
111
+
112
+ if mask is not None:
113
+ slds[mask] = 0.
114
+
115
+ rejected_indices = ~get_allowed_contrast_indices(slds, min_contrast, mask)
116
+ rejected_num = rejected_indices.sum(0).item()
117
+
118
+ if rejected_num:
119
+ if mask is not None:
120
+ mask = mask[rejected_indices]
121
+ slds[rejected_indices] = generate_slds_with_min_contrast(
122
+ sld_range, rejected_num, layers_num, min_contrast, device, dtype, mask, _depth=_depth+1
123
+ )
124
+ return slds
@@ -0,0 +1,47 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any
8
+
9
+ __all__ = [
10
+ "ProcessData",
11
+ "ProcessPipeline",
12
+ ]
13
+
14
+
15
+ class ProcessData(object):
16
+ def __add__(self, other):
17
+ if isinstance(other, ProcessData):
18
+ return ProcessPipeline(self, other)
19
+
20
+ def apply(self, args: Any, context: dict = None):
21
+ return args
22
+
23
+ def __call__(self, args: Any, context: dict = None):
24
+ return self.apply(args, context)
25
+
26
+ def __repr__(self):
27
+ return f'{self.__class__.__name__}()'
28
+
29
+
30
+ class ProcessPipeline(ProcessData):
31
+ def __init__(self, *processes):
32
+ self._processes = list(processes)
33
+
34
+ def apply(self, args: Any, context: dict = None):
35
+ for process in self._processes:
36
+ args = process(args, context)
37
+ return args
38
+
39
+ def __add__(self, other):
40
+ if isinstance(other, ProcessPipeline):
41
+ return ProcessPipeline(*self._processes, *other._processes)
42
+ elif isinstance(other, ProcessData):
43
+ return ProcessPipeline(*self._processes, other)
44
+
45
+ def __repr__(self):
46
+ processes = ", ".join(repr(p) for p in self._processes)
47
+ return f'ProcessPipeline({processes})'