neojax-operators 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. neojax_operators-0.0.1/.gitignore +2 -0
  2. neojax_operators-0.0.1/LICENSE +21 -0
  3. neojax_operators-0.0.1/PKG-INFO +37 -0
  4. neojax_operators-0.0.1/README.md +17 -0
  5. neojax_operators-0.0.1/neojax/__init__.py +1 -0
  6. neojax_operators-0.0.1/neojax/layers/__init__.py +0 -0
  7. neojax_operators-0.0.1/neojax/layers/attention_kernel_integral.py +300 -0
  8. neojax_operators-0.0.1/neojax/layers/base_spectral_conv.py +0 -0
  9. neojax_operators-0.0.1/neojax/layers/channel_mlp.py +0 -0
  10. neojax_operators-0.0.1/neojax/layers/coda_layer.py +0 -0
  11. neojax_operators-0.0.1/neojax/layers/complex.py +0 -0
  12. neojax_operators-0.0.1/neojax/layers/differential_conv.py +0 -0
  13. neojax_operators-0.0.1/neojax/layers/discrete_continuous_convolution.py +0 -0
  14. neojax_operators-0.0.1/neojax/layers/einsum_utils.py +0 -0
  15. neojax_operators-0.0.1/neojax/layers/embeddings.py +0 -0
  16. neojax_operators-0.0.1/neojax/layers/fno_block.py +0 -0
  17. neojax_operators-0.0.1/neojax/layers/fourier_continuation.py +0 -0
  18. neojax_operators-0.0.1/neojax/layers/gno_block.py +0 -0
  19. neojax_operators-0.0.1/neojax/layers/gno_weighting_functions.py +0 -0
  20. neojax_operators-0.0.1/neojax/layers/instance_norm1d.py +17 -0
  21. neojax_operators-0.0.1/neojax/layers/integral_transform.py +0 -0
  22. neojax_operators-0.0.1/neojax/layers/legacy_spectral_convolution.py +0 -0
  23. neojax_operators-0.0.1/neojax/layers/local_no_block.py +0 -0
  24. neojax_operators-0.0.1/neojax/layers/neighbor_search.py +0 -0
  25. neojax_operators-0.0.1/neojax/layers/normalization_layers.py +0 -0
  26. neojax_operators-0.0.1/neojax/layers/padding.py +0 -0
  27. neojax_operators-0.0.1/neojax/layers/resample.py +0 -0
  28. neojax_operators-0.0.1/neojax/layers/rno_block.py +0 -0
  29. neojax_operators-0.0.1/neojax/layers/segment_csr.py +0 -0
  30. neojax_operators-0.0.1/neojax/layers/skip_connections.py +0 -0
  31. neojax_operators-0.0.1/neojax/layers/spectral_convolution.py +0 -0
  32. neojax_operators-0.0.1/neojax/layers/spectral_projection.py +0 -0
  33. neojax_operators-0.0.1/neojax/layers/spherical_convolution.py +0 -0
  34. neojax_operators-0.0.1/neojax/losses/__init__.py +0 -0
  35. neojax_operators-0.0.1/neojax/models/__init__.py +0 -0
  36. neojax_operators-0.0.1/neojax/training/__init__.py +0 -0
  37. neojax_operators-0.0.1/neojax/utils.py +234 -0
  38. neojax_operators-0.0.1/pyproject.toml +39 -0
