reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,390 @@
1
+ from __future__ import annotations
2
+ from typing import Union
3
+
4
+ import torch
5
+ from torch import nn, Tensor, stack, cat
6
+ from reflectorch.models.activations import activation_by_name
7
+ import reflectorch
8
+
9
+ ###embedding network adapted from the PANPE repository
10
+
11
+ __all__ = [
12
+ "IntegralConvEmbedding",
13
+ ]
14
+
15
+ class IntegralConvEmbedding(nn.Module):
16
+ def __init__(
17
+ self,
18
+ z_num: Union[int, tuple[int, ...]],
19
+ z_range: tuple[float, float] = None,
20
+ in_dim: int = 2,
21
+ kernel_coef: int = 16,
22
+ dim_embedding: int = 256,
23
+ conv_dims: tuple[int, ...] = (32, 64, 128),
24
+ num_blocks: int = 4,
25
+ use_batch_norm: bool = False,
26
+ use_layer_norm: bool = True,
27
+ use_fft: bool = False,
28
+ activation: str = "gelu",
29
+ conv_activation: str = "lrelu",
30
+ resnet_activation: str = "relu",
31
+ ) -> None:
32
+ super().__init__()
33
+
34
+ if isinstance(z_num, int):
35
+ z_num = (z_num,)
36
+ num_kernel = len(z_num)
37
+
38
+ if z_range is not None:
39
+ zs = [(z_range[0], z_range[1], nz) for nz in z_num]
40
+ else:
41
+ zs = z_num
42
+
43
+ self.in_dim = in_dim
44
+
45
+ self.kernels = nn.ModuleList(
46
+ [
47
+ IntegralKernelBlock(
48
+ z,
49
+ in_dim,
50
+ kernel_coef=kernel_coef,
51
+ latent_dim=dim_embedding,
52
+ conv_dims=conv_dims,
53
+ use_fft=use_fft,
54
+ activation=activation,
55
+ conv_activation=conv_activation,
56
+ )
57
+ for z in zs
58
+ ]
59
+ )
60
+
61
+ self.fc = reflectorch.models.networks.residual_net.ResidualMLP(
62
+ dim_in=dim_embedding * num_kernel,
63
+ dim_out=dim_embedding,
64
+ layer_width=2 * dim_embedding,
65
+ num_blocks=num_blocks,
66
+ use_batch_norm=use_batch_norm,
67
+ use_layer_norm=use_layer_norm,
68
+ activation=resnet_activation,
69
+ )
70
+
71
+ def forward(self, q, y, drop_mask=None) -> Tensor:
72
+ x = cat([kernel(q, y, drop_mask=drop_mask) for kernel in self.kernels], dim=-1)
73
+ x = self.fc(x)
74
+
75
+ return x
76
+
77
+
78
+ class IntegralKernelBlock(nn.Module):
79
+ """
80
+ Examples:
81
+ >>> x = torch.rand(2, 100)
82
+ >>> y = torch.rand(2, 100, 3)
83
+ >>> block = IntegralKernelBlock((0, 1, 10), in_dim=3, latent_dim=32)
84
+ >>> output = block(x, y)
85
+ >>> output.shape
86
+ torch.Size([2, 32])
87
+
88
+ >>> block = IntegralKernelBlock(10, in_dim=3, latent_dim=32)
89
+ >>> output = block(x, y)
90
+ >>> output.shape
91
+ torch.Size([2, 32])
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ z: tuple[float, float, int] or int,
97
+ in_dim: int,
98
+ kernel_coef: int = 2,
99
+ latent_dim: int = 32,
100
+ conv_dims: tuple[int, ...] = (32, 64, 128),
101
+ use_fft: bool = False,
102
+ activation: str = "gelu",
103
+ conv_activation: str = "lrelu",
104
+ ):
105
+ super().__init__()
106
+
107
+ if isinstance(z, int):
108
+ z_num = z
109
+ kernel = FullIntegralKernel(z_num, in_dim=in_dim, kernel_coef=kernel_coef)
110
+ else:
111
+ kernel = FastIntegralKernel(
112
+ z, in_dim=in_dim, kernel_coef=kernel_coef, activation=activation
113
+ )
114
+ z_num = z[-1]
115
+
116
+ assert z_num % 2 == 0, "z_num should be even"
117
+
118
+ self.kernel = kernel
119
+ self.z_num = z_num
120
+ self.in_dim = in_dim
121
+ self.latent_dim = latent_dim
122
+ self.use_fft = use_fft
123
+
124
+ self.fc_in_dim = self.latent_dim + self.in_dim * self.z_num
125
+ if self.use_fft:
126
+ self.fc_in_dim += self.in_dim * 2 + self.in_dim * self.z_num
127
+
128
+ self.conv = reflectorch.models.encoders.conv_encoder.ConvEncoder(
129
+ dim_avpool=8,
130
+ hidden_channels=conv_dims,
131
+ in_channels=in_dim,
132
+ dim_embedding=latent_dim,
133
+ activation=conv_activation,
134
+ )
135
+ self.fc = FCBlock(
136
+ in_dim=self.fc_in_dim, hid_dim=self.latent_dim * 2, out_dim=self.latent_dim
137
+ )
138
+
139
+ def forward(self, x: Tensor, y: Tensor, drop_mask: Tensor = None) -> Tensor:
140
+ x = self.kernel(x, y, drop_mask=drop_mask)
141
+
142
+ assert x.shape == (x.shape[0], self.in_dim, self.z_num)
143
+
144
+ xc = self.conv(x) # (batch, latent_dim)
145
+
146
+ assert xc.shape == (x.shape[0], self.latent_dim)
147
+
148
+ if self.use_fft:
149
+ fft_x = torch.fft.rfft(x, dim=-1, norm="ortho") # (batch, in_dim, z_num)
150
+
151
+ fft_x = torch.cat(
152
+ [fft_x.real, fft_x.imag], -1
153
+ ) # (batch, in_dim, 2 * z_num)
154
+
155
+ assert fft_x.shape == (x.shape[0], x.shape[1], self.z_num + 2)
156
+
157
+ fft_x = fft_x.flatten(1) # (batch, in_dim * (z_num + 2))
158
+
159
+ x = torch.cat(
160
+ [x.flatten(1), fft_x, xc], -1
161
+ ) # (batch, in_dim * z_num * 3 + latent_dim)
162
+ else:
163
+ x = torch.cat([x.flatten(1), xc], -1)
164
+
165
+ assert (
166
+ x.shape[1] == self.fc_in_dim
167
+ ), f"Expected dim {self.fc_in_dim}, got {x.shape[1]}"
168
+
169
+ x = self.fc(x) # (batch, latent_dim)
170
+
171
+ return x
172
+
173
+
174
+ class FastIntegralKernel(nn.Module):
175
+ def __init__(
176
+ self,
177
+ z: tuple[float, float, int],
178
+ kernel_coef: int = 16,
179
+ in_dim: int = 1,
180
+ activation: str = "gelu",
181
+ ):
182
+ super().__init__()
183
+
184
+ z = torch.linspace(*z)
185
+
186
+ self.kernel = FCBlock(
187
+ in_dim + 2, kernel_coef * in_dim, in_dim, activation=activation
188
+ )
189
+
190
+ self.register_buffer("z", z)
191
+
192
+ def _get_z(self, x: Tensor):
193
+ # x.shape == (batch_size, num_x)
194
+ dz = self.z[1] - self.z[0]
195
+ indices = torch.ceil((x - self.z[0] - dz / 2) / dz).to(torch.int64)
196
+
197
+ z = torch.index_select(self.z, 0, indices.flatten()).view(*x.shape)
198
+
199
+ return z, indices
200
+
201
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
202
+ z, indices = self._get_z(x)
203
+ xz = torch.stack([x, z], -1)
204
+ kernel_input = torch.cat([xz, y], -1)
205
+ output = self.kernel(kernel_input) # (batch, x_num, in_dim)
206
+
207
+ output = compute_means(
208
+ output * y, indices, self.z.shape[-1], drop_mask=drop_mask
209
+ ) # (batch, z_num, in_dim)
210
+
211
+ output = output.swapaxes(1, 2) # (batch, in_dim, z_num)
212
+
213
+ return output
214
+
215
+
216
+ class FullIntegralKernel(nn.Module):
217
+ def __init__(
218
+ self,
219
+ z_num: int,
220
+ kernel_coef: int = 1,
221
+ in_dim: int = 1,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.z_num = z_num
226
+ self.in_dim = in_dim
227
+
228
+ self.kernel = nn.Sequential(
229
+ nn.Linear(in_dim + 1, z_num * kernel_coef),
230
+ nn.LayerNorm(z_num * kernel_coef),
231
+ nn.ReLU(),
232
+ nn.Linear(z_num * kernel_coef, z_num * in_dim),
233
+ )
234
+
235
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
236
+ # x.shape == (batch_size, num_x)
237
+ # y.shape == (batch_size, num_x, in_dim)
238
+ # drop_mask.shape == (batch_size, num_x)
239
+
240
+ batch_size, num_x = x.shape
241
+
242
+ kernel_input = torch.cat([x.unsqueeze(-1), y], -1) # (batch, x_num, in_dim + 1)
243
+ x = self.kernel(kernel_input) # (batch, x_num, z_num * in_dim)
244
+ x = x.reshape(
245
+ *x.shape[:-1], self.z_num, self.in_dim
246
+ ) # (batch, x_num, z_num, in_dim)
247
+ # permute to get (batch, z_num, x_num, in_dim)
248
+ x = x.permute(0, 2, 1, 3)
249
+
250
+ y = y.unsqueeze(1) # (batch, 1, x_num, in_dim)
251
+
252
+ assert x.shape == (
253
+ batch_size,
254
+ self.z_num,
255
+ num_x,
256
+ self.in_dim,
257
+ ) # (batch, z_num, in_dim, x_num)
258
+ assert y.shape == (
259
+ batch_size,
260
+ 1,
261
+ num_x,
262
+ self.in_dim,
263
+ ) # (batch, 1, x_num, in_dim)
264
+
265
+ if drop_mask is not None:
266
+ x = x * y
267
+ x = x.permute(0, 2, 1, 3) # (batch, x_num, z_num, in_dim)
268
+ x = masked_mean(x, drop_mask)
269
+ else:
270
+ x = (x * y).mean(-2) # (batch, z_num, in_dim)
271
+
272
+ assert x.shape == (batch_size, self.z_num, self.in_dim), f"{x.shape}"
273
+
274
+ x = x.swapaxes(1, 2) # (batch, in_dim, z_num)
275
+
276
+ return x
277
+
278
+
279
+ class FCBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_dim: int = 2,
283
+ hid_dim: int = 16,
284
+ out_dim: int = 16,
285
+ activation: str = "gelu",
286
+ ):
287
+ super().__init__()
288
+
289
+ self.fc1 = nn.Linear(in_dim, hid_dim)
290
+ self.layer_norm = nn.LayerNorm(hid_dim)
291
+ self.activation = activation_by_name(activation)()
292
+ self.fc2 = nn.Linear(hid_dim, out_dim)
293
+
294
+ def forward(self, x: Tensor) -> Tensor:
295
+ x = self.fc1(x)
296
+ x = self.layer_norm(x)
297
+ x = self.activation(x)
298
+ x = self.fc2(x)
299
+ return x
300
+ # return self.kernel(x)
301
+
302
+
303
+ def compute_means(x, indices, z: int, drop_mask: Tensor = None):
304
+ """
305
+ Compute the mean values of tensor 'x' for each unique index in 'indices' across each batch.
306
+
307
+ This function calculates the mean of elements in 'x' that correspond to each unique index in 'indices'.
308
+ The computation is performed for each batch separately, and the function is optimized to avoid Python loops
309
+ by using advanced PyTorch operations.
310
+
311
+ Parameters:
312
+ x (torch.Tensor): A tensor of shape (batch_size, n, d) containing the values to be averaged.
313
+ 'x' should be a floating-point tensor.
314
+ indices (torch.Tensor): An integer tensor of shape (batch_size, n) containing the indices.
315
+ The values in 'indices' should be in the range [0, z-1].
316
+ z (int): The number of unique indices. This determines the second dimension of the output tensor.
317
+ drop_mask (torch.Tensor): A boolean tensor of shape (batch_size, n) containing a mask for the indices to drop.
318
+ If None, all indices are used.
319
+
320
+ Returns:
321
+ torch.Tensor: A tensor of shape (batch_size, z, d) containing the mean values for each index in each batch.
322
+ If an index does not appear in a batch, its corresponding mean values are zeros.
323
+
324
+ Example:
325
+ >>> batch_size, n, d, z = 3, 4, 5, 6
326
+ >>> indices = torch.randint(0, z, (batch_size, n))
327
+ >>> x = torch.randn(batch_size, n, d)
328
+ >>> y = compute_means(x, indices, z)
329
+ >>> print(y.shape)
330
+ torch.Size([3, 6, 5])
331
+ """
332
+
333
+ batch_size, n, d = x.shape
334
+ device = x.device
335
+
336
+ drop = drop_mask is not None
337
+
338
+ # Initialize tensors to hold sums and counts
339
+ sums = torch.zeros(batch_size, z + int(drop), d, device=device)
340
+ counts = torch.zeros(batch_size, z + int(drop), device=device)
341
+
342
+ if drop_mask is not None:
343
+ # Set the values of the indices to drop to z
344
+ indices = indices.masked_fill(~drop_mask, z)
345
+
346
+ indices_expanded = indices.unsqueeze(-1).expand_as(x)
347
+ sums.scatter_add_(1, indices_expanded, x)
348
+ counts.scatter_add_(1, indices, torch.ones_like(indices, dtype=x.dtype))
349
+
350
+ if drop:
351
+ # Remove the z values from the sums and counts
352
+ sums = sums[:, :-1]
353
+ counts = counts[:, :-1]
354
+
355
+ # Compute the mean and handle division by zero
356
+ mean = sums / counts.unsqueeze(-1).clamp(min=1)
357
+
358
+ return mean
359
+
360
+
361
+ def masked_mean(x, mask):
362
+ """
363
+ Computes the mean of tensor x along the x_size dimension,
364
+ while masking out elements where the corresponding value in the mask is False.
365
+
366
+ Args:
367
+ x (torch.Tensor): A tensor of shape (batch, x_size, z, d).
368
+ mask (torch.Tensor): A boolean mask of shape (batch, x_size).
369
+
370
+ Returns:
371
+ torch.Tensor: The result tensor of shape (batch, z, d) after applying the mask and computing the mean.
372
+ """
373
+ if not mask.dtype == torch.bool:
374
+ raise TypeError("Mask must be a boolean tensor.")
375
+
376
+ # Ensure the mask is broadcastable to the shape of x
377
+ mask = mask.unsqueeze(-1).unsqueeze(-1)
378
+ masked_x = x * mask
379
+
380
+ # Compute the sum and the count of valid (unmasked) elements along the x_size dimension
381
+ sum_x = masked_x.sum(dim=1)
382
+ count_x = mask.sum(dim=1)
383
+
384
+ # Avoid division by zero
385
+ count_x[count_x == 0] = 1
386
+
387
+ # Compute the mean
388
+ mean_x = sum_x / count_x
389
+
390
+ return mean_x
@@ -0,0 +1,14 @@
1
+ from reflectorch.models.networks.mlp_networks import (
2
+ NetworkWithPriors,
3
+ NetworkWithPriorsConvEmb,
4
+ NetworkWithPriorsFnoEmb,
5
+ )
6
+ from reflectorch.models.networks.residual_net import ResidualMLP
7
+
8
+
9
+ __all__ = [
10
+ "ResidualMLP",
11
+ "NetworkWithPriors",
12
+ "NetworkWithPriorsConvEmb",
13
+ "NetworkWithPriorsFnoEmb",
14
+ ]