reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (39) hide show
  1. reflectorch/data_generation/__init__.py +2 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +90 -15
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +31 -11
  8. reflectorch/data_generation/reflectivity/__init__.py +56 -14
  9. reflectorch/data_generation/reflectivity/abeles.py +31 -16
  10. reflectorch/data_generation/reflectivity/kinematical.py +5 -6
  11. reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
  12. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  13. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  14. reflectorch/data_generation/smearing.py +42 -11
  15. reflectorch/data_generation/utils.py +92 -18
  16. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  17. reflectorch/inference/inference_model.py +220 -105
  18. reflectorch/inference/plotting.py +98 -0
  19. reflectorch/inference/scipy_fitter.py +84 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +122 -23
  26. reflectorch/models/__init__.py +1 -1
  27. reflectorch/models/encoders/__init__.py +0 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/networks/__init__.py +2 -0
  31. reflectorch/models/networks/mlp_networks.py +324 -152
  32. reflectorch/models/networks/residual_net.py +31 -5
  33. reflectorch/runs/train.py +0 -1
  34. reflectorch/runs/utils.py +43 -9
  35. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
  36. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
  37. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
  38. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
  39. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ __all__ = [
11
11
  "ConvEncoder",
12
12
  "ConvDecoder",
13
13
  "ConvAutoencoder",
14
- "ConvVAE",
15
14
  ]
16
15
 
17
16
  logger = logging.getLogger(__name__)
@@ -23,7 +22,7 @@ class ConvEncoder(nn.Module):
23
22
  Args:
24
23
  in_channels (int, optional): the number of input channels. Defaults to 1.
25
24
  hidden_channels (tuple, optional): the number of intermediate channels of each convolutional layer. Defaults to (32, 64, 128, 256, 512).
26
- dim_latent (int, optional): the dimension of the output latent embedding. Defaults to 64.
25
+ dim_embedding (int, optional): the dimension of the output latent embedding. Defaults to 64.
27
26
  dim_avpool (int, optional): the output size of the adaptive average pooling layer. Defaults to 1.
28
27
  use_batch_norm (bool, optional): whether to use batch normalization. Defaults to True.
29
28
  activation (str, optional): the type of activation function. Defaults to 'relu'.
@@ -31,9 +30,11 @@ class ConvEncoder(nn.Module):
31
30
  def __init__(self,
32
31
  in_channels: int = 1,
33
32
  hidden_channels: tuple = (32, 64, 128, 256, 512),
34
- dim_latent: int = 64,
33
+ kernel_size: int = 3,
34
+ dim_embedding: int = 64,
35
35
  dim_avpool: int = 1,
36
36
  use_batch_norm: bool = True,
37
+ use_se: bool = False,
37
38
  activation: str = 'relu',
38
39
  ):
39
40
  super().__init__()
@@ -44,22 +45,24 @@ class ConvEncoder(nn.Module):
44
45
 
45
46
  for h in hidden_channels:
46
47
  layers = [
47
- nn.Conv1d(in_channels, out_channels=h, kernel_size=3, stride=2, padding=1),
48
+ nn.Conv1d(in_channels, out_channels=h, kernel_size=kernel_size, stride=2, padding=kernel_size // 2),
48
49
  activation(),
49
50
  ]
50
51
 
51
52
  if use_batch_norm:
52
53
  layers.insert(1, nn.BatchNorm1d(h))
53
54
 
55
+ if use_se:
56
+ layers.insert(2, SEBlock(h))
57
+
54
58
  modules.append(nn.Sequential(*layers))
55
59
  in_channels = h
56
60
 
57
61
  self.core = nn.Sequential(*modules)
58
62
  self.avpool = nn.AdaptiveAvgPool1d(dim_avpool)
59
- self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_latent)
63
+ self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_embedding)
60
64
 
61
65
  def forward(self, x):
62
- """"""
63
66
  if len(x.shape) < 3:
64
67
  x = x.unsqueeze(1)
65
68
  x = self.core(x)
@@ -100,6 +103,7 @@ class ConvDecoder(nn.Module):
100
103
  hidden_channels: tuple = (512, 256, 128, 64, 32),
101
104
  dim_latent: int = 64,
102
105
  in_size: int = 8,
106
+ kernel_size: int = 3,
103
107
  use_batch_norm: bool = True,
104
108
  activation: str = 'relu',
