reflectorch 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +90 -15
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +31 -11
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +92 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +216 -103
- reflectorch/inference/plotting.py +98 -0
- reflectorch/inference/scipy_fitter.py +84 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +122 -23
- reflectorch/models/__init__.py +1 -1
- reflectorch/models/encoders/__init__.py +0 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +324 -152
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +43 -9
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/RECORD +37 -34
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -11,7 +11,6 @@ __all__ = [
|
|
|
11
11
|
"ConvEncoder",
|
|
12
12
|
"ConvDecoder",
|
|
13
13
|
"ConvAutoencoder",
|
|
14
|
-
"ConvVAE",
|
|
15
14
|
]
|
|
16
15
|
|
|
17
16
|
logger = logging.getLogger(__name__)
|
|
@@ -23,7 +22,7 @@ class ConvEncoder(nn.Module):
|
|
|
23
22
|
Args:
|
|
24
23
|
in_channels (int, optional): the number of input channels. Defaults to 1.
|
|
25
24
|
hidden_channels (tuple, optional): the number of intermediate channels of each convolutional layer. Defaults to (32, 64, 128, 256, 512).
|
|
26
|
-
|
|
25
|
+
dim_embedding (int, optional): the dimension of the output latent embedding. Defaults to 64.
|
|
27
26
|
dim_avpool (int, optional): the output size of the adaptive average pooling layer. Defaults to 1.
|
|
28
27
|
use_batch_norm (bool, optional): whether to use batch normalization. Defaults to True.
|
|
29
28
|
activation (str, optional): the type of activation function. Defaults to 'relu'.
|
|
@@ -31,9 +30,11 @@ class ConvEncoder(nn.Module):
|
|
|
31
30
|
def __init__(self,
|
|
32
31
|
in_channels: int = 1,
|
|
33
32
|
hidden_channels: tuple = (32, 64, 128, 256, 512),
|
|
34
|
-
|
|
33
|
+
kernel_size: int = 3,
|
|
34
|
+
dim_embedding: int = 64,
|
|
35
35
|
dim_avpool: int = 1,
|
|
36
36
|
use_batch_norm: bool = True,
|
|
37
|
+
use_se: bool = False,
|
|
37
38
|
activation: str = 'relu',
|
|
38
39
|
):
|
|
39
40
|
super().__init__()
|
|
@@ -44,22 +45,24 @@ class ConvEncoder(nn.Module):
|
|
|
44
45
|
|
|
45
46
|
for h in hidden_channels:
|
|
46
47
|
layers = [
|
|
47
|
-
nn.Conv1d(in_channels, out_channels=h, kernel_size=
|
|
48
|
+
nn.Conv1d(in_channels, out_channels=h, kernel_size=kernel_size, stride=2, padding=kernel_size // 2),
|
|
48
49
|
activation(),
|
|
49
50
|
]
|
|
50
51
|
|
|
51
52
|
if use_batch_norm:
|
|
52
53
|
layers.insert(1, nn.BatchNorm1d(h))
|
|
53
54
|
|
|
55
|
+
if use_se:
|
|
56
|
+
layers.insert(2, SEBlock(h))
|
|
57
|
+
|
|
54
58
|
modules.append(nn.Sequential(*layers))
|
|
55
59
|
in_channels = h
|
|
56
60
|
|
|
57
61
|
self.core = nn.Sequential(*modules)
|
|
58
62
|
self.avpool = nn.AdaptiveAvgPool1d(dim_avpool)
|
|
59
|
-
self.fc = nn.Linear(hidden_channels[-1] * dim_avpool,
|
|
63
|
+
self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_embedding)
|
|
60
64
|
|
|
61
65
|
def forward(self, x):
|
|
62
|
-
""""""
|
|
63
66
|
if len(x.shape) < 3:
|
|
64
67
|
x = x.unsqueeze(1)
|
|
65
68
|
x = self.core(x)
|
|
@@ -100,6 +103,7 @@ class ConvDecoder(nn.Module):
|
|
|
100
103
|
hidden_channels: tuple = (512, 256, 128, 64, 32),
|
|
101
104
|
dim_latent: int = 64,
|
|
102
105
|
in_size: int = 8,
|
|
106
|
+
kernel_size: int = 3,
|
|
103
107
|
use_batch_norm: bool = True,
|
|
104
108
|
activation: str = 'relu',
|
|
105
109
|
):
|
|
@@ -119,9 +123,9 @@ class ConvDecoder(nn.Module):
|
|
|
119
123
|
nn.ConvTranspose1d(
|
|
120
124
|
hidden_channels[i],
|
|
121
125
|
hidden_channels[i + 1],
|
|
122
|
-
kernel_size=3
|
|
126
|
+
kernel_size=kernel_size, #3
|
|
123
127
|
stride=2,
|
|
124
|
-
padding=1
|
|
128
|
+
padding=kernel_size // 2, #1
|
|
125
129
|
output_padding=1,
|
|
126
130
|
),
|
|
127
131
|
nn.BatchNorm1d(hidden_channels[i + 1]) if use_batch_norm else nn.Identity(),
|
|
@@ -134,9 +138,9 @@ class ConvDecoder(nn.Module):
|
|
|
134
138
|
self.final_layer = nn.Sequential(
|
|
135
139
|
nn.ConvTranspose1d(hidden_channels[-1],
|
|
136
140
|
hidden_channels[-1],
|
|
137
|
-
kernel_size=3
|
|
141
|
+
kernel_size=kernel_size, #3
|
|
138
142
|
stride=2,
|
|
139
|
-
padding=1
|
|
143
|
+
padding=kernel_size // 2, #1
|
|
140
144
|
output_padding=1),
|
|
141
145
|
nn.BatchNorm1d(hidden_channels[-1]) if use_batch_norm else nn.Identity(),
|
|
142
146
|
activation(),
|
|
@@ -160,46 +164,56 @@ class ConvAutoencoder(nn.Module):
|
|
|
160
164
|
decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
|
|
161
165
|
dim_latent: int = 64,
|
|
162
166
|
dim_avpool: int = 1,
|
|
167
|
+
kernel_size: int = 3,
|
|
163
168
|
use_batch_norm: bool = True,
|
|
164
169
|
activation: str = 'relu',
|
|
165
170
|
decoder_in_size: int = 8,
|
|
166
171
|
**kwargs
|
|
167
172
|
):
|
|
168
173
|
super().__init__()
|
|
169
|
-
self.encoder = ConvEncoder(
|
|
170
|
-
|
|
174
|
+
self.encoder = ConvEncoder(
|
|
175
|
+
in_channels=in_channels,
|
|
176
|
+
hidden_channels=encoder_hidden_channels,
|
|
177
|
+
kernel_size=kernel_size,
|
|
178
|
+
dim_embedding=dim_latent,
|
|
179
|
+
dim_avpool=dim_avpool,
|
|
180
|
+
use_batch_norm=use_batch_norm,
|
|
181
|
+
activation=activation,
|
|
182
|
+
**kwargs)
|
|
183
|
+
|
|
184
|
+
self.decoder = ConvDecoder(
|
|
185
|
+
hidden_channels=decoder_hidden_channels,
|
|
186
|
+
dim_latent=dim_latent,
|
|
187
|
+
in_size=decoder_in_size,
|
|
188
|
+
kernel_size=kernel_size,
|
|
189
|
+
use_batch_norm=use_batch_norm,
|
|
190
|
+
activation=activation,
|
|
191
|
+
**kwargs)
|
|
171
192
|
|
|
172
193
|
def forward(self, x):
|
|
173
194
|
return self.decoder(self.encoder(x))
|
|
174
195
|
|
|
175
|
-
class
|
|
176
|
-
"""
|
|
177
|
-
def __init__(self,
|
|
178
|
-
in_channels: int = 1,
|
|
179
|
-
encoder_hidden_channels: tuple = (32, 64, 128, 256, 512),
|
|
180
|
-
decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
|
|
181
|
-
dim_latent: int = 64,
|
|
182
|
-
dim_avpool: int = 1,
|
|
183
|
-
use_batch_norm: bool = True,
|
|
184
|
-
activation: str = 'relu',
|
|
185
|
-
decoder_in_size: int = 8,
|
|
186
|
-
**kwargs
|
|
187
|
-
):
|
|
196
|
+
class SEBlock(nn.Module):
|
|
197
|
+
"""Squeeze-and-excitation block (https://arxiv.org/abs/1709.01507) """
|
|
198
|
+
def __init__(self, in_channels, reduction=16):
|
|
188
199
|
super().__init__()
|
|
189
|
-
self.
|
|
190
|
-
self.
|
|
200
|
+
self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
|
|
201
|
+
self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
|
|
202
|
+
self.relu = nn.ReLU()
|
|
203
|
+
self.sigmoid = nn.Sigmoid()
|
|
204
|
+
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
191
205
|
|
|
192
206
|
def forward(self, x):
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
return
|
|
207
|
+
batch_size, channels, _ = x.size()
|
|
208
|
+
|
|
209
|
+
#Squeeze
|
|
210
|
+
se = self.global_avg_pool(x).view(batch_size, channels)
|
|
211
|
+
|
|
212
|
+
#Excitation
|
|
213
|
+
se = self.fc1(se)
|
|
214
|
+
se = self.relu(se)
|
|
215
|
+
se = self.fc2(se)
|
|
216
|
+
se = self.sigmoid(se).view(batch_size, channels, 1)
|
|
217
|
+
|
|
218
|
+
#Scale the input feature maps (channel-wise attention)
|
|
219
|
+
return x * se
|
|
@@ -46,28 +46,30 @@ class FnoEncoder(nn.Module):
|
|
|
46
46
|
:align: center
|
|
47
47
|
|
|
48
48
|
Args:
|
|
49
|
-
|
|
49
|
+
in_channels (int): number of input channels
|
|
50
50
|
dim_embedding (int): dimension of the output embedding
|
|
51
51
|
modes (int): number of Fourier modes
|
|
52
52
|
width_fno (int): number of channels of the intermediate representations
|
|
53
53
|
n_fno_blocks (int): number of FNO blocks
|
|
54
54
|
activation (str): the activation function
|
|
55
|
-
fusion_self_attention (bool):
|
|
55
|
+
fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
|
|
56
|
+
fsa_activation (str): the activation function of the fusion self attention block
|
|
56
57
|
"""
|
|
57
58
|
def __init__(
|
|
58
59
|
self,
|
|
59
|
-
|
|
60
|
+
in_channels: int = 2,
|
|
60
61
|
dim_embedding: int = 128,
|
|
61
62
|
modes: int = 32,
|
|
62
63
|
width_fno: int = 64,
|
|
63
64
|
n_fno_blocks: int = 6,
|
|
64
65
|
activation: str = 'gelu',
|
|
65
66
|
fusion_self_attention: bool = False,
|
|
67
|
+
fsa_activation: str = 'tanh',
|
|
66
68
|
):
|
|
67
69
|
super().__init__()
|
|
68
70
|
|
|
69
71
|
|
|
70
|
-
self.
|
|
72
|
+
self.in_channels = in_channels
|
|
71
73
|
self.dim_embedding = dim_embedding
|
|
72
74
|
|
|
73
75
|
self.modes = modes
|
|
@@ -77,13 +79,17 @@ class FnoEncoder(nn.Module):
|
|
|
77
79
|
self.fusion_self_attention = fusion_self_attention
|
|
78
80
|
|
|
79
81
|
|
|
80
|
-
self.fc0 = nn.Linear(
|
|
81
|
-
self.spectral_convs = nn.ModuleList([
|
|
82
|
-
|
|
82
|
+
self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
|
|
83
|
+
self.spectral_convs = nn.ModuleList([
|
|
84
|
+
SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
|
|
85
|
+
])
|
|
86
|
+
self.w_convs = nn.ModuleList([
|
|
87
|
+
nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
|
|
88
|
+
])
|
|
83
89
|
self.fc_out = nn.Linear(width_fno, dim_embedding)
|
|
84
90
|
|
|
85
91
|
if fusion_self_attention:
|
|
86
|
-
self.fusion = FusionSelfAttention(width_fno, 2*width_fno)
|
|
92
|
+
self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
|
|
87
93
|
|
|
88
94
|
def forward(self, x):
|
|
89
95
|
""""""
|
|
@@ -109,19 +115,20 @@ class FnoEncoder(nn.Module):
|
|
|
109
115
|
|
|
110
116
|
return x
|
|
111
117
|
|
|
118
|
+
|
|
112
119
|
class FusionSelfAttention(nn.Module):
|
|
113
|
-
def __init__(self,
|
|
114
|
-
embed_dim: int = 64,
|
|
115
|
-
hidden_dim: int = 64,
|
|
116
|
-
activation=nn.Tanh,
|
|
117
|
-
):
|
|
120
|
+
def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
|
|
118
121
|
super().__init__()
|
|
122
|
+
activation = activation_by_name(activation)()
|
|
119
123
|
self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
|
|
120
|
-
activation
|
|
124
|
+
activation,
|
|
121
125
|
nn.Linear(hidden_dim, 1, bias=False))
|
|
122
126
|
|
|
123
|
-
def forward(self,
|
|
127
|
+
def forward(self,
|
|
128
|
+
c: torch.Tensor, # (batch_size x seq_len x embed_dim)
|
|
129
|
+
mask: torch.Tensor = None, # (batch_size x seq_len)
|
|
130
|
+
):
|
|
124
131
|
a = self.fuser(c)
|
|
125
|
-
alpha = torch.exp(a)
|
|
132
|
+
alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
|
|
126
133
|
alpha = alpha/alpha.sum(dim=1, keepdim=True)
|
|
127
134
|
return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from reflectorch.models.networks.mlp_networks import (
|
|
2
|
+
NetworkWithPriors,
|
|
2
3
|
NetworkWithPriorsConvEmb,
|
|
3
4
|
NetworkWithPriorsFnoEmb,
|
|
4
5
|
)
|
|
@@ -7,6 +8,7 @@ from reflectorch.models.networks.residual_net import ResidualMLP
|
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
9
10
|
"ResidualMLP",
|
|
11
|
+
"NetworkWithPriors",
|
|
10
12
|
"NetworkWithPriorsConvEmb",
|
|
11
13
|
"NetworkWithPriorsFnoEmb",
|
|
12
14
|
]
|