reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ __all__ = [
11
11
  "ConvEncoder",
12
12
  "ConvDecoder",
13
13
  "ConvAutoencoder",
14
- "ConvVAE",
15
14
  ]
16
15
 
17
16
  logger = logging.getLogger(__name__)
@@ -23,7 +22,7 @@ class ConvEncoder(nn.Module):
23
22
  Args:
24
23
  in_channels (int, optional): the number of input channels. Defaults to 1.
25
24
  hidden_channels (tuple, optional): the number of intermediate channels of each convolutional layer. Defaults to (32, 64, 128, 256, 512).
26
- dim_latent (int, optional): the dimension of the output latent embedding. Defaults to 64.
25
+ dim_embedding (int, optional): the dimension of the output latent embedding. Defaults to 64.
27
26
  dim_avpool (int, optional): the output size of the adaptive average pooling layer. Defaults to 1.
28
27
  use_batch_norm (bool, optional): whether to use batch normalization. Defaults to True.
29
28
  activation (str, optional): the type of activation function. Defaults to 'relu'.
@@ -31,9 +30,11 @@ class ConvEncoder(nn.Module):
31
30
  def __init__(self,
32
31
  in_channels: int = 1,
33
32
  hidden_channels: tuple = (32, 64, 128, 256, 512),
34
- dim_latent: int = 64,
33
+ kernel_size: int = 3,
34
+ dim_embedding: int = 64,
35
35
  dim_avpool: int = 1,
36
36
  use_batch_norm: bool = True,
37
+ use_se: bool = False,
37
38
  activation: str = 'relu',
38
39
  ):
39
40
  super().__init__()
@@ -44,22 +45,24 @@ class ConvEncoder(nn.Module):
44
45
 
45
46
  for h in hidden_channels:
46
47
  layers = [
47
- nn.Conv1d(in_channels, out_channels=h, kernel_size=3, stride=2, padding=1),
48
+ nn.Conv1d(in_channels, out_channels=h, kernel_size=kernel_size, stride=2, padding=kernel_size // 2),
48
49
  activation(),
49
50
  ]
50
51
 
51
52
  if use_batch_norm:
52
53
  layers.insert(1, nn.BatchNorm1d(h))
53
54
 
55
+ if use_se:
56
+ layers.insert(2, SEBlock(h))
57
+
54
58
  modules.append(nn.Sequential(*layers))
55
59
  in_channels = h
56
60
 
57
61
  self.core = nn.Sequential(*modules)
58
62
  self.avpool = nn.AdaptiveAvgPool1d(dim_avpool)
59
- self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_latent)
63
+ self.fc = nn.Linear(hidden_channels[-1] * dim_avpool, dim_embedding)
60
64
 
61
65
  def forward(self, x):
62
- """"""
63
66
  if len(x.shape) < 3:
64
67
  x = x.unsqueeze(1)
65
68
  x = self.core(x)
@@ -100,6 +103,7 @@ class ConvDecoder(nn.Module):
100
103
  hidden_channels: tuple = (512, 256, 128, 64, 32),
101
104
  dim_latent: int = 64,
102
105
  in_size: int = 8,
106
+ kernel_size: int = 3,
103
107
  use_batch_norm: bool = True,
104
108
  activation: str = 'relu',
105
109
  ):
@@ -119,9 +123,9 @@ class ConvDecoder(nn.Module):
119
123
  nn.ConvTranspose1d(
120
124
  hidden_channels[i],
121
125
  hidden_channels[i + 1],
122
- kernel_size=3,
126
+ kernel_size=kernel_size, #3
123
127
  stride=2,
124
- padding=1,
128
+ padding=kernel_size // 2, #1
125
129
  output_padding=1,
126
130
  ),
127
131
  nn.BatchNorm1d(hidden_channels[i + 1]) if use_batch_norm else nn.Identity(),