105
109
  ):
@@ -119,9 +123,9 @@ class ConvDecoder(nn.Module):
119
123
  nn.ConvTranspose1d(
120
124
  hidden_channels[i],
121
125
  hidden_channels[i + 1],
122
- kernel_size=3,
126
+ kernel_size=kernel_size, #3
123
127
  stride=2,
124
- padding=1,
128
+ padding=kernel_size // 2, #1
125
129
  output_padding=1,
126
130
  ),
127
131
  nn.BatchNorm1d(hidden_channels[i + 1]) if use_batch_norm else nn.Identity(),
@@ -134,9 +138,9 @@ class ConvDecoder(nn.Module):
134
138
  self.final_layer = nn.Sequential(
135
139
  nn.ConvTranspose1d(hidden_channels[-1],
136
140
  hidden_channels[-1],
137
- kernel_size=3,
141
+ kernel_size=kernel_size, #3
138
142
  stride=2,
139
- padding=1,
143
+ padding=kernel_size // 2, #1
140
144
  output_padding=1),
141
145
  nn.BatchNorm1d(hidden_channels[-1]) if use_batch_norm else nn.Identity(),
142
146
  activation(),
@@ -160,46 +164,56 @@ class ConvAutoencoder(nn.Module):
160
164
  decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
161
165
  dim_latent: int = 64,
162
166
  dim_avpool: int = 1,
167
+ kernel_size: int = 3,
163
168
  use_batch_norm: bool = True,
164
169
  activation: str = 'relu',
165
170
  decoder_in_size: int = 8,
166
171
  **kwargs
167
172
  ):
168
173
  super().__init__()
169
- self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
170
- self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
174
+ self.encoder = ConvEncoder(
175
+ in_channels=in_channels,
176
+ hidden_channels=encoder_hidden_channels,
177
+ kernel_size=kernel_size,
178
+ dim_embedding=dim_latent,
179
+ dim_avpool=dim_avpool,
180
+ use_batch_norm=use_batch_norm,
181
+ activation=activation,
182
+ **kwargs)
183
+
184
+ self.decoder = ConvDecoder(
185
+ hidden_channels=decoder_hidden_channels,
186
+ dim_latent=dim_latent,
187
+ in_size=decoder_in_size,
188
+ kernel_size=kernel_size,
189
+ use_batch_norm=use_batch_norm,
190
+ activation=activation,
191
+ **kwargs)
171
192
 
172
193
  def forward(self, x):
173
194
  return self.decoder(self.encoder(x))
174
195
 
175
- class ConvVAE(nn.Module):
176
- """A 1D convolutional variational autoencoder"""
177
- def __init__(self,
178
- in_channels: int = 1,
179
- encoder_hidden_channels: tuple = (32, 64, 128, 256, 512),
180
- decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
181
- dim_latent: int = 64,
182
- dim_avpool: int = 1,
183
- use_batch_norm: bool = True,
184
- activation: str = 'relu',
185
- decoder_in_size: int = 8,
186
- **kwargs
187
- ):
196
+ class SEBlock(nn.Module):
197
+ """Squeeze-and-excitation block (https://arxiv.org/abs/1709.01507) """
198
+ def __init__(self, in_channels, reduction=16):
188
199
  super().__init__()
189
- self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, 2*dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
190
- self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
200
+ self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
201
+ self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
202
+ self.relu = nn.ReLU()
203
+ self.sigmoid = nn.Sigmoid()
204
+ self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
191
205
 
192
206
  def forward(self, x):
193
- z_mu, z_logvar = self.encoder(x).chunk(2, dim=-1)
194
- z = self.reparameterize(z_mu, z_logvar)
195
-
196
- x_r_mu, x_r_logvar = self.decoder(z).chunk(2, dim=-1)
197
- x = self.reparameterize(x_r_mu, x_r_logvar)
198
-
199
- return x, (z_mu, z_logvar, x_r_mu, x_r_logvar)
200
-
201
- @staticmethod
202
- def reparameterize(mu, logvar):
203
- std = torch.exp(0.5 * logvar)
204
- eps = torch.randn_like(std).to(std)
205
- return mu + eps * std
207
+ batch_size, channels, _ = x.size()
208
+
209
+ #Squeeze
210
+ se = self.global_avg_pool(x).view(batch_size, channels)
211
+
212
+ #Excitation
213
+ se = self.fc1(se)
214
+ se = self.relu(se)
215
+ se = self.fc2(se)
216
+ se = self.sigmoid(se).view(batch_size, channels, 1)
217
+
218
+ #Scale the input feature maps (channel-wise attention)
219
+ return x * se
@@ -46,28 +46,30 @@ class FnoEncoder(nn.Module):
46
46
  :align: center
