reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,211 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from torch import nn, load
12
+
13
+ from reflectorch.models.activations import activation_by_name
14
+ from reflectorch.paths import SAVED_MODELS_DIR
15
+
16
+ __all__ = [
17
+ "ConvEncoder",
18
+ "ConvDecoder",
19
+ "ConvAutoencoder",
20
+ "ConvVAE",
21
+ ]
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class ConvEncoder(nn.Module):
27
+ """A 1D CNN encoder / embedding network
28
+
29
+ Args:
30
+ in_channels (int, optional): the number of input channels. Defaults to 1.
31
+ hidden_channels (tuple, optional): the number of intermediate channels of each convolutional layer. Defaults to (32, 64, 128, 256, 512).
32
+ dim_latent (int, optional): the dimension of the output latent embedding. Defaults to 64.
33
+ dim_avpool (int, optional): the output size of the adaptive average pooling layer. Defaults to 1.
34
+ use_batch_norm (bool, optional): whether to use batch normalization. Defaults to True.
35
+ activation (str, optional): the type of activation function. Defaults to 'relu'.
36
+ """
37
+ def __init__(self,
38
+ in_channels: int = 1,
39
+ hidden_channels: tuple = (32, 64, 128, 256, 512),
40
+ dim_latent: int = 64,
41
+ dim_avpool: int = 1,
42
+ use_batch_norm: bool = True,
43
+ activation: str = 'relu',
44
+ ):
45
+ super().__init__()
46
+
47
+ modules = []
48
+
49
+ activation = activation_by_name(activation)
50
+
51
+ for h in hidden_channels:
52
+ layers = [
53
+ nn.Conv1d(in_channels, out_channels=h, kernel_size=3, stride=2, padding=1),
54
+ activation(),
55
+ ]
56
+
57
+ if use_batch_norm:
58
+ layers.insert(1, nn.BatchNorm1d(h))
59
+
60
+ modules.append(nn.Sequential(*layers))
61
+ in_channels = h
62
+
63
+ self.core = nn.Sequential(*modules)
64
+ self.avpool = nn.AdaptiveAvgPool1d(dim_avpool)
65
+ self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_latent)
66
+
67
+ def forward(self, x):
68
+ """"""
69
+ if len(x.shape) < 3:
70
+ x = x.unsqueeze(1)
71
+ x = self.core(x)
72
+ x = self.avpool(x).view(x.size(0), -1)
73
+ x = self.fc(x)
74
+ return x
75
+
76
+ def load_weights(self, path: str or Path = None, strict: bool = False):
77
+ if not path:
78
+ return
79
+
80
+ if isinstance(path, str):
81
+ if not path.endswith('.pt'):
82
+ path = path + '.pt'
83
+ path = SAVED_MODELS_DIR / path
84
+
85
+ if not path.is_file():
86
+ logger.error(f'File {str(path)} is not found.')
87
+ return
88
+ try:
89
+ state_dict = load(path)
90
+ self.load_state_dict(state_dict, strict=strict)
91
+ except Exception as err:
92
+ logger.exception(err)
93
+
94
+
95
+ class ConvDecoder(nn.Module):
96
+ """A 1D CNN decoder
97
+
98
+ Args:
99
+ hidden_dims (tuple, optional): the number of intermediate channels of each convolutional layer. Defaults to (512, 256, 128, 64, 32).
100
+ latent_dim (int, optional): the dimension of the input latent embedding. Defaults to 64.
101
+ in_size (int, optional): the initial size for upscaling. Defaults to 8.
102
+ use_batch_norm (bool, optional): whether to use batch normalization. Defaults to True.
103
+ activation (str, optional): the type of activation function. Defaults to 'relu'.
104
+ """
105
+ def __init__(self,
106
+ hidden_channels: tuple = (512, 256, 128, 64, 32),
107
+ dim_latent: int = 64,
108
+ in_size: int = 8,
109
+ use_batch_norm: bool = True,
110
+ activation: str = 'relu',
111
+ ):
112
+
113
+ super().__init__()
114
+
115
+ self.in_size = in_size
116
+ modules = []
117
+
118
+ self.decoder_input = nn.Linear(dim_latent, hidden_channels[0] * in_size)
119
+
120
+ activation = activation_by_name(activation)
121
+
122
+ for i in range(len(hidden_channels) - 1):
123
+ modules.append(
124
+ nn.Sequential(
125
+ nn.ConvTranspose1d(
126
+ hidden_channels[i],
127
+ hidden_channels[i + 1],
128
+ kernel_size=3,
129
+ stride=2,
130
+ padding=1,
131
+ output_padding=1,
132
+ ),
133
+ nn.BatchNorm1d(hidden_channels[i + 1]) if use_batch_norm else nn.Identity(),
134
+ activation(),
135
+ )
136
+ )
137
+
138
+ self.decoder = nn.Sequential(*modules)
139
+
140
+ self.final_layer = nn.Sequential(
141
+ nn.ConvTranspose1d(hidden_channels[-1],
142
+ hidden_channels[-1],
143
+ kernel_size=3,
144
+ stride=2,
145
+ padding=1,
146
+ output_padding=1),
147
+ nn.BatchNorm1d(hidden_channels[-1]) if use_batch_norm else nn.Identity(),
148
+ activation(),
149
+ nn.Conv1d(hidden_channels[-1], out_channels=1,
150
+ kernel_size=3, padding=1)
151
+ )
152
+
153
+ def forward(self, x):
154
+ batch_size = x.shape[0]
155
+ x = self.decoder_input(x).view(batch_size, -1, self.in_size)
156
+ x = self.decoder(x)
157
+ x = self.final_layer(x).flatten(1)
158
+ return x
159
+
160
+
161
+ class ConvAutoencoder(nn.Module):
162
+ """A 1D convolutional denoising autoencoder"""
163
+ def __init__(self,
164
+ in_channels: int = 1,
165
+ encoder_hidden_channels: tuple = (32, 64, 128, 256, 512),
166
+ decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
167
+ dim_latent: int = 64,
168
+ dim_avpool: int = 1,
169
+ use_batch_norm: bool = True,
170
+ activation: str = 'relu',
171
+ decoder_in_size: int = 8,
172
+ **kwargs
173
+ ):
174
+ super().__init__()
175
+ self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
176
+ self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
177
+
178
+ def forward(self, x):
179
+ return self.decoder(self.encoder(x))
180
+
181
+ class ConvVAE(nn.Module):
182
+ """A 1D convolutional variational autoencoder"""
183
+ def __init__(self,
184
+ in_channels: int = 1,
185
+ encoder_hidden_channels: tuple = (32, 64, 128, 256, 512),
186
+ decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
187
+ dim_latent: int = 64,
188
+ dim_avpool: int = 1,
189
+ use_batch_norm: bool = True,
190
+ activation: str = 'relu',
191
+ decoder_in_size: int = 8,
192
+ **kwargs
193
+ ):
194
+ super().__init__()
195
+ self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, 2*dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
196
+ self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
197
+
198
+ def forward(self, x):
199
+ z_mu, z_logvar = self.encoder(x).chunk(2, dim=-1)
200
+ z = self.reparameterize(z_mu, z_logvar)
201
+
202
+ x_r_mu, x_r_logvar = self.decoder(z).chunk(2, dim=-1)
203
+ x = self.reparameterize(x_r_mu, x_r_logvar)
204
+
205
+ return x, (z_mu, z_logvar, x_r_mu, x_r_logvar)
206
+
207
+ @staticmethod
208
+ def reparameterize(mu, logvar):
209
+ std = torch.exp(0.5 * logvar)
210
+ eps = torch.randn_like(std).to(std)
211
+ return mu + eps * std
@@ -0,0 +1,119 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.nn import init
10
+
11
+ __all__ = [
12
+ 'ConvResidualNet1D',
13
+ ]
14
+
15
+
16
+ class ConvResidualBlock1D(nn.Module):
17
+ def __init__(
18
+ self,
19
+ channels,
20
+ activation=F.gelu,
21
+ dropout_probability=0.0,
22
+ use_batch_norm=False,
23
+ zero_initialization=True,
24
+ kernel_size: int = 3,
25
+ dilation: int = 1,
26
+ padding: int = 1,
27
+ ):
28
+ super().__init__()
29
+ self.activation = activation
30
+
31
+ self.use_batch_norm = use_batch_norm
32
+
33
+ if use_batch_norm:
34
+ self.batch_norm_layers = nn.ModuleList(
35
+ [nn.BatchNorm1d(channels, eps=1e-3) for _ in range(2)]
36
+ )
37
+ self.conv_layers = nn.ModuleList(
38
+ [nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding, dilation=dilation)
39
+ for _ in range(2)]
40
+ )
41
+ self.dropout = nn.Dropout(p=dropout_probability)
42
+
43
+ if zero_initialization:
44
+ init.uniform_(self.conv_layers[-1].weight, -1e-3, 1e-3)
45
+ init.uniform_(self.conv_layers[-1].bias, -1e-3, 1e-3)
46
+
47
+ def forward(self, inputs):
48
+ temps = inputs
49
+ if self.use_batch_norm:
50
+ temps = self.batch_norm_layers[0](temps)
51
+ temps = self.activation(temps)
52
+ temps = self.conv_layers[0](temps)
53
+ if self.use_batch_norm:
54
+ temps = self.batch_norm_layers[1](temps)
55
+
56
+ temps = self.activation(temps)
57
+ temps = self.dropout(temps)
58
+ temps = self.conv_layers[1](temps)
59
+
60
+ return inputs + temps
61
+
62
+
63
+ class ConvResidualNet1D(nn.Module):
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 1,
67
+ out_channels: int = 64,
68
+ hidden_channels: int = 128,
69
+ num_blocks=5,
70
+ activation=F.gelu,
71
+ dropout_probability=0.0,
72
+ use_batch_norm=True,
73
+ kernel_size: int = 3,
74
+ dilation: int = 1,
75
+ padding: int = 1,
76
+ avpool: int = 8,
77
+
78
+ ):
79
+ super().__init__()
80
+
81
+ self.hidden_channels = hidden_channels
82
+
83
+ self.initial_layer = nn.Conv1d(
84
+ in_channels=in_channels,
85
+ out_channels=hidden_channels,
86
+ kernel_size=1,
87
+ padding=0,
88
+ )
89
+ self.blocks = nn.ModuleList(
90
+ [
91
+ ConvResidualBlock1D(
92
+ channels=hidden_channels,
93
+ activation=activation,
94
+ dropout_probability=dropout_probability,
95
+ use_batch_norm=use_batch_norm,
96
+ kernel_size=kernel_size,
97
+ dilation=dilation,
98
+ padding=padding,
99
+ )
100
+ for _ in range(num_blocks)
101
+ ]
102
+ )
103
+
104
+ self.avpool = nn.AdaptiveAvgPool1d(avpool)
105
+
106
+ self.final_layer = nn.Linear(
107
+ hidden_channels * avpool, out_channels
108
+ )
109
+
110
+ def forward(self, x):
111
+ temps = self.initial_layer(x.unsqueeze(1))
112
+
113
+ for block in self.blocks:
114
+ temps = block(temps)
115
+
116
+ temps = self.avpool(temps).view(temps.size(0), -1)
117
+ outputs = self.final_layer(temps)
118
+
119
+ return outputs
@@ -0,0 +1,127 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from reflectorch.models.activations import activation_by_name
6
+
7
+ class SpectralConv1d(nn.Module):
8
+ def __init__(self, in_channels, out_channels, modes):
9
+ super().__init__()
10
+
11
+ """
12
+ 1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
13
+ """
14
+
15
+ self.in_channels = in_channels
16
+ self.out_channels = out_channels
17
+ self.modes = modes #Number of Fourier modes to multiply, at most floor(N/2) + 1
18
+
19
+ self.scale = (1 / (in_channels*out_channels))
20
+ self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, modes, dtype=torch.cfloat))
21
+
22
+ # Complex multiplication
23
+ def compl_mul1d(self, input, weights):
24
+ # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
25
+ return torch.einsum("bix,iox->box", input, weights)
26
+
27
+ def forward(self, x):
28
+ batchsize = x.shape[0]
29
+ #Compute Fourier coeffcients up to factor of e^(- something constant)
30
+ x_ft = torch.fft.rfft(x)
31
+
32
+ # Multiply relevant Fourier modes
33
+ out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat)
34
+ out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights1)
35
+
36
+ #Return to physical space
37
+ x = torch.fft.irfft(out_ft, n=x.size(-1))
38
+ return x
39
+
40
+
41
+ class FnoEncoder(nn.Module):
42
+ """An embedding network based on the Fourier Neural Operator (FNO) architecture
43
+
44
+ .. image:: ../documentation/fig_reflectometry_embedding_networks.png
45
+ :width: 400px
46
+ :align: center
47
+
48
+ Args:
49
+ ch_in (int): number of input channels
50
+ dim_embedding (int): dimension of the output embedding
51
+ modes (int): number of Fourier modes
52
+ width_fno (int): number of channels of the intermediate representations
53
+ n_fno_blocks (int): number of FNO blocks
54
+ activation (str): the activation function
55
+ fusion_self_attention (bool): if ``True`` a fusion layer is used after the FNO blocks to produce the final embedding
56
+ """
57
+ def __init__(
58
+ self,
59
+ ch_in: int = 2,
60
+ dim_embedding: int = 128,
61
+ modes: int = 32,
62
+ width_fno: int = 64,
63
+ n_fno_blocks: int = 6,
64
+ activation: str = 'gelu',
65
+ fusion_self_attention: bool = False,
66
+ ):
67
+ super().__init__()
68
+
69
+
70
+ self.ch_in = ch_in
71
+ self.dim_embedding = dim_embedding
72
+
73
+ self.modes = modes
74
+ self.width_fno = width_fno
75
+ self.n_fno_blocks = n_fno_blocks
76
+ self.activation = activation_by_name(activation)()
77
+ self.fusion_self_attention = fusion_self_attention
78
+
79
+
80
+ self.fc0 = nn.Linear(ch_in, width_fno) #(r(q), q)
81
+ self.spectral_convs = nn.ModuleList([SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)])
82
+ self.w_convs = nn.ModuleList([nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)])
83
+ self.fc_out = nn.Linear(width_fno, dim_embedding)
84
+
85
+ if fusion_self_attention:
86
+ self.fusion = FusionSelfAttention(width_fno, 2*width_fno)
87
+
88
+ def forward(self, x):
89
+ """"""
90
+
91
+ x = x.permute(0, 2, 1) #(B, D, S) -> (B, S, D)
92
+ x = self.fc0(x)
93
+ x = x.permute(0, 2, 1) #(B, S, D) -> (B, D, S)
94
+
95
+ for i in range(self.n_fno_blocks):
96
+ x1 = self.spectral_convs[i](x)
97
+ x2 = self.w_convs[i](x)
98
+
99
+ x = x1 + x2
100
+ x = self.activation(x)
101
+
102
+ if self.fusion_self_attention:
103
+ x = x.permute(0, 2, 1)
104
+ x = self.fusion(x)
105
+ else:
106
+ x = x.mean(dim=-1)
107
+
108
+ x = self.fc_out(x)
109
+
110
+ return x
111
+
112
+ class FusionSelfAttention(nn.Module):
113
+ def __init__(self,
114
+ embed_dim: int = 64,
115
+ hidden_dim: int = 64,
116
+ activation=nn.Tanh,
117
+ ):
118
+ super().__init__()
119
+ self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
120
+ activation(),
121
+ nn.Linear(hidden_dim, 1, bias=False))
122
+
123
+ def forward(self, c): # (batch_size x seq_len x embed_dim)
124
+ a = self.fuser(c)
125
+ alpha = torch.exp(a)
126
+ alpha = alpha/alpha.sum(dim=1, keepdim=True)
127
+ return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
@@ -0,0 +1,56 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class TransformerEncoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ dim: int = 64,
16
+ nhead: int = 8,
17
+ num_encoder_layers: int = 4,
18
+ num_decoder_layers: int = 2,
19
+ dim_feedforward: int = 512,
20
+ dropout: float = 0.01,
21
+ activation: str = 'gelu',
22
+ in_dim: int = 2,
23
+ out_dim: int = None,
24
+ ):
25
+
26
+ super().__init__()
27
+
28
+ self.in_projector = nn.Linear(in_dim, dim)
29
+
30
+ self.dim = dim
31
+
32
+ self.transformer = nn.Transformer(
33
+ dim, nhead=nhead,
34
+ num_encoder_layers=num_encoder_layers,
35
+ num_decoder_layers=num_decoder_layers,
36
+ dim_feedforward=dim_feedforward,
37
+ dropout=dropout,
38
+ activation=activation
39
+ )
40
+
41
+ if out_dim:
42
+ self.out_projector = nn.Linear(dim, out_dim)
43
+ else:
44
+ self.out_projector = None
45
+
46
+ def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, **kwargs):
47
+ src = self.in_projector(src.transpose(1, 2)).transpose(0, 1)
48
+
49
+ res = self.transformer(
50
+ src, tgt, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, **kwargs
51
+ )
52
+
53
+ if self.out_projector:
54
+ res = self.out_projector(res).squeeze(-1)
55
+
56
+ return res.squeeze(0)
@@ -0,0 +1,18 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from reflectorch.models.networks.mlp_networks import (
8
+ NetworkWithPriorsConvEmb,
9
+ NetworkWithPriorsFnoEmb,
10
+ )
11
+ from reflectorch.models.networks.residual_net import ResidualMLP
12
+
13
+
14
+ __all__ = [
15
+ "ResidualMLP",
16
+ "NetworkWithPriorsConvEmb",
17
+ "NetworkWithPriorsFnoEmb",
18
+ ]