@@ -134,9 +138,9 @@ class ConvDecoder(nn.Module):
134
138
  self.final_layer = nn.Sequential(
135
139
  nn.ConvTranspose1d(hidden_channels[-1],
136
140
  hidden_channels[-1],
137
- kernel_size=3,
141
+ kernel_size=kernel_size, #3
138
142
  stride=2,
139
- padding=1,
143
+ padding=kernel_size // 2, #1
140
144
  output_padding=1),
141
145
  nn.BatchNorm1d(hidden_channels[-1]) if use_batch_norm else nn.Identity(),
142
146
  activation(),
@@ -160,46 +164,56 @@ class ConvAutoencoder(nn.Module):
160
164
  decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
161
165
  dim_latent: int = 64,
162
166
  dim_avpool: int = 1,
167
+ kernel_size: int = 3,
163
168
  use_batch_norm: bool = True,
164
169
  activation: str = 'relu',
165
170
  decoder_in_size: int = 8,
166
171
  **kwargs
167
172
  ):
168
173
  super().__init__()
169
- self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
170
- self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
174
+ self.encoder = ConvEncoder(
175
+ in_channels=in_channels,
176
+ hidden_channels=encoder_hidden_channels,
177
+ kernel_size=kernel_size,
178
+ dim_embedding=dim_latent,
179
+ dim_avpool=dim_avpool,
180
+ use_batch_norm=use_batch_norm,
181
+ activation=activation,
182
+ **kwargs)
183
+
184
+ self.decoder = ConvDecoder(
185
+ hidden_channels=decoder_hidden_channels,
186
+ dim_latent=dim_latent,
187
+ in_size=decoder_in_size,
188
+ kernel_size=kernel_size,
189
+ use_batch_norm=use_batch_norm,
190
+ activation=activation,
191
+ **kwargs)
171
192
 
172
193
  def forward(self, x):
173
194
  return self.decoder(self.encoder(x))
174
195
 
175
- class ConvVAE(nn.Module):
176
- """A 1D convolutional variational autoencoder"""
177
- def __init__(self,
178
- in_channels: int = 1,
179
- encoder_hidden_channels: tuple = (32, 64, 128, 256, 512),
180
- decoder_hidden_channels: tuple = (512, 256, 128, 64, 32),
181
- dim_latent: int = 64,
182
- dim_avpool: int = 1,
183
- use_batch_norm: bool = True,
184
- activation: str = 'relu',
185
- decoder_in_size: int = 8,
186
- **kwargs
187
- ):
196
+ class SEBlock(nn.Module):
197
+ """Squeeze-and-excitation block (https://arxiv.org/abs/1709.01507) """
198
+ def __init__(self, in_channels, reduction=16):
188
199
  super().__init__()
189
- self.encoder = ConvEncoder(in_channels, encoder_hidden_channels, 2*dim_latent, dim_avpool, use_batch_norm, activation, **kwargs)
190
- self.decoder = ConvDecoder(decoder_hidden_channels, dim_latent, decoder_in_size, use_batch_norm, activation, **kwargs)
200
+ self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
201
+ self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
202
+ self.relu = nn.ReLU()
203
+ self.sigmoid = nn.Sigmoid()
204
+ self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
191
205
 
192
206
  def forward(self, x):
193
- z_mu, z_logvar = self.encoder(x).chunk(2, dim=-1)
194
- z = self.reparameterize(z_mu, z_logvar)
195
-
196
- x_r_mu, x_r_logvar = self.decoder(z).chunk(2, dim=-1)
197
- x = self.reparameterize(x_r_mu, x_r_logvar)
198
-
199
- return x, (z_mu, z_logvar, x_r_mu, x_r_logvar)
200
-
201
- @staticmethod
202
- def reparameterize(mu, logvar):
203
- std = torch.exp(0.5 * logvar)
204
- eps = torch.randn_like(std).to(std)
205
- return mu + eps * std
207
+ batch_size, channels, _ = x.size()
208
+
209
+ #Squeeze
210
+ se = self.global_avg_pool(x).view(batch_size, channels)
211
+
212
+ #Excitation
213
+ se = self.fc1(se)
214
+ se = self.relu(se)
215
+ se = self.fc2(se)
216
+ se = self.sigmoid(se).view(batch_size, channels, 1)
217
+
218
+ #Scale the input feature maps (channel-wise attention)
219
+ return x * se
@@ -46,28 +46,30 @@ class FnoEncoder(nn.Module):
46
46
  :align: center
