neojax-operators 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neojax_operators-0.0.1/.gitignore +2 -0
- neojax_operators-0.0.1/LICENSE +21 -0
- neojax_operators-0.0.1/PKG-INFO +37 -0
- neojax_operators-0.0.1/README.md +17 -0
- neojax_operators-0.0.1/neojax/__init__.py +1 -0
- neojax_operators-0.0.1/neojax/layers/__init__.py +0 -0
- neojax_operators-0.0.1/neojax/layers/attention_kernel_integral.py +300 -0
- neojax_operators-0.0.1/neojax/layers/base_spectral_conv.py +0 -0
- neojax_operators-0.0.1/neojax/layers/channel_mlp.py +0 -0
- neojax_operators-0.0.1/neojax/layers/coda_layer.py +0 -0
- neojax_operators-0.0.1/neojax/layers/complex.py +0 -0
- neojax_operators-0.0.1/neojax/layers/differential_conv.py +0 -0
- neojax_operators-0.0.1/neojax/layers/discrete_continuous_convolution.py +0 -0
- neojax_operators-0.0.1/neojax/layers/einsum_utils.py +0 -0
- neojax_operators-0.0.1/neojax/layers/embeddings.py +0 -0
- neojax_operators-0.0.1/neojax/layers/fno_block.py +0 -0
- neojax_operators-0.0.1/neojax/layers/fourier_continuation.py +0 -0
- neojax_operators-0.0.1/neojax/layers/gno_block.py +0 -0
- neojax_operators-0.0.1/neojax/layers/gno_weighting_functions.py +0 -0
- neojax_operators-0.0.1/neojax/layers/instance_norm1d.py +17 -0
- neojax_operators-0.0.1/neojax/layers/integral_transform.py +0 -0
- neojax_operators-0.0.1/neojax/layers/legacy_spectral_convolution.py +0 -0
- neojax_operators-0.0.1/neojax/layers/local_no_block.py +0 -0
- neojax_operators-0.0.1/neojax/layers/neighbor_search.py +0 -0
- neojax_operators-0.0.1/neojax/layers/normalization_layers.py +0 -0
- neojax_operators-0.0.1/neojax/layers/padding.py +0 -0
- neojax_operators-0.0.1/neojax/layers/resample.py +0 -0
- neojax_operators-0.0.1/neojax/layers/rno_block.py +0 -0
- neojax_operators-0.0.1/neojax/layers/segment_csr.py +0 -0
- neojax_operators-0.0.1/neojax/layers/skip_connections.py +0 -0
- neojax_operators-0.0.1/neojax/layers/spectral_convolution.py +0 -0
- neojax_operators-0.0.1/neojax/layers/spectral_projection.py +0 -0
- neojax_operators-0.0.1/neojax/layers/spherical_convolution.py +0 -0
- neojax_operators-0.0.1/neojax/losses/__init__.py +0 -0
- neojax_operators-0.0.1/neojax/models/__init__.py +0 -0
- neojax_operators-0.0.1/neojax/training/__init__.py +0 -0
- neojax_operators-0.0.1/neojax/utils.py +234 -0
- neojax_operators-0.0.1/pyproject.toml +39 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Paul Gekeler
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: neojax-operators
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Neural Operators in JAX
|
|
5
|
+
Project-URL: Homepage, https://github.com/paulgekeler/neojax
|
|
6
|
+
Author-email: Paul Gekeler <pgekeler@gmail.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Keywords: deep-learning,fno,jax,neural-operators,scientific-computing
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.9
|
|
16
|
+
Requires-Dist: equinox
|
|
17
|
+
Requires-Dist: jax
|
|
18
|
+
Requires-Dist: jaxlib
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
|
|
21
|
+
This is neojax-operators (**Ne**ural **O**perators in **jax**), a port of the popular [neural operators](https://github.com/neuraloperator/neuraloperator) pytorch library to jax. It is built on top of [jax](https://github.com/jax-ml/jax) and [equinox](https://github.com/patrick-kidger/equinox) and provides the same API as the original pytorch library.
|
|
22
|
+
|
|
23
|
+
#### Installation
|
|
24
|
+
Install the python package via pypi
|
|
25
|
+
```bash
|
|
26
|
+
pip3 install neojax-operators
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
#### Quickstart
|
|
30
|
+
neojax-operators exposes the same API as neuraloperators and can therefore be used as a drop-in replacement:
|
|
31
|
+
```python
|
|
32
|
+
from neojax.models import FNO
|
|
33
|
+
operator = FNO(n_modes=(64, 64),
|
|
34
|
+
hidden_channels=64,
|
|
35
|
+
in_channels=2,
|
|
36
|
+
out_channels=1)
|
|
37
|
+
```
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
This is neojax-operators (**Ne**ural **O**perators in **jax**), a port of the popular [neural operators](https://github.com/neuraloperator/neuraloperator) pytorch library to jax. It is built on top of [jax](https://github.com/jax-ml/jax) and [equinox](https://github.com/patrick-kidger/equinox) and provides the same API as the original pytorch library.
|
|
2
|
+
|
|
3
|
+
#### Installation
|
|
4
|
+
Install the python package via pypi
|
|
5
|
+
```bash
|
|
6
|
+
pip3 install neojax-operators
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
#### Quickstart
|
|
10
|
+
neojax-operators exposes the same API as neuraloperators and can therefore be used as a drop-in replacement:
|
|
11
|
+
```python
|
|
12
|
+
from neojax.models import FNO
|
|
13
|
+
operator = FNO(n_modes=(64, 64),
|
|
14
|
+
hidden_channels=64,
|
|
15
|
+
in_channels=2,
|
|
16
|
+
out_channels=1)
|
|
17
|
+
```
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.1"
|
|
File without changes
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
2
|
+
import math
|
|
3
|
+
import jax
|
|
4
|
+
from jax import Array
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import equinox as eqx
|
|
7
|
+
from .instance_norm1d import InstanceNorm1d
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AttentionKernelIntegral(eqx.Module):
|
|
11
|
+
"""Kernel integral transform with attention
|
|
12
|
+
|
|
13
|
+
Computes \\int_{Omega} k(x, y) * f(y) dy,
|
|
14
|
+
where:
|
|
15
|
+
K(x, y) = \\sum_{c=1}^d q_c(x) * k_c(y), q(x) = [q_1(x); ...; q_d(x)], k(y) = [k_1(y); ...; k_d(y)]
|
|
16
|
+
f(y) = v(y)
|
|
17
|
+
More specifically, this module supports using just one input function (self-attention) or
|
|
18
|
+
two input functions (cross-attention) to compute the kernel integral transform.
|
|
19
|
+
|
|
20
|
+
1. Self-attention:
|
|
21
|
+
input function u(.), sampling grid D_x = {x_i}_{i=1}^N
|
|
22
|
+
query function: q(x_i) = u(x_i) W_q
|
|
23
|
+
key function: k(x_i) = u(x_i) W_k
|
|
24
|
+
value function: v(x_i) = u(x_i) W_v
|
|
25
|
+
|
|
26
|
+
2. Cross-attention:
|
|
27
|
+
first input function u_qry(.), sampling grid D_x = {x_i}_{i=1}^N
|
|
28
|
+
second input function u_src(.), sampling grid D_y = {y_j}_{j=1}^M, D_y can be different from D_x
|
|
29
|
+
query function: q(x_i) = u_qry(x_i) W_q
|
|
30
|
+
key function: k(y_j) = u_src(y_j) W_k
|
|
31
|
+
value function: v(y_j) = u_src(y_j) W_v
|
|
32
|
+
|
|
33
|
+
Self-attention can be considered as a special case of cross-attention, where u = u_qry = u_src and D_x = D_y.
|
|
34
|
+
|
|
35
|
+
The kernel integral transform will be numerically computed as:
|
|
36
|
+
\\int_{Omega} k(x, y) * f(y) dy \\appox \\sum_{j=1}^M * k(x, y_j) * f(y_j) * w(y_j)
|
|
37
|
+
For uniform quadrature, the weights w(y_j) = 1/M.
|
|
38
|
+
For non-uniform quadrature, the weights w(y_j) is specified as an input to the forward function.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
in_channels : int
|
|
43
|
+
Number of input channels
|
|
44
|
+
out_channels : int
|
|
45
|
+
Number of output channels
|
|
46
|
+
n_heads : int
|
|
47
|
+
Number of attention heads in multi-head attention
|
|
48
|
+
head_n_channels : int
|
|
49
|
+
Dimension of each attention head, determines how many function bases to use for the kernel
|
|
50
|
+
k(x, y) = \\sum_{c=1}^d \\q_c(x) * \\k_c(y), head_n_channels controls the d
|
|
51
|
+
pos_dim : int
|
|
52
|
+
Dimension of the domain, determines the dimension of coordinates
|
|
53
|
+
project_query : bool, optional
|
|
54
|
+
Whether to project the query function with pointwise linear layer
|
|
55
|
+
(this is sometimes not needed when using cross-attention), by default True
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
key: Array,
|
|
61
|
+
in_channels: int,
|
|
62
|
+
out_channels: int,
|
|
63
|
+
n_heads: int,
|
|
64
|
+
head_n_channels: int,
|
|
65
|
+
project_query: bool = True,
|
|
66
|
+
) -> None:
|
|
67
|
+
super().__init__()
|
|
68
|
+
l1_key, l2_key, l3_key, l4_key = jax.random.split(key, 4)
|
|
69
|
+
self.n_heads = n_heads
|
|
70
|
+
self.head_n_channels = head_n_channels
|
|
71
|
+
self.in_channels = in_channels
|
|
72
|
+
self.out_channels = out_channels
|
|
73
|
+
|
|
74
|
+
self.project_query = project_query
|
|
75
|
+
if project_query:
|
|
76
|
+
self.to_q = eqx.nn.Linear(
|
|
77
|
+
in_channels, head_n_channels * n_heads, use_bias=False, key=l1_key
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
self.to_q = eqx.nn.Identity()
|
|
81
|
+
|
|
82
|
+
self.to_k = eqx.nn.Linear(
|
|
83
|
+
in_channels, head_n_channels * n_heads, use_bias=False, key=l2_key
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.k_norm = InstanceNorm1d(head_n_channels)
|
|
87
|
+
self.v_norm = InstanceNorm1d(head_n_channels)
|
|
88
|
+
|
|
89
|
+
self.to_v = eqx.nn.Linear(
|
|
90
|
+
in_channels, head_n_channels * n_heads, use_bias=False, key=l3_key
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.to_out = (
|
|
94
|
+
eqx.nn.Linear(head_n_channels * n_heads, out_channels, key=l4_key)
|
|
95
|
+
if head_n_channels * n_heads != out_channels
|
|
96
|
+
else eqx.nn.Identity()
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.init_gain = 1 / math.sqrt(head_n_channels)
|
|
100
|
+
self.diagonal_weight = self.init_gain
|
|
101
|
+
self.initialize_qkv_weights()
|
|
102
|
+
|
|
103
|
+
def init_weight(
|
|
104
|
+
self, weight: Array, init_fn: Callable[[Array, Any], Array]
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Initialization for the projection matrix
|
|
108
|
+
basically initialize the weights for each heads with predefined initialization function and gain,
|
|
109
|
+
to add the diagonal bias, it requires input channels = head_n_channels
|
|
110
|
+
W = init_fn(W) + I * diagonal_weight
|
|
111
|
+
|
|
112
|
+
init_fn is xavier_uniform_ by default
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
for param in weight.parameters():
|
|
116
|
+
if param.ndim > 1:
|
|
117
|
+
for h in range(self.n_heads):
|
|
118
|
+
init_fn(
|
|
119
|
+
param[
|
|
120
|
+
h * self.head_n_channels : (h + 1) * self.head_n_channels, :
|
|
121
|
+
],
|
|
122
|
+
gain=self.init_gain,
|
|
123
|
+
)
|
|
124
|
+
if self.head_n_channels == self.in_channels:
|
|
125
|
+
diagonal_bias = self.diagonal_weight * jnp.diag(
|
|
126
|
+
jnp.ones(param.size(-1), dtype=jnp.float32)
|
|
127
|
+
)
|
|
128
|
+
param.data[
|
|
129
|
+
h * self.head_n_channels : (h + 1) * self.head_n_channels, :
|
|
130
|
+
] += diagonal_bias
|
|
131
|
+
|
|
132
|
+
def initialize_qkv_weights(self) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Initialize the weights for q, k, v projection matrix with a small gain and add a diagonal bias,
|
|
135
|
+
this technique has been found useful for scale-sensitive problem that has not been normalized
|
|
136
|
+
see Table 8 in https://arxiv.org/pdf/2105.14995.pdf
|
|
137
|
+
"""
|
|
138
|
+
init_fn = jax.nn.initializers.xavier_uniform()
|
|
139
|
+
|
|
140
|
+
if self.project_query:
|
|
141
|
+
self.init_weight(self.to_q, init_fn)
|
|
142
|
+
self.init_weight(self.to_k, init_fn)
|
|
143
|
+
self.init_weight(self.to_v, init_fn)
|
|
144
|
+
|
|
145
|
+
def normalize_wrt_domain(self, u: Array, norm_fn):
|
|
146
|
+
"""
|
|
147
|
+
Normalize the input function with respect to the domain,
|
|
148
|
+
reshape the tensor to [batch_size*n_heads, num_grid_points, head_n_channels]
|
|
149
|
+
The second dimension is equal to the number of grid points that discretize the domain
|
|
150
|
+
"""
|
|
151
|
+
# u: the input or transformed function
|
|
152
|
+
batch_size = u.shape[0]
|
|
153
|
+
u = u.view(batch_size * self.n_heads, -1, self.head_n_channels)
|
|
154
|
+
u = norm_fn(
|
|
155
|
+
u
|
|
156
|
+
) # layer norm with channel dimension or instance norm with spatial dimension
|
|
157
|
+
return u.view(batch_size, self.n_heads, -1, self.head_n_channels)
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
u_src,
|
|
162
|
+
pos_src,
|
|
163
|
+
positional_embedding_module: eqx.nn.RotaryPositionalEmbedding = None, # positional encoding module for encoding q/k
|
|
164
|
+
u_qry=None,
|
|
165
|
+
pos_qry=None,
|
|
166
|
+
weights=None,
|
|
167
|
+
associative=True, # can be much faster if num_grid_points is larger than the channel number c
|
|
168
|
+
return_kernel=False,
|
|
169
|
+
):
|
|
170
|
+
"""
|
|
171
|
+
Computes kernel integral transform with attention
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
u_src: input function (used to compute key and value in attention),
|
|
176
|
+
tensor of shape [batch_size, num_grid_points_src, channels]
|
|
177
|
+
pos_src: coordinate of the second source of function's sampling points y,
|
|
178
|
+
tensor of shape [batch_size, num_grid_points_src, pos_dim]
|
|
179
|
+
positional_embedding_module: positional embedding module for encoding query/key (q/k),
|
|
180
|
+
a torch.nn.Module
|
|
181
|
+
u_qry: query function,
|
|
182
|
+
tensor of shape [batch_size, num_grid_points_query, channels], if not provided, u_qry = u_src
|
|
183
|
+
pos_qry: coordinate of query points x,
|
|
184
|
+
tensor of shape [batch_size, num_grid_points_query, pos_dim], if not provided, pos_qry = pos_src
|
|
185
|
+
weights : quadrature weight w(y_j) for the kernel integral: u(x_i) = sum_{j} k(x_i, y_j) f(y_i) w(y_j),
|
|
186
|
+
tensor of shape [batch_size, num_grid_points_src], if not provided assume to be 1/num_grid_points_src
|
|
187
|
+
associative: if True, use associativity of matrix multiplication, first multiply K^T V, then multiply Q,
|
|
188
|
+
much faster when num_grid_points is larger than the channel number (which is usually the case)
|
|
189
|
+
return_kernel: if True, return the kernel matrix (for analyzing the kernel)
|
|
190
|
+
|
|
191
|
+
Output
|
|
192
|
+
----------
|
|
193
|
+
u: Output function given on the query points x.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
if u_qry is None:
|
|
197
|
+
u_qry = u_src # go back to self attention
|
|
198
|
+
if pos_qry is not None:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
"Query coordinates are provided but query function is not provided"
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
if pos_qry is None:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"Query coordinates are required if query function is provided"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if return_kernel and associative:
|
|
209
|
+
raise ValueError("Cannot get kernel matrix when associative is set to True")
|
|
210
|
+
|
|
211
|
+
batch_size, num_grid_points = u_src.shape[
|
|
212
|
+
:2
|
|
213
|
+
] # batch size and number of grid points
|
|
214
|
+
pos_dim = pos_src.shape[-1] # position dimension
|
|
215
|
+
|
|
216
|
+
q = self.to_q(u_qry)
|
|
217
|
+
k = self.to_k(u_src)
|
|
218
|
+
v = self.to_v(u_src)
|
|
219
|
+
q = (
|
|
220
|
+
q.view(batch_size, -1, self.n_heads, self.head_n_channels)
|
|
221
|
+
.permute(0, 2, 1, 3)
|
|
222
|
+
.contiguous()
|
|
223
|
+
)
|
|
224
|
+
k = (
|
|
225
|
+
k.view(batch_size, -1, self.n_heads, self.head_n_channels)
|
|
226
|
+
.permute(0, 2, 1, 3)
|
|
227
|
+
.contiguous()
|
|
228
|
+
)
|
|
229
|
+
v = (
|
|
230
|
+
v.view(batch_size, -1, self.n_heads, self.head_n_channels)
|
|
231
|
+
.permute(0, 2, 1, 3)
|
|
232
|
+
.contiguous()
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
k = self.normalize_wrt_domain(k, self.k_norm)
|
|
236
|
+
v = self.normalize_wrt_domain(v, self.v_norm)
|
|
237
|
+
|
|
238
|
+
if positional_embedding_module is not None:
|
|
239
|
+
if pos_dim == 2:
|
|
240
|
+
k_freqs_1 = positional_embedding_module.forward(pos_src[..., 0])
|
|
241
|
+
k_freqs_2 = positional_embedding_module.forward(pos_src[..., 1])
|
|
242
|
+
k_freqs_1 = k_freqs_1.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
|
|
243
|
+
k_freqs_2 = k_freqs_2.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
|
|
244
|
+
|
|
245
|
+
if pos_qry is None:
|
|
246
|
+
q_freqs_1 = k_freqs_1
|
|
247
|
+
q_freqs_2 = k_freqs_2
|
|
248
|
+
else:
|
|
249
|
+
q_freqs_1 = positional_embedding_module.forward(pos_qry[..., 0])
|
|
250
|
+
q_freqs_2 = positional_embedding_module.forward(pos_qry[..., 1])
|
|
251
|
+
q_freqs_1 = q_freqs_1.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
|
|
252
|
+
q_freqs_2 = q_freqs_2.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
|
|
253
|
+
|
|
254
|
+
q = positional_embedding_module.apply_2d_rotary_pos_emb(
|
|
255
|
+
q, q_freqs_1, q_freqs_2
|
|
256
|
+
)
|
|
257
|
+
k = positional_embedding_module.apply_2d_rotary_pos_emb(
|
|
258
|
+
k, k_freqs_1, k_freqs_2
|
|
259
|
+
)
|
|
260
|
+
elif pos_dim == 1:
|
|
261
|
+
k_freqs = positional_embedding_module.forward(pos_src[..., 0])
|
|
262
|
+
k_freqs = k_freqs.unsqueeze(1).repeat([batch_size, self.n_heads, 1, 1])
|
|
263
|
+
|
|
264
|
+
if pos_qry is None:
|
|
265
|
+
q_freqs = k_freqs
|
|
266
|
+
else:
|
|
267
|
+
q_freqs = positional_embedding_module.forward(pos_qry[..., 0])
|
|
268
|
+
q_freqs = q_freqs.unsqueeze(1).repeat(
|
|
269
|
+
[batch_size, self.n_heads, 1, 1]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
q = positional_embedding_module.apply_1d_rotary_pos_emb(q, q_freqs)
|
|
273
|
+
k = positional_embedding_module.apply_1d_rotary_pos_emb(k, k_freqs)
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
"Currently doesnt support relative embedding >= 3 dimensions"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if weights is not None:
|
|
280
|
+
weights = weights.view(batch_size, 1, num_grid_points, 1)
|
|
281
|
+
else:
|
|
282
|
+
weights = 1.0 / num_grid_points
|
|
283
|
+
|
|
284
|
+
if associative:
|
|
285
|
+
dots = jnp.matmul(k.transpose(-1, -2), v)
|
|
286
|
+
u = jnp.matmul(q, dots) * weights
|
|
287
|
+
else:
|
|
288
|
+
# this is more efficient when num_grid_points<<channels
|
|
289
|
+
kxy = jnp.matmul(q, k.transpose(-1, -2))
|
|
290
|
+
u = jnp.matmul(kxy, v) * weights
|
|
291
|
+
|
|
292
|
+
u = (
|
|
293
|
+
u.permute(0, 2, 1, 3)
|
|
294
|
+
.contiguous()
|
|
295
|
+
.view(batch_size, num_grid_points, self.n_heads * self.head_n_channels)
|
|
296
|
+
)
|
|
297
|
+
u = self.to_out(u)
|
|
298
|
+
if return_kernel:
|
|
299
|
+
return u, kxy
|
|
300
|
+
return u
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
from jax import Array
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InstanceNorm1d(eqx.Module):
|
|
7
|
+
ln: eqx.nn.LayerNorm
|
|
8
|
+
|
|
9
|
+
def __init__(self, shape_per_channel: int) -> None:
|
|
10
|
+
# shape_per_channel is the length of the 1d signal
|
|
11
|
+
# we disable learnable affine parameters to match torch neuralop config
|
|
12
|
+
self.ln = eqx.nn.LayerNorm(
|
|
13
|
+
shape=(shape_per_channel,), use_weight=False, use_bias=False
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
def __call__(self, x: Array) -> Array:
|
|
17
|
+
return jax.vmap(self.ln)(x)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
from math import prod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax import Array
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
|
|
9
|
+
# Only import wandb and use if installed
|
|
10
|
+
wandb_available = False
|
|
11
|
+
try:
|
|
12
|
+
import wandb
|
|
13
|
+
|
|
14
|
+
wandb_available = True
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
wandb_available = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def count_model_params(model: eqx.Module) -> int:
|
|
20
|
+
"""Returns the total number of parameters of an equinox model
|
|
21
|
+
|
|
22
|
+
Notes
|
|
23
|
+
-----
|
|
24
|
+
One complex number is counted as two parameters (we count real and imaginary parts)'
|
|
25
|
+
"""
|
|
26
|
+
trainable_params = eqx.filter(model, eqx.is_inexact_array)
|
|
27
|
+
total_params = jax.vmap(count_array_params)(trainable_params)
|
|
28
|
+
return int(jnp.sum(total_params))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def count_array_params(arr: Array, dims=None) -> int:
|
|
32
|
+
"""Returns the number of parameters (elements) in a single array, optionally, along certain dimensions only
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
array : Array
|
|
37
|
+
dims : int list or None, default is None
|
|
38
|
+
if not None, the dimensions to consider when counting the number of parameters (elements)
|
|
39
|
+
|
|
40
|
+
Notes
|
|
41
|
+
-----
|
|
42
|
+
One complex number is counted as two parameters (we count real and imaginary parts)'
|
|
43
|
+
"""
|
|
44
|
+
if dims is None:
|
|
45
|
+
dims = list(arr.shape)
|
|
46
|
+
else:
|
|
47
|
+
dims = [arr.shape[d] for d in dims]
|
|
48
|
+
n_params = prod(dims)
|
|
49
|
+
if jnp.iscomplex(arr):
|
|
50
|
+
return 2 * n_params
|
|
51
|
+
return n_params
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def wandb_login(
|
|
55
|
+
api_key_file: str = "../config/wandb_api_key.txt", key: str = None
|
|
56
|
+
) -> None:
|
|
57
|
+
if key is None:
|
|
58
|
+
key = get_wandb_api_key(api_key_file)
|
|
59
|
+
|
|
60
|
+
wandb.login(key=key)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def set_wandb_api_key(api_key_file: str = "../config/wandb_api_key.txt") -> None:
|
|
64
|
+
import os
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
os.environ["WANDB_API_KEY"]
|
|
68
|
+
except KeyError:
|
|
69
|
+
with open(api_key_file, "r") as f:
|
|
70
|
+
key = f.read()
|
|
71
|
+
os.environ["WANDB_API_KEY"] = key.strip()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_wandb_api_key(api_key_file: str = "../config/wandb_api_key.txt") -> str:
|
|
75
|
+
import os
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
return os.environ["WANDB_API_KEY"]
|
|
79
|
+
except KeyError:
|
|
80
|
+
with open(api_key_file, "r") as f:
|
|
81
|
+
key = f.read()
|
|
82
|
+
return key.strip()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Define the function to compute the spectrum
|
|
86
|
+
def spectrum_2d(signal: Array, n_observations: int, normalize: bool = True) -> Array:
|
|
87
|
+
"""This function computes the spectrum of a 2D signal using the Fast Fourier Transform (FFT).
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
signal : an array of shape (T * n_observations * n_observations)
|
|
92
|
+
A 2D discretized signal represented as a 1D array with shape
|
|
93
|
+
(T * n_observations * n_observations), where T is the number of time
|
|
94
|
+
steps and n_observations is the spatial size of the signal.
|
|
95
|
+
|
|
96
|
+
T can be any number of channels that we reshape into and
|
|
97
|
+
n_observations * n_observations is the spatial resolution.
|
|
98
|
+
n_observations: an integer
|
|
99
|
+
Number of discretized points. Basically the resolution of the signal.
|
|
100
|
+
normalize: bool
|
|
101
|
+
whether to apply normalization to the output of the 2D FFT.
|
|
102
|
+
If True, normalizes the outputs by ``1/n_observations``
|
|
103
|
+
(actually ``1/sqrt(n_observations * n_observations)``).
|
|
104
|
+
Returns
|
|
105
|
+
--------
|
|
106
|
+
spectrum: a array
|
|
107
|
+
A 1D array of shape (s,) representing the computed spectrum.
|
|
108
|
+
The spectrum is computed using a square approximation to radial
|
|
109
|
+
binning, meaning that the wavenumber 'bin' into which a particular
|
|
110
|
+
coefficient is the coefficient's location along the diagonal, indexed
|
|
111
|
+
from the top-left corner of the 2d FFT output.
|
|
112
|
+
"""
|
|
113
|
+
T = signal.shape[0]
|
|
114
|
+
signal = jnp.reshape(signal, (T, n_observations, n_observations))
|
|
115
|
+
|
|
116
|
+
if normalize:
|
|
117
|
+
signal = jnp.fft.fft2(signal, norm="ortho")
|
|
118
|
+
else:
|
|
119
|
+
signal = jnp.fft.rfft2(
|
|
120
|
+
signal, s=(n_observations, n_observations), norm="backward"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# 2d wavenumbers following numpy fft convention
|
|
124
|
+
k_max = n_observations // 2
|
|
125
|
+
wavenumers = jnp.repeat(
|
|
126
|
+
jnp.concat(
|
|
127
|
+
(
|
|
128
|
+
jnp.arange(start=0, stop=k_max, step=1),
|
|
129
|
+
jnp.arange(start=-k_max, stop=0, step=1),
|
|
130
|
+
),
|
|
131
|
+
axis=0,
|
|
132
|
+
),
|
|
133
|
+
n_observations,
|
|
134
|
+
1,
|
|
135
|
+
)
|
|
136
|
+
k_x = jnp.transpose(wavenumers, (0, 1))
|
|
137
|
+
k_y = wavenumers
|
|
138
|
+
|
|
139
|
+
# Sum wavenumbers
|
|
140
|
+
sum_k = jnp.sqrt(k_x**2 + k_y**2)
|
|
141
|
+
sum_k = sum_k
|
|
142
|
+
|
|
143
|
+
# Remove symmetric components from wavenumbers
|
|
144
|
+
index = -1.0 * jnp.ones((n_observations, n_observations))
|
|
145
|
+
k_max1 = k_max + 1
|
|
146
|
+
index[0:k_max1, 0:k_max1] = sum_k[0:k_max1, 0:k_max1]
|
|
147
|
+
|
|
148
|
+
spectrum = jnp.zeros((T, n_observations))
|
|
149
|
+
for j in range(1, n_observations + 1):
|
|
150
|
+
ind = jnp.where(index == j)
|
|
151
|
+
spectrum[:, j - 1] = jnp.sum(jnp.abs(signal[:, ind[0], ind[1]]) ** 2, axis=1)
|
|
152
|
+
|
|
153
|
+
spectrum = jnp.mean(spectrum, axis=0)
|
|
154
|
+
return spectrum
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
Number = Union[float, int]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def validate_scaling_factor(
|
|
161
|
+
scaling_factor: Union[None, Number, list[Number], list[list[Number]]],
|
|
162
|
+
n_dim: int,
|
|
163
|
+
n_layers: Optional[int] = None,
|
|
164
|
+
) -> Union[None, list[float], list[list[float]]]:
|
|
165
|
+
"""
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
scaling_factor : None OR float OR list[float] Or list[list[float]]
|
|
169
|
+
n_dim : int
|
|
170
|
+
n_layers : int or None; defaults to None
|
|
171
|
+
If None, return a single list (rather than a list of lists)
|
|
172
|
+
with `factor` repeated `dim` times.
|
|
173
|
+
"""
|
|
174
|
+
if scaling_factor is None:
|
|
175
|
+
return None
|
|
176
|
+
if isinstance(scaling_factor, (float, int)):
|
|
177
|
+
if n_layers is None:
|
|
178
|
+
return [float(scaling_factor)] * n_dim
|
|
179
|
+
|
|
180
|
+
return [[float(scaling_factor)] * n_dim] * n_layers
|
|
181
|
+
|
|
182
|
+
if (
|
|
183
|
+
isinstance(scaling_factor, list)
|
|
184
|
+
and len(scaling_factor) > 0
|
|
185
|
+
and all([isinstance(s, (float, int)) for s in scaling_factor])
|
|
186
|
+
):
|
|
187
|
+
if n_layers is None and len(scaling_factor) == n_dim:
|
|
188
|
+
# this is a dim-wise scaling
|
|
189
|
+
return [float(s) for s in scaling_factor]
|
|
190
|
+
return [[float(s)] * n_dim for s in scaling_factor]
|
|
191
|
+
|
|
192
|
+
if (
|
|
193
|
+
isinstance(scaling_factor, list)
|
|
194
|
+
and len(scaling_factor) > 0
|
|
195
|
+
and all([isinstance(s, (list)) for s in scaling_factor])
|
|
196
|
+
):
|
|
197
|
+
s_sub_pass = True
|
|
198
|
+
for s in scaling_factor:
|
|
199
|
+
if all([isinstance(s_sub, (int, float)) for s_sub in s]):
|
|
200
|
+
pass
|
|
201
|
+
else:
|
|
202
|
+
s_sub_pass = False
|
|
203
|
+
if s_sub_pass:
|
|
204
|
+
return scaling_factor
|
|
205
|
+
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def compute_rank(arr: Array) -> Array:
|
|
210
|
+
# Compute the matrix rank of a array
|
|
211
|
+
rank = jnp.linalg.matrix_rank(arr)
|
|
212
|
+
return rank
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def compute_stable_rank(arr: Array) -> Array:
|
|
216
|
+
# Compute the stable rank of a array
|
|
217
|
+
fro_norm = jnp.linalg.norm(arr, ord="fro") ** 2
|
|
218
|
+
l2_norm = jnp.linalg.norm(arr, ord=2) ** 2
|
|
219
|
+
rank = fro_norm / l2_norm
|
|
220
|
+
rank = rank
|
|
221
|
+
return rank
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def compute_explained_variance(frequency_max: int, s: Array) -> Array:
|
|
225
|
+
# Compute the explained variance based on frequency_max and singular
|
|
226
|
+
# values (s)
|
|
227
|
+
s_current = s.copy()
|
|
228
|
+
s_current[frequency_max:] = 0
|
|
229
|
+
return 1 - jnp.var(s - s_current) / jnp.var(s)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def get_project_root() -> Path:
|
|
233
|
+
root = Path(__file__).parent.parent
|
|
234
|
+
return root
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "neojax-operators"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Neural Operators in JAX"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Paul Gekeler", email = "pgekeler@gmail.com" },
|
|
14
|
+
]
|
|
15
|
+
keywords = [
|
|
16
|
+
"jax",
|
|
17
|
+
"neural-operators",
|
|
18
|
+
"fno",
|
|
19
|
+
"deep-learning",
|
|
20
|
+
"scientific-computing",
|
|
21
|
+
]
|
|
22
|
+
classifiers = [
|
|
23
|
+
"Development Status :: 3 - Alpha",
|
|
24
|
+
"Intended Audience :: Science/Research",
|
|
25
|
+
"License :: OSI Approved :: MIT License",
|
|
26
|
+
"Programming Language :: Python :: 3",
|
|
27
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
28
|
+
]
|
|
29
|
+
dependencies = [
|
|
30
|
+
"jax",
|
|
31
|
+
"jaxlib",
|
|
32
|
+
"equinox",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[tool.hatch.build.targets.wheel]
|
|
36
|
+
packages = ["neojax"]
|
|
37
|
+
|
|
38
|
+
[project.urls]
|
|
39
|
+
Homepage = "https://github.com/paulgekeler/neojax"
|