reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,134 +1,134 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
import torch.nn.functional as F
|
|
4
|
-
|
|
5
|
-
from reflectorch.models.activations import activation_by_name
|
|
6
|
-
|
|
7
|
-
class SpectralConv1d(nn.Module):
|
|
8
|
-
def __init__(self, in_channels, out_channels, modes):
|
|
9
|
-
super().__init__()
|
|
10
|
-
|
|
11
|
-
"""
|
|
12
|
-
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
self.in_channels = in_channels
|
|
16
|
-
self.out_channels = out_channels
|
|
17
|
-
self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
|
|
18
|
-
|
|
19
|
-
self.scale = (1 / (in_channels*out_channels))
|
|
20
|
-
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
|
|
21
|
-
|
|
22
|
-
# Complex multiplication
|
|
23
|
-
def compl_mul1d(self, input, weights):
|
|
24
|
-
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
25
|
-
return torch.einsum("bix,iox->box", input, weights)
|
|
26
|
-
|
|
27
|
-
def forward(self, x):
|
|
28
|
-
batchsize = x.shape[0]
|
|
29
|
-
#Compute Fourier coeffcients up to factor of e^(- something constant)
|
|
30
|
-
x_ft = torch.fft.rfft(x)
|
|
31
|
-
|
|
32
|
-
# Multiply relevant Fourier modes
|
|
33
|
-
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
|
|
34
|
-
out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
|
|
35
|
-
|
|
36
|
-
#Return to physical space
|
|
37
|
-
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
|
38
|
-
return x
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class FnoEncoder(nn.Module):
|
|
42
|
-
"""An embedding network based on the Fourier Neural Operator (FNO) architecture
|
|
43
|
-
|
|
44
|
-
.. image:: ../documentation/fig_reflectometry_embedding_networks.png
|
|
45
|
-
:width: 400px
|
|
46
|
-
:align: center
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
in_channels (int): number of input channels
|
|
50
|
-
dim_embedding (int): dimension of the output embedding
|
|
51
|
-
modes (int): number of Fourier modes
|
|
52
|
-
width_fno (int): number of channels of the intermediate representations
|
|
53
|
-
n_fno_blocks (int): number of FNO blocks
|
|
54
|
-
activation (str): the activation function
|
|
55
|
-
fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
|
|
56
|
-
fsa_activation (str): the activation function of the fusion self attention block
|
|
57
|
-
"""
|
|
58
|
-
def __init__(
|
|
59
|
-
self,
|
|
60
|
-
in_channels: int = 2,
|
|
61
|
-
dim_embedding: int = 128,
|
|
62
|
-
modes: int = 32,
|
|
63
|
-
width_fno: int = 64,
|
|
64
|
-
n_fno_blocks: int = 6,
|
|
65
|
-
activation: str = 'gelu',
|
|
66
|
-
fusion_self_attention: bool = False,
|
|
67
|
-
fsa_activation: str = 'tanh',
|
|
68
|
-
):
|
|
69
|
-
super().__init__()
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
self.in_channels = in_channels
|
|
73
|
-
self.dim_embedding = dim_embedding
|
|
74
|
-
|
|
75
|
-
self.modes = modes
|
|
76
|
-
self.width_fno = width_fno
|
|
77
|
-
self.n_fno_blocks = n_fno_blocks
|
|
78
|
-
self.activation = activation_by_name(activation)()
|
|
79
|
-
self.fusion_self_attention = fusion_self_attention
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
|
|
83
|
-
self.spectral_convs = nn.ModuleList([
|
|
84
|
-
SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
|
|
85
|
-
])
|
|
86
|
-
self.w_convs = nn.ModuleList([
|
|
87
|
-
nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
|
|
88
|
-
])
|
|
89
|
-
self.fc_out = nn.Linear(width_fno, dim_embedding)
|
|
90
|
-
|
|
91
|
-
if fusion_self_attention:
|
|
92
|
-
self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
|
|
93
|
-
|
|
94
|
-
def forward(self, x):
|
|
95
|
-
""""""
|
|
96
|
-
|
|
97
|
-
x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
|
|
98
|
-
x = self.fc0(x)
|
|
99
|
-
x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
|
|
100
|
-
|
|
101
|
-
for i in range(self.n_fno_blocks):
|
|
102
|
-
x1 = self.spectral_convs[i](x)
|
|
103
|
-
x2 = self.w_convs[i](x)
|
|
104
|
-
|
|
105
|
-
x = x1 + x2
|
|
106
|
-
x = self.activation(x)
|
|
107
|
-
|
|
108
|
-
if self.fusion_self_attention:
|
|
109
|
-
x = x.permute(0, 2, 1)
|
|
110
|
-
x = self.fusion(x)
|
|
111
|
-
else:
|
|
112
|
-
x = x.mean(dim=-1)
|
|
113
|
-
|
|
114
|
-
x = self.fc_out(x)
|
|
115
|
-
|
|
116
|
-
return x
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
class FusionSelfAttention(nn.Module):
|
|
120
|
-
def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
|
|
121
|
-
super().__init__()
|
|
122
|
-
activation = activation_by_name(activation)()
|
|
123
|
-
self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
|
|
124
|
-
activation,
|
|
125
|
-
nn.Linear(hidden_dim, 1, bias=False))
|
|
126
|
-
|
|
127
|
-
def forward(self,
|
|
128
|
-
c: torch.Tensor, # (batch_size x seq_len x embed_dim)
|
|
129
|
-
mask: torch.Tensor = None, # (batch_size x seq_len)
|
|
130
|
-
):
|
|
131
|
-
a = self.fuser(c)
|
|
132
|
-
alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
|
|
133
|
-
alpha = alpha/alpha.sum(dim=1, keepdim=True)
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from reflectorch.models.activations import activation_by_name
|
|
6
|
+
|
|
7
|
+
class SpectralConv1d(nn.Module):
|
|
8
|
+
def __init__(self, in_channels, out_channels, modes):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
self.in_channels = in_channels
|
|
16
|
+
self.out_channels = out_channels
|
|
17
|
+
self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
|
|
18
|
+
|
|
19
|
+
self.scale = (1 / (in_channels*out_channels))
|
|
20
|
+
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
|
|
21
|
+
|
|
22
|
+
# Complex multiplication
|
|
23
|
+
def compl_mul1d(self, input, weights):
|
|
24
|
+
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
|
|
25
|
+
return torch.einsum("bix,iox->box", input, weights)
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
batchsize = x.shape[0]
|
|
29
|
+
#Compute Fourier coeffcients up to factor of e^(- something constant)
|
|
30
|
+
x_ft = torch.fft.rfft(x)
|
|
31
|
+
|
|
32
|
+
# Multiply relevant Fourier modes
|
|
33
|
+
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
|
|
34
|
+
out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
|
|
35
|
+
|
|
36
|
+
#Return to physical space
|
|
37
|
+
x = torch.fft.irfft(out_ft, n=x.size(-1))
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FnoEncoder(nn.Module):
|
|
42
|
+
"""An embedding network based on the Fourier Neural Operator (FNO) architecture
|
|
43
|
+
|
|
44
|
+
.. image:: ../documentation/fig_reflectometry_embedding_networks.png
|
|
45
|
+
:width: 400px
|
|
46
|
+
:align: center
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
in_channels (int): number of input channels
|
|
50
|
+
dim_embedding (int): dimension of the output embedding
|
|
51
|
+
modes (int): number of Fourier modes
|
|
52
|
+
width_fno (int): number of channels of the intermediate representations
|
|
53
|
+
n_fno_blocks (int): number of FNO blocks
|
|
54
|
+
activation (str): the activation function
|
|
55
|
+
fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
|
|
56
|
+
fsa_activation (str): the activation function of the fusion self attention block
|
|
57
|
+
"""
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
in_channels: int = 2,
|
|
61
|
+
dim_embedding: int = 128,
|
|
62
|
+
modes: int = 32,
|
|
63
|
+
width_fno: int = 64,
|
|
64
|
+
n_fno_blocks: int = 6,
|
|
65
|
+
activation: str = 'gelu',
|
|
66
|
+
fusion_self_attention: bool = False,
|
|
67
|
+
fsa_activation: str = 'tanh',
|
|
68
|
+
):
|
|
69
|
+
super().__init__()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
self.in_channels = in_channels
|
|
73
|
+
self.dim_embedding = dim_embedding
|
|
74
|
+
|
|
75
|
+
self.modes = modes
|
|
76
|
+
self.width_fno = width_fno
|
|
77
|
+
self.n_fno_blocks = n_fno_blocks
|
|
78
|
+
self.activation = activation_by_name(activation)()
|
|
79
|
+
self.fusion_self_attention = fusion_self_attention
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
|
|
83
|
+
self.spectral_convs = nn.ModuleList([
|
|
84
|
+
SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
|
|
85
|
+
])
|
|
86
|
+
self.w_convs = nn.ModuleList([
|
|
87
|
+
nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
|
|
88
|
+
])
|
|
89
|
+
self.fc_out = nn.Linear(width_fno, dim_embedding)
|
|
90
|
+
|
|
91
|
+
if fusion_self_attention:
|
|
92
|
+
self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
|
|
93
|
+
|
|
94
|
+
def forward(self, x):
|
|
95
|
+
""""""
|
|
96
|
+
|
|
97
|
+
x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
|
|
98
|
+
x = self.fc0(x)
|
|
99
|
+
x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
|
|
100
|
+
|
|
101
|
+
for i in range(self.n_fno_blocks):
|
|
102
|
+
x1 = self.spectral_convs[i](x)
|
|
103
|
+
x2 = self.w_convs[i](x)
|
|
104
|
+
|
|
105
|
+
x = x1 + x2
|
|
106
|
+
x = self.activation(x)
|
|
107
|
+
|
|
108
|
+
if self.fusion_self_attention:
|
|
109
|
+
x = x.permute(0, 2, 1)
|
|
110
|
+
x = self.fusion(x)
|
|
111
|
+
else:
|
|
112
|
+
x = x.mean(dim=-1)
|
|
113
|
+
|
|
114
|
+
x = self.fc_out(x)
|
|
115
|
+
|
|
116
|
+
return x
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class FusionSelfAttention(nn.Module):
|
|
120
|
+
def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
|
|
121
|
+
super().__init__()
|
|
122
|
+
activation = activation_by_name(activation)()
|
|
123
|
+
self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
|
|
124
|
+
activation,
|
|
125
|
+
nn.Linear(hidden_dim, 1, bias=False))
|
|
126
|
+
|
|
127
|
+
def forward(self,
|
|
128
|
+
c: torch.Tensor, # (batch_size x seq_len x embed_dim)
|
|
129
|
+
mask: torch.Tensor = None, # (batch_size x seq_len)
|
|
130
|
+
):
|
|
131
|
+
a = self.fuser(c)
|
|
132
|
+
alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
|
|
133
|
+
alpha = alpha/alpha.sum(dim=1, keepdim=True)
|
|
134
134
|
return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
|