reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,298 +1,298 @@
1
- from typing import Tuple, Union, List
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
- from reflectorch.data_generation.utils import (
7
- uniform_sampler,
8
- logdist_sampler,
9
- )
10
-
11
- from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
12
- from reflectorch.data_generation.priors.base import PriorSampler
13
- from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
14
- from reflectorch.data_generation.priors.params import Params
15
- from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
16
- from reflectorch.data_generation.priors.no_constraints import (
17
- DEFAULT_DEVICE,
18
- DEFAULT_DTYPE,
19
- )
20
-
21
-
22
- class ExpUniformSubPriorSampler(PriorSampler, ScalerMixin):
23
- PARAM_CLS = UniformSubPriorParams
24
-
25
- def __init__(self,
26
- params: List[Union[float, Tuple[float, float], Tuple[float, float, float, float]]],
27
- device: torch.device = DEFAULT_DEVICE,
28
- dtype: torch.dtype = DEFAULT_DTYPE,
29
- scaled_range: Tuple[float, float] = (-1, 1),
30
- logdist: bool = False,
31
- relative_min_bound_width: float = 1e-4,
32
- smaller_roughnesses: bool = True,
33
- ):
34
- self.device = device
35
- self.dtype = dtype
36
- self.scaled_range = scaled_range
37
- self.relative_min_bound_width = relative_min_bound_width
38
- self.logdist = logdist
39
- self.smaller_roughnesses = smaller_roughnesses
40
- self._init_params(*params)
41
-
42
- @property
43
- def max_num_layers(self) -> int:
44
- return self.num_layers
45
-
46
- def _init_params(self, *params: Union[float, Tuple[float, float], Tuple[float, float, float, float]]):
47
- self.num_layers = (len(params) - 2) // 3
48
- self._total_num_params = len(params)
49
-
50
- fixed_mask = []
51
- bounds = []
52
- delta_bounds = []
53
- param_dim = 0
54
-
55
- for param in params:
56
- if isinstance(param, (float, int)):
57
- deltas = (0, 0)
58
- param = (param, param)
59
- fixed_mask.append(True)
60
- else:
61
- param_dim += 1
62
- fixed_mask.append(False)
63
-
64
- if len(param) == 4:
65
- param, deltas = param[:2], param[2:]
66
- else:
67
- max_delta = param[1] - param[0]
68
- deltas = (max_delta * self.relative_min_bound_width, max_delta)
69
-
70
- bounds.append(param)
71
- delta_bounds.append(deltas)
72
-
73
- self.fixed_mask = torch.tensor(fixed_mask).to(self.device)
74
- self.fitted_mask = ~self.fixed_mask
75
- self.roughnesses_mask = torch.zeros_like(self.fitted_mask)
76
- self.roughnesses_mask[self.num_layers: self.num_layers * 2 + 1] = True
77
-
78
- self.min_bounds, self.max_bounds = torch.tensor(bounds).to(self.device).to(self.dtype).T
79
- self.min_deltas, self.max_deltas = map(torch.atleast_2d, torch.tensor(delta_bounds).to(self.min_bounds).T)
80
- self._param_dim = param_dim
81
- self._num_fixed = self.fixed_mask.sum()
82
-
83
- self.fixed_params = self.min_bounds[self.fixed_mask]
84
-
85
- @property
86
- def param_dim(self) -> int:
87
- return self._param_dim
88
-
89
- def sample(self, batch_size: int) -> UniformSubPriorParams:
90
- min_bounds, max_bounds = self.sample_bounds(batch_size)
91
-
92
- params = torch.rand(
93
- *min_bounds.shape,
94
- device=self.device,
95
- dtype=self.dtype
96
- ) * (max_bounds - min_bounds) + min_bounds
97
-
98
- thicknesses, roughnesses, slds = torch.split(
99
- params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
100
- )
101
-
102
- if self.smaller_roughnesses:
103
- fitted_r_mask = self.fitted_mask.clone()
104
- fitted_r_mask[~self.roughnesses_mask] = False
105
-
106
- min_roughness = self.min_bounds[fitted_r_mask]
107
- max_roughness = torch.clamp(
108
- get_max_allowed_roughness(thicknesses)[..., self.fitted_mask[self.roughnesses_mask]],
109
- min_roughness,
110
- self.max_bounds[fitted_r_mask]
111
- )
112
-
113
- min_vector = self.min_bounds.clone()[None].repeat(batch_size, 1)
114
- max_vector = self.max_bounds.clone()[None].repeat(batch_size, 1)
115
-
116
- max_vector[..., fitted_r_mask] = max_roughness
117
-
118
- assert torch.all(max_vector[..., fitted_r_mask] == max_roughness)
119
-
120
- min_deltas = self.min_deltas.clone().repeat(batch_size, 1)
121
- max_deltas = self.max_deltas.clone().repeat(batch_size, 1)
122
-
123
- max_deltas[..., fitted_r_mask] = torch.clamp_max(
124
- max_deltas[..., fitted_r_mask],
125
- max_roughness - min_roughness,
126
- )
127
-
128
- fitted_mask = torch.zeros_like(self.fitted_mask)
129
- fitted_mask[fitted_r_mask] = True
130
-
131
- updated_min_bounds, updated_max_bounds = self._sample_bounds(
132
- batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask
133
- )
134
-
135
- min_bounds[..., fitted_mask], max_bounds[..., fitted_mask] = (
136
- updated_min_bounds[..., fitted_mask], updated_max_bounds[..., fitted_mask]
137
- )
138
-
139
- params[..., fitted_mask] = torch.rand(
140
- batch_size, fitted_mask.sum().item(),
141
- device=self.device,
142
- dtype=self.dtype
143
- ) * (max_bounds[..., fitted_mask] - min_bounds[..., fitted_mask]) + min_bounds[..., fitted_mask]
144
-
145
- thicknesses, roughnesses, slds = torch.split(
146
- params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
147
- )
148
-
149
- params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
150
-
151
- return params
152
-
153
- def scale_params(self, params: UniformSubPriorParams) -> Tensor:
154
- params_t = params.as_tensor(add_bounds=False)
155
-
156
- scaled_params = self._scale(params_t, params.min_bounds, params.max_bounds)[..., self.fitted_mask]
157
-
158
- scaled_min_bounds = self._scale(params.min_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
159
-
160
- scaled_max_bounds = self._scale(params.max_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
161
-
162
- scaled_params = torch.cat([scaled_params, scaled_min_bounds, scaled_max_bounds], -1)
163
-
164
- return scaled_params
165
-
166
- def scale_bounds(self, bounds: Tensor) -> Tensor:
167
- return self._scale(bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
168
-
169
- def restore_params(self, scaled_params: Tensor) -> UniformSubPriorParams:
170
- scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
171
- scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
172
- )
173
-
174
- min_bounds = self._restore(
175
- scaled_min_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
176
- )
177
- max_bounds = self._restore(
178
- scaled_max_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
179
- )
180
-
181
- restored_params = self._restore(scaled_params, min_bounds, max_bounds)
182
-
183
- params_t = torch.cat(
184
- [
185
- self._cat_restored_with_fixed_vector(restored_params),
186
- self._cat_restored_with_fixed_vector(min_bounds),
187
- self._cat_restored_with_fixed_vector(max_bounds),
188
- ], -1
189
- )
190
-
191
- params = UniformSubPriorParams.from_tensor(params_t)
192
-
193
- return params
194
-
195
- def _cat_restored_with_fixed_vector(self, restored_t: Tensor) -> Tensor:
196
- return self._cat_fitted_fixed_t(restored_t, self.fixed_params)
197
-
198
- def _cat_fitted_fixed_t(self, fitted_t: Tensor, fixed_t: Tensor, fitted_mask: Tensor = None) -> Tensor:
199
- if fitted_mask is None:
200
- fitted_mask = self.fitted_mask
201
- fixed_mask = self.fixed_mask
202
- else:
203
- fixed_mask = ~fitted_mask
204
-
205
- total_num_params = self.fitted_mask.sum().item() + self.fixed_mask.sum().item()
206
-
207
- batch_size = fitted_t.shape[0]
208
-
209
- concat_t = torch.empty(
210
- batch_size, total_num_params, device=fitted_t.device, dtype=fitted_t.dtype
211
- )
212
- concat_t[:, fitted_mask] = fitted_t
213
- concat_t[:, fixed_mask] = fixed_t[None].expand(batch_size, -1)
214
-
215
- return concat_t
216
-
217
- def log_prob(self, params: UniformSubPriorParams) -> Tensor:
218
- log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
219
- indices = self.get_indices_within_bounds(params)
220
- log_prob[~indices] = float('-inf')
221
- return log_prob
222
-
223
- def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
224
- t_params = torch.cat([
225
- params.thicknesses,
226
- params.roughnesses,
227
- params.slds
228
- ], dim=-1)
229
-
230
- indices = (
231
- torch.all(t_params >= params.min_bounds, dim=-1) &
232
- torch.all(t_params <= params.max_bounds, dim=-1)
233
- )
234
-
235
- return indices
236
-
237
- def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
238
- params = UniformSubPriorParams.from_tensor(
239
- torch.cat([
240
- torch.clamp(
241
- params.as_tensor(add_bounds=False),
242
- params.min_bounds, params.max_bounds
243
- ),
244
- params.min_bounds, params.max_bounds
245
- ], dim=1)
246
- )
247
- return params
248
-
249
- def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
250
- return self.get_indices_within_bounds(params)
251
-
252
- def sample_bounds(self, batch_size: int):
253
- return self._sample_bounds(
254
- batch_size,
255
- self.min_bounds,
256
- self.max_bounds,
257
- self.min_deltas,
258
- self.max_deltas,
259
- self.fitted_mask,
260
- )
261
-
262
- def _sample_bounds(self, batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask):
263
- if self.logdist:
264
- widths_sampler_func = logdist_sampler
265
- else:
266
- widths_sampler_func = uniform_sampler
267
-
268
- num_fitted = fitted_mask.sum().item()
269
- num_fixed = fitted_mask.numel() - num_fitted
270
-
271
- prior_widths = widths_sampler_func(
272
- min_deltas[..., fitted_mask], max_deltas[..., fitted_mask],
273
- batch_size, num_fitted,
274
- device=self.device, dtype=self.dtype
275
- )
276
-
277
- prior_widths = self._cat_fitted_fixed_t(prior_widths, torch.zeros(num_fixed).to(prior_widths), fitted_mask)
278
-
279
- prior_centers = uniform_sampler(
280
- min_vector + prior_widths / 2, max_vector - prior_widths / 2,
281
- *prior_widths.shape,
282
- device=self.device, dtype=self.dtype
283
- )
284
-
285
- min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
286
-
287
- return min_bounds, max_bounds
288
-
289
- @staticmethod
290
- def scale_bounds_with_q(bounds: Tensor, q_ratio: float) -> Tensor:
291
- params = Params.from_tensor(torch.atleast_2d(bounds).clone())
292
- params.scale_with_q(q_ratio)
293
- return params.as_tensor().squeeze()
294
-
295
- def clamp_bounds(self, bounds: Tensor) -> Tensor:
296
- return torch.clamp(
297
- torch.atleast_2d(bounds), torch.atleast_2d(self.min_bounds), torch.atleast_2d(self.max_bounds)
298
- )
1
+ from typing import Tuple, Union, List
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.utils import (
7
+ uniform_sampler,
8
+ logdist_sampler,
9
+ )
10
+
11
+ from reflectorch.data_generation.priors.scaler_mixin import ScalerMixin
12
+ from reflectorch.data_generation.priors.base import PriorSampler
13
+ from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
14
+ from reflectorch.data_generation.priors.params import Params
15
+ from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
16
+ from reflectorch.data_generation.priors.no_constraints import (
17
+ DEFAULT_DEVICE,
18
+ DEFAULT_DTYPE,
19
+ )
20
+
21
+
22
+ class ExpUniformSubPriorSampler(PriorSampler, ScalerMixin):
23
+ PARAM_CLS = UniformSubPriorParams
24
+
25
+ def __init__(self,
26
+ params: List[Union[float, Tuple[float, float], Tuple[float, float, float, float]]],
27
+ device: torch.device = DEFAULT_DEVICE,
28
+ dtype: torch.dtype = DEFAULT_DTYPE,
29
+ scaled_range: Tuple[float, float] = (-1, 1),
30
+ logdist: bool = False,
31
+ relative_min_bound_width: float = 1e-4,
32
+ smaller_roughnesses: bool = True,
33
+ ):
34
+ self.device = device
35
+ self.dtype = dtype
36
+ self.scaled_range = scaled_range
37
+ self.relative_min_bound_width = relative_min_bound_width
38
+ self.logdist = logdist
39
+ self.smaller_roughnesses = smaller_roughnesses
40
+ self._init_params(*params)
41
+
42
+ @property
43
+ def max_num_layers(self) -> int:
44
+ return self.num_layers
45
+
46
+ def _init_params(self, *params: Union[float, Tuple[float, float], Tuple[float, float, float, float]]):
47
+ self.num_layers = (len(params) - 2) // 3
48
+ self._total_num_params = len(params)
49
+
50
+ fixed_mask = []
51
+ bounds = []
52
+ delta_bounds = []
53
+ param_dim = 0
54
+
55
+ for param in params:
56
+ if isinstance(param, (float, int)):
57
+ deltas = (0, 0)
58
+ param = (param, param)
59
+ fixed_mask.append(True)
60
+ else:
61
+ param_dim += 1
62
+ fixed_mask.append(False)
63
+
64
+ if len(param) == 4:
65
+ param, deltas = param[:2], param[2:]
66
+ else:
67
+ max_delta = param[1] - param[0]
68
+ deltas = (max_delta * self.relative_min_bound_width, max_delta)
69
+
70
+ bounds.append(param)
71
+ delta_bounds.append(deltas)
72
+
73
+ self.fixed_mask = torch.tensor(fixed_mask).to(self.device)
74
+ self.fitted_mask = ~self.fixed_mask
75
+ self.roughnesses_mask = torch.zeros_like(self.fitted_mask)
76
+ self.roughnesses_mask[self.num_layers: self.num_layers * 2 + 1] = True
77
+
78
+ self.min_bounds, self.max_bounds = torch.tensor(bounds).to(self.device).to(self.dtype).T
79
+ self.min_deltas, self.max_deltas = map(torch.atleast_2d, torch.tensor(delta_bounds).to(self.min_bounds).T)
80
+ self._param_dim = param_dim
81
+ self._num_fixed = self.fixed_mask.sum()
82
+
83
+ self.fixed_params = self.min_bounds[self.fixed_mask]
84
+
85
+ @property
86
+ def param_dim(self) -> int:
87
+ return self._param_dim
88
+
89
+ def sample(self, batch_size: int) -> UniformSubPriorParams:
90
+ min_bounds, max_bounds = self.sample_bounds(batch_size)
91
+
92
+ params = torch.rand(
93
+ *min_bounds.shape,
94
+ device=self.device,
95
+ dtype=self.dtype
96
+ ) * (max_bounds - min_bounds) + min_bounds
97
+
98
+ thicknesses, roughnesses, slds = torch.split(
99
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
100
+ )
101
+
102
+ if self.smaller_roughnesses:
103
+ fitted_r_mask = self.fitted_mask.clone()
104
+ fitted_r_mask[~self.roughnesses_mask] = False
105
+
106
+ min_roughness = self.min_bounds[fitted_r_mask]
107
+ max_roughness = torch.clamp(
108
+ get_max_allowed_roughness(thicknesses)[..., self.fitted_mask[self.roughnesses_mask]],
109
+ min_roughness,
110
+ self.max_bounds[fitted_r_mask]
111
+ )
112
+
113
+ min_vector = self.min_bounds.clone()[None].repeat(batch_size, 1)
114
+ max_vector = self.max_bounds.clone()[None].repeat(batch_size, 1)
115
+
116
+ max_vector[..., fitted_r_mask] = max_roughness
117
+
118
+ assert torch.all(max_vector[..., fitted_r_mask] == max_roughness)
119
+
120
+ min_deltas = self.min_deltas.clone().repeat(batch_size, 1)
121
+ max_deltas = self.max_deltas.clone().repeat(batch_size, 1)
122
+
123
+ max_deltas[..., fitted_r_mask] = torch.clamp_max(
124
+ max_deltas[..., fitted_r_mask],
125
+ max_roughness - min_roughness,
126
+ )
127
+
128
+ fitted_mask = torch.zeros_like(self.fitted_mask)
129
+ fitted_mask[fitted_r_mask] = True
130
+
131
+ updated_min_bounds, updated_max_bounds = self._sample_bounds(
132
+ batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask
133
+ )
134
+
135
+ min_bounds[..., fitted_mask], max_bounds[..., fitted_mask] = (
136
+ updated_min_bounds[..., fitted_mask], updated_max_bounds[..., fitted_mask]
137
+ )
138
+
139
+ params[..., fitted_mask] = torch.rand(
140
+ batch_size, fitted_mask.sum().item(),
141
+ device=self.device,
142
+ dtype=self.dtype
143
+ ) * (max_bounds[..., fitted_mask] - min_bounds[..., fitted_mask]) + min_bounds[..., fitted_mask]
144
+
145
+ thicknesses, roughnesses, slds = torch.split(
146
+ params, [self.max_num_layers, self.max_num_layers + 1, self.max_num_layers + 1], dim=-1
147
+ )
148
+
149
+ params = UniformSubPriorParams(thicknesses, roughnesses, slds, min_bounds, max_bounds)
150
+
151
+ return params
152
+
153
+ def scale_params(self, params: UniformSubPriorParams) -> Tensor:
154
+ params_t = params.as_tensor(add_bounds=False)
155
+
156
+ scaled_params = self._scale(params_t, params.min_bounds, params.max_bounds)[..., self.fitted_mask]
157
+
158
+ scaled_min_bounds = self._scale(params.min_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
159
+
160
+ scaled_max_bounds = self._scale(params.max_bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
161
+
162
+ scaled_params = torch.cat([scaled_params, scaled_min_bounds, scaled_max_bounds], -1)
163
+
164
+ return scaled_params
165
+
166
+ def scale_bounds(self, bounds: Tensor) -> Tensor:
167
+ return self._scale(bounds, self.min_bounds, self.max_bounds)[..., self.fitted_mask]
168
+
169
+ def restore_params(self, scaled_params: Tensor) -> UniformSubPriorParams:
170
+ scaled_params, scaled_min_bounds, scaled_max_bounds = torch.split(
171
+ scaled_params, [self.param_dim, self.param_dim, self.param_dim], dim=1
172
+ )
173
+
174
+ min_bounds = self._restore(
175
+ scaled_min_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
176
+ )
177
+ max_bounds = self._restore(
178
+ scaled_max_bounds, self.min_bounds[self.fitted_mask], self.max_bounds[self.fitted_mask]
179
+ )
180
+
181
+ restored_params = self._restore(scaled_params, min_bounds, max_bounds)
182
+
183
+ params_t = torch.cat(
184
+ [
185
+ self._cat_restored_with_fixed_vector(restored_params),
186
+ self._cat_restored_with_fixed_vector(min_bounds),
187
+ self._cat_restored_with_fixed_vector(max_bounds),
188
+ ], -1
189
+ )
190
+
191
+ params = UniformSubPriorParams.from_tensor(params_t)
192
+
193
+ return params
194
+
195
+ def _cat_restored_with_fixed_vector(self, restored_t: Tensor) -> Tensor:
196
+ return self._cat_fitted_fixed_t(restored_t, self.fixed_params)
197
+
198
+ def _cat_fitted_fixed_t(self, fitted_t: Tensor, fixed_t: Tensor, fitted_mask: Tensor = None) -> Tensor:
199
+ if fitted_mask is None:
200
+ fitted_mask = self.fitted_mask
201
+ fixed_mask = self.fixed_mask
202
+ else:
203
+ fixed_mask = ~fitted_mask
204
+
205
+ total_num_params = self.fitted_mask.sum().item() + self.fixed_mask.sum().item()
206
+
207
+ batch_size = fitted_t.shape[0]
208
+
209
+ concat_t = torch.empty(
210
+ batch_size, total_num_params, device=fitted_t.device, dtype=fitted_t.dtype
211
+ )
212
+ concat_t[:, fitted_mask] = fitted_t
213
+ concat_t[:, fixed_mask] = fixed_t[None].expand(batch_size, -1)
214
+
215
+ return concat_t
216
+
217
+ def log_prob(self, params: UniformSubPriorParams) -> Tensor:
218
+ log_prob = torch.zeros(params.batch_size, device=params.device, dtype=params.dtype)
219
+ indices = self.get_indices_within_bounds(params)
220
+ log_prob[~indices] = float('-inf')
221
+ return log_prob
222
+
223
+ def get_indices_within_bounds(self, params: UniformSubPriorParams) -> Tensor:
224
+ t_params = torch.cat([
225
+ params.thicknesses,
226
+ params.roughnesses,
227
+ params.slds
228
+ ], dim=-1)
229
+
230
+ indices = (
231
+ torch.all(t_params >= params.min_bounds, dim=-1) &
232
+ torch.all(t_params <= params.max_bounds, dim=-1)
233
+ )
234
+
235
+ return indices
236
+
237
+ def clamp_params(self, params: UniformSubPriorParams) -> UniformSubPriorParams:
238
+ params = UniformSubPriorParams.from_tensor(
239
+ torch.cat([
240
+ torch.clamp(
241
+ params.as_tensor(add_bounds=False),
242
+ params.min_bounds, params.max_bounds
243
+ ),
244
+ params.min_bounds, params.max_bounds
245
+ ], dim=1)
246
+ )
247
+ return params
248
+
249
+ def get_indices_within_domain(self, params: UniformSubPriorParams) -> Tensor:
250
+ return self.get_indices_within_bounds(params)
251
+
252
+ def sample_bounds(self, batch_size: int):
253
+ return self._sample_bounds(
254
+ batch_size,
255
+ self.min_bounds,
256
+ self.max_bounds,
257
+ self.min_deltas,
258
+ self.max_deltas,
259
+ self.fitted_mask,
260
+ )
261
+
262
+ def _sample_bounds(self, batch_size, min_vector, max_vector, min_deltas, max_deltas, fitted_mask):
263
+ if self.logdist:
264
+ widths_sampler_func = logdist_sampler
265
+ else:
266
+ widths_sampler_func = uniform_sampler
267
+
268
+ num_fitted = fitted_mask.sum().item()
269
+ num_fixed = fitted_mask.numel() - num_fitted
270
+
271
+ prior_widths = widths_sampler_func(
272
+ min_deltas[..., fitted_mask], max_deltas[..., fitted_mask],
273
+ batch_size, num_fitted,
274
+ device=self.device, dtype=self.dtype
275
+ )
276
+
277
+ prior_widths = self._cat_fitted_fixed_t(prior_widths, torch.zeros(num_fixed).to(prior_widths), fitted_mask)
278
+
279
+ prior_centers = uniform_sampler(
280
+ min_vector + prior_widths / 2, max_vector - prior_widths / 2,
281
+ *prior_widths.shape,
282
+ device=self.device, dtype=self.dtype
283
+ )
284
+
285
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
286
+
287
+ return min_bounds, max_bounds
288
+
289
+ @staticmethod
290
+ def scale_bounds_with_q(bounds: Tensor, q_ratio: float) -> Tensor:
291
+ params = Params.from_tensor(torch.atleast_2d(bounds).clone())
292
+ params.scale_with_q(q_ratio)
293
+ return params.as_tensor().squeeze()
294
+
295
+ def clamp_bounds(self, bounds: Tensor) -> Tensor:
296
+ return torch.clamp(
297
+ torch.atleast_2d(bounds), torch.atleast_2d(self.min_bounds), torch.atleast_2d(self.max_bounds)
298
+ )