reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/priors/parametric_models.py +1 -1
- reflectorch/data_generation/q_generator.py +70 -36
- reflectorch/data_generation/utils.py +1 -0
- reflectorch/inference/inference_model.py +711 -188
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +505 -86
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +19 -5
- reflectorch/ml/trainers.py +9 -0
- reflectorch/models/__init__.py +1 -0
- reflectorch/models/encoders/__init__.py +2 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/mlp_networks.py +10 -4
- reflectorch/runs/utils.py +5 -2
- reflectorch/utils.py +30 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/METADATA +3 -2
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/RECORD +21 -19
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/licenses/LICENSE.txt +0 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -29,6 +29,7 @@ from reflectorch.data_generation.q_generator import (
|
|
|
29
29
|
ConstantQ,
|
|
30
30
|
VariableQ,
|
|
31
31
|
EquidistantQ,
|
|
32
|
+
MaskedVariableQ,
|
|
32
33
|
)
|
|
33
34
|
from reflectorch.data_generation.noise import (
|
|
34
35
|
QNoiseGenerator,
|
|
@@ -78,6 +79,7 @@ __all__ = [
|
|
|
78
79
|
"ConstantQ",
|
|
79
80
|
"VariableQ",
|
|
80
81
|
"EquidistantQ",
|
|
82
|
+
"MaskedVariableQ",
|
|
81
83
|
"QNoiseGenerator",
|
|
82
84
|
"IntensityNoiseGenerator",
|
|
83
85
|
"MultiplicativeLogNormalNoiseGenerator",
|
|
@@ -37,7 +37,7 @@ class ParametricModel(object):
|
|
|
37
37
|
self.max_num_layers = max_num_layers
|
|
38
38
|
self._sampler_strategy = self._init_sampler_strategy(**kwargs)
|
|
39
39
|
|
|
40
|
-
def _init_sampler_strategy(self, **kwargs):
|
|
40
|
+
def _init_sampler_strategy(self, nuisance_params_dim: int = 0, **kwargs):
|
|
41
41
|
return BasicSamplerStrategy(**kwargs)
|
|
42
42
|
|
|
43
43
|
@property
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"VariableQ",
|
|
17
17
|
"EquidistantQ",
|
|
18
18
|
"ConstantAngle",
|
|
19
|
+
"MaskedVariableQ",
|
|
19
20
|
]
|
|
20
21
|
|
|
21
22
|
|
|
@@ -124,6 +125,11 @@ class VariableQ(QGenerator):
|
|
|
124
125
|
q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
|
|
125
126
|
elif self.mode == 'random':
|
|
126
127
|
q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
|
|
128
|
+
elif self.mode == 'logspace':
|
|
129
|
+
q = torch.logspace(
|
|
130
|
+
start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
|
|
131
|
+
end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
|
|
132
|
+
steps=n_q, dtype=self.dtype, device=self.device)
|
|
127
133
|
|
|
128
134
|
q = q_min[:, None] + q * (q_max - q_min)[:, None]
|
|
129
135
|
|
|
@@ -198,49 +204,77 @@ class EquidistantQ(QGenerator):
|
|
|
198
204
|
return qs
|
|
199
205
|
|
|
200
206
|
|
|
201
|
-
class
|
|
207
|
+
class MaskedVariableQ:
|
|
202
208
|
def __init__(self,
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
self.
|
|
209
|
+
q_min_range=(0.01, 0.03),
|
|
210
|
+
q_max_range=(0.1, 0.5),
|
|
211
|
+
n_q_range=(64, 256),
|
|
212
|
+
mode='equidistant',
|
|
213
|
+
shuffle_mask=False,
|
|
214
|
+
total_thickness_constraint=True,
|
|
215
|
+
min_points_per_fringe=4,
|
|
216
|
+
device=DEFAULT_DEVICE,
|
|
217
|
+
dtype=DEFAULT_DTYPE):
|
|
218
|
+
self.q_min_range = q_min_range
|
|
219
|
+
self.q_max_range = q_max_range
|
|
220
|
+
self.n_q_range = n_q_range
|
|
213
221
|
self.device = device
|
|
214
222
|
self.dtype = dtype
|
|
215
|
-
|
|
216
|
-
|
|
223
|
+
self.mode = mode
|
|
224
|
+
self.shuffle_mask = shuffle_mask
|
|
225
|
+
self.total_thickness_constraint = total_thickness_constraint
|
|
226
|
+
self.min_points_per_fringe = min_points_per_fringe
|
|
227
|
+
|
|
228
|
+
def get_batch(self, batch_size, context):
|
|
217
229
|
assert context is not None
|
|
218
230
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
assert total_thickness.shape[0] == batch_size
|
|
223
|
-
|
|
224
|
-
min_dqs = torch.clamp(
|
|
225
|
-
2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs
|
|
229
|
-
|
|
230
|
-
num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int)
|
|
231
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
232
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
231
233
|
|
|
232
|
-
|
|
234
|
+
max_n_q = self.n_q_range[1]
|
|
233
235
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
+
if self.mode == 'equidistant':
|
|
237
|
+
positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
|
|
238
|
+
elif self.mode == 'random':
|
|
239
|
+
positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
240
|
+
positions, _ = positions.sort(dim=-1)
|
|
241
|
+
elif self.mode == 'mixed':
|
|
242
|
+
positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
243
|
+
|
|
244
|
+
half = batch_size // 2 # half batch gets equidistant
|
|
245
|
+
eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
|
|
246
|
+
positions[:half] = eq_pos
|
|
247
|
+
|
|
248
|
+
rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
|
|
249
|
+
rand_pos, _ = rand_pos.sort(dim=-1)
|
|
250
|
+
positions[half:] = rand_pos
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"Unknown spacing mode: {self.mode}")
|
|
253
|
+
|
|
254
|
+
q = q_min[:, None] + positions * (q_max - q_min)[:, None]
|
|
255
|
+
|
|
256
|
+
n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
|
|
257
|
+
|
|
258
|
+
if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
|
|
259
|
+
d_total = context['params'].thicknesses.sum(-1)
|
|
260
|
+
limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
|
|
261
|
+
limit = limit.ceil().int()
|
|
262
|
+
n_qs = torch.maximum(n_qs, limit)
|
|
263
|
+
n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
|
|
264
|
+
|
|
265
|
+
indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
|
|
266
|
+
valid_mask = indices < n_qs[:, None] # right side padding
|
|
267
|
+
|
|
268
|
+
if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
|
|
269
|
+
perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
|
|
270
|
+
valid_mask = torch.gather(valid_mask, dim=1, index=perm)
|
|
236
271
|
|
|
237
|
-
|
|
272
|
+
context['key_padding_mask'] = valid_mask
|
|
273
|
+
context['n_points'] = valid_mask.sum(dim=-1)
|
|
238
274
|
|
|
275
|
+
return q
|
|
276
|
+
|
|
277
|
+
def scale_q(self, q):
|
|
278
|
+
scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
|
|
239
279
|
|
|
240
|
-
|
|
241
|
-
batch_size = num_q_values.shape[0]
|
|
242
|
-
dqs = (q_max / num_q_values)[:, None]
|
|
243
|
-
q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
|
|
244
|
-
mask = (q_values > q_max + dqs / 2)
|
|
245
|
-
q_values[mask] = 0.
|
|
246
|
-
return q_values, mask
|
|
280
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
@@ -163,6 +163,7 @@ def get_density_profiles(
|
|
|
163
163
|
else:
|
|
164
164
|
if ambient_sld.ndim == 1:
|
|
165
165
|
ambient_sld = ambient_sld.unsqueeze(-1)
|
|
166
|
+
ambient_sld = ambient_sld.expand(bs, 1)
|
|
166
167
|
|
|
167
168
|
slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
|
|
168
169
|
d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
|