reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,134 +1,134 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
import torch.nn.functional as F
|
|
4
|
-
|
|
5
|
-
from reflectorch.models.activations import activation_by_name
|
|
6
|
-
|
|
7
|
-
class SpectralConv1d(nn.Module):
|
|
8
|
-
def __init__(self, in_channels, out_channels, modes):
|
|
9
|
-
super().__init__()
|
|
10
|
-
|
|
11
|
-
"""
|
|
12
|
-
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
self.in_channels = in_channels
|
|
16
|
-
self.out_channels = out_channels
|
|
17
|
-
self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
|
|
18
|
-
|
|
19
|
-
self.scale = (1 / (in_channels*out_channels))
|
|
20
|
-
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
|
|
21
|
-
|
|
22
|
-
# Complex multiplication
|
|
23
|
-
def compl_mul1d(self, input, weights):
|
|
24
|
-
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
25
|
-
return torch.einsum("bix,iox->box", input, weights)
|
|
26
|
-
|
|
27
|
-
def forward(self, x):
|
|
28
|
-
batchsize = x.shape[0]
|
|
29
|
-
#Compute Fourier coeffcients up to factor of e^(- something constant)
|
|
30
|
-
x_ft = torch.fft.rfft(x)
|
|
31
|
-
|
|
32
|
-
# Multiply relevant Fourier modes
|
|
33
|
-
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
|
|
34
|
-
out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
|
|
35
|
-
|
|
36
|
-
#Return to physical space
|
|
37
|
-
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
|
38
|
-
return x
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class FnoEncoder(nn.Module):
|
|
42
|
-
"""An embedding network based on the Fourier Neural Operator (FNO) architecture
|
|
43
|
-
|
|
44
|
-
.. image:: ../documentation/fig_reflectometry_embedding_networks.png
|
|
45
|
-
:width: 400px
|
|
46
|
-
:align: center
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
in_channels (int): number of input channels
|
|
50
|
-
dim_embedding (int): dimension of the output embedding
|
|
51
|
-
modes (int): number of Fourier modes
|
|
52
|
-
width_fno (int): number of channels of the intermediate representations
|
|
53
|
-
n_fno_blocks (int): number of FNO blocks
|
|
54
|
-
activation (str): the activation function
|
|
55
|
-
fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
|
|
56
|
-
fsa_activation (str): the activation function of the fusion self attention block
|
|
57
|
-
"""
|
|
58
|
-
def __init__(
|
|
59
|
-
self,
|
|
60
|
-
in_channels: int = 2,
|
|
61
|
-
dim_embedding: int = 128,
|
|
62
|
-
modes: int = 32,
|
|
63
|
-
width_fno: int = 64,
|
|
64
|
-
n_fno_blocks: int = 6,
|
|
65
|
-
activation: str = 'gelu',
|
|
66
|
-
fusion_self_attention: bool = False,
|
|
67
|
-
fsa_activation: str = 'tanh',
|
|
68
|
-
):
|
|
69
|
-
super().__init__()
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
self.in_channels = in_channels
|
|
73
|
-
self.dim_embedding = dim_embedding
|
|
74
|
-
|
|
75
|
-
self.modes = modes
|
|
76
|
-
self.width_fno = width_fno
|
|
77
|
-
self.n_fno_blocks = n_fno_blocks
|
|
78
|
-
self.activation = activation_by_name(activation)()
|
|
79
|
-
self.fusion_self_attention = fusion_self_attention
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
|
|
83
|
-
self.spectral_convs = nn.ModuleList([
|
|
84
|
-
SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
|
|
85
|
-
])
|
|
86
|
-
self.w_convs = nn.ModuleList([
|
|
87
|
-
nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
|
|
88
|
-
])
|
|
89
|
-
self.fc_out = nn.Linear(width_fno, dim_embedding)
|
|
90
|
-
|
|
91
|
-
if fusion_self_attention:
|
|
92
|
-
self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
|
|
93
|
-
|
|
94
|
-
def forward(self, x):
|
|
95
|
-
""""""
|
|
96
|
-
|
|
97
|
-
x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
|
|
98
|
-
x = self.fc0(x)
|
|
99
|
-
x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
|
|
100
|
-
|
|
101
|
-
for i in range(self.n_fno_blocks):
|
|
102
|
-
x1 = self.spectral_convs[i](x)
|
|
103
|
-
x2 = self.w_convs[i](x)
|
|
104
|
-
|
|
105
|
-
x = x1 + x2
|
|
106
|
-
x = self.activation(x)
|
|
107
|
-
|
|
108
|
-
if self.fusion_self_attention:
|
|
109
|
-
x = x.permute(0, 2, 1)
|
|
110
|
-
x = self.fusion(x)
|
|
111
|
-
else:
|
|
112
|
-
x = x.mean(dim=-1)
|
|
113
|
-
|
|
114
|
-
x = self.fc_out(x)
|
|
115
|
-
|
|
116
|
-
return x
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
class FusionSelfAttention(nn.Module):
|
|
120
|
-
def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
|
|
121
|
-
super().__init__()
|
|
122
|
-
activation = activation_by_name(activation)()
|
|
123
|
-
self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
|
|
124
|
-
activation,
|
|
125
|
-
nn.Linear(hidden_dim, 1, bias=False))
|
|
126
|
-
|
|
127
|
-
def forward(self,
|
|
128
|
-
c: torch.Tensor, # (batch_size x seq_len x embed_dim)
|
|
129
|
-
mask: torch.Tensor = None, # (batch_size x seq_len)
|
|
130
|
-
):
|
|
131
|
-
a = self.fuser(c)
|
|
132
|
-
alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
|
|
133
|
-
alpha = alpha/alpha.sum(dim=1, keepdim=True)
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from reflectorch.models.activations import activation_by_name
|
|
6
|
+
|
|
7
|
+
class SpectralConv1d(nn.Module):
|
|
8
|
+
def __init__(self, in_channels, out_channels, modes):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
self.in_channels = in_channels
|
|
16
|
+
self.out_channels = out_channels
|
|
17
|
+
self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
|
|
18
|
+
|
|
19
|
+
self.scale = (1 / (in_channels*out_channels))
|
|
20
|
+
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
|
|
21
|
+
|
|
22
|
+
# Complex multiplication
|
|
23
|
+
def compl_mul1d(self, input, weights):
|
|
24
|
+
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
25
|
+
return torch.einsum("bix,iox->box", input, weights)
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
batchsize = x.shape[0]
|
|
29
|
+
#Compute Fourier coeffcients up to factor of e^(- something constant)
|
|
30
|
+
x_ft = torch.fft.rfft(x)
|
|
31
|
+
|
|
32
|
+
# Multiply relevant Fourier modes
|
|
33
|
+
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
|
|
34
|
+
out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
|
|
35
|
+
|
|
36
|
+
#Return to physical space
|
|
37
|
+
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FnoEncoder(nn.Module):
|
|
42
|
+
"""An embedding network based on the Fourier Neural Operator (FNO) architecture
|
|
43
|
+
|
|
44
|
+
.. image:: ../documentation/fig_reflectometry_embedding_networks.png
|
|
45
|
+
:width: 400px
|
|
46
|
+
:align: center
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
in_channels (int): number of input channels
|
|
50
|
+
dim_embedding (int): dimension of the output embedding
|
|
51
|
+
modes (int): number of Fourier modes
|
|
52
|
+
width_fno (int): number of channels of the intermediate representations
|
|
53
|
+
n_fno_blocks (int): number of FNO blocks
|
|
54
|
+
activation (str): the activation function
|
|
55
|
+
fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
|
|
56
|
+
fsa_activation (str): the activation function of the fusion self attention block
|
|
57
|
+
"""
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
in_channels: int = 2,
|
|
61
|
+
dim_embedding: int = 128,
|
|
62
|
+
modes: int = 32,
|
|
63
|
+
width_fno: int = 64,
|
|
64
|
+
n_fno_blocks: int = 6,
|
|
65
|
+
activation: str = 'gelu',
|
|
66
|
+
fusion_self_attention: bool = False,
|
|
67
|
+
fsa_activation: str = 'tanh',
|
|
68
|
+
):
|
|
69
|
+
super().__init__()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
self.in_channels = in_channels
|
|
73
|
+
self.dim_embedding = dim_embedding
|
|
74
|
+
|
|
75
|
+
self.modes = modes
|
|
76
|
+
self.width_fno = width_fno
|
|
77
|
+
self.n_fno_blocks = n_fno_blocks
|
|
78
|
+
self.activation = activation_by_name(activation)()
|
|
79
|
+
self.fusion_self_attention = fusion_self_attention
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
|
|
83
|
+
self.spectral_convs = nn.ModuleList([
|
|
84
|
+
SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
|
|
85
|
+
])
|
|
86
|
+
self.w_convs = nn.ModuleList([
|
|
87
|
+
nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
|
|
88
|
+
])
|
|
89
|
+
self.fc_out = nn.Linear(width_fno, dim_embedding)
|
|
90
|
+
|
|
91
|
+
if fusion_self_attention:
|
|
92
|
+
self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
|
|
93
|
+
|
|
94
|
+
def forward(self, x):
|
|
95
|
+
""""""
|
|
96
|
+
|
|
97
|
+
x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
|
|
98
|
+
x = self.fc0(x)
|
|
99
|
+
x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
|
|
100
|
+
|
|
101
|
+
for i in range(self.n_fno_blocks):
|
|
102
|
+
x1 = self.spectral_convs[i](x)
|
|
103
|
+
x2 = self.w_convs[i](x)
|
|
104
|
+
|
|
105
|
+
x = x1 + x2
|
|
106
|
+
x = self.activation(x)
|
|
107
|
+
|
|
108
|
+
if self.fusion_self_attention:
|
|
109
|
+
x = x.permute(0, 2, 1)
|
|
110
|
+
x = self.fusion(x)
|
|
111
|
+
else:
|
|
112
|
+
x = x.mean(dim=-1)
|
|
113
|
+
|
|
114
|
+
x = self.fc_out(x)
|
|
115
|
+
|
|
116
|
+
return x
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class FusionSelfAttention(nn.Module):
|
|
120
|
+
def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
|
|
121
|
+
super().__init__()
|
|
122
|
+
activation = activation_by_name(activation)()
|
|
123
|
+
self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
|
|
124
|
+
activation,
|
|
125
|
+
nn.Linear(hidden_dim, 1, bias=False))
|
|
126
|
+
|
|
127
|
+
def forward(self,
|
|
128
|
+
c: torch.Tensor, # (batch_size x seq_len x embed_dim)
|
|
129
|
+
mask: torch.Tensor = None, # (batch_size x seq_len)
|
|
130
|
+
):
|
|
131
|
+
a = self.fuser(c)
|
|
132
|
+
alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
|
|
133
|
+
alpha = alpha/alpha.sum(dim=1, keepdim=True)
|
|
134
134
|
return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
|
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, Tensor, stack, cat
|
|
6
|
+
from reflectorch.models.activations import activation_by_name
|
|
7
|
+
import reflectorch
|
|
8
|
+
|
|
9
|
+
###embedding network adapted from the PANPE repository
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"IntegralConvEmbedding",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
class IntegralConvEmbedding(nn.Module):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
z_num: Union[int, tuple[int, ...]],
|
|
19
|
+
z_range: tuple[float, float] = None,
|
|
20
|
+
in_dim: int = 2,
|
|
21
|
+
kernel_coef: int = 16,
|
|
22
|
+
dim_embedding: int = 256,
|
|
23
|
+
conv_dims: tuple[int, ...] = (32, 64, 128),
|
|
24
|
+
num_blocks: int = 4,
|
|
25
|
+
use_batch_norm: bool = False,
|
|
26
|
+
use_layer_norm: bool = True,
|
|
27
|
+
use_fft: bool = False,
|
|
28
|
+
activation: str = "gelu",
|
|
29
|
+
conv_activation: str = "lrelu",
|
|
30
|
+
resnet_activation: str = "relu",
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
if isinstance(z_num, int):
|
|
35
|
+
z_num = (z_num,)
|
|
36
|
+
num_kernel = len(z_num)
|
|
37
|
+
|
|
38
|
+
if z_range is not None:
|
|
39
|
+
zs = [(z_range[0], z_range[1], nz) for nz in z_num]
|
|
40
|
+
else:
|
|
41
|
+
zs = z_num
|
|
42
|
+
|
|
43
|
+
self.in_dim = in_dim
|
|
44
|
+
|
|
45
|
+
self.kernels = nn.ModuleList(
|
|
46
|
+
[
|
|
47
|
+
IntegralKernelBlock(
|
|
48
|
+
z,
|
|
49
|
+
in_dim,
|
|
50
|
+
kernel_coef=kernel_coef,
|
|
51
|
+
latent_dim=dim_embedding,
|
|
52
|
+
conv_dims=conv_dims,
|
|
53
|
+
use_fft=use_fft,
|
|
54
|
+
activation=activation,
|
|
55
|
+
conv_activation=conv_activation,
|
|
56
|
+
)
|
|
57
|
+
for z in zs
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.fc = reflectorch.models.networks.residual_net.ResidualMLP(
|
|
62
|
+
dim_in=dim_embedding * num_kernel,
|
|
63
|
+
dim_out=dim_embedding,
|
|
64
|
+
layer_width=2 * dim_embedding,
|
|
65
|
+
num_blocks=num_blocks,
|
|
66
|
+
use_batch_norm=use_batch_norm,
|
|
67
|
+
use_layer_norm=use_layer_norm,
|
|
68
|
+
activation=resnet_activation,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def forward(self, q, y, drop_mask=None) -> Tensor:
|
|
72
|
+
x = cat([kernel(q, y, drop_mask=drop_mask) for kernel in self.kernels], dim=-1)
|
|
73
|
+
x = self.fc(x)
|
|
74
|
+
|
|
75
|
+
return x
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class IntegralKernelBlock(nn.Module):
|
|
79
|
+
"""
|
|
80
|
+
Examples:
|
|
81
|
+
>>> x = torch.rand(2, 100)
|
|
82
|
+
>>> y = torch.rand(2, 100, 3)
|
|
83
|
+
>>> block = IntegralKernelBlock((0, 1, 10), in_dim=3, latent_dim=32)
|
|
84
|
+
>>> output = block(x, y)
|
|
85
|
+
>>> output.shape
|
|
86
|
+
torch.Size([2, 32])
|
|
87
|
+
|
|
88
|
+
>>> block = IntegralKernelBlock(10, in_dim=3, latent_dim=32)
|
|
89
|
+
>>> output = block(x, y)
|
|
90
|
+
>>> output.shape
|
|
91
|
+
torch.Size([2, 32])
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
z: tuple[float, float, int] or int,
|
|
97
|
+
in_dim: int,
|
|
98
|
+
kernel_coef: int = 2,
|
|
99
|
+
latent_dim: int = 32,
|
|
100
|
+
conv_dims: tuple[int, ...] = (32, 64, 128),
|
|
101
|
+
use_fft: bool = False,
|
|
102
|
+
activation: str = "gelu",
|
|
103
|
+
conv_activation: str = "lrelu",
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
|
|
107
|
+
if isinstance(z, int):
|
|
108
|
+
z_num = z
|
|
109
|
+
kernel = FullIntegralKernel(z_num, in_dim=in_dim, kernel_coef=kernel_coef)
|
|
110
|
+
else:
|
|
111
|
+
kernel = FastIntegralKernel(
|
|
112
|
+
z, in_dim=in_dim, kernel_coef=kernel_coef, activation=activation
|
|
113
|
+
)
|
|
114
|
+
z_num = z[-1]
|
|
115
|
+
|
|
116
|
+
assert z_num % 2 == 0, "z_num should be even"
|
|
117
|
+
|
|
118
|
+
self.kernel = kernel
|
|
119
|
+
self.z_num = z_num
|
|
120
|
+
self.in_dim = in_dim
|
|
121
|
+
self.latent_dim = latent_dim
|
|
122
|
+
self.use_fft = use_fft
|
|
123
|
+
|
|
124
|
+
self.fc_in_dim = self.latent_dim + self.in_dim * self.z_num
|
|
125
|
+
if self.use_fft:
|
|
126
|
+
self.fc_in_dim += self.in_dim * 2 + self.in_dim * self.z_num
|
|
127
|
+
|
|
128
|
+
self.conv = reflectorch.models.encoders.conv_encoder.ConvEncoder(
|
|
129
|
+
dim_avpool=8,
|
|
130
|
+
hidden_channels=conv_dims,
|
|
131
|
+
in_channels=in_dim,
|
|
132
|
+
dim_embedding=latent_dim,
|
|
133
|
+
activation=conv_activation,
|
|
134
|
+
)
|
|
135
|
+
self.fc = FCBlock(
|
|
136
|
+
in_dim=self.fc_in_dim, hid_dim=self.latent_dim * 2, out_dim=self.latent_dim
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask: Tensor = None) -> Tensor:
|
|
140
|
+
x = self.kernel(x, y, drop_mask=drop_mask)
|
|
141
|
+
|
|
142
|
+
assert x.shape == (x.shape[0], self.in_dim, self.z_num)
|
|
143
|
+
|
|
144
|
+
xc = self.conv(x) # (batch, latent_dim)
|
|
145
|
+
|
|
146
|
+
assert xc.shape == (x.shape[0], self.latent_dim)
|
|
147
|
+
|
|
148
|
+
if self.use_fft:
|
|
149
|
+
fft_x = torch.fft.rfft(x, dim=-1, norm="ortho") # (batch, in_dim, z_num)
|
|
150
|
+
|
|
151
|
+
fft_x = torch.cat(
|
|
152
|
+
[fft_x.real, fft_x.imag], -1
|
|
153
|
+
) # (batch, in_dim, 2 * z_num)
|
|
154
|
+
|
|
155
|
+
assert fft_x.shape == (x.shape[0], x.shape[1], self.z_num + 2)
|
|
156
|
+
|
|
157
|
+
fft_x = fft_x.flatten(1) # (batch, in_dim * (z_num + 2))
|
|
158
|
+
|
|
159
|
+
x = torch.cat(
|
|
160
|
+
[x.flatten(1), fft_x, xc], -1
|
|
161
|
+
) # (batch, in_dim * z_num * 3 + latent_dim)
|
|
162
|
+
else:
|
|
163
|
+
x = torch.cat([x.flatten(1), xc], -1)
|
|
164
|
+
|
|
165
|
+
assert (
|
|
166
|
+
x.shape[1] == self.fc_in_dim
|
|
167
|
+
), f"Expected dim {self.fc_in_dim}, got {x.shape[1]}"
|
|
168
|
+
|
|
169
|
+
x = self.fc(x) # (batch, latent_dim)
|
|
170
|
+
|
|
171
|
+
return x
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class FastIntegralKernel(nn.Module):
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
z: tuple[float, float, int],
|
|
178
|
+
kernel_coef: int = 16,
|
|
179
|
+
in_dim: int = 1,
|
|
180
|
+
activation: str = "gelu",
|
|
181
|
+
):
|
|
182
|
+
super().__init__()
|
|
183
|
+
|
|
184
|
+
z = torch.linspace(*z)
|
|
185
|
+
|
|
186
|
+
self.kernel = FCBlock(
|
|
187
|
+
in_dim + 2, kernel_coef * in_dim, in_dim, activation=activation
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
self.register_buffer("z", z)
|
|
191
|
+
|
|
192
|
+
def _get_z(self, x: Tensor):
|
|
193
|
+
# x.shape == (batch_size, num_x)
|
|
194
|
+
dz = self.z[1] - self.z[0]
|
|
195
|
+
indices = torch.ceil((x - self.z[0] - dz / 2) / dz).to(torch.int64)
|
|
196
|
+
|
|
197
|
+
z = torch.index_select(self.z, 0, indices.flatten()).view(*x.shape)
|
|
198
|
+
|
|
199
|
+
return z, indices
|
|
200
|
+
|
|
201
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask=None):
|
|
202
|
+
z, indices = self._get_z(x)
|
|
203
|
+
xz = torch.stack([x, z], -1)
|
|
204
|
+
kernel_input = torch.cat([xz, y], -1)
|
|
205
|
+
output = self.kernel(kernel_input) # (batch, x_num, in_dim)
|
|
206
|
+
|
|
207
|
+
output = compute_means(
|
|
208
|
+
output * y, indices, self.z.shape[-1], drop_mask=drop_mask
|
|
209
|
+
) # (batch, z_num, in_dim)
|
|
210
|
+
|
|
211
|
+
output = output.swapaxes(1, 2) # (batch, in_dim, z_num)
|
|
212
|
+
|
|
213
|
+
return output
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class FullIntegralKernel(nn.Module):
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
z_num: int,
|
|
220
|
+
kernel_coef: int = 1,
|
|
221
|
+
in_dim: int = 1,
|
|
222
|
+
):
|
|
223
|
+
super().__init__()
|
|
224
|
+
|
|
225
|
+
self.z_num = z_num
|
|
226
|
+
self.in_dim = in_dim
|
|
227
|
+
|
|
228
|
+
self.kernel = nn.Sequential(
|
|
229
|
+
nn.Linear(in_dim + 1, z_num * kernel_coef),
|
|
230
|
+
nn.LayerNorm(z_num * kernel_coef),
|
|
231
|
+
nn.ReLU(),
|
|
232
|
+
nn.Linear(z_num * kernel_coef, z_num * in_dim),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def forward(self, x: Tensor, y: Tensor, drop_mask=None):
|
|
236
|
+
# x.shape == (batch_size, num_x)
|
|
237
|
+
# y.shape == (batch_size, num_x, in_dim)
|
|
238
|
+
# drop_mask.shape == (batch_size, num_x)
|
|
239
|
+
|
|
240
|
+
batch_size, num_x = x.shape
|
|
241
|
+
|
|
242
|
+
kernel_input = torch.cat([x.unsqueeze(-1), y], -1) # (batch, x_num, in_dim + 1)
|
|
243
|
+
x = self.kernel(kernel_input) # (batch, x_num, z_num * in_dim)
|
|
244
|
+
x = x.reshape(
|
|
245
|
+
*x.shape[:-1], self.z_num, self.in_dim
|
|
246
|
+
) # (batch, x_num, z_num, in_dim)
|
|
247
|
+
# permute to get (batch, z_num, x_num, in_dim)
|
|
248
|
+
x = x.permute(0, 2, 1, 3)
|
|
249
|
+
|
|
250
|
+
y = y.unsqueeze(1) # (batch, 1, x_num, in_dim)
|
|
251
|
+
|
|
252
|
+
assert x.shape == (
|
|
253
|
+
batch_size,
|
|
254
|
+
self.z_num,
|
|
255
|
+
num_x,
|
|
256
|
+
self.in_dim,
|
|
257
|
+
) # (batch, z_num, in_dim, x_num)
|
|
258
|
+
assert y.shape == (
|
|
259
|
+
batch_size,
|
|
260
|
+
1,
|
|
261
|
+
num_x,
|
|
262
|
+
self.in_dim,
|
|
263
|
+
) # (batch, 1, x_num, in_dim)
|
|
264
|
+
|
|
265
|
+
if drop_mask is not None:
|
|
266
|
+
x = x * y
|
|
267
|
+
x = x.permute(0, 2, 1, 3) # (batch, x_num, z_num, in_dim)
|
|
268
|
+
x = masked_mean(x, drop_mask)
|
|
269
|
+
else:
|
|
270
|
+
x = (x * y).mean(-2) # (batch, z_num, in_dim)
|
|
271
|
+
|
|
272
|
+
assert x.shape == (batch_size, self.z_num, self.in_dim), f"{x.shape}"
|
|
273
|
+
|
|
274
|
+
x = x.swapaxes(1, 2) # (batch, in_dim, z_num)
|
|
275
|
+
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class FCBlock(nn.Module):
|
|
280
|
+
def __init__(
|
|
281
|
+
self,
|
|
282
|
+
in_dim: int = 2,
|
|
283
|
+
hid_dim: int = 16,
|
|
284
|
+
out_dim: int = 16,
|
|
285
|
+
activation: str = "gelu",
|
|
286
|
+
):
|
|
287
|
+
super().__init__()
|
|
288
|
+
|
|
289
|
+
self.fc1 = nn.Linear(in_dim, hid_dim)
|
|
290
|
+
self.layer_norm = nn.LayerNorm(hid_dim)
|
|
291
|
+
self.activation = activation_by_name(activation)()
|
|
292
|
+
self.fc2 = nn.Linear(hid_dim, out_dim)
|
|
293
|
+
|
|
294
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
295
|
+
x = self.fc1(x)
|
|
296
|
+
x = self.layer_norm(x)
|
|
297
|
+
x = self.activation(x)
|
|
298
|
+
x = self.fc2(x)
|
|
299
|
+
return x
|
|
300
|
+
# return self.kernel(x)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def compute_means(x, indices, z: int, drop_mask: Tensor = None):
|
|
304
|
+
"""
|
|
305
|
+
Compute the mean values of tensor 'x' for each unique index in 'indices' across each batch.
|
|
306
|
+
|
|
307
|
+
This function calculates the mean of elements in 'x' that correspond to each unique index in 'indices'.
|
|
308
|
+
The computation is performed for each batch separately, and the function is optimized to avoid Python loops
|
|
309
|
+
by using advanced PyTorch operations.
|
|
310
|
+
|
|
311
|
+
Parameters:
|
|
312
|
+
x (torch.Tensor): A tensor of shape (batch_size, n, d) containing the values to be averaged.
|
|
313
|
+
'x' should be a floating-point tensor.
|
|
314
|
+
indices (torch.Tensor): An integer tensor of shape (batch_size, n) containing the indices.
|
|
315
|
+
The values in 'indices' should be in the range [0, z-1].
|
|
316
|
+
z (int): The number of unique indices. This determines the second dimension of the output tensor.
|
|
317
|
+
drop_mask (torch.Tensor): A boolean tensor of shape (batch_size, n) containing a mask for the indices to drop.
|
|
318
|
+
If None, all indices are used.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
torch.Tensor: A tensor of shape (batch_size, z, d) containing the mean values for each index in each batch.
|
|
322
|
+
If an index does not appear in a batch, its corresponding mean values are zeros.
|
|
323
|
+
|
|
324
|
+
Example:
|
|
325
|
+
>>> batch_size, n, d, z = 3, 4, 5, 6
|
|
326
|
+
>>> indices = torch.randint(0, z, (batch_size, n))
|
|
327
|
+
>>> x = torch.randn(batch_size, n, d)
|
|
328
|
+
>>> y = compute_means(x, indices, z)
|
|
329
|
+
>>> print(y.shape)
|
|
330
|
+
torch.Size([3, 6, 5])
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
batch_size, n, d = x.shape
|
|
334
|
+
device = x.device
|
|
335
|
+
|
|
336
|
+
drop = drop_mask is not None
|
|
337
|
+
|
|
338
|
+
# Initialize tensors to hold sums and counts
|
|
339
|
+
sums = torch.zeros(batch_size, z + int(drop), d, device=device)
|
|
340
|
+
counts = torch.zeros(batch_size, z + int(drop), device=device)
|
|
341
|
+
|
|
342
|
+
if drop_mask is not None:
|
|
343
|
+
# Set the values of the indices to drop to z
|
|
344
|
+
indices = indices.masked_fill(~drop_mask, z)
|
|
345
|
+
|
|
346
|
+
indices_expanded = indices.unsqueeze(-1).expand_as(x)
|
|
347
|
+
sums.scatter_add_(1, indices_expanded, x)
|
|
348
|
+
counts.scatter_add_(1, indices, torch.ones_like(indices, dtype=x.dtype))
|
|
349
|
+
|
|
350
|
+
if drop:
|
|
351
|
+
# Remove the z values from the sums and counts
|
|
352
|
+
sums = sums[:, :-1]
|
|
353
|
+
counts = counts[:, :-1]
|
|
354
|
+
|
|
355
|
+
# Compute the mean and handle division by zero
|
|
356
|
+
mean = sums / counts.unsqueeze(-1).clamp(min=1)
|
|
357
|
+
|
|
358
|
+
return mean
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def masked_mean(x, mask):
|
|
362
|
+
"""
|
|
363
|
+
Computes the mean of tensor x along the x_size dimension,
|
|
364
|
+
while masking out elements where the corresponding value in the mask is False.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
x (torch.Tensor): A tensor of shape (batch, x_size, z, d).
|
|
368
|
+
mask (torch.Tensor): A boolean mask of shape (batch, x_size).
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
torch.Tensor: The result tensor of shape (batch, z, d) after applying the mask and computing the mean.
|
|
372
|
+
"""
|
|
373
|
+
if not mask.dtype == torch.bool:
|
|
374
|
+
raise TypeError("Mask must be a boolean tensor.")
|
|
375
|
+
|
|
376
|
+
# Ensure the mask is broadcastable to the shape of x
|
|
377
|
+
mask = mask.unsqueeze(-1).unsqueeze(-1)
|
|
378
|
+
masked_x = x * mask
|
|
379
|
+
|
|
380
|
+
# Compute the sum and the count of valid (unmasked) elements along the x_size dimension
|
|
381
|
+
sum_x = masked_x.sum(dim=1)
|
|
382
|
+
count_x = mask.sum(dim=1)
|
|
383
|
+
|
|
384
|
+
# Avoid division by zero
|
|
385
|
+
count_x[count_x == 0] = 1
|
|
386
|
+
|
|
387
|
+
# Compute the mean
|
|
388
|
+
mean_x = sum_x / count_x
|
|
389
|
+
|
|
390
|
+
return mean_x
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
from reflectorch.models.networks.mlp_networks import (
|
|
2
|
-
NetworkWithPriors,
|
|
3
|
-
NetworkWithPriorsConvEmb,
|
|
4
|
-
NetworkWithPriorsFnoEmb,
|
|
5
|
-
)
|
|
6
|
-
from reflectorch.models.networks.residual_net import ResidualMLP
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
__all__ = [
|
|
10
|
-
"ResidualMLP",
|
|
11
|
-
"NetworkWithPriors",
|
|
12
|
-
"NetworkWithPriorsConvEmb",
|
|
13
|
-
"NetworkWithPriorsFnoEmb",
|
|
14
|
-
]
|
|
1
|
+
from reflectorch.models.networks.mlp_networks import (
|
|
2
|
+
NetworkWithPriors,
|
|
3
|
+
NetworkWithPriorsConvEmb,
|
|
4
|
+
NetworkWithPriorsFnoEmb,
|
|
5
|
+
)
|
|
6
|
+
from reflectorch.models.networks.residual_net import ResidualMLP
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ResidualMLP",
|
|
11
|
+
"NetworkWithPriors",
|
|
12
|
+
"NetworkWithPriorsConvEmb",
|
|
13
|
+
"NetworkWithPriorsFnoEmb",
|
|
14
|
+
]
|