reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,246 +1,280 @@
|
|
|
1
|
-
from typing import Tuple, Union
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from torch import Tensor
|
|
7
|
-
|
|
8
|
-
from reflectorch.data_generation.utils import uniform_sampler
|
|
9
|
-
from reflectorch.data_generation.priors import BasicParams
|
|
10
|
-
from reflectorch.utils import angle_to_q
|
|
11
|
-
from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
|
|
12
|
-
|
|
13
|
-
__all__ = [
|
|
14
|
-
"QGenerator",
|
|
15
|
-
"ConstantQ",
|
|
16
|
-
"VariableQ",
|
|
17
|
-
"EquidistantQ",
|
|
18
|
-
"ConstantAngle",
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
self.
|
|
55
|
-
self.
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
self.
|
|
103
|
-
self.
|
|
104
|
-
self.
|
|
105
|
-
self.
|
|
106
|
-
self.
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
1
|
+
from typing import Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from reflectorch.data_generation.utils import uniform_sampler
|
|
9
|
+
from reflectorch.data_generation.priors import BasicParams
|
|
10
|
+
from reflectorch.utils import angle_to_q
|
|
11
|
+
from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"QGenerator",
|
|
15
|
+
"ConstantQ",
|
|
16
|
+
"VariableQ",
|
|
17
|
+
"EquidistantQ",
|
|
18
|
+
"ConstantAngle",
|
|
19
|
+
"MaskedVariableQ",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class QGenerator(object):
|
|
24
|
+
"""Base class for momentum transfer (q) generators"""
|
|
25
|
+
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ConstantQ(QGenerator):
|
|
30
|
+
"""Q generator for reflectivity curves with fixed discretization
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128).
|
|
34
|
+
device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
35
|
+
dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
36
|
+
remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False.
|
|
37
|
+
fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self,
|
|
41
|
+
q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128),
|
|
42
|
+
device=DEFAULT_DEVICE,
|
|
43
|
+
dtype=DEFAULT_DTYPE,
|
|
44
|
+
remove_zero: bool = False,
|
|
45
|
+
fixed_zero: bool = False,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(q, (tuple, list)):
|
|
48
|
+
q = torch.linspace(*q, device=device, dtype=dtype)
|
|
49
|
+
if remove_zero:
|
|
50
|
+
if fixed_zero:
|
|
51
|
+
q = q[1:]
|
|
52
|
+
else:
|
|
53
|
+
q = q[:-1]
|
|
54
|
+
self.q_min = q.min().item()
|
|
55
|
+
self.q_max = q.max().item()
|
|
56
|
+
self.q = q
|
|
57
|
+
|
|
58
|
+
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
59
|
+
"""generate a batch of q values
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
batch_size (int): the batch size
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Tensor: generated batch of q values
|
|
66
|
+
"""
|
|
67
|
+
return self.q.clone()[None].expand(batch_size, self.q.shape[0])
|
|
68
|
+
|
|
69
|
+
def scale_q(self, q):
|
|
70
|
+
"""Scales the q values to the range [-1, 1].
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
q (Tensor): unscaled q values
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tensor: scaled q values
|
|
77
|
+
"""
|
|
78
|
+
scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
|
|
79
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class VariableQ(QGenerator):
|
|
83
|
+
"""Q generator for reflectivity curves with variable discretization
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
|
|
87
|
+
q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
|
|
88
|
+
n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
|
|
89
|
+
the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
|
|
90
|
+
device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
91
|
+
dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self,
|
|
95
|
+
q_min_range: Tuple[float, float] = (0.01, 0.03),
|
|
96
|
+
q_max_range: Tuple[float, float] = (0.1, 0.5),
|
|
97
|
+
n_q_range: Tuple[int, int] = (64, 256),
|
|
98
|
+
mode: str = 'equidistant',
|
|
99
|
+
device=DEFAULT_DEVICE,
|
|
100
|
+
dtype=DEFAULT_DTYPE,
|
|
101
|
+
):
|
|
102
|
+
self.q_min_range = q_min_range
|
|
103
|
+
self.q_max_range = q_max_range
|
|
104
|
+
self.n_q_range = n_q_range
|
|
105
|
+
self.mode = mode
|
|
106
|
+
self.device = device
|
|
107
|
+
self.dtype = dtype
|
|
108
|
+
|
|
109
|
+
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
110
|
+
"""generate a batch of q values (the number of points varies between batches but is constant within a batch)
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
batch_size (int): the batch size
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tensor: generated batch of q values
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
120
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
121
|
+
|
|
122
|
+
n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
|
|
123
|
+
|
|
124
|
+
if self.mode == 'equidistant':
|
|
125
|
+
q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
|
|
126
|
+
elif self.mode == 'random':
|
|
127
|
+
q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
|
|
128
|
+
elif self.mode == 'logspace':
|
|
129
|
+
q = torch.logspace(
|
|
130
|
+
start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
|
|
131
|
+
end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
|
|
132
|
+
steps=n_q, dtype=self.dtype, device=self.device)
|
|
133
|
+
|
|
134
|
+
q = q_min[:, None] + q * (q_max - q_min)[:, None]
|
|
135
|
+
|
|
136
|
+
return q
|
|
137
|
+
|
|
138
|
+
def scale_q(self, q):
|
|
139
|
+
"""scales the q values to the range [-1, 1]
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
q (Tensor): unscaled q values
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Tensor: scaled q values
|
|
146
|
+
"""
|
|
147
|
+
scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
|
|
148
|
+
|
|
149
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ConstantAngle(QGenerator):
|
|
153
|
+
"""Q generator for reflectivity curves measured at equidistant angles
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257).
|
|
157
|
+
wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1.
|
|
158
|
+
device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
159
|
+
dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
160
|
+
"""
|
|
161
|
+
def __init__(self,
|
|
162
|
+
angle_range: Tuple[float, float, int] = (0., 0.2, 257),
|
|
163
|
+
wavelength: float = 1.,
|
|
164
|
+
device=DEFAULT_DEVICE,
|
|
165
|
+
dtype=DEFAULT_DTYPE,
|
|
166
|
+
):
|
|
167
|
+
self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
|
|
168
|
+
|
|
169
|
+
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
170
|
+
"""generate a batch of q values
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
batch_size (int): the batch size
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Tensor: generated batch of q values
|
|
177
|
+
"""
|
|
178
|
+
return self.q.clone()[None].expand(batch_size, self.q.shape[0])
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class EquidistantQ(QGenerator):
|
|
182
|
+
def __init__(self,
|
|
183
|
+
max_range: Tuple[float, float],
|
|
184
|
+
num_values: Union[int, Tuple[int, int]],
|
|
185
|
+
device=None,
|
|
186
|
+
dtype=torch.float64
|
|
187
|
+
):
|
|
188
|
+
self.max_range = max_range
|
|
189
|
+
self._num_values = num_values
|
|
190
|
+
self.device = device
|
|
191
|
+
self.dtype = dtype
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def num_values(self) -> int:
|
|
195
|
+
if isinstance(self._num_values, int):
|
|
196
|
+
return self._num_values
|
|
197
|
+
return np.random.randint(*self._num_values)
|
|
198
|
+
|
|
199
|
+
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
200
|
+
num_values = self.num_values
|
|
201
|
+
q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype)
|
|
202
|
+
norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None]
|
|
203
|
+
qs = norm_qs * q_max
|
|
204
|
+
return qs
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class MaskedVariableQ:
|
|
208
|
+
def __init__(self,
|
|
209
|
+
q_min_range=(0.01, 0.03),
|
|
210
|
+
q_max_range=(0.1, 0.5),
|
|
211
|
+
n_q_range=(64, 256),
|
|
212
|
+
mode='equidistant',
|
|
213
|
+
shuffle_mask=False,
|
|
214
|
+
total_thickness_constraint=True,
|
|
215
|
+
min_points_per_fringe=4,
|
|
216
|
+
device=DEFAULT_DEVICE,
|
|
217
|
+
dtype=DEFAULT_DTYPE):
|
|
218
|
+
self.q_min_range = q_min_range
|
|
219
|
+
self.q_max_range = q_max_range
|
|
220
|
+
self.n_q_range = n_q_range
|
|
221
|
+
self.device = device
|
|
222
|
+
self.dtype = dtype
|
|
223
|
+
self.mode = mode
|
|
224
|
+
self.shuffle_mask = shuffle_mask
|
|
225
|
+
self.total_thickness_constraint = total_thickness_constraint
|
|
226
|
+
self.min_points_per_fringe = min_points_per_fringe
|
|
227
|
+
|
|
228
|
+
def get_batch(self, batch_size, context):
|
|
229
|
+
assert context is not None
|
|
230
|
+
|
|
231
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
232
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
233
|
+
|
|
234
|
+
max_n_q = self.n_q_range[1]
|
|
235
|
+
|
|
236
|
+
if self.mode == 'equidistant':
|
|
237
|
+
positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
|
|
238
|
+
elif self.mode == 'random':
|
|
239
|
+
positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
240
|
+
positions, _ = positions.sort(dim=-1)
|
|
241
|
+
elif self.mode == 'mixed':
|
|
242
|
+
positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
243
|
+
|
|
244
|
+
half = batch_size // 2 # half batch gets equidistant
|
|
245
|
+
eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
|
|
246
|
+
positions[:half] = eq_pos
|
|
247
|
+
|
|
248
|
+
rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
|
|
249
|
+
rand_pos, _ = rand_pos.sort(dim=-1)
|
|
250
|
+
positions[half:] = rand_pos
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"Unknown spacing mode: {self.mode}")
|
|
253
|
+
|
|
254
|
+
q = q_min[:, None] + positions * (q_max - q_min)[:, None]
|
|
255
|
+
|
|
256
|
+
n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
|
|
257
|
+
|
|
258
|
+
if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
|
|
259
|
+
d_total = context['params'].thicknesses.sum(-1)
|
|
260
|
+
limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
|
|
261
|
+
limit = limit.ceil().int()
|
|
262
|
+
n_qs = torch.maximum(n_qs, limit)
|
|
263
|
+
n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
|
|
264
|
+
|
|
265
|
+
indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
|
|
266
|
+
valid_mask = indices < n_qs[:, None] # right side padding
|
|
267
|
+
|
|
268
|
+
if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
|
|
269
|
+
perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
|
|
270
|
+
valid_mask = torch.gather(valid_mask, dim=1, index=perm)
|
|
271
|
+
|
|
272
|
+
context['key_padding_mask'] = valid_mask
|
|
273
|
+
context['n_points'] = valid_mask.sum(dim=-1)
|
|
274
|
+
|
|
275
|
+
return q
|
|
276
|
+
|
|
277
|
+
def scale_q(self, q):
|
|
278
|
+
scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
|
|
279
|
+
|
|
280
|
+
return 2.0 * (scaled_q_01 - 0.5)
|