reflectorch 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

@@ -63,6 +63,8 @@ def refl_fit(
63
63
  bounds: np.ndarray = None,
64
64
  error_bars: np.ndarray = None,
65
65
  scale_curve_func=np.log10,
66
+ method: str = 'trf', #'lm', 'trf'
67
+ polishing_max_nfev: int = None,
66
68
  reflectivity_kwargs: dict = None,
67
69
  **kwargs
68
70
  ):
@@ -77,14 +79,23 @@ def refl_fit(
77
79
  adjusted_bounds[1, i] += epsilon
78
80
 
79
81
  init_params = np.clip(init_params, *adjusted_bounds)
80
- kwargs['bounds'] = adjusted_bounds
82
+ if method != 'lm':
83
+ kwargs['bounds'] = adjusted_bounds
81
84
 
82
85
  reflectivity_kwargs = reflectivity_kwargs or {}
83
86
  for key, value in reflectivity_kwargs.items():
84
87
  if isinstance(value, float):
85
88
  reflectivity_kwargs[key] = torch.tensor([[value]], dtype=torch.float64)
86
89
  elif isinstance(value, np.ndarray):
87
- reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
90
+ reflectivity_kwargs[key] = torch.tensor(value, dtype=torch.float32).unsqueeze(0)
91
+
92
+ curve = np.clip(curve, a_min=1e-12, a_max=None)
93
+
94
+ if error_bars is not None and scale_curve_func == np.log10:
95
+ error_bars = np.clip(error_bars, a_min=1e-20, a_max=None)
96
+ scaled_error_bars = error_bars / (curve * np.log(10))
97
+ else:
98
+ scaled_error_bars = None
88
99
 
89
100
  res = curve_fit(
90
101
  f=get_scaled_curve_func(
@@ -93,10 +104,12 @@ def refl_fit(
93
104
  reflectivity_kwargs=reflectivity_kwargs,
94
105
  ),
95
106
  xdata=q,
96
- ydata=scale_curve_func(curve),
107
+ ydata=scale_curve_func(curve).reshape(-1),
97
108
  p0=init_params,
98
- sigma=error_bars if error_bars is not None else None,
109
+ sigma=scaled_error_bars,
99
110
  absolute_sigma=True,
111
+ method=method,
112
+ max_nfev=polishing_max_nfev,
100
113
  **kwargs
101
114
  )
102
115
 
@@ -184,7 +197,8 @@ def get_scaled_curve_func(
184
197
  fitted_curve = fitted_curve_tensor.squeeze().numpy()
185
198
 
186
199
  scaled_curve = scale_curve_func(fitted_curve)
187
- return scaled_curve
200
+
201
+ return scaled_curve.reshape(-1)
188
202
 
189
203
  return scaled_curve_func
190
204
 
@@ -24,7 +24,9 @@ class BasicBatchData:
24
24
  scaled_sigmas: Optional[torch.Tensor] = None
25
25
  scaled_q_values: Optional[torch.Tensor] = None
26
26
  scaled_denoised_curves: Optional[torch.Tensor] = None
27
+ key_padding_mask: Optional[torch.Tensor] = None
27
28
  scaled_conditioning_params: Optional[torch.Tensor] = None
29
+ unscaled_q_values: Optional[torch.Tensor] = None
28
30
 
29
31
  class RealTimeSimTrainer(Trainer):
30
32
  """Trainer with functionality to customize the sampled batch of data"""
@@ -74,6 +76,7 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
74
76
  scaled_curves = batch_data['scaled_noisy_curves'].to(torch.float32)
75
77
  scaled_denoised_curves = get_scaled_or_none('curves', self.loader.curves_scaler.scale)
76
78
  scaled_q_values = get_scaled_or_none('q_values', self.loader.q_generator.scale_q) if self.train_with_q_input else None
79
+ key_padding_mask = batch_data.get('key_padding_mask', None)
77
80
 
78
81
  scaled_q_resolutions = get_scaled_or_none('q_resolutions', self.loader.smearing.scale_resolutions) if self.condition_on_q_resolutions else None
79
82
  conditioning_params = []
@@ -92,6 +95,8 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
92
95
  scaled_q_values=scaled_q_values,
93
96
  scaled_denoised_curves=scaled_denoised_curves,
94
97
  scaled_conditioning_params=scaled_conditioning_params,
98
+ unscaled_q_values=batch_data['q_values'],
99
+ key_padding_mask=key_padding_mask,
95
100
  )
96
101
 
97
102
  def get_loss_dict(self, batch_data: BasicBatchData):
@@ -100,13 +105,17 @@ class PointEstimatorTrainer(RealTimeSimTrainer):
100
105
  scaled_curves=batch_data.scaled_curves
101
106
  scaled_bounds=batch_data.scaled_bounds
102
107
  scaled_q_values=batch_data.scaled_q_values
108
+ key_padding_mask=batch_data.key_padding_mask
103
109
  scaled_conditioning_params=batch_data.scaled_conditioning_params
110
+ unscaled_q_values=batch_data.unscaled_q_values
104
111
 
105
112
  predicted_params = self.model(
106
113
  curves = scaled_curves,
107
114
  bounds = scaled_bounds,
108
115
  q_values = scaled_q_values,
109
116
  conditioning_params = scaled_conditioning_params,
117
+ key_padding_mask = key_padding_mask,
118
+ unscaled_q_values = unscaled_q_values,
110
119
  )
111
120
 
112
121
  if not self.rescale_loss_interval_width:
@@ -6,6 +6,7 @@ __all__ = [
6
6
  "ConvDecoder",
7
7
  "ConvAutoencoder",
8
8
  "FnoEncoder",
9
+ "IntegralConvEmbedding",
9
10
  "SpectralConv1d",
10
11
  "ConvResidualNet1D",
11
12
  "ResidualMLP",
@@ -4,6 +4,7 @@ from reflectorch.models.encoders.conv_encoder import (
4
4
  ConvAutoencoder,
5
5
  )
6
6
  from reflectorch.models.encoders.fno import FnoEncoder, SpectralConv1d
7
+ from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
7
8
  from reflectorch.models.encoders.conv_res_net import ConvResidualNet1D
8
9
 
9
10
 
@@ -14,4 +15,5 @@ __all__ = [
14
15
  "ConvResidualNet1D",
15
16
  "FnoEncoder",
16
17
  "SpectralConv1d",
18
+ "IntegralConvEmbedding",
17
19
  ]
@@ -0,0 +1,390 @@
1
+ from __future__ import annotations
2
+ from typing import Union
3
+
4
+ import torch
5
+ from torch import nn, Tensor, stack, cat
6
+ from reflectorch.models.activations import activation_by_name
7
+ import reflectorch
8
+
9
+ ###embedding network adapted from the PANPE repository
10
+
11
+ __all__ = [
12
+ "IntegralConvEmbedding",
13
+ ]
14
+
15
+ class IntegralConvEmbedding(nn.Module):
16
+ def __init__(
17
+ self,
18
+ z_num: Union[int, tuple[int, ...]],
19
+ z_range: tuple[float, float] = None,
20
+ in_dim: int = 2,
21
+ kernel_coef: int = 16,
22
+ dim_embedding: int = 256,
23
+ conv_dims: tuple[int, ...] = (32, 64, 128),
24
+ num_blocks: int = 4,
25
+ use_batch_norm: bool = False,
26
+ use_layer_norm: bool = True,
27
+ use_fft: bool = False,
28
+ activation: str = "gelu",
29
+ conv_activation: str = "lrelu",
30
+ resnet_activation: str = "relu",
31
+ ) -> None:
32
+ super().__init__()
33
+
34
+ if isinstance(z_num, int):
35
+ z_num = (z_num,)
36
+ num_kernel = len(z_num)
37
+
38
+ if z_range is not None:
39
+ zs = [(z_range[0], z_range[1], nz) for nz in z_num]
40
+ else:
41
+ zs = z_num
42
+
43
+ self.in_dim = in_dim
44
+
45
+ self.kernels = nn.ModuleList(
46
+ [
47
+ IntegralKernelBlock(
48
+ z,
49
+ in_dim,
50
+ kernel_coef=kernel_coef,
51
+ latent_dim=dim_embedding,
52
+ conv_dims=conv_dims,
53
+ use_fft=use_fft,
54
+ activation=activation,
55
+ conv_activation=conv_activation,
56
+ )
57
+ for z in zs
58
+ ]
59
+ )
60
+
61
+ self.fc = reflectorch.models.networks.residual_net.ResidualMLP(
62
+ dim_in=dim_embedding * num_kernel,
63
+ dim_out=dim_embedding,
64
+ layer_width=2 * dim_embedding,
65
+ num_blocks=num_blocks,
66
+ use_batch_norm=use_batch_norm,
67
+ use_layer_norm=use_layer_norm,
68
+ activation=resnet_activation,
69
+ )
70
+
71
+ def forward(self, q, y, drop_mask=None) -> Tensor:
72
+ x = cat([kernel(q, y, drop_mask=drop_mask) for kernel in self.kernels], dim=-1)
73
+ x = self.fc(x)
74
+
75
+ return x
76
+
77
+
78
+ class IntegralKernelBlock(nn.Module):
79
+ """
80
+ Examples:
81
+ >>> x = torch.rand(2, 100)
82
+ >>> y = torch.rand(2, 100, 3)
83
+ >>> block = IntegralKernelBlock((0, 1, 10), in_dim=3, latent_dim=32)
84
+ >>> output = block(x, y)
85
+ >>> output.shape
86
+ torch.Size([2, 32])
87
+
88
+ >>> block = IntegralKernelBlock(10, in_dim=3, latent_dim=32)
89
+ >>> output = block(x, y)
90
+ >>> output.shape
91
+ torch.Size([2, 32])
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ z: tuple[float, float, int] or int,
97
+ in_dim: int,
98
+ kernel_coef: int = 2,
99
+ latent_dim: int = 32,
100
+ conv_dims: tuple[int, ...] = (32, 64, 128),
101
+ use_fft: bool = False,
102
+ activation: str = "gelu",
103
+ conv_activation: str = "lrelu",
104
+ ):
105
+ super().__init__()
106
+
107
+ if isinstance(z, int):
108
+ z_num = z
109
+ kernel = FullIntegralKernel(z_num, in_dim=in_dim, kernel_coef=kernel_coef)
110
+ else:
111
+ kernel = FastIntegralKernel(
112
+ z, in_dim=in_dim, kernel_coef=kernel_coef, activation=activation
113
+ )
114
+ z_num = z[-1]
115
+
116
+ assert z_num % 2 == 0, "z_num should be even"
117
+
118
+ self.kernel = kernel
119
+ self.z_num = z_num
120
+ self.in_dim = in_dim
121
+ self.latent_dim = latent_dim
122
+ self.use_fft = use_fft
123
+
124
+ self.fc_in_dim = self.latent_dim + self.in_dim * self.z_num
125
+ if self.use_fft:
126
+ self.fc_in_dim += self.in_dim * 2 + self.in_dim * self.z_num
127
+
128
+ self.conv = reflectorch.models.encoders.conv_encoder.ConvEncoder(
129
+ dim_avpool=8,
130
+ hidden_channels=conv_dims,
131
+ in_channels=in_dim,
132
+ dim_embedding=latent_dim,
133
+ activation=conv_activation,
134
+ )
135
+ self.fc = FCBlock(
136
+ in_dim=self.fc_in_dim, hid_dim=self.latent_dim * 2, out_dim=self.latent_dim
137
+ )
138
+
139
+ def forward(self, x: Tensor, y: Tensor, drop_mask: Tensor = None) -> Tensor:
140
+ x = self.kernel(x, y, drop_mask=drop_mask)
141
+
142
+ assert x.shape == (x.shape[0], self.in_dim, self.z_num)
143
+
144
+ xc = self.conv(x) # (batch, latent_dim)
145
+
146
+ assert xc.shape == (x.shape[0], self.latent_dim)
147
+
148
+ if self.use_fft:
149
+ fft_x = torch.fft.rfft(x, dim=-1, norm="ortho") # (batch, in_dim, z_num)
150
+
151
+ fft_x = torch.cat(
152
+ [fft_x.real, fft_x.imag], -1
153
+ ) # (batch, in_dim, 2 * z_num)
154
+
155
+ assert fft_x.shape == (x.shape[0], x.shape[1], self.z_num + 2)
156
+
157
+ fft_x = fft_x.flatten(1) # (batch, in_dim * (z_num + 2))
158
+
159
+ x = torch.cat(
160
+ [x.flatten(1), fft_x, xc], -1
161
+ ) # (batch, in_dim * z_num * 3 + latent_dim)
162
+ else:
163
+ x = torch.cat([x.flatten(1), xc], -1)
164
+
165
+ assert (
166
+ x.shape[1] == self.fc_in_dim
167
+ ), f"Expected dim {self.fc_in_dim}, got {x.shape[1]}"
168
+
169
+ x = self.fc(x) # (batch, latent_dim)
170
+
171
+ return x
172
+
173
+
174
+ class FastIntegralKernel(nn.Module):
175
+ def __init__(
176
+ self,
177
+ z: tuple[float, float, int],
178
+ kernel_coef: int = 16,
179
+ in_dim: int = 1,
180
+ activation: str = "gelu",
181
+ ):
182
+ super().__init__()
183
+
184
+ z = torch.linspace(*z)
185
+
186
+ self.kernel = FCBlock(
187
+ in_dim + 2, kernel_coef * in_dim, in_dim, activation=activation
188
+ )
189
+
190
+ self.register_buffer("z", z)
191
+
192
+ def _get_z(self, x: Tensor):
193
+ # x.shape == (batch_size, num_x)
194
+ dz = self.z[1] - self.z[0]
195
+ indices = torch.ceil((x - self.z[0] - dz / 2) / dz).to(torch.int64)
196
+
197
+ z = torch.index_select(self.z, 0, indices.flatten()).view(*x.shape)
198
+
199
+ return z, indices
200
+
201
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
202
+ z, indices = self._get_z(x)
203
+ xz = torch.stack([x, z], -1)
204
+ kernel_input = torch.cat([xz, y], -1)
205
+ output = self.kernel(kernel_input) # (batch, x_num, in_dim)
206
+
207
+ output = compute_means(
208
+ output * y, indices, self.z.shape[-1], drop_mask=drop_mask
209
+ ) # (batch, z_num, in_dim)
210
+
211
+ output = output.swapaxes(1, 2) # (batch, in_dim, z_num)
212
+
213
+ return output
214
+
215
+
216
+ class FullIntegralKernel(nn.Module):
217
+ def __init__(
218
+ self,
219
+ z_num: int,
220
+ kernel_coef: int = 1,
221
+ in_dim: int = 1,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.z_num = z_num
226
+ self.in_dim = in_dim
227
+
228
+ self.kernel = nn.Sequential(
229
+ nn.Linear(in_dim + 1, z_num * kernel_coef),
230
+ nn.LayerNorm(z_num * kernel_coef),
231
+ nn.ReLU(),
232
+ nn.Linear(z_num * kernel_coef, z_num * in_dim),
233
+ )
234
+
235
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
236
+ # x.shape == (batch_size, num_x)
237
+ # y.shape == (batch_size, num_x, in_dim)
238
+ # drop_mask.shape == (batch_size, num_x)
239
+
240
+ batch_size, num_x = x.shape
241
+
242
+ kernel_input = torch.cat([x.unsqueeze(-1), y], -1) # (batch, x_num, in_dim + 1)
243
+ x = self.kernel(kernel_input) # (batch, x_num, z_num * in_dim)
244
+ x = x.reshape(
245
+ *x.shape[:-1], self.z_num, self.in_dim
246
+ ) # (batch, x_num, z_num, in_dim)
247
+ # permute to get (batch, z_num, x_num, in_dim)
248
+ x = x.permute(0, 2, 1, 3)
249
+
250
+ y = y.unsqueeze(1) # (batch, 1, x_num, in_dim)
251
+
252
+ assert x.shape == (
253
+ batch_size,
254
+ self.z_num,
255
+ num_x,
256
+ self.in_dim,
257
+ ) # (batch, z_num, in_dim, x_num)
258
+ assert y.shape == (
259
+ batch_size,
260
+ 1,
261
+ num_x,
262
+ self.in_dim,
263
+ ) # (batch, 1, x_num, in_dim)
264
+
265
+ if drop_mask is not None:
266
+ x = x * y
267
+ x = x.permute(0, 2, 1, 3) # (batch, x_num, z_num, in_dim)
268
+ x = masked_mean(x, drop_mask)
269
+ else:
270
+ x = (x * y).mean(-2) # (batch, z_num, in_dim)
271
+
272
+ assert x.shape == (batch_size, self.z_num, self.in_dim), f"{x.shape}"
273
+
274
+ x = x.swapaxes(1, 2) # (batch, in_dim, z_num)
275
+
276
+ return x
277
+
278
+
279
+ class FCBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_dim: int = 2,
283
+ hid_dim: int = 16,
284
+ out_dim: int = 16,
285
+ activation: str = "gelu",
286
+ ):
287
+ super().__init__()
288
+
289
+ self.fc1 = nn.Linear(in_dim, hid_dim)
290
+ self.layer_norm = nn.LayerNorm(hid_dim)
291
+ self.activation = activation_by_name(activation)()
292
+ self.fc2 = nn.Linear(hid_dim, out_dim)
293
+
294
+ def forward(self, x: Tensor) -> Tensor:
295
+ x = self.fc1(x)
296
+ x = self.layer_norm(x)
297
+ x = self.activation(x)
298
+ x = self.fc2(x)
299
+ return x
300
+ # return self.kernel(x)
301
+
302
+
303
+ def compute_means(x, indices, z: int, drop_mask: Tensor = None):
304
+ """
305
+ Compute the mean values of tensor 'x' for each unique index in 'indices' across each batch.
306
+
307
+ This function calculates the mean of elements in 'x' that correspond to each unique index in 'indices'.
308
+ The computation is performed for each batch separately, and the function is optimized to avoid Python loops
309
+ by using advanced PyTorch operations.
310
+
311
+ Parameters:
312
+ x (torch.Tensor): A tensor of shape (batch_size, n, d) containing the values to be averaged.
313
+ 'x' should be a floating-point tensor.
314
+ indices (torch.Tensor): An integer tensor of shape (batch_size, n) containing the indices.
315
+ The values in 'indices' should be in the range [0, z-1].
316
+ z (int): The number of unique indices. This determines the second dimension of the output tensor.
317
+ drop_mask (torch.Tensor): A boolean tensor of shape (batch_size, n) containing a mask for the indices to drop.
318
+ If None, all indices are used.
319
+
320
+ Returns:
321
+ torch.Tensor: A tensor of shape (batch_size, z, d) containing the mean values for each index in each batch.
322
+ If an index does not appear in a batch, its corresponding mean values are zeros.
323
+
324
+ Example:
325
+ >>> batch_size, n, d, z = 3, 4, 5, 6
326
+ >>> indices = torch.randint(0, z, (batch_size, n))
327
+ >>> x = torch.randn(batch_size, n, d)
328
+ >>> y = compute_means(x, indices, z)
329
+ >>> print(y.shape)
330
+ torch.Size([3, 6, 5])
331
+ """
332
+
333
+ batch_size, n, d = x.shape
334
+ device = x.device
335
+
336
+ drop = drop_mask is not None
337
+
338
+ # Initialize tensors to hold sums and counts
339
+ sums = torch.zeros(batch_size, z + int(drop), d, device=device)
340
+ counts = torch.zeros(batch_size, z + int(drop), device=device)
341
+
342
+ if drop_mask is not None:
343
+ # Set the values of the indices to drop to z
344
+ indices = indices.masked_fill(~drop_mask, z)
345
+
346
+ indices_expanded = indices.unsqueeze(-1).expand_as(x)
347
+ sums.scatter_add_(1, indices_expanded, x)
348
+ counts.scatter_add_(1, indices, torch.ones_like(indices, dtype=x.dtype))
349
+
350
+ if drop:
351
+ # Remove the z values from the sums and counts
352
+ sums = sums[:, :-1]
353
+ counts = counts[:, :-1]
354
+
355
+ # Compute the mean and handle division by zero
356
+ mean = sums / counts.unsqueeze(-1).clamp(min=1)
357
+
358
+ return mean
359
+
360
+
361
+ def masked_mean(x, mask):
362
+ """
363
+ Computes the mean of tensor x along the x_size dimension,
364
+ while masking out elements where the corresponding value in the mask is False.
365
+
366
+ Args:
367
+ x (torch.Tensor): A tensor of shape (batch, x_size, z, d).
368
+ mask (torch.Tensor): A boolean mask of shape (batch, x_size).
369
+
370
+ Returns:
371
+ torch.Tensor: The result tensor of shape (batch, z, d) after applying the mask and computing the mean.
372
+ """
373
+ if not mask.dtype == torch.bool:
374
+ raise TypeError("Mask must be a boolean tensor.")
375
+
376
+ # Ensure the mask is broadcastable to the shape of x
377
+ mask = mask.unsqueeze(-1).unsqueeze(-1)
378
+ masked_x = x * mask
379
+
380
+ # Compute the sum and the count of valid (unmasked) elements along the x_size dimension
381
+ sum_x = masked_x.sum(dim=1)
382
+ count_x = mask.sum(dim=1)
383
+
384
+ # Avoid division by zero
385
+ count_x[count_x == 0] = 1
386
+
387
+ # Compute the mean
388
+ mean_x = sum_x / count_x
389
+
390
+ return mean_x
@@ -7,6 +7,7 @@ from torch import nn, cat, split, Tensor
7
7
 
8
8
  from reflectorch.models.networks.residual_net import ResidualMLP
9
9
  from reflectorch.models.encoders.conv_encoder import ConvEncoder
10
+ from reflectorch.models.encoders.integral_kernel_embedding import IntegralConvEmbedding
10
11
  from reflectorch.models.encoders.fno import FnoEncoder
11
12
  from reflectorch.models.activations import activation_by_name
12
13
 
@@ -18,7 +19,7 @@ class NetworkWithPriors(nn.Module):
18
19
  :align: center
19
20
 
20
21
  Args:
21
- embedding_net_type (str): the type of embedding network, either 'conv' or 'fno'.
22
+ embedding_net_type (str): the type of embedding network, either 'conv', 'fno' or 'integral_conv'.
22
23
  embedding_net_kwargs (dict): dictionary containing the keyword arguments for the embedding network.
23
24
  dim_out (int, optional): the dimension of the output produced by the MLP. Defaults to 8.
24
25
  dim_conditioning_params (int, optional): the dimension of other parameters the network is conditioned on (e.g. for the smearing coefficient dq/q)
@@ -66,6 +67,8 @@ class NetworkWithPriors(nn.Module):
66
67
  self.embedding_net = ConvEncoder(**embedding_net_kwargs)
67
68
  elif embedding_net_type == 'fno':
68
69
  self.embedding_net = FnoEncoder(**embedding_net_kwargs)
70
+ elif embedding_net_type == 'integral_conv':
71
+ self.embedding_net = IntegralConvEmbedding(**embedding_net_kwargs)
69
72
  elif embedding_net_type == 'no_embedding_net':
70
73
  self.embedding_net = nn.Identity()
71
74
  else:
@@ -108,7 +111,7 @@ class NetworkWithPriors(nn.Module):
108
111
  self.embedding_net.load_weights(pretrained_embedding_net)
109
112
 
110
113
 
111
- def forward(self, curves, bounds, q_values=None, conditioning_params=None):
114
+ def forward(self, curves, bounds, q_values=None, conditioning_params=None, key_padding_mask=None, unscaled_q_values=None):
112
115
  """
113
116
  Args:
114
117
  scaled_curves (torch.Tensor): Input tensor of shape [batch_size, n_points] or [batch_size, n_channels, n_points].
@@ -121,13 +124,16 @@ class NetworkWithPriors(nn.Module):
121
124
  curves = curves.unsqueeze(1)
122
125
 
123
126
  additional_channels = []
124
- if q_values is not None:
127
+ if q_values is not None and not isinstance(self.embedding_net, IntegralConvEmbedding):
125
128
  additional_channels.append(q_values.unsqueeze(1))
126
129
 
127
130
  if additional_channels:
128
131
  curves = torch.cat([curves] + additional_channels, dim=1) # [batch_size, n_channels, n_points]
129
132
 
130
- x = self.embedding_net(curves)
133
+ if isinstance(self.embedding_net, IntegralConvEmbedding):
134
+ x = self.embedding_net(q=unscaled_q_values.float(), y=curves.permute(0, 2, 1), drop_mask=key_padding_mask)
135
+ else:
136
+ x = self.embedding_net(curves)
131
137
 
132
138
  if self.conditioning == 'concat':
133
139
  x = torch.cat([x, bounds] + ([conditioning_params] if conditioning_params is not None else []), dim=-1)
reflectorch/runs/utils.py CHANGED
@@ -278,7 +278,7 @@ def load_pretrained(model, model_name: str, saved_models_dir: Path):
278
278
  raise FileNotFoundError(f'File {str(model_path)} does not exist.')
279
279
 
280
280
  try:
281
- pretrained = torch.load(model_path)
281
+ pretrained = torch.load(model_path, weights_only=False)
282
282
  except Exception as err:
283
283
  raise RuntimeError(f'Could not load model from {str(model_path)}') from err
284
284
 
@@ -295,6 +295,8 @@ def load_pretrained(model, model_name: str, saved_models_dir: Path):
295
295
  def init_dset(config: dict):
296
296
  """Initializes the dataset / dataloader object"""
297
297
  dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader
298
+ dset_kwargs = config.get('kwargs', {})
299
+
298
300
  prior_sampler = init_from_conf(config['prior_sampler'])
299
301
  intensity_noise = init_from_conf(config['intensity_noise'])
300
302
  q_generator = init_from_conf(config['q_generator'])
@@ -309,6 +311,7 @@ def init_dset(config: dict):
309
311
  curves_scaler=curves_scaler,
310
312
  smearing=smearing,
311
313
  q_noise=q_noise,
314
+ **dset_kwargs,
312
315
  )
313
316
 
314
317
  return dset
@@ -358,7 +361,7 @@ def convert_pt_to_safetensors(input_dir):
358
361
  continue
359
362
 
360
363
  print(f"Converting {pt_file_path} to .safetensors format.")
361
- data_pt = torch.load(pt_file_path)
364
+ data_pt = torch.load(pt_file_path, weights_only=False)
362
365
  model_state_dict = data_pt["model"]
363
366
  model_state_dict = split_complex_tensors(model_state_dict) #handle tensors with complex dtype which are not natively supported by safetensors
364
367
 
reflectorch/utils.py CHANGED
@@ -66,3 +66,33 @@ def energy_to_wavelength(energy: float):
66
66
  def wavelength_to_energy(wavelength: float):
67
67
  """Conversion from photon wavelength (angstroms) to photon energy (eV)"""
68
68
  return 1.2398 / wavelength * 1e4
69
+
70
+ def get_filtering_mask(Q, R, dR, threshold=0.3, consecutive=3,
71
+ remove_singles=True, remove_consecutives=True,
72
+ q_start_trunc=0.1):
73
+ Q, R, dR = Q.copy(), R.copy(), dR.copy()
74
+ rel_error = np.abs(dR / R)
75
+
76
+ # Mask for singles
77
+ mask_singles = (rel_error >= threshold) if remove_singles else np.zeros_like(Q, dtype=bool)
78
+
79
+ # Mask for truncation
80
+ mask_consecutive = np.zeros_like(Q, dtype=bool)
81
+ if remove_consecutives:
82
+ count = 0
83
+ cutoff_idx = None
84
+ for i in range(len(Q)):
85
+ if Q[i] < q_start_trunc:
86
+ continue
87
+ if rel_error[i] >= threshold:
88
+ count += 1
89
+ if count >= consecutive:
90
+ cutoff_idx = i - consecutive + 1
91
+ break
92
+ else:
93
+ count = 0
94
+ if cutoff_idx is not None:
95
+ mask_consecutive[cutoff_idx:] = True
96
+
97
+ final_mask = mask_singles | mask_consecutive
98
+ return ~final_mask
@@ -1,10 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reflectorch
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: A Pytorch-based package for the analysis of reflectometry data
5
5
  Author-email: Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>
6
6
  Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
7
- License-Expression: MIT
8
7
  Project-URL: Source, https://github.com/schreiber-lab/reflectorch/
9
8
  Project-URL: Issues, https://github.com/schreiber-lab/reflectorch/issues
10
9
  Project-URL: Documentation, https://schreiber-lab.github.io/reflectorch/
@@ -106,6 +105,8 @@ Configuration files and the corresponding pretrained model weights are hosted on
106
105
  <!-- [![Docker](https://img.shields.io/badge/Docker-2496ED.svg?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/)
107
106
  Docker images for reflectorch *will* be hosted on Dockerhub. -->
108
107
 
108
+ ## Contributing
109
+ If you'd like to contribute to the package, please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
109
110
 
110
111
  ## Citation
111
112
  If you find our work useful in your research, please cite as follows: