reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,370 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.utils import (
5
+ uniform_sampler,
6
+ logdist_sampler,
7
+ )
8
+
9
+ from reflectorch.data_generation.priors.utils import get_max_allowed_roughness
10
+
11
+
12
+ class SamplerStrategy(object):
13
+ """Base class for sampler strategies"""
14
+ def sample(self, batch_size: int,
15
+ total_min_bounds: Tensor,
16
+ total_max_bounds: Tensor,
17
+ total_min_delta: Tensor,
18
+ total_max_delta: Tensor,
19
+ ):
20
+ raise NotImplementedError
21
+
22
+
23
+ class BasicSamplerStrategy(SamplerStrategy):
24
+ """Sampler strategy with no constraints on the values of the parameters
25
+
26
+ Args:
27
+ logdist (bool, optional): if True the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
28
+ """
29
+ def __init__(self, logdist: bool = False):
30
+ if logdist:
31
+ self.widths_sampler_func = logdist_sampler
32
+ else:
33
+ self.widths_sampler_func = uniform_sampler
34
+
35
+ def sample(self, batch_size: int,
36
+ total_min_bounds: Tensor,
37
+ total_max_bounds: Tensor,
38
+ total_min_delta: Tensor,
39
+ total_max_delta: Tensor,
40
+ ):
41
+ """
42
+ Args:
43
+ batch_size (int): the batch size
44
+ total_min_bounds (Tensor): mimimum values of the parameters
45
+ total_max_bounds (Tensor): maximum values of the parameters
46
+ total_min_delta (Tensor): minimum widths of the subprior intervals
47
+ total_max_delta (Tensor): maximum widths of the subprior intervals
48
+
49
+ Returns:
50
+ tuple(Tensor): samples the values of the parameters and their prior bounds (params, min_bounds, max_bounds). The widths W of the subprior interval are sampled first, then the centers C of the subprior interval, such that the prior bounds are C-W/2 and C+W/2, then the parameters are sampled from [C-W/2, C+W/2] )
51
+ """
52
+ return basic_sampler(
53
+ batch_size,
54
+ total_min_bounds,
55
+ total_max_bounds,
56
+ total_min_delta,
57
+ total_max_delta,
58
+ self.widths_sampler_func,
59
+ )
60
+
61
+
62
+ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
63
+ """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses
64
+
65
+ Args:
66
+ thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
67
+ roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
68
+ logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
69
+ max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5.
70
+ """
71
+ def __init__(self,
72
+ thickness_mask: Tensor,
73
+ roughness_mask: Tensor,
74
+ logdist: bool = False,
75
+ max_thickness_share: float = 0.5,
76
+ max_total_thickness: float = None,
77
+ ):
78
+ super().__init__(logdist=logdist)
79
+ self.thickness_mask = thickness_mask
80
+ self.roughness_mask = roughness_mask
81
+ self.max_thickness_share = max_thickness_share
82
+ self.max_total_thickness = max_total_thickness
83
+
84
+ def sample(self, batch_size: int,
85
+ total_min_bounds: Tensor,
86
+ total_max_bounds: Tensor,
87
+ total_min_delta: Tensor,
88
+ total_max_delta: Tensor,
89
+ ):
90
+ """
91
+ Args:
92
+ batch_size (int): the batch size
93
+ total_min_bounds (Tensor): mimimum values of the parameters
94
+ total_max_bounds (Tensor): maximum values of the parameters
95
+ total_min_delta (Tensor): minimum widths of the subprior intervals
96
+ total_max_delta (Tensor): maximum widths of the subprior intervals
97
+
98
+ Returns:
99
+ tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** / 2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
100
+ """
101
+ device = total_min_bounds.device
102
+ return constrained_roughness_sampler(
103
+ batch_size,
104
+ total_min_bounds,
105
+ total_max_bounds,
106
+ total_min_delta,
107
+ total_max_delta,
108
+ thickness_mask=self.thickness_mask.to(device),
109
+ roughness_mask=self.roughness_mask.to(device),
110
+ widths_sampler_func=self.widths_sampler_func,
111
+ coef_roughness=self.max_thickness_share,
112
+ max_total_thickness=self.max_total_thickness,
113
+ )
114
+
115
+ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
116
+ """Sampler strategy where the roughnesses are constrained not to exceed a fraction of the two neighboring thicknesses, and the imaginary slds are constrained not to exceed a fraction of the real slds
117
+
118
+ Args:
119
+ thickness_mask (Tensor): indices in the tensors which correspond to thicknesses
120
+ roughness_mask (Tensor): indices in the tensors which correspond to roughnesses
121
+ sld_mask (Tensor): indices in the tensors which correspond to real slds
122
+ isld_mask (Tensor): indices in the tensors which correspond to imaginary slds
123
+ logdist (bool, optional): if ``True`` the relative widths of the subprior intervals are sampled uniformly on a logarithmic scale instead of uniformly. Defaults to False.
124
+ max_thickness_share (float, optional): fraction of the layer thickness that the roughness should not exceed. Defaults to 0.5
125
+ max_sld_share (float, optional): fraction of the real sld that the imaginary sld should not exceed. Defaults to 0.2.
126
+ """
127
+ def __init__(self,
128
+ thickness_mask: Tensor,
129
+ roughness_mask: Tensor,
130
+ sld_mask: Tensor,
131
+ isld_mask: Tensor,
132
+ logdist: bool = False,
133
+ max_thickness_share: float = 0.5,
134
+ max_sld_share: float = 0.2,
135
+ max_total_thickness: float = None,
136
+ ):
137
+ super().__init__(logdist=logdist)
138
+ self.thickness_mask = thickness_mask
139
+ self.roughness_mask = roughness_mask
140
+ self.sld_mask = sld_mask
141
+ self.isld_mask = isld_mask
142
+ self.max_thickness_share = max_thickness_share
143
+ self.max_sld_share = max_sld_share
144
+ self.max_total_thickness = max_total_thickness
145
+
146
+ def sample(self, batch_size: int,
147
+ total_min_bounds: Tensor,
148
+ total_max_bounds: Tensor,
149
+ total_min_delta: Tensor,
150
+ total_max_delta: Tensor,
151
+ ):
152
+ """
153
+ Args:
154
+ batch_size (int): the batch size
155
+ total_min_bounds (Tensor): mimimum values of the parameters
156
+ total_max_bounds (Tensor): maximum values of the parameters
157
+ total_min_delta (Tensor): minimum widths of the subprior intervals
158
+ total_max_delta (Tensor): maximum widths of the subprior intervals
159
+
160
+ Returns:
161
+ tuple(Tensor): samples the values of the parameters and their prior bounds *(params, min_bounds, max_bounds)*, the roughnesses and imaginary slds being constrained. The widths **W** of the subprior interval are sampled first, then the centers **C** of the subprior interval, such that the prior bounds are **C** - **W** /2 and **C** + **W** / 2, then the parameters are sampled from [**C** - **W** / 2, **C** + **W** / 2] )
162
+ """
163
+ device = total_min_bounds.device
164
+ return constrained_roughness_and_isld_sampler(
165
+ batch_size,
166
+ total_min_bounds,
167
+ total_max_bounds,
168
+ total_min_delta,
169
+ total_max_delta,
170
+ thickness_mask=self.thickness_mask.to(device),
171
+ roughness_mask=self.roughness_mask.to(device),
172
+ sld_mask=self.sld_mask.to(device),
173
+ isld_mask=self.isld_mask.to(device),
174
+ widths_sampler_func=self.widths_sampler_func,
175
+ coef_roughness=self.max_thickness_share,
176
+ coef_isld=self.max_sld_share,
177
+ max_total_thickness=self.max_total_thickness,
178
+ )
179
+
180
+ def basic_sampler(
181
+ batch_size: int,
182
+ total_min_bounds: Tensor,
183
+ total_max_bounds: Tensor,
184
+ total_min_delta: Tensor,
185
+ total_max_delta: Tensor,
186
+ widths_sampler_func,
187
+ ):
188
+
189
+ delta_vector = total_max_bounds - total_min_bounds
190
+
191
+ prior_widths = widths_sampler_func(
192
+ total_min_delta, total_max_delta,
193
+ batch_size, delta_vector.shape[1],
194
+ device=total_min_bounds.device, dtype=total_min_bounds.dtype
195
+ )
196
+
197
+ prior_centers = uniform_sampler(
198
+ total_min_bounds + prior_widths / 2, total_max_bounds - prior_widths / 2,
199
+ *prior_widths.shape,
200
+ device=total_min_bounds.device, dtype=total_min_bounds.dtype
201
+ )
202
+
203
+ min_bounds, max_bounds = prior_centers - prior_widths / 2, prior_centers + prior_widths / 2
204
+
205
+ params = torch.rand(
206
+ *min_bounds.shape,
207
+ device=min_bounds.device,
208
+ dtype=min_bounds.dtype
209
+ ) * (max_bounds - min_bounds) + min_bounds
210
+
211
+ return params, min_bounds, max_bounds
212
+
213
+
214
+ def constrained_roughness_sampler(
215
+ batch_size: int,
216
+ total_min_bounds: Tensor,
217
+ total_max_bounds: Tensor,
218
+ total_min_delta: Tensor,
219
+ total_max_delta: Tensor,
220
+ thickness_mask: Tensor,
221
+ roughness_mask: Tensor,
222
+ widths_sampler_func,
223
+ coef_roughness: float = 0.5,
224
+ max_total_thickness: float = None,
225
+ ):
226
+ params, min_bounds, max_bounds = basic_sampler(
227
+ batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
228
+ widths_sampler_func=widths_sampler_func,
229
+ )
230
+
231
+ if max_total_thickness is not None:
232
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
233
+ indices = total_thickness > max_total_thickness
234
+
235
+ if indices.any():
236
+ eps = 0.01
237
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
238
+ scale_coef = max_total_thickness / total_thickness * rand_scale
239
+ scale_coef[~indices] = 1.0
240
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
241
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
242
+ params[:, thickness_mask] *= scale_coef[:, None]
243
+
244
+ min_bounds[:, thickness_mask] = torch.clamp_min(
245
+ min_bounds[:, thickness_mask],
246
+ total_min_bounds[:, thickness_mask],
247
+ )
248
+
249
+ max_bounds[:, thickness_mask] = torch.clamp_min(
250
+ max_bounds[:, thickness_mask],
251
+ total_min_bounds[:, thickness_mask],
252
+ )
253
+
254
+ params[:, thickness_mask] = torch.clamp_min(
255
+ params[:, thickness_mask],
256
+ total_min_bounds[:, thickness_mask],
257
+ )
258
+
259
+ max_roughness = torch.minimum(
260
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
261
+ total_max_bounds[..., roughness_mask]
262
+ )
263
+ min_roughness = total_min_bounds[..., roughness_mask]
264
+
265
+ assert torch.all(min_roughness <= max_roughness)
266
+
267
+ min_roughness_delta = total_min_delta[..., roughness_mask]
268
+ max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
269
+
270
+ roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
271
+ batch_size, min_roughness, max_roughness,
272
+ min_roughness_delta, max_roughness_delta,
273
+ widths_sampler_func=widths_sampler_func
274
+ )
275
+
276
+ min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
277
+ params[..., roughness_mask] = roughnesses
278
+
279
+ return params, min_bounds, max_bounds
280
+
281
+ def constrained_roughness_and_isld_sampler(
282
+ batch_size: int,
283
+ total_min_bounds: Tensor,
284
+ total_max_bounds: Tensor,
285
+ total_min_delta: Tensor,
286
+ total_max_delta: Tensor,
287
+ thickness_mask: Tensor,
288
+ roughness_mask: Tensor,
289
+ sld_mask: Tensor,
290
+ isld_mask: Tensor,
291
+ widths_sampler_func,
292
+ coef_roughness: float = 0.5,
293
+ coef_isld: float = 0.2,
294
+ max_total_thickness: float = None,
295
+ ):
296
+ params, min_bounds, max_bounds = basic_sampler(
297
+ batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
298
+ widths_sampler_func=widths_sampler_func,
299
+ )
300
+
301
+ if max_total_thickness is not None:
302
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
303
+ indices = total_thickness > max_total_thickness
304
+
305
+ if indices.any():
306
+ eps = 0.01
307
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
308
+ scale_coef = max_total_thickness / total_thickness * rand_scale
309
+ scale_coef[~indices] = 1.0
310
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
311
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
312
+ params[:, thickness_mask] *= scale_coef[:, None]
313
+
314
+ min_bounds[:, thickness_mask] = torch.clamp_min(
315
+ min_bounds[:, thickness_mask],
316
+ total_min_bounds[:, thickness_mask],
317
+ )
318
+
319
+ max_bounds[:, thickness_mask] = torch.clamp_min(
320
+ max_bounds[:, thickness_mask],
321
+ total_min_bounds[:, thickness_mask],
322
+ )
323
+
324
+ params[:, thickness_mask] = torch.clamp_min(
325
+ params[:, thickness_mask],
326
+ total_min_bounds[:, thickness_mask],
327
+ )
328
+
329
+ max_roughness = torch.minimum(
330
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
331
+ total_max_bounds[..., roughness_mask]
332
+ )
333
+ min_roughness = total_min_bounds[..., roughness_mask]
334
+
335
+ assert torch.all(min_roughness <= max_roughness)
336
+
337
+ min_roughness_delta = total_min_delta[..., roughness_mask]
338
+ max_roughness_delta = torch.minimum(total_max_delta[..., roughness_mask], max_roughness - min_roughness)
339
+
340
+ roughnesses, min_r_bounds, max_r_bounds = basic_sampler(
341
+ batch_size, min_roughness, max_roughness,
342
+ min_roughness_delta, max_roughness_delta,
343
+ widths_sampler_func=widths_sampler_func
344
+ )
345
+
346
+ min_bounds[..., roughness_mask], max_bounds[..., roughness_mask] = min_r_bounds, max_r_bounds
347
+ params[..., roughness_mask] = roughnesses
348
+
349
+ max_isld = torch.minimum(
350
+ torch.abs(params[..., sld_mask]) * coef_isld,
351
+ total_max_bounds[..., isld_mask]
352
+ )
353
+ min_isld = total_min_bounds[..., isld_mask]
354
+
355
+ assert torch.all(min_isld <= max_isld)
356
+
357
+ min_isld_delta = total_min_delta[..., isld_mask]
358
+ max_isld_delta = torch.minimum(total_max_delta[..., isld_mask], max_isld - min_isld)
359
+
360
+ islds, min_isld_bounds, max_isld_bounds = basic_sampler(
361
+ batch_size, min_isld, max_isld,
362
+ min_isld_delta, max_isld_delta,
363
+ widths_sampler_func=widths_sampler_func
364
+ )
365
+
366
+ min_bounds[..., isld_mask], max_bounds[..., isld_mask] = min_isld_bounds, max_isld_bounds
367
+ params[..., isld_mask] = islds
368
+
369
+
370
+ return params, min_bounds, max_bounds
@@ -0,0 +1,65 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class ScalerMixin:
8
+ """Provides functionality to multiple inheritance classes for scaling the parameters to a specified range and restoring them to the original range."""
9
+ @staticmethod
10
+ def _get_delta_vector(min_vector: Tensor, max_vector: Tensor):
11
+ delta_vector = max_vector - min_vector
12
+ delta_vector[delta_vector == 0.] = 1.
13
+ return delta_vector
14
+
15
+ def _scale(self, params_t: Tensor, min_vector: Tensor, max_vector: Tensor):
16
+ """scale the parameters to a specific range
17
+ Args:
18
+ params_t (Tensor): the values of the parameters
19
+ min_vector (Tensor): minimum possible values of each parameter
20
+ max_vector (Tensor): maximum possible values of each parameter
21
+
22
+ Returns:
23
+ Tensor: the scaled parameters
24
+ """
25
+ if params_t.dim() == 2:
26
+ min_vector = torch.atleast_2d(min_vector)
27
+ max_vector = torch.atleast_2d(max_vector)
28
+
29
+ delta_vector = max_vector - min_vector
30
+ delta_vector[delta_vector == 0.] = 1.
31
+ scaled_params = (
32
+ params_t - min_vector
33
+ ) / self._get_delta_vector(min_vector, max_vector) * self._length + self._bias
34
+ return scaled_params
35
+
36
+ def _restore(self, scaled_params: Tensor, min_vector: Tensor, max_vector: Tensor):
37
+ """restores the parameters to their original range
38
+ Args:
39
+ scaled_params: (Tensor): the scaled parameters
40
+ min_vector (Tensor): minimum possible values of each parameter
41
+ max_vector (Tensor): maximum possible values of each parameter
42
+
43
+ Returns:
44
+ Tensor: the restored parameters
45
+ """
46
+ if scaled_params.dim() == 2:
47
+ min_vector = torch.atleast_2d(min_vector)
48
+ max_vector = torch.atleast_2d(max_vector)
49
+
50
+ params_t = (
51
+ scaled_params - self._bias
52
+ ) / self._length * self._get_delta_vector(min_vector, max_vector) + min_vector
53
+ return params_t
54
+
55
+ @property
56
+ def scaled_range(self) -> Tuple[float, float]:
57
+ return self._scaled_range
58
+
59
+ @scaled_range.setter
60
+ def scaled_range(self, scaled_range: Tuple[float, float]):
61
+ """sets the range used for scaling the parameters"""
62
+ self._scaled_range = scaled_range
63
+ self._length = scaled_range[1] - scaled_range[0]
64
+ self._init_bias = (scaled_range[0] + scaled_range[1]) / 2
65
+ self._bias = (self._init_bias - 0.5 * self._length)