jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +1 -0
- jinns/data/_DataGenerators.py +93 -90
- jinns/data/_display.py +102 -13
- jinns/experimental/__init__.py +2 -0
- jinns/experimental/_sinuspinn.py +135 -0
- jinns/experimental/_spectralpinn.py +87 -0
- jinns/solver/_rar.py +203 -146
- jinns/solver/_seq2seq.py +2 -2
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/METADATA +1 -1
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/RECORD +13 -11
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/LICENSE +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/WHEEL +0 -0
- {jinns-0.8.7.dist-info → jinns-0.8.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SinusPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose layers are x_sin2x functions whose frequencies are
|
|
16
|
+
determined by an other network
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
layers_aux_nn: list
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
list_layers_aux_nn
|
|
31
|
+
A list as eqx_list in jinns' PINN utility for the network which outputs
|
|
32
|
+
the PINN's activation frequencies
|
|
33
|
+
"""
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
self.layers_aux_nn = []
|
|
42
|
+
for idx, l in enumerate(list_layers_aux_nn):
|
|
43
|
+
if len(l) == 1:
|
|
44
|
+
self.layers_aux_nn.append(l[0])
|
|
45
|
+
else:
|
|
46
|
+
key, subkey = jax.random.split(key, 2)
|
|
47
|
+
linear_layer = l[0](*l[1:], key=subkey)
|
|
48
|
+
key, subkey = jax.random.split(key, 2)
|
|
49
|
+
linear_layer = eqx.tree_at(
|
|
50
|
+
lambda l: l.weight,
|
|
51
|
+
linear_layer,
|
|
52
|
+
almost_zero_init(linear_layer.weight, subkey),
|
|
53
|
+
)
|
|
54
|
+
if (idx == len(list_layers_aux_nn) - 1) or (
|
|
55
|
+
idx == len(list_layers_aux_nn) - 2
|
|
56
|
+
):
|
|
57
|
+
# for the last layer: almost 0 weights and 0.5 bias
|
|
58
|
+
linear_layer = eqx.tree_at(
|
|
59
|
+
lambda l: l.bias,
|
|
60
|
+
linear_layer,
|
|
61
|
+
0.5 * jnp.ones(linear_layer.bias.shape),
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
# for the other previous layers:
|
|
65
|
+
# almost 0 weight and 0 bias
|
|
66
|
+
linear_layer = eqx.tree_at(
|
|
67
|
+
lambda l: l.bias,
|
|
68
|
+
linear_layer,
|
|
69
|
+
jnp.zeros(linear_layer.bias.shape),
|
|
70
|
+
)
|
|
71
|
+
self.layers_aux_nn.append(linear_layer)
|
|
72
|
+
|
|
73
|
+
## init to zero the frequency network except last biases
|
|
74
|
+
# key, subkey = jax.random.split(key, 2)
|
|
75
|
+
# _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
|
|
76
|
+
# key, subkey = jax.random.split(key, 2)
|
|
77
|
+
# _pinn = init_linear_bias(_pinn, zero_init, subkey)
|
|
78
|
+
# print(_pinn)
|
|
79
|
+
# print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
|
|
80
|
+
# p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
|
|
81
|
+
# _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
|
|
82
|
+
# jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
|
|
83
|
+
# #, is_leaf=lambda
|
|
84
|
+
# #p:not isinstance(p, eqx.nn.Linear))
|
|
85
|
+
|
|
86
|
+
def __call__(self, x):
|
|
87
|
+
x_ = x.copy()
|
|
88
|
+
# forward pass in the network which determines the freq
|
|
89
|
+
for layer in self.layers_aux_nn:
|
|
90
|
+
x_ = layer(x_)
|
|
91
|
+
freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
|
|
92
|
+
x_ = x.copy()
|
|
93
|
+
# forward pass through the actual PINN
|
|
94
|
+
for idx, layer in enumerate(self.layers_pinn):
|
|
95
|
+
if idx % 2 == 0:
|
|
96
|
+
# Currently: every two layer we have an activation
|
|
97
|
+
# requiring a frequency
|
|
98
|
+
x_ = layer(x_)
|
|
99
|
+
else:
|
|
100
|
+
x_ = layer(x_, freq_list[(idx - 1) // 2])
|
|
101
|
+
return x_
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class sinusPINN(PINN):
|
|
105
|
+
"""
|
|
106
|
+
MUST inherit from PINN to pass all the checks
|
|
107
|
+
|
|
108
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
|
|
112
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
113
|
+
key, subkey = jax.random.split(key, 2)
|
|
114
|
+
_pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
|
|
115
|
+
|
|
116
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
117
|
+
|
|
118
|
+
def init_params(self):
|
|
119
|
+
return self.params
|
|
120
|
+
|
|
121
|
+
def __call__(self, x, params):
|
|
122
|
+
try:
|
|
123
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
124
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
125
|
+
model = eqx.combine(params, self.static)
|
|
126
|
+
res = model(x)
|
|
127
|
+
if not res.shape:
|
|
128
|
+
return jnp.expand_dims(res, axis=-1)
|
|
129
|
+
return res
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
|
|
133
|
+
""" """
|
|
134
|
+
u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
|
|
135
|
+
return u
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jinns.utils._pinn import PINN
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
|
|
8
|
+
out, in_ = weight.shape
|
|
9
|
+
stddev = 1e-2
|
|
10
|
+
return stddev * jax.random.normal(key, shape=(out, in_))
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _SpectralPINN(eqx.Module):
|
|
14
|
+
"""
|
|
15
|
+
A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
|
|
16
|
+
(Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
layers_pinn: list
|
|
20
|
+
nbands: int
|
|
21
|
+
|
|
22
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
23
|
+
"""
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
key
|
|
27
|
+
A jax random key
|
|
28
|
+
list_layers_pinn
|
|
29
|
+
A list as eqx_list in jinns' PINN utility for the main PINN
|
|
30
|
+
nbands
|
|
31
|
+
Number of spectral bands (i.e., neurones in the single layer of the PINN)
|
|
32
|
+
"""
|
|
33
|
+
self.nbands = nbands
|
|
34
|
+
self.layers_pinn = []
|
|
35
|
+
for l in list_layers_pinn:
|
|
36
|
+
if len(l) == 1:
|
|
37
|
+
self.layers_pinn.append(l[0])
|
|
38
|
+
else:
|
|
39
|
+
key, subkey = jax.random.split(key, 2)
|
|
40
|
+
self.layers_pinn.append(l[0](*l[1:], key=subkey))
|
|
41
|
+
|
|
42
|
+
def __call__(self, x):
|
|
43
|
+
# forward pass through the actual PINN
|
|
44
|
+
for layer in self.layers_pinn:
|
|
45
|
+
x = layer(x)
|
|
46
|
+
|
|
47
|
+
return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class spectralPINN(PINN):
|
|
51
|
+
"""
|
|
52
|
+
MUST inherit from PINN to pass all the checks
|
|
53
|
+
|
|
54
|
+
HOWEVER we dot not bother with reimplementing anything
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, key, list_layers_pinn, nbands):
|
|
58
|
+
super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
|
|
59
|
+
key, subkey = jax.random.split(key, 2)
|
|
60
|
+
_pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
|
|
61
|
+
|
|
62
|
+
self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
|
|
63
|
+
|
|
64
|
+
def init_params(self):
|
|
65
|
+
return self.params
|
|
66
|
+
|
|
67
|
+
def __call__(self, x, params):
|
|
68
|
+
try:
|
|
69
|
+
model = eqx.combine(params["nn_params"], self.static)
|
|
70
|
+
except (KeyError, TypeError) as e: # give more flexibility
|
|
71
|
+
model = eqx.combine(params, self.static)
|
|
72
|
+
# model = eqx.tree_at(lambda m:
|
|
73
|
+
# m.layers_pinn[0].bias,
|
|
74
|
+
# model,
|
|
75
|
+
# model.layers_pinn[0].bias % (2 *
|
|
76
|
+
# jnp.pi)
|
|
77
|
+
# )
|
|
78
|
+
res = model(x)
|
|
79
|
+
if not res.shape:
|
|
80
|
+
return jnp.expand_dims(res, axis=-1)
|
|
81
|
+
return res
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def create_spectralPINN(key, list_layers_pinn, nbands):
|
|
85
|
+
""" """
|
|
86
|
+
u = spectralPINN(key, list_layers_pinn, nbands)
|
|
87
|
+
return u
|