@@ -0,0 +1,2 @@
1
+ .vscode/*
2
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Paul Gekeler
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,37 @@
1
+ Metadata-Version: 2.4
2
+ Name: neojax-operators
3
+ Version: 0.0.1
4
+ Summary: Neural Operators in JAX
5
+ Project-URL: Homepage, https://github.com/paulgekeler/neojax
6
+ Author-email: Paul Gekeler <pgekeler@gmail.com>
7
+ License: MIT
8
+ License-File: LICENSE
9
+ Keywords: deep-learning,fno,jax,neural-operators,scientific-computing
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.9
16
+ Requires-Dist: equinox
17
+ Requires-Dist: jax
18
+ Requires-Dist: jaxlib
19
+ Description-Content-Type: text/markdown
20
+
21
+ This is neojax-operators (**Ne**ural **O**perators in **jax**), a port of the popular [neural operators](https://github.com/neuraloperator/neuraloperator) pytorch library to jax. It is built on top of [jax](https://github.com/jax-ml/jax) and [equinox](https://github.com/patrick-kidger/equinox) and provides the same API as the original pytorch library.
22
+
23
+ #### Installation
24
+ Install the python package via pypi
25
+ ```bash
26
+ pip3 install neojax-operators
27
+ ```
28
+
29
+ #### Quickstart
30
+ neojax-operators exposes the same API as neuraloperators and can therefore be used as a drop-in replacement:
31
+ ```python
32
+ from neojax.models import FNO
33
+ operator = FNO(n_modes=(64, 64),
34
+ hidden_channels=64,
35
+ in_channels=2,
36
+ out_channels=1)
37
+ ```
@@ -0,0 +1,17 @@
1
+ This is neojax-operators (**Ne**ural **O**perators in **jax**), a port of the popular [neural operators](https://github.com/neuraloperator/neuraloperator) pytorch library to jax. It is built on top of [jax](https://github.com/jax-ml/jax) and [equinox](https://github.com/patrick-kidger/equinox) and provides the same API as the original pytorch library.
2
+
3
+ #### Installation
4
+ Install the python package via pypi
5
+ ```bash
6
+ pip3 install neojax-operators
7
+ ```
8
+
9
+ #### Quickstart
10
+ neojax-operators exposes the same API as neuraloperators and can therefore be used as a drop-in replacement:
11
+ ```python
12
+ from neojax.models import FNO
13
+ operator = FNO(n_modes=(64, 64),
14
+ hidden_channels=64,
15
+ in_channels=2,
16
+ out_channels=1)
17
+ ```
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1"
File without changes
@@ -0,0 +1,300 @@
1
+ from typing import Callable, Any
2
+ import math
3
+ import jax
4
+ from jax import Array
5
+ import jax.numpy as jnp
6
+ import equinox as eqx
7
+ from .instance_norm1d import InstanceNorm1d
8
+
9
+
10
+ class AttentionKernelIntegral(eqx.Module):
11
+ """Kernel integral transform with attention
12
+
13
+ Computes \\int_{Omega} k(x, y) * f(y) dy,
14
+ where:
15
+ K(x, y) = \\sum_{c=1}^d q_c(x) * k_c(y), q(x) = [q_1(x); ...; q_d(x)], k(y) = [k_1(y); ...; k_d(y)]
16
+ f(y) = v(y)
17
+ More specifically, this module supports using just one input function (self-attention) or
18
+ two input functions (cross-attention) to compute the kernel integral transform.
19
+
20
+ 1. Self-attention:
21
+ input function u(.), sampling grid D_x = {x_i}_{i=1}^N
22
+ query function: q(x_i) = u(x_i) W_q
23
+ key function: k(x_i) = u(x_i) W_k
24
+ value function: v(x_i) = u(x_i) W_v
25
+
26
+ 2. Cross-attention:
27
+ first input function u_qry(.), sampling grid D_x = {x_i}_{i=1}^N
28
+ second input function u_src(.), sampling grid D_y = {y_j}_{j=1}^M, D_y can be different from D_x
29
+ query function: q(x_i) = u_qry(x_i) W_q
30
+ key function: k(y_j) = u_src(y_j) W_k
31
+ value function: v(y_j) = u_src(y_j) W_v
32
+
33
+ Self-attention can be considered as a special case of cross-attention, where u = u_qry = u_src and D_x = D_y.
34
+
35
+ The kernel integral transform will be numerically computed as:
36
+ \\int_{Omega} k(x, y) * f(y) dy \\appox \\sum_{j=1}^M * k(x, y_j) * f(y_j) * w(y_j)
37
+ For uniform quadrature, the weights w(y_j) = 1/M.
38
+ For non-uniform quadrature, the weights w(y_j) is specified as an input to the forward function.
39
+
40
+ Parameters
41
+ ----------
42
+ in_channels : int
43
+ Number of input channels
44
+ out_channels : int
45
+ Number of output channels
46
+ n_heads : int
47
+ Number of attention heads in multi-head attention
48
+ head_n_channels : int
49
+ Dimension of each attention head, determines how many function bases to use for the kernel
50
+ k(x, y) = \\sum_{c=1}^d \\q_c(x) * \\k_c(y), head_n_channels controls the d
51
+ pos_dim : int
52
+ Dimension of the domain, determines the dimension of coordinates
53
+ project_query : bool, optional
54
+ Whether to project the query function with pointwise linear layer
55
+ (this is sometimes not needed when using cross-attention), by default True
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ key: Array,
61
+ in_channels: int,
62
+ out_channels: int,
63
+ n_heads: int,
64
+ head_n_channels: int,
65
+ project_query: bool = True,
66
+ ) -> None:
67
+ super().__init__()
68
+ l1_key, l2_key, l3_key, l4_key = jax.random.split(key, 4)
69
+ self.n_heads = n_heads
70
+ self.head_n_channels = head_n_channels
71
+ self.in_channels = in_channels
72
+ self.out_channels = out_channels
73
+
74
+ self.project_query = project_query
75
+ if project_query:
76
+ self.to_q = eqx.nn.Linear(
77
+ in_channels, head_n_channels * n_heads, use_bias=False, key=l1_key
78
+ )
79
+ else:
80
+ self.to_q = eqx.nn.Identity()
81
+
82
+ self.to_k = eqx.nn.Linear(
83
+ in_channels, head_n_channels * n_heads, use_bias=False, key=l2_key
84
+ )
85
+
86
+ self.k_norm = InstanceNorm1d(head_n_channels)
87
+ self.v_norm = InstanceNorm1d(head_n_channels)
88
+
89
+ self.to_v = eqx.nn.Linear(
90
+ in_channels, head_n_channels * n_heads, use_bias=False, key=l3_key
91
+ )
92
+
93
+ self.to_out = (
94
+ eqx.nn.Linear(head_n_channels * n_heads, out_channels, key=l4_key)
95
+ if head_n_channels * n_heads != out_channels
96
+ else eqx.nn.Identity()
97
+ )
98
+
99
+ self.init_gain = 1 / math.sqrt(head_n_channels)
100
+ self.diagonal_weight = self.init_gain
101
+ self.initialize_qkv_weights()
102
+
103
+ def init_weight(
104
+ self, weight: Array, init_fn: Callable[[Array, Any], Array]
105
+ ) -> None:
106
+ """
107
+ Initialization for the projection matrix
108
+ basically initialize the weights for each heads with predefined initialization function and gain,
109
+ to add the diagonal bias, it requires input channels = head_n_channels
110
+ W = init_fn(W) + I * diagonal_weight
111
+
112
+ init_fn is xavier_uniform_ by default
113
+ """
114
+
115
+ for param in weight.parameters():
116
+ if param.ndim > 1:
117
+ for h in range(self.n_heads):
118
+ init_fn(
119
+ param[
120
+ h * self.head_n_channels : (h + 1) * self.head_n_channels, :
121
+ ],
122
+ gain=self.init_gain,
123
+ )
124
+ if self.head_n_channels == self.in_channels:
125
+ diagonal_bias = self.diagonal_weight * jnp.diag(
126
+ jnp.ones(param.size(-1), dtype=jnp.float32)
127
+ )
128
+ param.data[
129
+ h * self.head_n_channels : (h + 1) * self.head_n_channels, :
130
+ ] += diagonal_bias
131
+
132
+ def initialize_qkv_weights(self) -> None:
133
+ """
134
+ Initialize the weights for q, k, v projection matrix with a small gain and add a diagonal bias,
135
+ this technique has been found useful for scale-sensitive problem that has not been normalized
136
+ see Table 8 in https://arxiv.org/pdf/2105.14995.pdf
137
+ """
138
+ init_fn = jax.nn.initializers.xavier_uniform()
139
+
140
+ if self.project_query:
141
+ self.init_weight(self.to_q, init_fn)
142
+ self.init_weight(self.to_k, init_fn)
143
+ self.init_weight(self.to_v, init_fn)
144
+
145
+ def normalize_wrt_domain(self, u: Array, norm_fn):
146
+ """
147
+ Normalize the input function with respect to the domain,
148
+ reshape the tensor to [batch_size*n_heads, num_grid_points, head_n_channels]
149
+ The second dimension is equal to the number of grid points that discretize the domain
150
+ """
151
+ # u: the input or transformed function
152
+ batch_size = u.shape[0]
153
+ u = u.view(batch_size * self.n_heads, -1, self.head_n_channels)
154
+ u = norm_fn(
155
+ u
156
+ ) # layer norm with channel dimension or instance norm with spatial dimension
157
+ return u.view(batch_size, self.n_heads, -1, self.head_n_channels)
158
+
159
+ def forward(
160
+ self,
161
+ u_src,
162
+ pos_src,
163
+ positional_embedding_module: eqx.nn.RotaryPositionalEmbedding = None, # positional encoding module for encoding q/k
164
+ u_qry=None,
165
+ pos_qry=None,
166
+ weights=None,
167
+ associative=True, # can be much faster if num_grid_points is larger than the channel number c
168
+ return_kernel=False,
169
+ ):
170
+ """
171
+ Computes kernel integral transform with attention
172
+
173
+ Parameters
174
+ ----------
175
+ u_src: input function (used to compute key and value in attention),
176
+ tensor of shape [batch_size, num_grid_points_src, channels]
177
+ pos_src: coordinate of the second source of function's sampling points y,
178
+ tensor of shape [batch_size, num_grid_points_src, pos_dim]
179
+ positional_embedding_module: positional embedding module for encoding query/key (q/k),
180
+ a torch.nn.Module
181
+ u_qry: query function,
182
+ tensor of shape [batch_size, num_grid_points_query, channels], if not provided, u_qry = u_src
183
+ pos_qry: coordinate of query points x,
184
+ tensor of shape [batch_size, num_grid_points_query, pos_dim], if not provided, pos_qry = pos_src
185
+ weights : quadrature weight w(y_j) for the kernel integral: u(x_i) = sum_{j} k(x_i, y_j) f(y_i) w(y_j),
186
+ tensor of shape [batch_size, num_grid_points_src], if not provided assume to be 1/num_grid_points_src
187
+ associative: if True, use associativity of matrix multiplication, first multiply K^T V, then multiply Q,
188
+ much faster when num_grid_points is larger than the channel number (which is usually the case)
189
+ return_kernel: if True, return the kernel matrix (for analyzing the kernel)
190
+
191
+ Output
192
+ ----------
193
+ u: Output function given on the query points x.
194
+ """
195
+
196
+ if u_qry is None:
197
+ u_qry = u_src # go back to self attention
198
+ if pos_qry is not None:
199
+ raise ValueError(
200
+ "Query coordinates are provided but query function is not provided"
201
+ )
202
+ else:
203
+ if pos_qry is None:
204
+ raise ValueError(
205
+ "Query coordinates are required if query function is provided"
206
+ )
207
+
208
+ if return_kernel and associative:
209
+ raise ValueError("Cannot get kernel matrix when associative is set to True")
210
+
211
+ batch_size, num_grid_points = u_src.shape[
212
+ :2
213
+ ] # batch size and number of grid points
214
+ pos_dim = pos_src.shape[-1] # position dimension
215
+
216
+ q = self.to_q(u_qry)
217
+ k = self.to_k(u_src)
218
+ v = self.to_v(u_src)
219
+ q = (
220
+ q.view(batch_size, -1, self.n_heads, self.head_n_channels)
221
+ .permute(0, 2, 1, 3)
222
+ .contiguous()
223
+ )
224
+ k = (
225
+ k.view(batch_size, -1, self.n_heads, self.head_n_channels)
226
+ .permute(0, 2, 1, 3)
227
+ .contiguous()
228
+ )
229
+ v = (
230
+ v.view(batch_size, -1, self.n_heads, self.head_n_channels)
231
+ .permute(0, 2, 1, 3)
232
+ .contiguous()
233
+ )
234
+
235
+ k = self.normalize_wrt_domain(k, self.k_norm)
236
+ v = self.normalize_wrt_domain(v, self.v_norm)
237
+
238
+ if positional_embedding_module is not None:
239
+ if pos_dim == 2:
240
+ k_freqs_1 = positional_embedding_module.forward(pos_src[..., 0])
241
+ k_freqs_2 = positional_embedding_module.forward(pos_src[..., 1])
242
+ k_freqs_1 = k_freqs_1.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
243
+ k_freqs_2 = k_freqs_2.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
244
+
245
+ if pos_qry is None:
246
+ q_freqs_1 = k_freqs_1
247
+ q_freqs_2 = k_freqs_2
248
+ else:
249
+ q_freqs_1 = positional_embedding_module.forward(pos_qry[..., 0])
250
+ q_freqs_2 = positional_embedding_module.forward(pos_qry[..., 1])
251
+ q_freqs_1 = q_freqs_1.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
252
+ q_freqs_2 = q_freqs_2.unsqueeze(1).repeat([1, self.n_heads, 1, 1])
253
+
254
+ q = positional_embedding_module.apply_2d_rotary_pos_emb(
255
+ q, q_freqs_1, q_freqs_2
256
+ )
257
+ k = positional_embedding_module.apply_2d_rotary_pos_emb(
258
+ k, k_freqs_1, k_freqs_2
259
+ )
260
+ elif pos_dim == 1:
261
+ k_freqs = positional_embedding_module.forward(pos_src[..., 0])
262
+ k_freqs = k_freqs.unsqueeze(1).repeat([batch_size, self.n_heads, 1, 1])
263
+
264
+ if pos_qry is None:
265
+ q_freqs = k_freqs
266
+ else:
267
+ q_freqs = positional_embedding_module.forward(pos_qry[..., 0])
268
+ q_freqs = q_freqs.unsqueeze(1).repeat(
269
+ [batch_size, self.n_heads, 1, 1]
270
+ )
271
+
272
+ q = positional_embedding_module.apply_1d_rotary_pos_emb(q, q_freqs)
273
+ k = positional_embedding_module.apply_1d_rotary_pos_emb(k, k_freqs)
274
+ else:
275
+ raise ValueError(
276
+ "Currently doesnt support relative embedding >= 3 dimensions"
277
+ )
278
+
279
+ if weights is not None:
280
+ weights = weights.view(batch_size, 1, num_grid_points, 1)
281
+ else:
282
+ weights = 1.0 / num_grid_points
283
+
284
+ if associative:
285
+ dots = jnp.matmul(k.transpose(-1, -2), v)
286
+ u = jnp.matmul(q, dots) * weights
287
+ else:
288
+ # this is more efficient when num_grid_points<<channels
289
+ kxy = jnp.matmul(q, k.transpose(-1, -2))
290
+ u = jnp.matmul(kxy, v) * weights
291
+
292
+ u = (
293
+ u.permute(0, 2, 1, 3)
294
+ .contiguous()
295
+ .view(batch_size, num_grid_points, self.n_heads * self.head_n_channels)
296
+ )
297
+ u = self.to_out(u)
298
+ if return_kernel:
299
+ return u, kxy
300
+ return u
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,17 @@
1
+ import jax
2
+ from jax import Array
3
+ import equinox as eqx
4
+
5
+
6
+ class InstanceNorm1d(eqx.Module):
7
+ ln: eqx.nn.LayerNorm
8
+
9
+ def __init__(self, shape_per_channel: int) -> None:
10
+ # shape_per_channel is the length of the 1d signal
11
+ # we disable learnable affine parameters to match torch neuralop config
12
+ self.ln = eqx.nn.LayerNorm(
13
+ shape=(shape_per_channel,), use_weight=False, use_bias=False
14
+ )
15
+
16
+ def __call__(self, x: Array) -> Array:
17
+ return jax.vmap(self.ln)(x)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,234 @@
1
+ from typing import Optional, Union
2
+ from math import prod
3
+ from pathlib import Path
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax import Array
7
+ import equinox as eqx
8
+
9
+ # Only import wandb and use if installed
10
+ wandb_available = False
11
+ try:
12
+ import wandb
13
+
14
+ wandb_available = True
15
+ except ModuleNotFoundError:
16
+ wandb_available = False
17
+
18
+
19
+ def count_model_params(model: eqx.Module) -> int:
20
+ """Returns the total number of parameters of an equinox model
21
+
22
+ Notes
23
+ -----
24
+ One complex number is counted as two parameters (we count real and imaginary parts)'
25
+ """
26
+ trainable_params = eqx.filter(model, eqx.is_inexact_array)
27
+ total_params = jax.vmap(count_array_params)(trainable_params)
28
+ return int(jnp.sum(total_params))
29
+
30
+
31
+ def count_array_params(arr: Array, dims=None) -> int:
32
+ """Returns the number of parameters (elements) in a single array, optionally, along certain dimensions only
33
+
34
+ Parameters
35
+ ----------
36
+ array : Array
37
+ dims : int list or None, default is None
38
+ if not None, the dimensions to consider when counting the number of parameters (elements)
39
+
40
+ Notes
41
+ -----
42
+ One complex number is counted as two parameters (we count real and imaginary parts)'
43
+ """
44
+ if dims is None:
45
+ dims = list(arr.shape)
46
+ else:
47
+ dims = [arr.shape[d] for d in dims]
48
+ n_params = prod(dims)
49
+ if jnp.iscomplex(arr):
50
+ return 2 * n_params
51
+ return n_params
52
+
53
+
54
+ def wandb_login(
55
+ api_key_file: str = "../config/wandb_api_key.txt", key: str = None
56
+ ) -> None:
57
+ if key is None:
58
+ key = get_wandb_api_key(api_key_file)
59
+
60
+ wandb.login(key=key)
61
+
62
+
63
+ def set_wandb_api_key(api_key_file: str = "../config/wandb_api_key.txt") -> None:
64
+ import os
65
+
66
+ try:
67
+ os.environ["WANDB_API_KEY"]
68
+ except KeyError:
69
+ with open(api_key_file, "r") as f:
70
+ key = f.read()
71
+ os.environ["WANDB_API_KEY"] = key.strip()
72
+
73
+
74
+ def get_wandb_api_key(api_key_file: str = "../config/wandb_api_key.txt") -> str:
75
+ import os
76
+
77
+ try:
78
+ return os.environ["WANDB_API_KEY"]
79
+ except KeyError:
80
+ with open(api_key_file, "r") as f:
81
+ key = f.read()
82
+ return key.strip()
83
+
84
+
85
+ # Define the function to compute the spectrum
86
+ def spectrum_2d(signal: Array, n_observations: int, normalize: bool = True) -> Array:
87
+ """This function computes the spectrum of a 2D signal using the Fast Fourier Transform (FFT).
88
+
89
+ Parameters
90
+ ----------
91
+ signal : an array of shape (T * n_observations * n_observations)
92
+ A 2D discretized signal represented as a 1D array with shape
93
+ (T * n_observations * n_observations), where T is the number of time
94
+ steps and n_observations is the spatial size of the signal.
95
+
96
+ T can be any number of channels that we reshape into and
97
+ n_observations * n_observations is the spatial resolution.
98
+ n_observations: an integer
99
+ Number of discretized points. Basically the resolution of the signal.
100
+ normalize: bool
101
+ whether to apply normalization to the output of the 2D FFT.
102
+ If True, normalizes the outputs by ``1/n_observations``
103
+ (actually ``1/sqrt(n_observations * n_observations)``).
104
+ Returns
105
+ --------
106
+ spectrum: a array
107
+ A 1D array of shape (s,) representing the computed spectrum.
108
+ The spectrum is computed using a square approximation to radial
109
+ binning, meaning that the wavenumber 'bin' into which a particular
110
+ coefficient is the coefficient's location along the diagonal, indexed
111
+ from the top-left corner of the 2d FFT output.
112
+ """
113
+ T = signal.shape[0]
114
+ signal = jnp.reshape(signal, (T, n_observations, n_observations))
115
+
116
+ if normalize:
117
+ signal = jnp.fft.fft2(signal, norm="ortho")
118
+ else:
119
+ signal = jnp.fft.rfft2(
120
+ signal, s=(n_observations, n_observations), norm="backward"
121
+ )
122
+
123
+ # 2d wavenumbers following numpy fft convention
124
+ k_max = n_observations // 2
125
+ wavenumers = jnp.repeat(
126
+ jnp.concat(
127
+ (
128
+ jnp.arange(start=0, stop=k_max, step=1),
129
+ jnp.arange(start=-k_max, stop=0, step=1),
130
+ ),
131
+ axis=0,
132
+ ),
133
+ n_observations,
134
+ 1,
135
+ )
136
+ k_x = jnp.transpose(wavenumers, (0, 1))
137
+ k_y = wavenumers
138
+
139
+ # Sum wavenumbers
140
+ sum_k = jnp.sqrt(k_x**2 + k_y**2)
141
+ sum_k = sum_k
142
+
143
+ # Remove symmetric components from wavenumbers
144
+ index = -1.0 * jnp.ones((n_observations, n_observations))
145
+ k_max1 = k_max + 1
146
+ index[0:k_max1, 0:k_max1] = sum_k[0:k_max1, 0:k_max1]
147
+
148
+ spectrum = jnp.zeros((T, n_observations))
149
+ for j in range(1, n_observations + 1):
150
+ ind = jnp.where(index == j)
151
+ spectrum[:, j - 1] = jnp.sum(jnp.abs(signal[:, ind[0], ind[1]]) ** 2, axis=1)
152
+
153
+ spectrum = jnp.mean(spectrum, axis=0)
154
+ return spectrum
155
+
156
+
157
+ Number = Union[float, int]
158
+
159
+
160
+ def validate_scaling_factor(
161
+ scaling_factor: Union[None, Number, list[Number], list[list[Number]]],
162
+ n_dim: int,
163
+ n_layers: Optional[int] = None,
164
+ ) -> Union[None, list[float], list[list[float]]]:
165
+ """
166
+ Parameters
167
+ ----------
168
+ scaling_factor : None OR float OR list[float] Or list[list[float]]
169
+ n_dim : int
170
+ n_layers : int or None; defaults to None
171
+ If None, return a single list (rather than a list of lists)
172
+ with `factor` repeated `dim` times.
173
+ """
174
+ if scaling_factor is None:
175
+ return None
176
+ if isinstance(scaling_factor, (float, int)):
177
+ if n_layers is None:
178
+ return [float(scaling_factor)] * n_dim
179
+
180
+ return [[float(scaling_factor)] * n_dim] * n_layers
181
+
182
+ if (
183
+ isinstance(scaling_factor, list)
184
+ and len(scaling_factor) > 0
185
+ and all([isinstance(s, (float, int)) for s in scaling_factor])
186
+ ):
187
+ if n_layers is None and len(scaling_factor) == n_dim:
188
+ # this is a dim-wise scaling
189
+ return [float(s) for s in scaling_factor]
190
+ return [[float(s)] * n_dim for s in scaling_factor]
191
+
192
+ if (
193
+ isinstance(scaling_factor, list)
194
+ and len(scaling_factor) > 0
195
+ and all([isinstance(s, (list)) for s in scaling_factor])
196
+ ):
197
+ s_sub_pass = True
198
+ for s in scaling_factor:
199
+ if all([isinstance(s_sub, (int, float)) for s_sub in s]):
200
+ pass
201
+ else:
202
+ s_sub_pass = False
203
+ if s_sub_pass:
204
+ return scaling_factor
205
+
206
+ return None
207
+
208
+
209
+ def compute_rank(arr: Array) -> Array:
210
+ # Compute the matrix rank of a array
211
+ rank = jnp.linalg.matrix_rank(arr)
212
+ return rank
213
+
214
+
215
+ def compute_stable_rank(arr: Array) -> Array:
216
+ # Compute the stable rank of a array
217
+ fro_norm = jnp.linalg.norm(arr, ord="fro") ** 2
218
+ l2_norm = jnp.linalg.norm(arr, ord=2) ** 2
219
+ rank = fro_norm / l2_norm
220
+ rank = rank
221
+ return rank
222
+
223
+
224
+ def compute_explained_variance(frequency_max: int, s: Array) -> Array:
225
+ # Compute the explained variance based on frequency_max and singular
226
+ # values (s)
227
+ s_current = s.copy()
228
+ s_current[frequency_max:] = 0
229
+ return 1 - jnp.var(s - s_current) / jnp.var(s)
230
+
231
+
232
+ def get_project_root() -> Path:
233
+ root = Path(__file__).parent.parent
234
+ return root
@@ -0,0 +1,39 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "neojax-operators"
7
+ version = "0.0.1"
8
+ description = "Neural Operators in JAX"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Paul Gekeler", email = "pgekeler@gmail.com" },
14
+ ]
15
+ keywords = [
16
+ "jax",
17
+ "neural-operators",
18
+ "fno",
19
+ "deep-learning",
20
+ "scientific-computing",
21
+ ]
22
+ classifiers = [
23
+ "Development Status :: 3 - Alpha",
24
+ "Intended Audience :: Science/Research",
25
+ "License :: OSI Approved :: MIT License",
26
+ "Programming Language :: Python :: 3",
27
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
28
+ ]
29
+ dependencies = [
30
+ "jax",
31
+ "jaxlib",
32
+ "equinox",
33
+ ]
34
+
35
+ [tool.hatch.build.targets.wheel]
36
+ packages = ["neojax"]
37
+
38
+ [project.urls]
39
+ Homepage = "https://github.com/paulgekeler/neojax"