reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/priors/parametric_models.py +1 -1
- reflectorch/data_generation/q_generator.py +70 -36
- reflectorch/data_generation/utils.py +1 -0
- reflectorch/inference/inference_model.py +711 -188
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +505 -86
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +19 -5
- reflectorch/ml/trainers.py +9 -0
- reflectorch/models/__init__.py +1 -0
- reflectorch/models/encoders/__init__.py +2 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/mlp_networks.py +10 -4
- reflectorch/runs/utils.py +5 -2
- reflectorch/utils.py +30 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/METADATA +3 -2
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/RECORD +21 -19
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/licenses/LICENSE.txt +0 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -63,6 +63,8 @@ def refl_fit(
|
|
|
63
63
|
bounds: np.ndarray = None,
|
|
64
64
|
error_bars: np.ndarray = None,
|
|
65
65
|
scale_curve_func=np.log10,
|
|
66
|
+
method: str = 'trf', #'lm', 'trf'
|
|
67
|
+
polishing_max_nfev: int = None,
|
|
66
68
|
reflectivity_kwargs: dict = None,
|
|
67
69
|
**kwargs
|
|
68
70
|
):
|
|
@@ -77,14 +79,23 @@ def refl_fit(
|
|
|
77
79
|
adjusted_bounds[1, i] += epsilon
|
|
78
80
|
|
|
79
81
|
init_params = np.clip(init_params, *adjusted_bounds)
|
|
80
|
-
|
|
82
|
+
if method != 'lm':
|
|
83
|
+
kwargs['bounds'] = adjusted_bounds
|
|
81
84
|
|
|
82
85
|
reflectivity_kwargs = reflectivity_kwargs or {}
|
|
83
86
|
for key, value in reflectivity_kwargs.items():
|
|
84
87
|
if isinstance(value, float):
|
|
85
88
|
reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
|
|
86
89
|
elif isinstance(value, np.ndarray):
|
|
87
|
-
reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
|
|
90
|
+
reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
|
|
91
|
+
|
|
92
|
+
curve = np.clip(curve, a_min=1e-12, a_max=None)
|
|
93
|
+
|
|
94
|
+
if error_bars is not None and scale_curve_func == np.log10:
|
|
95
|
+
error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
|
|
96
|
+
scaled_error_bars = error_bars / (curve * np.log(10))
|
|
97
|
+
else:
|
|
98
|
+
scaled_error_bars = None
|
|
88
99
|
|
|
89
100
|
res = curve_fit(
|
|
90
101
|
f=get_scaled_curve_func(
|
|
@@ -93,10 +104,12 @@ def refl_fit(
|
|
|
93
104
|
reflectivity_kwargs=reflectivity_kwargs,
|
|
94
105
|
),
|
|
95
106
|
xdata=q,
|
|
96
|
-
ydata=scale_curve_func(curve),
|
|
107
|
+
ydata=scale_curve_func(curve).reshape(-1),
|
|
97
108
|
p0=init_params,
|
|
98
|
-
sigma=
|
|
109
|
+
sigma=scaled_error_bars,
|
|
99
110
|
absolute_sigma=True,
|
|
111
|
+
method=method,
|
|
112
|
+
max_nfev=polishing_max_nfev,
|
|
100
113
|
**kwargs
|
|
101
114
|
)
|
|
102
115
|
|
|
@@ -184,7 +197,8 @@ def get_scaled_curve_func(
|
|
|
184
197
|
fitted_curve = fitted_curve_tensor.squeeze().numpy()
|
|
185
198
|
|
|
186
199
|
scaled_curve = scale_curve_func(fitted_curve)
|
|
187
|
-
|
|
200
|
+
|
|
201
|
+
return scaled_curve.reshape(-1)
|
|
188
202
|
|
|
189
203
|
return scaled_curve_func
|
|
190
204
|
|
reflectorch/ml/trainers.py
CHANGED
|
@@ -24,7 +24,9 @@ class BasicBatchData:
|
|
|
24
24
|
scaled_sigmas: Optional[torch.Tensor] = None
|
|
25
25
|
scaled_q_values: Optional[torch.Tensor] = None
|
|
26
26
|
scaled_denoised_curves: Optional[torch.Tensor] = None
|
|
27
|
+
key_padding_mask: Optional[torch.Tensor] = None
|
|
27
28
|
scaled_conditioning_params: Optional[torch.Tensor] = None
|
|
29
|
+
unscaled_q_values: Optional[torch.Tensor] = None
|
|
28
30
|
|
|
29
31
|
class RealTimeSimTrainer(Trainer):
|
|
30
32
|
"""Trainer with functionality to customize the sampled batch of data"""
|
|
@@ -74,6 +76,7 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
|
74
76
|
scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
|
|
75
77
|
scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
|
|
76
78
|
scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
|
|
79
|
+
key_padding_mask = batch_data.get('key_padding_mask', None)
|
|
77
80
|
|
|
78
81
|
scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
|
|
79
82
|
conditioning_params = []
|
|
@@ -92,6 +95,8 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
|
92
95
|
scaled_q_values=scaled_q_values,
|
|
93
96
|
scaled_denoised_curves=scaled_denoised_curves,
|
|
94
97
|
scaled_conditioning_params=scaled_conditioning_params,
|
|
98
|
+
unscaled_q_values=batch_data['q_values'],
|
|
99
|
+
key_padding_mask=key_padding_mask,
|
|
95
100
|
)
|
|
96
101
|
|
|
97
102
|
def get_loss_dict(self, batch_data: BasicBatchData):
|
|
@@ -100,13 +105,17 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
|
|
|
100
105
|
scaled_curves=batch_data.scaled_curves
|
|
101
106
|
scaled_bounds=batch_data.scaled_bounds
|
|
102
107
|
scaled_q_values=batch_data.scaled_q_values
|
|
108
|
+
key_padding_mask=batch_data.key_padding_mask
|
|
103
109
|
scaled_conditioning_params=batch_data.scaled_conditioning_params
|
|
110
|
+
unscaled_q_values=batch_data.unscaled_q_values
|
|
104
111
|
|
|
105
112
|
predicted_params = self.model(
|
|
106
113
|
curves = scaled_curves,
|
|
107
114
|
bounds = scaled_bounds,
|
|
108
115
|
q_values = scaled_q_values,
|
|
109
116
|
conditioning_params = scaled_conditioning_params,
|
|
117
|
+
key_padding_mask = key_padding_mask,
|
|
118
|
+
unscaled_q_values = unscaled_q_values,
|
|
110
119
|
)
|
|
111
120
|
|
|
112
121
|
if not self.rescale_loss_interval_width:
|
reflectorch/models/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from reflectorch.models.encoders.conv_encoder import (
|
|
|
4
4
|
ConvAutoencoder,
|
|
5
5
|
)
|
|
6
6
|
from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
|
|
7
|
+
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
7
8
|
from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
|
|
8
9
|
|
|
9
10
|
|
|
@@ -14,4 +15,5 @@ __all__ = [
|
|
|
14
15
|
"ConvResidualNet1D",
|
|
15
16
|
"FnoEncoder",
|
|
16
17
|
"SpectralConv1d",
|
|
18
|
+
"IntegralConvEmbedding",
|
|
17
19
|
]
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, Tensor, stack, cat
|
|
6
|
+
from reflectorch.models.activations import activation_by_name
|
|
7
|
+
import reflectorch
|
|
8
|
+
|
|
9
|
+
###embedding network adapted from the PANPE repository
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"IntegralConvEmbedding",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
class IntegralConvEmbedding(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
z_num: Union[int, tuple[int, ...]],
|
|
19
|
+
z_range: tuple[float, float] = None,
|
|
20
|
+
in_dim: int = 2,
|
|
21
|
+
kernel_coef: int = 16,
|
|
22
|
+
dim_embedding: int = 256,
|
|
23
|
+
conv_dims: tuple[int, ...] = (32, 64, 128),
|
|
24
|
+
num_blocks: int = 4,
|
|
25
|
+
use_batch_norm: bool = False,
|
|
26
|
+
use_layer_norm: bool = True,
|
|
27
|
+
use_fft: bool = False,
|
|
28
|
+
activation: str = "gelu",
|
|
29
|
+
conv_activation: str = "lrelu",
|
|
30
|
+
resnet_activation: str = "relu",
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
if isinstance(z_num, int):
|
|
35
|
+
z_num = (z_num,)
|
|
36
|
+
num_kernel = len(z_num)
|
|
37
|
+
|
|
38
|
+
if z_range is not None:
|
|
39
|
+
zs = [(z_range[0], z_range[1], nz) for nz in z_num]
|
|
40
|
+
else:
|
|
41
|
+
zs = z_num
|
|
42
|
+
|
|
43
|
+
self.in_dim = in_dim
|
|
44
|
+
|
|
45
|
+
self.kernels = nn.ModuleList(
|
|
46
|
+
[
|
|
47
|
+
IntegralKernelBlock(
|
|
48
|
+
z,
|
|
49
|
+
in_dim,
|
|
50
|
+
kernel_coef=kernel_coef,
|
|
51
|
+
latent_dim=dim_embedding,
|
|
52
|
+
conv_dims=conv_dims,
|
|
53
|
+
use_fft=use_fft,
|
|
54
|
+
activation=activation,
|
|
55
|
+
conv_activation=conv_activation,
|
|
56
|
+
)
|
|
57
|
+
for z in zs
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.fc = reflectorch.models.networks.residual_net.ResidualMLP(
|
|
62
|
+
dim_in=dim_embedding * num_kernel,
|
|
63
|
+
dim_out=dim_embedding,
|
|
64
|
+
layer_width=2 * dim_embedding,
|
|
65
|
+
num_blocks=num_blocks,
|
|
66
|
+
use_batch_norm=use_batch_norm,
|
|
67
|
+
use_layer_norm=use_layer_norm,
|
|
68
|
+
activation=resnet_activation,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def forward(self, q, y, drop_mask=None) -> Tensor:
|
|
72
|
+
x = cat([kernel(q, y, drop_mask=drop_mask) for kernel in self.kernels], dim=-1)
|
|
73
|
+
x = self.fc(x)
|
|
74
|
+
|
|
75
|
+
return x
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class IntegralKernelBlock(nn.Module):
|
|
79
|
+
"""
|
|
80
|
+
Examples:
|
|
81
|
+
>>> x = torch.rand(2, 100)
|
|
82
|
+
>>> y = torch.rand(2, 100, 3)
|
|
83
|
+
>>> block = IntegralKernelBlock((0, 1, 10), in_dim=3, latent_dim=32)
|
|
84
|
+
>>> output = block(x, y)
|
|
85
|
+
>>> output.shape
|
|
86
|
+
torch.Size([2, 32])
|
|
87
|
+
|
|
88
|
+
>>> block = IntegralKernelBlock(10, in_dim=3, latent_dim=32)
|
|
89
|
+
>>> output = block(x, y)
|
|
90
|
+
>>> output.shape
|
|
91
|
+
torch.Size([2, 32])
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
z: tuple[float, float, int] or int,
|
|
97
|
+
in_dim: int,
|
|
98
|
+
kernel_coef: int = 2,
|
|
99
|
+
latent_dim: int = 32,
|
|
100
|
+
conv_dims: tuple[int, ...] = (32, 64, 128),
|
|
101
|
+
use_fft: bool = False,
|
|
102
|
+
activation: str = "gelu",
|
|
103
|
+
conv_activation: str = "lrelu",
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
|
|
107
|
+
if isinstance(z, int):
|
|
108
|
+
z_num = z
|
|
109
|
+
kernel = FullIntegralKernel(z_num, in_dim=in_dim, kernel_coef=kernel_coef)
|
|
110
|
+
else:
|
|
111
|
+
kernel = FastIntegralKernel(
|
|
112
|
+
z, in_dim=in_dim, kernel_coef=kernel_coef, activation=activation
|
|
113
|
+
)
|
|
114
|
+
z_num = z[-1]
|
|
115
|
+
|
|
116
|
+
assert z_num % 2 == 0, "z_num should be even"
|
|
117
|
+
|
|
118
|
+
self.kernel = kernel
|
|
119
|
+
self.z_num = z_num
|
|
120
|
+
self.in_dim = in_dim
|
|
121
|
+
self.latent_dim = latent_dim
|
|
122
|
+
self.use_fft = use_fft
|
|
123
|
+
|
|
124
|
+
self.fc_in_dim = self.latent_dim + self.in_dim * self.z_num
|
|
125
|
+
if self.use_fft:
|
|
126
|
+
self.fc_in_dim += self.in_dim * 2 + self.in_dim * self.z_num
|
|
127
|
+
|
|
128
|
+
self.conv = reflectorch.models.encoders.conv_encoder.ConvEncoder(
|
|
129
|
+
dim_avpool=8,
|
|
130
|
+
hidden_channels=conv_dims,
|
|
131
|
+
in_channels=in_dim,
|
|
132
|
+
dim_embedding=latent_dim,
|
|
133
|
+
activation=conv_activation,
|
|
134
|
+
)
|
|
135
|
+
self.fc = FCBlock(
|
|
136
|
+
in_dim=self.fc_in_dim, hid_dim=self.latent_dim * 2, out_dim=self.latent_dim
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask: Tensor = None) -> Tensor:
|
|
140
|
+
x = self.kernel(x, y, drop_mask=drop_mask)
|
|
141
|
+
|
|
142
|
+
assert x.shape == (x.shape[0], self.in_dim, self.z_num)
|
|
143
|
+
|
|
144
|
+
xc = self.conv(x) # (batch, latent_dim)
|
|
145
|
+
|
|
146
|
+
assert xc.shape == (x.shape[0], self.latent_dim)
|
|
147
|
+
|
|
148
|
+
if self.use_fft:
|
|
149
|
+
fft_x = torch.fft.rfft(x, dim=-1, norm="ortho") # (batch, in_dim, z_num)
|
|
150
|
+
|
|
151
|
+
fft_x = torch.cat(
|
|
152
|
+
[fft_x.real, fft_x.imag], -1
|
|
153
|
+
) # (batch, in_dim, 2 * z_num)
|
|
154
|
+
|
|
155
|
+
assert fft_x.shape == (x.shape[0], x.shape[1], self.z_num + 2)
|
|
156
|
+
|
|
157
|
+
fft_x = fft_x.flatten(1) # (batch, in_dim * (z_num + 2))
|
|
158
|
+
|
|
159
|
+
x = torch.cat(
|
|
160
|
+
[x.flatten(1), fft_x, xc], -1
|
|
161
|
+
) # (batch, in_dim * z_num * 3 + latent_dim)
|
|
162
|
+
else:
|
|
163
|
+
x = torch.cat([x.flatten(1), xc], -1)
|
|
164
|
+
|
|
165
|
+
assert (
|
|
166
|
+
x.shape[1] == self.fc_in_dim
|
|
167
|
+
), f"Expected dim {self.fc_in_dim}, got {x.shape[1]}"
|
|
168
|
+
|
|
169
|
+
x = self.fc(x) # (batch, latent_dim)
|
|
170
|
+
|
|
171
|
+
return x
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class FastIntegralKernel(nn.Module):
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
z: tuple[float, float, int],
|
|
178
|
+
kernel_coef: int = 16,
|
|
179
|
+
in_dim: int = 1,
|
|
180
|
+
activation: str = "gelu",
|
|
181
|
+
):
|
|
182
|
+
super().__init__()
|
|
183
|
+
|
|
184
|
+
z = torch.linspace(*z)
|
|
185
|
+
|
|
186
|
+
self.kernel = FCBlock(
|
|
187
|
+
in_dim + 2, kernel_coef * in_dim, in_dim, activation=activation
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
self.register_buffer("z", z)
|
|
191
|
+
|
|
192
|
+
def _get_z(self, x: Tensor):
|
|
193
|
+
# x.shape == (batch_size, num_x)
|
|
194
|
+
dz = self.z[1] - self.z[0]
|
|
195
|
+
indices = torch.ceil((x - self.z[0] - dz / 2) / dz).to(torch.int64)
|
|
196
|
+
|
|
197
|
+
z = torch.index_select(self.z, 0, indices.flatten()).view(*x.shape)
|
|
198
|
+
|
|
199
|
+
return z, indices
|
|
200
|
+
|
|
201
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask=None):
|
|
202
|
+
z, indices = self._get_z(x)
|
|
203
|
+
xz = torch.stack([x, z], -1)
|
|
204
|
+
kernel_input = torch.cat([xz, y], -1)
|
|
205
|
+
output = self.kernel(kernel_input) # (batch, x_num, in_dim)
|
|
206
|
+
|
|
207
|
+
output = compute_means(
|
|
208
|
+
output * y, indices, self.z.shape[-1], drop_mask=drop_mask
|
|
209
|
+
) # (batch, z_num, in_dim)
|
|
210
|
+
|
|
211
|
+
output = output.swapaxes(1, 2) # (batch, in_dim, z_num)
|
|
212
|
+
|
|
213
|
+
return output
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class FullIntegralKernel(nn.Module):
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
z_num: int,
|
|
220
|
+
kernel_coef: int = 1,
|
|
221
|
+
in_dim: int = 1,
|
|
222
|
+
):
|
|
223
|
+
super().__init__()
|
|
224
|
+
|
|
225
|
+
self.z_num = z_num
|
|
226
|
+
self.in_dim = in_dim
|
|
227
|
+
|
|
228
|
+
self.kernel = nn.Sequential(
|
|
229
|
+
nn.Linear(in_dim + 1, z_num * kernel_coef),
|
|
230
|
+
nn.LayerNorm(z_num * kernel_coef),
|
|
231
|
+
nn.ReLU(),
|
|
232
|
+
nn.Linear(z_num * kernel_coef, z_num * in_dim),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask=None):
|
|
236
|
+
# x.shape == (batch_size, num_x)
|
|
237
|
+
# y.shape == (batch_size, num_x, in_dim)
|
|
238
|
+
# drop_mask.shape == (batch_size, num_x)
|
|
239
|
+
|
|
240
|
+
batch_size, num_x = x.shape
|
|
241
|
+
|
|
242
|
+
kernel_input = torch.cat([x.unsqueeze(-1), y], -1) # (batch, x_num, in_dim + 1)
|
|
243
|
+
x = self.kernel(kernel_input) # (batch, x_num, z_num * in_dim)
|
|
244
|
+
x = x.reshape(
|
|
245
|
+
*x.shape[:-1], self.z_num, self.in_dim
|
|
246
|
+
) # (batch, x_num, z_num, in_dim)
|
|
247
|
+
# permute to get (batch, z_num, x_num, in_dim)
|
|
248
|
+
x = x.permute(0, 2, 1, 3)
|
|
249
|
+
|
|
250
|
+
y = y.unsqueeze(1) # (batch, 1, x_num, in_dim)
|
|
251
|
+
|
|
252
|
+
assert x.shape == (
|
|
253
|
+
batch_size,
|
|
254
|
+
self.z_num,
|
|
255
|
+
num_x,
|
|
256
|
+
self.in_dim,
|
|
257
|
+
) # (batch, z_num, in_dim, x_num)
|
|
258
|
+
assert y.shape == (
|
|
259
|
+
batch_size,
|
|
260
|
+
1,
|
|
261
|
+
num_x,
|
|
262
|
+
self.in_dim,
|
|
263
|
+
) # (batch, 1, x_num, in_dim)
|
|
264
|
+
|
|
265
|
+
if drop_mask is not None:
|
|
266
|
+
x = x * y
|
|
267
|
+
x = x.permute(0, 2, 1, 3) # (batch, x_num, z_num, in_dim)
|
|
268
|
+
x = masked_mean(x, drop_mask)
|
|
269
|
+
else:
|
|
270
|
+
x = (x * y).mean(-2) # (batch, z_num, in_dim)
|
|
271
|
+
|
|
272
|
+
assert x.shape == (batch_size, self.z_num, self.in_dim), f"{x.shape}"
|
|
273
|
+
|
|
274
|
+
x = x.swapaxes(1, 2) # (batch, in_dim, z_num)
|
|
275
|
+
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class FCBlock(nn.Module):
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
in_dim: int = 2,
|
|
283
|
+
hid_dim: int = 16,
|
|
284
|
+
out_dim: int = 16,
|
|
285
|
+
activation: str = "gelu",
|
|
286
|
+
):
|
|
287
|
+
super().__init__()
|
|
288
|
+
|
|
289
|
+
self.fc1 = nn.Linear(in_dim, hid_dim)
|
|
290
|
+
self.layer_norm = nn.LayerNorm(hid_dim)
|
|
291
|
+
self.activation = activation_by_name(activation)()
|
|
292
|
+
self.fc2 = nn.Linear(hid_dim, out_dim)
|
|
293
|
+
|
|
294
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
295
|
+
x = self.fc1(x)
|
|
296
|
+
x = self.layer_norm(x)
|
|
297
|
+
x = self.activation(x)
|
|
298
|
+
x = self.fc2(x)
|
|
299
|
+
return x
|
|
300
|
+
# return self.kernel(x)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def compute_means(x, indices, z: int, drop_mask: Tensor = None):
|
|
304
|
+
"""
|
|
305
|
+
Compute the mean values of tensor 'x' for each unique index in 'indices' across each batch.
|
|
306
|
+
|
|
307
|
+
This function calculates the mean of elements in 'x' that correspond to each unique index in 'indices'.
|
|
308
|
+
The computation is performed for each batch separately, and the function is optimized to avoid Python loops
|
|
309
|
+
by using advanced PyTorch operations.
|
|
310
|
+
|
|
311
|
+
Parameters:
|
|
312
|
+
x (torch.Tensor): A tensor of shape (batch_size, n, d) containing the values to be averaged.
|
|
313
|
+
'x' should be a floating-point tensor.
|
|
314
|
+
indices (torch.Tensor): An integer tensor of shape (batch_size, n) containing the indices.
|
|
315
|
+
The values in 'indices' should be in the range [0, z-1].
|
|
316
|
+
z (int): The number of unique indices. This determines the second dimension of the output tensor.
|
|
317
|
+
drop_mask (torch.Tensor): A boolean tensor of shape (batch_size, n) containing a mask for the indices to drop.
|
|
318
|
+
If None, all indices are used.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
torch.Tensor: A tensor of shape (batch_size, z, d) containing the mean values for each index in each batch.
|
|
322
|
+
If an index does not appear in a batch, its corresponding mean values are zeros.
|
|
323
|
+
|
|
324
|
+
Example:
|
|
325
|
+
>>> batch_size, n, d, z = 3, 4, 5, 6
|
|
326
|
+
>>> indices = torch.randint(0, z, (batch_size, n))
|
|
327
|
+
>>> x = torch.randn(batch_size, n, d)
|
|
328
|
+
>>> y = compute_means(x, indices, z)
|
|
329
|
+
>>> print(y.shape)
|
|
330
|
+
torch.Size([3, 6, 5])
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
batch_size, n, d = x.shape
|
|
334
|
+
device = x.device
|
|
335
|
+
|
|
336
|
+
drop = drop_mask is not None
|
|
337
|
+
|
|
338
|
+
# Initialize tensors to hold sums and counts
|
|
339
|
+
sums = torch.zeros(batch_size, z + int(drop), d, device=device)
|
|
340
|
+
counts = torch.zeros(batch_size, z + int(drop), device=device)
|
|
341
|
+
|
|
342
|
+
if drop_mask is not None:
|
|
343
|
+
# Set the values of the indices to drop to z
|
|
344
|
+
indices = indices.masked_fill(~drop_mask, z)
|
|
345
|
+
|
|
346
|
+
indices_expanded = indices.unsqueeze(-1).expand_as(x)
|
|
347
|
+
sums.scatter_add_(1, indices_expanded, x)
|
|
348
|
+
counts.scatter_add_(1, indices, torch.ones_like(indices, dtype=x.dtype))
|
|
349
|
+
|
|
350
|
+
if drop:
|
|
351
|
+
# Remove the z values from the sums and counts
|
|
352
|
+
sums = sums[:, :-1]
|
|
353
|
+
counts = counts[:, :-1]
|
|
354
|
+
|
|
355
|
+
# Compute the mean and handle division by zero
|
|
356
|
+
mean = sums / counts.unsqueeze(-1).clamp(min=1)
|
|
357
|
+
|
|
358
|
+
return mean
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def masked_mean(x, mask):
|
|
362
|
+
"""
|
|
363
|
+
Computes the mean of tensor x along the x_size dimension,
|
|
364
|
+
while masking out elements where the corresponding value in the mask is False.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
x (torch.Tensor): A tensor of shape (batch, x_size, z, d).
|
|
368
|
+
mask (torch.Tensor): A boolean mask of shape (batch, x_size).
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
torch.Tensor: The result tensor of shape (batch, z, d) after applying the mask and computing the mean.
|
|
372
|
+
"""
|
|
373
|
+
if not mask.dtype == torch.bool:
|
|
374
|
+
raise TypeError("Mask must be a boolean tensor.")
|
|
375
|
+
|
|
376
|
+
# Ensure the mask is broadcastable to the shape of x
|
|
377
|
+
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
|
378
|
+
masked_x = x * mask
|
|
379
|
+
|
|
380
|
+
# Compute the sum and the count of valid (unmasked) elements along the x_size dimension
|
|
381
|
+
sum_x = masked_x.sum(dim=1)
|
|
382
|
+
count_x = mask.sum(dim=1)
|
|
383
|
+
|
|
384
|
+
# Avoid division by zero
|
|
385
|
+
count_x[count_x == 0] = 1
|
|
386
|
+
|
|
387
|
+
# Compute the mean
|
|
388
|
+
mean_x = sum_x / count_x
|
|
389
|
+
|
|
390
|
+
return mean_x
|
|
@@ -7,6 +7,7 @@ from torch import nn, cat, split, Tensor
|
|
|
7
7
|
|
|
8
8
|
from reflectorch.models.networks.residual_net import ResidualMLP
|
|
9
9
|
from reflectorch.models.encoders.conv_encoder import ConvEncoder
|
|
10
|
+
from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
|
|
10
11
|
from reflectorch.models.encoders.fno import FnoEncoder
|
|
11
12
|
from reflectorch.models.activations import activation_by_name
|
|
12
13
|
|
|
@@ -18,7 +19,7 @@ class NetworkWithPriors(nn.Module):
|
|
|
18
19
|
:align: center
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
|
-
embedding_net_type (str): the type of embedding network, either 'conv' or '
|
|
22
|
+
embedding_net_type (str): the type of embedding network, either 'conv', 'fno' or 'integral_conv'.
|
|
22
23
|
embedding_net_kwargs (dict): dictionary containing the keyword arguments for the embedding network.
|
|
23
24
|
dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
|
|
24
25
|
dim_conditioning_params (int, optional): the dimension of other parameters the network is conditioned on (e.g. for the smearing coefficient dq/q)
|
|
@@ -66,6 +67,8 @@ class NetworkWithPriors(nn.Module):
|
|
|
66
67
|
self.embedding_net = ConvEncoder(**embedding_net_kwargs)
|
|
67
68
|
elif embedding_net_type == 'fno':
|
|
68
69
|
self.embedding_net = FnoEncoder(**embedding_net_kwargs)
|
|
70
|
+
elif embedding_net_type == 'integral_conv':
|
|
71
|
+
self.embedding_net = IntegralConvEmbedding(**embedding_net_kwargs)
|
|
69
72
|
elif embedding_net_type == 'no_embedding_net':
|
|
70
73
|
self.embedding_net = nn.Identity()
|
|
71
74
|
else:
|
|
@@ -108,7 +111,7 @@ class NetworkWithPriors(nn.Module):
|
|
|
108
111
|
self.embedding_net.load_weights(pretrained_embedding_net)
|
|
109
112
|
|
|
110
113
|
|
|
111
|
-
def forward(self, curves, bounds, q_values=None, conditioning_params=None):
|
|
114
|
+
def forward(self, curves, bounds, q_values=None, conditioning_params=None, key_padding_mask=None, unscaled_q_values=None):
|
|
112
115
|
"""
|
|
113
116
|
Args:
|
|
114
117
|
scaled_curves (torch.Tensor): Input tensor of shape [batch_size, n_points] or [batch_size, n_channels, n_points].
|
|
@@ -121,13 +124,16 @@ class NetworkWithPriors(nn.Module):
|
|
|
121
124
|
curves = curves.unsqueeze(1)
|
|
122
125
|
|
|
123
126
|
additional_channels = []
|
|
124
|
-
if q_values is not None:
|
|
127
|
+
if q_values is not None and not isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
125
128
|
additional_channels.append(q_values.unsqueeze(1))
|
|
126
129
|
|
|
127
130
|
if additional_channels:
|
|
128
131
|
curves = torch.cat([curves] + additional_channels, dim=1) # [batch_size, n_channels, n_points]
|
|
129
132
|
|
|
130
|
-
|
|
133
|
+
if isinstance(self.embedding_net, IntegralConvEmbedding):
|
|
134
|
+
x = self.embedding_net(q=unscaled_q_values.float(), y=curves.permute(0, 2, 1), drop_mask=key_padding_mask)
|
|
135
|
+
else:
|
|
136
|
+
x = self.embedding_net(curves)
|
|
131
137
|
|
|
132
138
|
if self.conditioning == 'concat':
|
|
133
139
|
x = torch.cat([x, bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
|
reflectorch/runs/utils.py
CHANGED
|
@@ -278,7 +278,7 @@ def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
|
278
278
|
raise FileNotFoundError(f'File {str(model_path)} does not exist.')
|
|
279
279
|
|
|
280
280
|
try:
|
|
281
|
-
pretrained = torch.load(model_path)
|
|
281
|
+
pretrained = torch.load(model_path, weights_only=False)
|
|
282
282
|
except Exception as err:
|
|
283
283
|
raise RuntimeError(f'Could not load model from {str(model_path)}') from err
|
|
284
284
|
|
|
@@ -295,6 +295,8 @@ def load_pretrained(model, model_name: str, saved_models_dir: Path):
|
|
|
295
295
|
def init_dset(config: dict):
|
|
296
296
|
"""Initializes the dataset / dataloader object"""
|
|
297
297
|
dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
|
|
298
|
+
dset_kwargs = config.get('kwargs', {})
|
|
299
|
+
|
|
298
300
|
prior_sampler = init_from_conf(config['prior_sampler'])
|
|
299
301
|
intensity_noise = init_from_conf(config['intensity_noise'])
|
|
300
302
|
q_generator = init_from_conf(config['q_generator'])
|
|
@@ -309,6 +311,7 @@ def init_dset(config: dict):
|
|
|
309
311
|
curves_scaler=curves_scaler,
|
|
310
312
|
smearing=smearing,
|
|
311
313
|
q_noise=q_noise,
|
|
314
|
+
**dset_kwargs,
|
|
312
315
|
)
|
|
313
316
|
|
|
314
317
|
return dset
|
|
@@ -358,7 +361,7 @@ def convert_pt_to_safetensors(input_dir):
|
|
|
358
361
|
continue
|
|
359
362
|
|
|
360
363
|
print(f"Converting {pt_file_path} to .safetensors format.")
|
|
361
|
-
data_pt = torch.load(pt_file_path)
|
|
364
|
+
data_pt = torch.load(pt_file_path, weights_only=False)
|
|
362
365
|
model_state_dict = data_pt["model"]
|
|
363
366
|
model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
|
|
364
367
|
|
reflectorch/utils.py
CHANGED
|
@@ -66,3 +66,33 @@ def energy_to_wavelength(energy: float):
|
|
|
66
66
|
def wavelength_to_energy(wavelength: float):
|
|
67
67
|
"""Conversion from photon wavelength (angstroms) to photon energy (eV)"""
|
|
68
68
|
return 1.2398 / wavelength * 1e4
|
|
69
|
+
|
|
70
|
+
def get_filtering_mask(Q, R, dR, threshold=0.3, consecutive=3,
|
|
71
|
+
remove_singles=True, remove_consecutives=True,
|
|
72
|
+
q_start_trunc=0.1):
|
|
73
|
+
Q, R, dR = Q.copy(), R.copy(), dR.copy()
|
|
74
|
+
rel_error = np.abs(dR / R)
|
|
75
|
+
|
|
76
|
+
# Mask for singles
|
|
77
|
+
mask_singles = (rel_error >= threshold) if remove_singles else np.zeros_like(Q, dtype=bool)
|
|
78
|
+
|
|
79
|
+
# Mask for truncation
|
|
80
|
+
mask_consecutive = np.zeros_like(Q, dtype=bool)
|
|
81
|
+
if remove_consecutives:
|
|
82
|
+
count = 0
|
|
83
|
+
cutoff_idx = None
|
|
84
|
+
for i in range(len(Q)):
|
|
85
|
+
if Q[i] < q_start_trunc:
|
|
86
|
+
continue
|
|
87
|
+
if rel_error[i] >= threshold:
|
|
88
|
+
count += 1
|
|
89
|
+
if count >= consecutive:
|
|
90
|
+
cutoff_idx = i - consecutive + 1
|
|
91
|
+
break
|
|
92
|
+
else:
|
|
93
|
+
count = 0
|
|
94
|
+
if cutoff_idx is not None:
|
|
95
|
+
mask_consecutive[cutoff_idx:] = True
|
|
96
|
+
|
|
97
|
+
final_mask = mask_singles | mask_consecutive
|
|
98
|
+
return ~final_mask
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: reflectorch
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: A Pytorch-based package for the analysis of reflectometry data
|
|
5
5
|
Author-email: Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>
|
|
6
6
|
Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
|
|
7
|
-
License-Expression: MIT
|
|
8
7
|
Project-URL: Source, https://github.com/schreiber-lab/reflectorch/
|
|
9
8
|
Project-URL: Issues, https://github.com/schreiber-lab/reflectorch/issues
|
|
10
9
|
Project-URL: Documentation, https://schreiber-lab.github.io/reflectorch/
|
|
@@ -106,6 +105,8 @@ Configuration files and the corresponding pretrained model weights are hosted on
|
|
|
106
105
|
<!-- [](https://hub.docker.com/)
|
|
107
106
|
Docker images for reflectorch *will* be hosted on Dockerhub. -->
|
|
108
107
|
|
|
108
|
+
## Contributing
|
|
109
|
+
If you'd like to contribute to the package, please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
|
109
110
|
|
|
110
111
|
## Citation
|
|
111
112
|
If you find our work useful in your research, please cite as follows:
|