reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,246 +1,280 @@
1
- from typing import Tuple, Union
2
-
3
- import numpy as np
4
-
5
- import torch
6
- from torch import Tensor
7
-
8
- from reflectorch.data_generation.utils import uniform_sampler
9
- from reflectorch.data_generation.priors import BasicParams
10
- from reflectorch.utils import angle_to_q
11
- from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
12
-
13
- __all__ = [
14
- "QGenerator",
15
- "ConstantQ",
16
- "VariableQ",
17
- "EquidistantQ",
18
- "ConstantAngle",
19
- ]
20
-
21
-
22
- class QGenerator(object):
23
- """Base class for momentum transfer (q) generators"""
24
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
25
- pass
26
-
27
-
28
- class ConstantQ(QGenerator):
29
- """Q generator for reflectivity curves with fixed discretization
30
-
31
- Args:
32
- q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128).
33
- device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
34
- dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
35
- remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False.
36
- fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False.
37
- """
38
-
39
- def __init__(self,
40
- q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128),
41
- device=DEFAULT_DEVICE,
42
- dtype=DEFAULT_DTYPE,
43
- remove_zero: bool = False,
44
- fixed_zero: bool = False,
45
- ):
46
- if isinstance(q, (tuple, list)):
47
- q = torch.linspace(*q, device=device, dtype=dtype)
48
- if remove_zero:
49
- if fixed_zero:
50
- q = q[1:]
51
- else:
52
- q = q[:-1]
53
- self.q_min = q.min().item()
54
- self.q_max = q.max().item()
55
- self.q = q
56
-
57
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
58
- """generate a batch of q values
59
-
60
- Args:
61
- batch_size (int): the batch size
62
-
63
- Returns:
64
- Tensor: generated batch of q values
65
- """
66
- return self.q.clone()[None].expand(batch_size, self.q.shape[0])
67
-
68
- def scale_q(self, q):
69
- """Scales the q values to the range [-1, 1].
70
-
71
- Args:
72
- q (Tensor): unscaled q values
73
-
74
- Returns:
75
- Tensor: scaled q values
76
- """
77
- scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
78
- return 2.0 * (scaled_q_01 - 0.5)
79
-
80
-
81
- class VariableQ(QGenerator):
82
- """Q generator for reflectivity curves with variable discretization
83
-
84
- Args:
85
- q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
86
- q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
87
- n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
88
- the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
89
- device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
90
- dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
91
- """
92
-
93
- def __init__(self,
94
- q_min_range: Tuple[float, float] = (0.01, 0.03),
95
- q_max_range: Tuple[float, float] = (0.1, 0.5),
96
- n_q_range: Tuple[int, int] = (64, 256),
97
- mode: str = 'equidistant',
98
- device=DEFAULT_DEVICE,
99
- dtype=DEFAULT_DTYPE,
100
- ):
101
- self.q_min_range = q_min_range
102
- self.q_max_range = q_max_range
103
- self.n_q_range = n_q_range
104
- self.mode = mode
105
- self.device = device
106
- self.dtype = dtype
107
-
108
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
109
- """generate a batch of q values (the number of points varies between batches but is constant within a batch)
110
-
111
- Args:
112
- batch_size (int): the batch size
113
-
114
- Returns:
115
- Tensor: generated batch of q values
116
- """
117
-
118
- q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
119
- q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
120
-
121
- n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
122
-
123
- if self.mode == 'equidistant':
124
- q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
125
- elif self.mode == 'random':
126
- q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
127
-
128
- q = q_min[:, None] + q * (q_max - q_min)[:, None]
129
-
130
- return q
131
-
132
- def scale_q(self, q):
133
- """scales the q values to the range [-1, 1]
134
-
135
- Args:
136
- q (Tensor): unscaled q values
137
-
138
- Returns:
139
- Tensor: scaled q values
140
- """
141
- scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
142
-
143
- return 2.0 * (scaled_q_01 - 0.5)
144
-
145
-
146
- class ConstantAngle(QGenerator):
147
- """Q generator for reflectivity curves measured at equidistant angles
148
-
149
- Args:
150
- angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257).
151
- wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1.
152
- device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
153
- dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
154
- """
155
- def __init__(self,
156
- angle_range: Tuple[float, float, int] = (0., 0.2, 257),
157
- wavelength: float = 1.,
158
- device=DEFAULT_DEVICE,
159
- dtype=DEFAULT_DTYPE,
160
- ):
161
- self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
162
-
163
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
164
- """generate a batch of q values
165
-
166
- Args:
167
- batch_size (int): the batch size
168
-
169
- Returns:
170
- Tensor: generated batch of q values
171
- """
172
- return self.q.clone()[None].expand(batch_size, self.q.shape[0])
173
-
174
-
175
- class EquidistantQ(QGenerator):
176
- def __init__(self,
177
- max_range: Tuple[float, float],
178
- num_values: Union[int, Tuple[int, int]],
179
- device=None,
180
- dtype=torch.float64
181
- ):
182
- self.max_range = max_range
183
- self._num_values = num_values
184
- self.device = device
185
- self.dtype = dtype
186
-
187
- @property
188
- def num_values(self) -> int:
189
- if isinstance(self._num_values, int):
190
- return self._num_values
191
- return np.random.randint(*self._num_values)
192
-
193
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
194
- num_values = self.num_values
195
- q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype)
196
- norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None]
197
- qs = norm_qs * q_max
198
- return qs
199
-
200
-
201
- class TransformerQ(QGenerator):
202
- def __init__(self,
203
- q_max: float = 0.2,
204
- num_values: Union[int, Tuple[int, int]] = (30, 512),
205
- min_dq_ratio: float = 5.,
206
- device=None,
207
- dtype=torch.float64,
208
- ):
209
- self.min_dq_ratio = min_dq_ratio
210
- self.q_max = q_max
211
- self._dq_range = q_max / num_values[1], q_max / num_values[0]
212
- self._num_values = num_values
213
- self.device = device
214
- self.dtype = dtype
215
-
216
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
217
- assert context is not None
218
-
219
- params: BasicParams = context['params']
220
- total_thickness = params.thicknesses.sum(-1)
221
-
222
- assert total_thickness.shape[0] == batch_size
223
-
224
- min_dqs = torch.clamp(
225
- 2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9
226
- )
227
-
228
- dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs
229
-
230
- num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int)
231
-
232
- q_values, mask = generate_q_padding_mask(num_q_values, self.q_max)
233
-
234
- context['tgt_key_padding_mask'] = mask
235
- context['num_q_values'] = num_q_values
236
-
237
- return q_values
238
-
239
-
240
- def generate_q_padding_mask(num_q_values: Tensor, q_max: float):
241
- batch_size = num_q_values.shape[0]
242
- dqs = (q_max / num_q_values)[:, None]
243
- q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
244
- mask = (q_values > q_max + dqs / 2)
245
- q_values[mask] = 0.
246
- return q_values, mask
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from reflectorch.data_generation.utils import uniform_sampler
9
+ from reflectorch.data_generation.priors import BasicParams
10
+ from reflectorch.utils import angle_to_q
11
+ from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
12
+
13
+ __all__ = [
14
+ "QGenerator",
15
+ "ConstantQ",
16
+ "VariableQ",
17
+ "EquidistantQ",
18
+ "ConstantAngle",
19
+ "MaskedVariableQ",
20
+ ]
21
+
22
+
23
+ class QGenerator(object):
24
+ """Base class for momentum transfer (q) generators"""
25
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
26
+ pass
27
+
28
+
29
+ class ConstantQ(QGenerator):
30
+ """Q generator for reflectivity curves with fixed discretization
31
+
32
+ Args:
33
+ q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128).
34
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
35
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
36
+ remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False.
37
+ fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False.
38
+ """
39
+
40
+ def __init__(self,
41
+ q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128),
42
+ device=DEFAULT_DEVICE,
43
+ dtype=DEFAULT_DTYPE,
44
+ remove_zero: bool = False,
45
+ fixed_zero: bool = False,
46
+ ):
47
+ if isinstance(q, (tuple, list)):
48
+ q = torch.linspace(*q, device=device, dtype=dtype)
49
+ if remove_zero:
50
+ if fixed_zero:
51
+ q = q[1:]
52
+ else:
53
+ q = q[:-1]
54
+ self.q_min = q.min().item()
55
+ self.q_max = q.max().item()
56
+ self.q = q
57
+
58
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
59
+ """generate a batch of q values
60
+
61
+ Args:
62
+ batch_size (int): the batch size
63
+
64
+ Returns:
65
+ Tensor: generated batch of q values
66
+ """
67
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
68
+
69
+ def scale_q(self, q):
70
+ """Scales the q values to the range [-1, 1].
71
+
72
+ Args:
73
+ q (Tensor): unscaled q values
74
+
75
+ Returns:
76
+ Tensor: scaled q values
77
+ """
78
+ scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
79
+ return 2.0 * (scaled_q_01 - 0.5)
80
+
81
+
82
+ class VariableQ(QGenerator):
83
+ """Q generator for reflectivity curves with variable discretization
84
+
85
+ Args:
86
+ q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
87
+ q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
88
+ n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
89
+ the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
90
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
91
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
92
+ """
93
+
94
+ def __init__(self,
95
+ q_min_range: Tuple[float, float] = (0.01, 0.03),
96
+ q_max_range: Tuple[float, float] = (0.1, 0.5),
97
+ n_q_range: Tuple[int, int] = (64, 256),
98
+ mode: str = 'equidistant',
99
+ device=DEFAULT_DEVICE,
100
+ dtype=DEFAULT_DTYPE,
101
+ ):
102
+ self.q_min_range = q_min_range
103
+ self.q_max_range = q_max_range
104
+ self.n_q_range = n_q_range
105
+ self.mode = mode
106
+ self.device = device
107
+ self.dtype = dtype
108
+
109
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
110
+ """generate a batch of q values (the number of points varies between batches but is constant within a batch)
111
+
112
+ Args:
113
+ batch_size (int): the batch size
114
+
115
+ Returns:
116
+ Tensor: generated batch of q values
117
+ """
118
+
119
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
120
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
121
+
122
+ n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
123
+
124
+ if self.mode == 'equidistant':
125
+ q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
126
+ elif self.mode == 'random':
127
+ q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
128
+ elif self.mode == 'logspace':
129
+ q = torch.logspace(
130
+ start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
131
+ end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
132
+ steps=n_q, dtype=self.dtype, device=self.device)
133
+
134
+ q = q_min[:, None] + q * (q_max - q_min)[:, None]
135
+
136
+ return q
137
+
138
+ def scale_q(self, q):
139
+ """scales the q values to the range [-1, 1]
140
+
141
+ Args:
142
+ q (Tensor): unscaled q values
143
+
144
+ Returns:
145
+ Tensor: scaled q values
146
+ """
147
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
148
+
149
+ return 2.0 * (scaled_q_01 - 0.5)
150
+
151
+
152
+ class ConstantAngle(QGenerator):
153
+ """Q generator for reflectivity curves measured at equidistant angles
154
+
155
+ Args:
156
+ angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257).
157
+ wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1.
158
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
159
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
160
+ """
161
+ def __init__(self,
162
+ angle_range: Tuple[float, float, int] = (0., 0.2, 257),
163
+ wavelength: float = 1.,
164
+ device=DEFAULT_DEVICE,
165
+ dtype=DEFAULT_DTYPE,
166
+ ):
167
+ self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
168
+
169
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
170
+ """generate a batch of q values
171
+
172
+ Args:
173
+ batch_size (int): the batch size
174
+
175
+ Returns:
176
+ Tensor: generated batch of q values
177
+ """
178
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
179
+
180
+
181
+ class EquidistantQ(QGenerator):
182
+ def __init__(self,
183
+ max_range: Tuple[float, float],
184
+ num_values: Union[int, Tuple[int, int]],
185
+ device=None,
186
+ dtype=torch.float64
187
+ ):
188
+ self.max_range = max_range
189
+ self._num_values = num_values
190
+ self.device = device
191
+ self.dtype = dtype
192
+
193
+ @property
194
+ def num_values(self) -> int:
195
+ if isinstance(self._num_values, int):
196
+ return self._num_values
197
+ return np.random.randint(*self._num_values)
198
+
199
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
200
+ num_values = self.num_values
201
+ q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype)
202
+ norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None]
203
+ qs = norm_qs * q_max
204
+ return qs
205
+
206
+
207
+ class MaskedVariableQ:
208
+ def __init__(self,
209
+ q_min_range=(0.01, 0.03),
210
+ q_max_range=(0.1, 0.5),
211
+ n_q_range=(64, 256),
212
+ mode='equidistant',
213
+ shuffle_mask=False,
214
+ total_thickness_constraint=True,
215
+ min_points_per_fringe=4,
216
+ device=DEFAULT_DEVICE,
217
+ dtype=DEFAULT_DTYPE):
218
+ self.q_min_range = q_min_range
219
+ self.q_max_range = q_max_range
220
+ self.n_q_range = n_q_range
221
+ self.device = device
222
+ self.dtype = dtype
223
+ self.mode = mode
224
+ self.shuffle_mask = shuffle_mask
225
+ self.total_thickness_constraint = total_thickness_constraint
226
+ self.min_points_per_fringe = min_points_per_fringe
227
+
228
+ def get_batch(self, batch_size, context):
229
+ assert context is not None
230
+
231
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
232
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
233
+
234
+ max_n_q = self.n_q_range[1]
235
+
236
+ if self.mode == 'equidistant':
237
+ positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
238
+ elif self.mode == 'random':
239
+ positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
240
+ positions, _ = positions.sort(dim=-1)
241
+ elif self.mode == 'mixed':
242
+ positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
243
+
244
+ half = batch_size // 2 # half batch gets equidistant
245
+ eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
246
+ positions[:half] = eq_pos
247
+
248
+ rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
249
+ rand_pos, _ = rand_pos.sort(dim=-1)
250
+ positions[half:] = rand_pos
251
+ else:
252
+ raise ValueError(f"Unknown spacing mode: {self.mode}")
253
+
254
+ q = q_min[:, None] + positions * (q_max - q_min)[:, None]
255
+
256
+ n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
257
+
258
+ if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
259
+ d_total = context['params'].thicknesses.sum(-1)
260
+ limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
261
+ limit = limit.ceil().int()
262
+ n_qs = torch.maximum(n_qs, limit)
263
+ n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
264
+
265
+ indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
266
+ valid_mask = indices < n_qs[:, None] # right side padding
267
+
268
+ if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
269
+ perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
270
+ valid_mask = torch.gather(valid_mask, dim=1, index=perm)
271
+
272
+ context['key_padding_mask'] = valid_mask
273
+ context['n_points'] = valid_mask.sum(dim=-1)
274
+
275
+ return q
276
+
277
+ def scale_q(self, q):
278
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
279
+
280
+ return 2.0 * (scaled_q_01 - 0.5)