47
47
 
48
48
  Args:
49
- ch_in (int): number of input channels
49
+ in_channels (int): number of input channels
50
50
  dim_embedding (int): dimension of the output embedding
51
51
  modes (int): number of Fourier modes
52
52
  width_fno (int): number of channels of the intermediate representations
53
53
  n_fno_blocks (int): number of FNO blocks
54
54
  activation (str): the activation function
55
- fusion_self_attention (bool): if ``True`` a fusion layer is used after the FNO blocks to produce the final embedding
55
+ fusion_self_attention (bool): whether to use fusion self attention for merging the tokens (instead of mean)
56
+ fsa_activation (str): the activation function of the fusion self attention block
56
57
  """
57
58
  def __init__(
58
59
  self,
59
- ch_in: int = 2,
60
+ in_channels: int = 2,
60
61
  dim_embedding: int = 128,
61
62
  modes: int = 32,
62
63
  width_fno: int = 64,
63
64
  n_fno_blocks: int = 6,
64
65
  activation: str = 'gelu',
65
66
  fusion_self_attention: bool = False,
67
+ fsa_activation: str = 'tanh',
66
68
  ):
67
69
  super().__init__()
68
70
 
69
71
 
70
- self.ch_in = ch_in
72
+ self.in_channels = in_channels
71
73
  self.dim_embedding = dim_embedding
72
74
 
73
75
  self.modes = modes
@@ -77,13 +79,17 @@ class FnoEncoder(nn.Module):
77
79
  self.fusion_self_attention = fusion_self_attention
78
80
 
79
81
 
80
- self.fc0 = nn.Linear(ch_in, width_fno) #(r(q), q)
81
- self.spectral_convs = nn.ModuleList([SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)])
82
- self.w_convs = nn.ModuleList([nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)])
82
+ self.fc0 = nn.Linear(in_channels, width_fno) #(r(q), q)
83
+ self.spectral_convs = nn.ModuleList([
84
+ SpectralConv1d(in_channels=width_fno, out_channels=width_fno, modes=modes) for _ in range(n_fno_blocks)
85
+ ])
86
+ self.w_convs = nn.ModuleList([
87
+ nn.Conv1d(in_channels=width_fno, out_channels=width_fno, kernel_size=1) for _ in range(n_fno_blocks)
88
+ ])
83
89
  self.fc_out = nn.Linear(width_fno, dim_embedding)
84
90
 
85
91
  if fusion_self_attention:
86
- self.fusion = FusionSelfAttention(width_fno, 2*width_fno)
92
+ self.fusion = FusionSelfAttention(embed_dim=width_fno, hidden_dim=2*width_fno, activation=fsa_activation)
87
93
 
88
94
  def forward(self, x):
89
95
  """"""
@@ -109,19 +115,20 @@ class FnoEncoder(nn.Module):
109
115
 
110
116
  return x
111
117
 
118
+
112
119
  class FusionSelfAttention(nn.Module):
113
- def __init__(self,
114
- embed_dim: int = 64,
115
- hidden_dim: int = 64,
116
- activation=nn.Tanh,
117
- ):
120
+ def __init__(self, embed_dim: int = 64, hidden_dim: int = 64, activation: str = 'gelu'):
118
121
  super().__init__()
122
+ activation = activation_by_name(activation)()
119
123
  self.fuser = nn.Sequential(nn.Linear(embed_dim, hidden_dim),
120
- activation(),
124
+ activation,
121
125
  nn.Linear(hidden_dim, 1, bias=False))
122
126
 
123
- def forward(self, c): # (batch_size x seq_len x embed_dim)
127
+ def forward(self,
128
+ c: torch.Tensor, # (batch_size x seq_len x embed_dim)
129
+ mask: torch.Tensor = None, # (batch_size x seq_len)
130
+ ):
124
131
  a = self.fuser(c)
125
- alpha = torch.exp(a)
132
+ alpha = torch.exp(a)*mask.unsqueeze(-1) if mask is not None else torch.exp(a)
126
133
  alpha = alpha/alpha.sum(dim=1, keepdim=True)
127
134
  return (alpha*c).sum(dim=1) # (batch_size x embed_dim)
@@ -0,0 +1,390 @@
1
+ from __future__ import annotations
2
+ from typing import Union
3
+
4
+ import torch
5
+ from torch import nn, Tensor, stack, cat
6
+ from reflectorch.models.activations import activation_by_name
7
+ import reflectorch
8
+
9
+ ###embedding network adapted from the PANPE repository
10
+
11
+ __all__ = [
12
+ "IntegralConvEmbedding",
13
+ ]
14
+
15
+ class IntegralConvEmbedding(nn.Module):
16
+ def __init__(
17
+ self,
18
+ z_num: Union[int, tuple[int, ...]],
19
+ z_range: tuple[float, float] = None,
20
+ in_dim: int = 2,
21
+ kernel_coef: int = 16,
22
+ dim_embedding: int = 256,
23
+ conv_dims: tuple[int, ...] = (32, 64, 128),
24
+ num_blocks: int = 4,
25
+ use_batch_norm: bool = False,
26
+ use_layer_norm: bool = True,
27
+ use_fft: bool = False,
28
+ activation: str = "gelu",
29
+ conv_activation: str = "lrelu",
30
+ resnet_activation: str = "relu",
31
+ ) -> None:
32
+ super().__init__()
33
+
34
+ if isinstance(z_num, int):
35
+ z_num = (z_num,)
36
+ num_kernel = len(z_num)
37
+
38
+ if z_range is not None:
39
+ zs = [(z_range[0], z_range[1], nz) for nz in z_num]
40
+ else:
41
+ zs = z_num
42
+
43
+ self.in_dim = in_dim
44
+
45
+ self.kernels = nn.ModuleList(
46
+ [
47
+ IntegralKernelBlock(
48
+ z,
49
+ in_dim,
50
+ kernel_coef=kernel_coef,
51
+ latent_dim=dim_embedding,
52
+ conv_dims=conv_dims,
53
+ use_fft=use_fft,
54
+ activation=activation,
55
+ conv_activation=conv_activation,
56
+ )
57
+ for z in zs
58
+ ]
59
+ )
60
+
61
+ self.fc = reflectorch.models.networks.residual_net.ResidualMLP(
62
+ dim_in=dim_embedding * num_kernel,
63
+ dim_out=dim_embedding,
64
+ layer_width=2 * dim_embedding,
65
+ num_blocks=num_blocks,
66
+ use_batch_norm=use_batch_norm,
67
+ use_layer_norm=use_layer_norm,
68
+ activation=resnet_activation,
69
+ )
70
+
71
+ def forward(self, q, y, drop_mask=None) -> Tensor:
72
+ x = cat([kernel(q, y, drop_mask=drop_mask) for kernel in self.kernels], dim=-1)
73
+ x = self.fc(x)
74
+
75
+ return x
76
+
77
+
78
+ class IntegralKernelBlock(nn.Module):
79
+ """
80
+ Examples:
81
+ >>> x = torch.rand(2, 100)
82
+ >>> y = torch.rand(2, 100, 3)
83
+ >>> block = IntegralKernelBlock((0, 1, 10), in_dim=3, latent_dim=32)
84
+ >>> output = block(x, y)
85
+ >>> output.shape
86
+ torch.Size([2, 32])
87
+
88
+ >>> block = IntegralKernelBlock(10, in_dim=3, latent_dim=32)
89
+ >>> output = block(x, y)
90
+ >>> output.shape
91
+ torch.Size([2, 32])
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ z: tuple[float, float, int] or int,
97
+ in_dim: int,
98
+ kernel_coef: int = 2,
99
+ latent_dim: int = 32,
100
+ conv_dims: tuple[int, ...] = (32, 64, 128),
101
+ use_fft: bool = False,
102
+ activation: str = "gelu",
103
+ conv_activation: str = "lrelu",
104
+ ):
105
+ super().__init__()
106
+
107
+ if isinstance(z, int):
108
+ z_num = z
109
+ kernel = FullIntegralKernel(z_num, in_dim=in_dim, kernel_coef=kernel_coef)
110
+ else:
111
+ kernel = FastIntegralKernel(
112
+ z, in_dim=in_dim, kernel_coef=kernel_coef, activation=activation
113
+ )
114
+ z_num = z[-1]
115
+
116
+ assert z_num % 2 == 0, "z_num should be even"
117
+
118
+ self.kernel = kernel
119
+ self.z_num = z_num
120
+ self.in_dim = in_dim
121
+ self.latent_dim = latent_dim
122
+ self.use_fft = use_fft
123
+
124
+ self.fc_in_dim = self.latent_dim + self.in_dim * self.z_num
125
+ if self.use_fft:
126
+ self.fc_in_dim += self.in_dim * 2 + self.in_dim * self.z_num
127
+
128
+ self.conv = reflectorch.models.encoders.conv_encoder.ConvEncoder(
129
+ dim_avpool=8,
130
+ hidden_channels=conv_dims,
131
+ in_channels=in_dim,
132
+ dim_embedding=latent_dim,
133
+ activation=conv_activation,
134
+ )
135
+ self.fc = FCBlock(
136
+ in_dim=self.fc_in_dim, hid_dim=self.latent_dim * 2, out_dim=self.latent_dim
137
+ )
138
+
139
+ def forward(self, x: Tensor, y: Tensor, drop_mask: Tensor = None) -> Tensor:
140
+ x = self.kernel(x, y, drop_mask=drop_mask)
141
+
142
+ assert x.shape == (x.shape[0], self.in_dim, self.z_num)
143
+
144
+ xc = self.conv(x) # (batch, latent_dim)
145
+
146
+ assert xc.shape == (x.shape[0], self.latent_dim)
147
+
148
+ if self.use_fft:
149
+ fft_x = torch.fft.rfft(x, dim=-1, norm="ortho") # (batch, in_dim, z_num)
150
+
151
+ fft_x = torch.cat(
152
+ [fft_x.real, fft_x.imag], -1
153
+ ) # (batch, in_dim, 2 * z_num)
154
+
155
+ assert fft_x.shape == (x.shape[0], x.shape[1], self.z_num + 2)
156
+
157
+ fft_x = fft_x.flatten(1) # (batch, in_dim * (z_num + 2))
158
+
159
+ x = torch.cat(
160
+ [x.flatten(1), fft_x, xc], -1
161
+ ) # (batch, in_dim * z_num * 3 + latent_dim)
162
+ else:
163
+ x = torch.cat([x.flatten(1), xc], -1)
164
+
165
+ assert (
166
+ x.shape[1] == self.fc_in_dim
167
+ ), f"Expected dim {self.fc_in_dim}, got {x.shape[1]}"
168
+
169
+ x = self.fc(x) # (batch, latent_dim)
170
+
171
+ return x
172
+
173
+
174
+ class FastIntegralKernel(nn.Module):
175
+ def __init__(
176
+ self,
177
+ z: tuple[float, float, int],
178
+ kernel_coef: int = 16,
179
+ in_dim: int = 1,
180
+ activation: str = "gelu",
181
+ ):
182
+ super().__init__()
183
+
184
+ z = torch.linspace(*z)
185
+
186
+ self.kernel = FCBlock(
187
+ in_dim + 2, kernel_coef * in_dim, in_dim, activation=activation
188
+ )
189
+
190
+ self.register_buffer("z", z)
191
+
192
+ def _get_z(self, x: Tensor):
193
+ # x.shape == (batch_size, num_x)
194
+ dz = self.z[1] - self.z[0]
195
+ indices = torch.ceil((x - self.z[0] - dz / 2) / dz).to(torch.int64)
196
+
197
+ z = torch.index_select(self.z, 0, indices.flatten()).view(*x.shape)
198
+
199
+ return z, indices
200
+
201
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
202
+ z, indices = self._get_z(x)
203
+ xz = torch.stack([x, z], -1)
204
+ kernel_input = torch.cat([xz, y], -1)
205
+ output = self.kernel(kernel_input) # (batch, x_num, in_dim)
206
+
207
+ output = compute_means(
208
+ output * y, indices, self.z.shape[-1], drop_mask=drop_mask
209
+ ) # (batch, z_num, in_dim)
210
+
211
+ output = output.swapaxes(1, 2) # (batch, in_dim, z_num)
212
+
213
+ return output
214
+
215
+
216
+ class FullIntegralKernel(nn.Module):
217
+ def __init__(
218
+ self,
219
+ z_num: int,
220
+ kernel_coef: int = 1,
221
+ in_dim: int = 1,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.z_num = z_num
226
+ self.in_dim = in_dim
227
+
228
+ self.kernel = nn.Sequential(
229
+ nn.Linear(in_dim + 1, z_num * kernel_coef),
230
+ nn.LayerNorm(z_num * kernel_coef),
231
+ nn.ReLU(),
232
+ nn.Linear(z_num * kernel_coef, z_num * in_dim),
233
+ )
234
+
235
+ def forward(self, x: Tensor, y: Tensor, drop_mask=None):
236
+ # x.shape == (batch_size, num_x)
237
+ # y.shape == (batch_size, num_x, in_dim)
238
+ # drop_mask.shape == (batch_size, num_x)
239
+
240
+ batch_size, num_x = x.shape
241
+
242
+ kernel_input = torch.cat([x.unsqueeze(-1), y], -1) # (batch, x_num, in_dim + 1)
243
+ x = self.kernel(kernel_input) # (batch, x_num, z_num * in_dim)
244
+ x = x.reshape(
245
+ *x.shape[:-1], self.z_num, self.in_dim
246
+ ) # (batch, x_num, z_num, in_dim)
247
+ # permute to get (batch, z_num, x_num, in_dim)
248
+ x = x.permute(0, 2, 1, 3)
249
+
250
+ y = y.unsqueeze(1) # (batch, 1, x_num, in_dim)
251
+
252
+ assert x.shape == (
253
+ batch_size,
254
+ self.z_num,
255
+ num_x,
256
+ self.in_dim,
257
+ ) # (batch, z_num, in_dim, x_num)
258
+ assert y.shape == (
259
+ batch_size,
260
+ 1,
261
+ num_x,
262
+ self.in_dim,
263
+ ) # (batch, 1, x_num, in_dim)
264
+
265
+ if drop_mask is not None:
266
+ x = x * y
267
+ x = x.permute(0, 2, 1, 3) # (batch, x_num, z_num, in_dim)
268
+ x = masked_mean(x, drop_mask)
269
+ else:
270
+ x = (x * y).mean(-2) # (batch, z_num, in_dim)
271
+
272
+ assert x.shape == (batch_size, self.z_num, self.in_dim), f"{x.shape}"
273
+
274
+ x = x.swapaxes(1, 2) # (batch, in_dim, z_num)
275
+
276
+ return x
277
+
278
+
279
+ class FCBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ in_dim: int = 2,
283
+ hid_dim: int = 16,
284
+ out_dim: int = 16,
285
+ activation: str = "gelu",
286
+ ):
287
+ super().__init__()
288
+
289
+ self.fc1 = nn.Linear(in_dim, hid_dim)
290
+ self.layer_norm = nn.LayerNorm(hid_dim)
291
+ self.activation = activation_by_name(activation)()
292
+ self.fc2 = nn.Linear(hid_dim, out_dim)
293
+
294
+ def forward(self, x: Tensor) -> Tensor:
295
+ x = self.fc1(x)
296
+ x = self.layer_norm(x)
297
+ x = self.activation(x)
298
+ x = self.fc2(x)
299
+ return x
300
+ # return self.kernel(x)
301
+
302
+
303
+ def compute_means(x, indices, z: int, drop_mask: Tensor = None):
304
+ """
305
+ Compute the mean values of tensor 'x' for each unique index in 'indices' across each batch.
306
+
307
+ This function calculates the mean of elements in 'x' that correspond to each unique index in 'indices'.
308
+ The computation is performed for each batch separately, and the function is optimized to avoid Python loops
309
+ by using advanced PyTorch operations.
310
+
311
+ Parameters:
312
+ x (torch.Tensor): A tensor of shape (batch_size, n, d) containing the values to be averaged.
313
+ 'x' should be a floating-point tensor.
314
+ indices (torch.Tensor): An integer tensor of shape (batch_size, n) containing the indices.
315
+ The values in 'indices' should be in the range [0, z-1].
316
+ z (int): The number of unique indices. This determines the second dimension of the output tensor.
317
+ drop_mask (torch.Tensor): A boolean tensor of shape (batch_size, n) containing a mask for the indices to drop.
318
+ If None, all indices are used.
319
+
320
+ Returns:
321
+ torch.Tensor: A tensor of shape (batch_size, z, d) containing the mean values for each index in each batch.
322
+ If an index does not appear in a batch, its corresponding mean values are zeros.
323
+
324
+ Example:
325
+ >>> batch_size, n, d, z = 3, 4, 5, 6
326
+ >>> indices = torch.randint(0, z, (batch_size, n))
327
+ >>> x = torch.randn(batch_size, n, d)
328
+ >>> y = compute_means(x, indices, z)
329
+ >>> print(y.shape)
330
+ torch.Size([3, 6, 5])
331
+ """
332
+
333
+ batch_size, n, d = x.shape
334
+ device = x.device
335
+
336
+ drop = drop_mask is not None
337
+
338
+ # Initialize tensors to hold sums and counts
339
+ sums = torch.zeros(batch_size, z + int(drop), d, device=device)
340
+ counts = torch.zeros(batch_size, z + int(drop), device=device)
341
+
342
+ if drop_mask is not None:
343
+ # Set the values of the indices to drop to z
344
+ indices = indices.masked_fill(~drop_mask, z)
345
+
346
+ indices_expanded = indices.unsqueeze(-1).expand_as(x)
347
+ sums.scatter_add_(1, indices_expanded, x)
348
+ counts.scatter_add_(1, indices, torch.ones_like(indices, dtype=x.dtype))
349
+
350
+ if drop:
351
+ # Remove the z values from the sums and counts
352
+ sums = sums[:, :-1]
353
+ counts = counts[:, :-1]
354
+
355
+ # Compute the mean and handle division by zero
356
+ mean = sums / counts.unsqueeze(-1).clamp(min=1)
357
+
358
+ return mean
359
+
360
+
361
+ def masked_mean(x, mask):
362
+ """
363
+ Computes the mean of tensor x along the x_size dimension,
364
+ while masking out elements where the corresponding value in the mask is False.
365
+
366
+ Args:
367
+ x (torch.Tensor): A tensor of shape (batch, x_size, z, d).
368
+ mask (torch.Tensor): A boolean mask of shape (batch, x_size).
369
+
370
+ Returns:
371
+ torch.Tensor: The result tensor of shape (batch, z, d) after applying the mask and computing the mean.
372
+ """
373
+ if not mask.dtype == torch.bool:
374
+ raise TypeError("Mask must be a boolean tensor.")
375
+
376
+ # Ensure the mask is broadcastable to the shape of x
377
+ mask = mask.unsqueeze(-1).unsqueeze(-1)
378
+ masked_x = x * mask
379
+
380
+ # Compute the sum and the count of valid (unmasked) elements along the x_size dimension
381
+ sum_x = masked_x.sum(dim=1)
382
+ count_x = mask.sum(dim=1)
383
+
384
+ # Avoid division by zero
385
+ count_x[count_x == 0] = 1
386
+
387
+ # Compute the mean
388
+ mean_x = sum_x / count_x
389
+
390
+ return mean_x
@@ -1,4 +1,5 @@
1
1
  from reflectorch.models.networks.mlp_networks import (
2
+ NetworkWithPriors,
2
3
  NetworkWithPriorsConvEmb,
3
4
  NetworkWithPriorsFnoEmb,
4
5
  )
@@ -7,6 +8,7 @@ from reflectorch.models.networks.residual_net import ResidualMLP
7
8
 
8
9
  __all__ = [
9
10
  "ResidualMLP",
11
+ "NetworkWithPriors",
10
12
  "NetworkWithPriorsConvEmb",
11
13
  "NetworkWithPriorsFnoEmb",
12
14
  ]