jinns 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,135 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SinusPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose layers are x_sin2x functions whose frequencies are
16
+ determined by an other network
17
+ """
18
+
19
+ layers_pinn: list
20
+ layers_aux_nn: list
21
+
22
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ list_layers_aux_nn
31
+ A list as eqx_list in jinns' PINN utility for the network which outputs
32
+ the PINN's activation frequencies
33
+ """
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+ self.layers_aux_nn = []
42
+ for idx, l in enumerate(list_layers_aux_nn):
43
+ if len(l) == 1:
44
+ self.layers_aux_nn.append(l[0])
45
+ else:
46
+ key, subkey = jax.random.split(key, 2)
47
+ linear_layer = l[0](*l[1:], key=subkey)
48
+ key, subkey = jax.random.split(key, 2)
49
+ linear_layer = eqx.tree_at(
50
+ lambda l: l.weight,
51
+ linear_layer,
52
+ almost_zero_init(linear_layer.weight, subkey),
53
+ )
54
+ if (idx == len(list_layers_aux_nn) - 1) or (
55
+ idx == len(list_layers_aux_nn) - 2
56
+ ):
57
+ # for the last layer: almost 0 weights and 0.5 bias
58
+ linear_layer = eqx.tree_at(
59
+ lambda l: l.bias,
60
+ linear_layer,
61
+ 0.5 * jnp.ones(linear_layer.bias.shape),
62
+ )
63
+ else:
64
+ # for the other previous layers:
65
+ # almost 0 weight and 0 bias
66
+ linear_layer = eqx.tree_at(
67
+ lambda l: l.bias,
68
+ linear_layer,
69
+ jnp.zeros(linear_layer.bias.shape),
70
+ )
71
+ self.layers_aux_nn.append(linear_layer)
72
+
73
+ ## init to zero the frequency network except last biases
74
+ # key, subkey = jax.random.split(key, 2)
75
+ # _pinn = init_linear_weight(_pinn, almost_zero_init, subkey)
76
+ # key, subkey = jax.random.split(key, 2)
77
+ # _pinn = init_linear_bias(_pinn, zero_init, subkey)
78
+ # print(_pinn)
79
+ # print(jax.tree_util.tree_leaves(_pinn, is_leaf=lambda
80
+ # p:not isinstance(p,eqx.nn.Linear))[0].layers_aux_nn[-1].bias)
81
+ # _pinn = eqx.tree_at(lambda p:_pinn.layers_aux_nn[-1].bias, 0.5 *
82
+ # jnp.ones(_pinn.layers_aux_nn[-1].bias.shape))
83
+ # #, is_leaf=lambda
84
+ # #p:not isinstance(p, eqx.nn.Linear))
85
+
86
+ def __call__(self, x):
87
+ x_ = x.copy()
88
+ # forward pass in the network which determines the freq
89
+ for layer in self.layers_aux_nn:
90
+ x_ = layer(x_)
91
+ freq_list = jnp.clip(jnp.square(x_), a_min=1e-4, a_max=5)
92
+ x_ = x.copy()
93
+ # forward pass through the actual PINN
94
+ for idx, layer in enumerate(self.layers_pinn):
95
+ if idx % 2 == 0:
96
+ # Currently: every two layer we have an activation
97
+ # requiring a frequency
98
+ x_ = layer(x_)
99
+ else:
100
+ x_ = layer(x_, freq_list[(idx - 1) // 2])
101
+ return x_
102
+
103
+
104
+ class sinusPINN(PINN):
105
+ """
106
+ MUST inherit from PINN to pass all the checks
107
+
108
+ HOWEVER we dot not bother with reimplementing anything
109
+ """
110
+
111
+ def __init__(self, key, list_layers_pinn, list_layers_aux_nn):
112
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
113
+ key, subkey = jax.random.split(key, 2)
114
+ _pinn = _SinusPINN(subkey, list_layers_pinn, list_layers_aux_nn)
115
+
116
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
117
+
118
+ def init_params(self):
119
+ return self.params
120
+
121
+ def __call__(self, x, params):
122
+ try:
123
+ model = eqx.combine(params["nn_params"], self.static)
124
+ except (KeyError, TypeError) as e: # give more flexibility
125
+ model = eqx.combine(params, self.static)
126
+ res = model(x)
127
+ if not res.shape:
128
+ return jnp.expand_dims(res, axis=-1)
129
+ return res
130
+
131
+
132
+ def create_sinusPINN(key, list_layers_pinn, list_layers_aux_nn):
133
+ """ """
134
+ u = sinusPINN(key, list_layers_pinn, list_layers_aux_nn)
135
+ return u
@@ -0,0 +1,87 @@
1
+ import jax
2
+ import equinox as eqx
3
+ import jax.numpy as jnp
4
+ from jinns.utils._pinn import PINN
5
+
6
+
7
+ def almost_zero_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
8
+ out, in_ = weight.shape
9
+ stddev = 1e-2
10
+ return stddev * jax.random.normal(key, shape=(out, in_))
11
+
12
+
13
+ class _SpectralPINN(eqx.Module):
14
+ """
15
+ A specific PINN whose acrhitecture is similar to spectral method for simulation of a spatial field
16
+ (Chilès and Delfiner, 2012) - a single layer with cos() activation function and sum for last layer
17
+ """
18
+
19
+ layers_pinn: list
20
+ nbands: int
21
+
22
+ def __init__(self, key, list_layers_pinn, nbands):
23
+ """
24
+ Parameters
25
+ ----------
26
+ key
27
+ A jax random key
28
+ list_layers_pinn
29
+ A list as eqx_list in jinns' PINN utility for the main PINN
30
+ nbands
31
+ Number of spectral bands (i.e., neurones in the single layer of the PINN)
32
+ """
33
+ self.nbands = nbands
34
+ self.layers_pinn = []
35
+ for l in list_layers_pinn:
36
+ if len(l) == 1:
37
+ self.layers_pinn.append(l[0])
38
+ else:
39
+ key, subkey = jax.random.split(key, 2)
40
+ self.layers_pinn.append(l[0](*l[1:], key=subkey))
41
+
42
+ def __call__(self, x):
43
+ # forward pass through the actual PINN
44
+ for layer in self.layers_pinn:
45
+ x = layer(x)
46
+
47
+ return jnp.sqrt(2 / self.nbands) * jnp.sum(x)
48
+
49
+
50
+ class spectralPINN(PINN):
51
+ """
52
+ MUST inherit from PINN to pass all the checks
53
+
54
+ HOWEVER we dot not bother with reimplementing anything
55
+ """
56
+
57
+ def __init__(self, key, list_layers_pinn, nbands):
58
+ super().__init__({}, jnp.s_[...], "statio_PDE", None, None, None)
59
+ key, subkey = jax.random.split(key, 2)
60
+ _pinn = _SpectralPINN(subkey, list_layers_pinn, nbands)
61
+
62
+ self.params, self.static = eqx.partition(_pinn, eqx.is_inexact_array)
63
+
64
+ def init_params(self):
65
+ return self.params
66
+
67
+ def __call__(self, x, params):
68
+ try:
69
+ model = eqx.combine(params["nn_params"], self.static)
70
+ except (KeyError, TypeError) as e: # give more flexibility
71
+ model = eqx.combine(params, self.static)
72
+ # model = eqx.tree_at(lambda m:
73
+ # m.layers_pinn[0].bias,
74
+ # model,
75
+ # model.layers_pinn[0].bias % (2 *
76
+ # jnp.pi)
77
+ # )
78
+ res = model(x)
79
+ if not res.shape:
80
+ return jnp.expand_dims(res, axis=-1)
81
+ return res
82
+
83
+
84
+ def create_spectralPINN(key, list_layers_pinn, nbands):
85
+ """ """
86
+ u = spectralPINN(key, list_layers_pinn, nbands)
87
+ return u