47
47
 
48
48
  Args:
49
- ch_in (int): number of input channels
49
+ in_channels (int): number of input channels
50
50
  dim_embedding (int): dimension of the output embedding
51
51
  modes (int): number of Fourier modes
52
52
  width_fno (int): number of channels of the intermediate representations
53
53
  n_fno_blocks (int): number of FNO blocks
54
54
  activation (str): the activation function
55
- fusion_self_attention (bool): if ``True`` a fusion layer is used after the FNO blocks to produce the final embedding
55
+ fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
56
+ fsa_activation (str): the activation function of the fusion self attention block
56
57
  """
57
58
  def __init__(
58
59
  self,
59
- ch_in: int = 2,
60
+ in_channels: int = 2,
60
61
  dim_embedding: int = 128,
61
62
  modes: int = 32,
62
63
  width_fno: int = 64,
63
64
  n_fno_blocks: int = 6,
64
65
  activation: str = 'gelu',
65
66
  fusion_self_attention: bool = False,
67
+ fsa_activation: str = 'tanh',
66
68
  ):
67
69
  super().__init__()
68
70
 
69
71
 
70
- self.ch_in = ch_in
72
+ self.in_channels = in_channels
71
73
  self.dim_embedding = dim_embedding
72
74
 
73
75
  self.modes = modes
@@ -77,13 +79,17 @@ class FnoEncoder(nn.Module):
77
79
  self.fusion_self_attention = fusion_self_attention
78
80
 
79
81
 
80
- self.fc0 = nn.Linear(ch_in, width_fno) #(r(q), q)
81
- self.spectral_convs = nn.ModuleList([SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)])
82
- self.w_convs = nn.ModuleList([nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)])
82
+ self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
83
+ self.spectral_convs = nn.ModuleList([
84
+ SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
85
+ ])
86
+ self.w_convs = nn.ModuleList([
87
+ nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
88
+ ])
83
89
  self.fc_out = nn.Linear(width_fno, dim_embedding)
84
90
 
85
91
  if fusion_self_attention:
86
- self.fusion = FusionSelfAttention(width_fno, 2*width_fno)
92
+ self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
87
93
 
88
94
  def forward(self, x):
89
95
  """"""
@@ -109,19 +115,20 @@ class FnoEncoder(nn.Module):
109
115
 
110
116
  return x
111
117
 
118
+
112
119
  class FusionSelfAttention(nn.Module):
113
- def __init__(self,
114
- embed_dim: int = 64,
115
- hidden_dim: int = 64,
116
- activation=nn.Tanh,
117
- ):
120
+ def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
118
121
  super().__init__()
122
+ activation = activation_by_name(activation)()
119
123
  self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
120
- activation(),
124
+ activation,
121
125
  nn.Linear(hidden_dim, 1, bias=False))
122
126
 
123
- def forward(self, c): # (batch_size x seq_len x embed_dim)
127
+ def forward(self,
128
+ c: torch.Tensor, # (batch_size x seq_len x embed_dim)
129
+ mask: torch.Tensor = None, # (batch_size x seq_len)
130
+ ):
124
131
  a = self.fuser(c)
125
- alpha = torch.exp(a)
132
+ alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
126
133
  alpha = alpha/alpha.sum(dim=1, keepdim=True)
127
134
  return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
@@ -1,4 +1,5 @@
1
1
  from reflectorch.models.networks.mlp_networks import (
2
+ NetworkWithPriors,
2
3
  NetworkWithPriorsConvEmb,
3
4
  NetworkWithPriorsFnoEmb,
4
5
  )
@@ -7,6 +8,7 @@ from reflectorch.models.networks.residual_net import ResidualMLP
7
8
 
8
9
  __all__ = [
9
10
  "ResidualMLP",
11
+ "NetworkWithPriors",
10
12
  "NetworkWithPriorsConvEmb",
11
13
  "NetworkWithPriorsFnoEmb",
12
14
  ]