jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_DataGenerators.py +2 -2
- jinns/loss/_DynamicLoss.py +2 -2
- jinns/loss/_LossODE.py +1 -1
- jinns/loss/_LossPDE.py +75 -38
- jinns/loss/_boundary_conditions.py +2 -2
- jinns/loss/_loss_utils.py +21 -15
- jinns/loss/_operators.py +0 -2
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +39 -23
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +3 -3
- jinns/solver/_rar.py +3 -3
- jinns/solver/_solve.py +23 -9
- jinns/utils/__init__.py +0 -5
- jinns/utils/_types.py +4 -4
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/METADATA +9 -9
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/utils/_spinn.py
DELETED
|
@@ -1,249 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Implements utility function to create Separable PINNs
|
|
3
|
-
https://arxiv.org/abs/2211.08761
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from dataclasses import InitVar
|
|
7
|
-
from typing import Callable, Literal
|
|
8
|
-
import jax
|
|
9
|
-
import jax.numpy as jnp
|
|
10
|
-
import equinox as eqx
|
|
11
|
-
from jaxtyping import Key, Array, Float, PyTree
|
|
12
|
-
|
|
13
|
-
from jinns.parameters._params import Params, ParamsDict
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class _SPINN(eqx.Module):
|
|
17
|
-
"""
|
|
18
|
-
Construct a Separable PINN as proposed in
|
|
19
|
-
Cho et al., _Separable Physics-Informed Neural Networks_, NeurIPS, 2023
|
|
20
|
-
|
|
21
|
-
Parameters
|
|
22
|
-
----------
|
|
23
|
-
key : InitVar[Key]
|
|
24
|
-
A jax random key for the layer initializations.
|
|
25
|
-
d : int
|
|
26
|
-
The number of dimensions to treat separately, including time `t` if
|
|
27
|
-
used for non-stationnary equations.
|
|
28
|
-
eqx_list : InitVar[tuple[tuple[Callable, int, int] | Callable, ...]]
|
|
29
|
-
A tuple of tuples of successive equinox modules and activation functions to
|
|
30
|
-
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
31
|
-
activation function as first item, other items represents arguments
|
|
32
|
-
that could be required (eg. the size of the layer).
|
|
33
|
-
The `key` argument need not be given.
|
|
34
|
-
Thus typical example is `eqx_list=
|
|
35
|
-
((eqx.nn.Linear, 2, 20),
|
|
36
|
-
jax.nn.tanh,
|
|
37
|
-
(eqx.nn.Linear, 20, 20),
|
|
38
|
-
jax.nn.tanh,
|
|
39
|
-
(eqx.nn.Linear, 20, 20),
|
|
40
|
-
jax.nn.tanh,
|
|
41
|
-
(eqx.nn.Linear, 20, 1)
|
|
42
|
-
)`.
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
d: int = eqx.field(static=True, kw_only=True)
|
|
46
|
-
|
|
47
|
-
key: InitVar[Key] = eqx.field(kw_only=True)
|
|
48
|
-
eqx_list: InitVar[tuple[tuple[Callable, int, int] | Callable, ...]] = eqx.field(
|
|
49
|
-
kw_only=True
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
layers: list = eqx.field(init=False)
|
|
53
|
-
separated_mlp: list = eqx.field(init=False)
|
|
54
|
-
|
|
55
|
-
def __post_init__(self, key, eqx_list):
|
|
56
|
-
self.separated_mlp = []
|
|
57
|
-
for _ in range(self.d):
|
|
58
|
-
self.layers = []
|
|
59
|
-
for l in eqx_list:
|
|
60
|
-
if len(l) == 1:
|
|
61
|
-
self.layers.append(l[0])
|
|
62
|
-
else:
|
|
63
|
-
key, subkey = jax.random.split(key, 2)
|
|
64
|
-
self.layers.append(l[0](*l[1:], key=subkey))
|
|
65
|
-
self.separated_mlp.append(self.layers)
|
|
66
|
-
|
|
67
|
-
def __call__(
|
|
68
|
-
self, inputs: Float[Array, "dim"] | Float[Array, "dim+1"]
|
|
69
|
-
) -> Float[Array, "d embed_dim*output_dim"]:
|
|
70
|
-
outputs = []
|
|
71
|
-
for d in range(self.d):
|
|
72
|
-
t_ = inputs[d : d + 1]
|
|
73
|
-
for layer in self.separated_mlp[d]:
|
|
74
|
-
t_ = layer(t_)
|
|
75
|
-
outputs += [t_]
|
|
76
|
-
return jnp.asarray(outputs)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class SPINN(eqx.Module):
|
|
80
|
-
"""
|
|
81
|
-
A SPINN object compatible with the rest of jinns.
|
|
82
|
-
This is typically created with `create_SPINN`.
|
|
83
|
-
|
|
84
|
-
Parameters
|
|
85
|
-
----------
|
|
86
|
-
d : int
|
|
87
|
-
The number of dimensions to treat separately, including time `t` if
|
|
88
|
-
used for non-stationnary equations.
|
|
89
|
-
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
d: int = eqx.field(static=True, kw_only=True)
|
|
93
|
-
r: int = eqx.field(static=True, kw_only=True)
|
|
94
|
-
eq_type: str = eqx.field(static=True, kw_only=True)
|
|
95
|
-
m: int = eqx.field(static=True, kw_only=True)
|
|
96
|
-
|
|
97
|
-
spinn_mlp: InitVar[eqx.Module] = eqx.field(kw_only=True)
|
|
98
|
-
|
|
99
|
-
params: PyTree = eqx.field(init=False)
|
|
100
|
-
static: PyTree = eqx.field(init=False, static=True)
|
|
101
|
-
|
|
102
|
-
def __post_init__(self, spinn_mlp):
|
|
103
|
-
self.params, self.static = eqx.partition(spinn_mlp, eqx.is_inexact_array)
|
|
104
|
-
|
|
105
|
-
@property
|
|
106
|
-
def init_params(self) -> PyTree:
|
|
107
|
-
"""
|
|
108
|
-
Returns an initial set of parameters
|
|
109
|
-
"""
|
|
110
|
-
return self.params
|
|
111
|
-
|
|
112
|
-
def __call__(
|
|
113
|
-
self,
|
|
114
|
-
t_x: Float[Array, "batch_size 1+dim"],
|
|
115
|
-
params: Params | ParamsDict | PyTree,
|
|
116
|
-
) -> Float[Array, "output_dim"]:
|
|
117
|
-
"""
|
|
118
|
-
Evaluate the SPINN on some inputs with some params.
|
|
119
|
-
"""
|
|
120
|
-
try:
|
|
121
|
-
spinn = eqx.combine(params.nn_params, self.static)
|
|
122
|
-
except (KeyError, AttributeError, TypeError) as e:
|
|
123
|
-
spinn = eqx.combine(params, self.static)
|
|
124
|
-
v_model = jax.vmap(spinn)
|
|
125
|
-
res = v_model(t_x)
|
|
126
|
-
|
|
127
|
-
a = ", ".join([f"{chr(97 + d)}z" for d in range(res.shape[1])])
|
|
128
|
-
b = "".join([f"{chr(97 + d)}" for d in range(res.shape[1])])
|
|
129
|
-
res = jnp.stack(
|
|
130
|
-
[
|
|
131
|
-
jnp.einsum(
|
|
132
|
-
f"{a} -> {b}",
|
|
133
|
-
*(
|
|
134
|
-
res[:, d, m * self.r : (m + 1) * self.r]
|
|
135
|
-
for d in range(res.shape[1])
|
|
136
|
-
),
|
|
137
|
-
)
|
|
138
|
-
for m in range(self.m)
|
|
139
|
-
],
|
|
140
|
-
axis=-1,
|
|
141
|
-
) # compute each output dimension
|
|
142
|
-
|
|
143
|
-
# force (1,) output for non vectorial solution (consistency)
|
|
144
|
-
if len(res.shape) == self.d:
|
|
145
|
-
return jnp.expand_dims(res, axis=-1)
|
|
146
|
-
return res
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def create_SPINN(
|
|
150
|
-
key: Key,
|
|
151
|
-
d: int,
|
|
152
|
-
r: int,
|
|
153
|
-
eqx_list: tuple[tuple[Callable, int, int] | Callable, ...],
|
|
154
|
-
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
155
|
-
m: int = 1,
|
|
156
|
-
) -> tuple[SPINN, PyTree]:
|
|
157
|
-
"""
|
|
158
|
-
Utility function to create a SPINN neural network with the equinox
|
|
159
|
-
library.
|
|
160
|
-
|
|
161
|
-
*Note* that a SPINN is not vmapped and expects the
|
|
162
|
-
same batch size for each of its input axis. It directly outputs a solution
|
|
163
|
-
of shape `(batchsize, batchsize)`. See the paper for more details.
|
|
164
|
-
|
|
165
|
-
Parameters
|
|
166
|
-
----------
|
|
167
|
-
key
|
|
168
|
-
A JAX random key that will be used to initialize the network parameters
|
|
169
|
-
d
|
|
170
|
-
The number of dimensions to treat separately.
|
|
171
|
-
r
|
|
172
|
-
An integer. The dimension of the embedding.
|
|
173
|
-
eqx_list
|
|
174
|
-
A tuple of tuples of successive equinox modules and activation functions to
|
|
175
|
-
describe the PINN architecture. The inner tuples must have the eqx module or
|
|
176
|
-
activation function as first item, other items represents arguments
|
|
177
|
-
that could be required (eg. the size of the layer).
|
|
178
|
-
The `key` argument need not be given.
|
|
179
|
-
Thus typical example is
|
|
180
|
-
`eqx_list=((eqx.nn.Linear, 2, 20),
|
|
181
|
-
jax.nn.tanh,
|
|
182
|
-
(eqx.nn.Linear, 20, 20),
|
|
183
|
-
jax.nn.tanh,
|
|
184
|
-
(eqx.nn.Linear, 20, 20),
|
|
185
|
-
jax.nn.tanh,
|
|
186
|
-
(eqx.nn.Linear, 20, 1)
|
|
187
|
-
)`.
|
|
188
|
-
eq_type : Literal["ODE", "statio_PDE", "nonstatio_PDE"]
|
|
189
|
-
A string with three possibilities.
|
|
190
|
-
"ODE": the PINN is called with one input `t`.
|
|
191
|
-
"statio_PDE": the PINN is called with one input `x`, `x`
|
|
192
|
-
can be high dimensional.
|
|
193
|
-
"nonstatio_PDE": the PINN is called with two inputs `t` and `x`, `x`
|
|
194
|
-
can be high dimensional.
|
|
195
|
-
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
196
|
-
of the dimension of `t` + the dimension of `x` or the output dimension
|
|
197
|
-
after the `input_transform` function.
|
|
198
|
-
m
|
|
199
|
-
The output dimension of the neural network. According to
|
|
200
|
-
the SPINN article, a total embedding dimension of `r*m` is defined. We
|
|
201
|
-
then sum groups of `r` embedding dimensions to compute each output.
|
|
202
|
-
Default is 1.
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
Returns
|
|
207
|
-
-------
|
|
208
|
-
spinn
|
|
209
|
-
An instanciated SPINN
|
|
210
|
-
spinn.init_params
|
|
211
|
-
The initial set of parameters of the model
|
|
212
|
-
|
|
213
|
-
Raises
|
|
214
|
-
------
|
|
215
|
-
RuntimeError
|
|
216
|
-
If the parameter value for eq_type is not in `["ODE", "statio_PDE",
|
|
217
|
-
"nonstatio_PDE"]` and for various failing checks
|
|
218
|
-
"""
|
|
219
|
-
|
|
220
|
-
if eq_type not in ["ODE", "statio_PDE", "nonstatio_PDE"]:
|
|
221
|
-
raise RuntimeError("Wrong parameter value for eq_type")
|
|
222
|
-
|
|
223
|
-
try:
|
|
224
|
-
nb_inputs_declared = eqx_list[0][1] # normally we look for 2nd ele of 1st layer
|
|
225
|
-
except IndexError:
|
|
226
|
-
nb_inputs_declared = eqx_list[1][
|
|
227
|
-
1
|
|
228
|
-
] # but we can have, eg, a flatten first layer
|
|
229
|
-
if nb_inputs_declared != 1:
|
|
230
|
-
raise ValueError("Input dim must be set to 1 in SPINN!")
|
|
231
|
-
|
|
232
|
-
try:
|
|
233
|
-
nb_outputs_declared = eqx_list[-1][2] # normally we look for 3rd ele of
|
|
234
|
-
# last layer
|
|
235
|
-
except IndexError:
|
|
236
|
-
nb_outputs_declared = eqx_list[-2][2]
|
|
237
|
-
# but we can have, eg, a `jnp.exp` last layer
|
|
238
|
-
if nb_outputs_declared != r * m:
|
|
239
|
-
raise ValueError("Output dim must be set to r * m in SPINN!")
|
|
240
|
-
|
|
241
|
-
if d > 24:
|
|
242
|
-
raise ValueError(
|
|
243
|
-
"Too many dimensions, not enough letters available in jnp.einsum"
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
spinn_mlp = _SPINN(key=key, d=d, eqx_list=eqx_list)
|
|
247
|
-
spinn = SPINN(spinn_mlp=spinn_mlp, d=d, r=r, eq_type=eq_type, m=m)
|
|
248
|
-
|
|
249
|
-
return spinn, spinn.init_params
|
jinns-1.2.0.dist-info/RECORD
DELETED
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
jinns/__init__.py,sha256=5p7V5VJd7PXEINqhqS4mUsnQtXlyPwfctRhL4p0loFg,181
|
|
2
|
-
jinns/data/_Batchs.py,sha256=oc7-N1wEbsEvbe9fjVFKG2OPoZJVEjzPm8uj_icACf4,817
|
|
3
|
-
jinns/data/_DataGenerators.py,sha256=BWNF28jM7EN6wyJ3X2X1qjG4sjs6vHLgIVVY6aqhKS8,65099
|
|
4
|
-
jinns/data/__init__.py,sha256=TRCH0Z4-SQZ50MbSf46CUYWBkWVDmXCyez9T-EGiv_8,338
|
|
5
|
-
jinns/experimental/__init__.py,sha256=3jCIy2R2i_0Erwxg-HwISdH79Nt1XCXhS9yY1F5awiY,208
|
|
6
|
-
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
7
|
-
jinns/loss/_DynamicLoss.py,sha256=x1pJLZ_4P2iHhf5hjRhkWi1v5Q2dAH-v7Gctv4ax73E,25819
|
|
8
|
-
jinns/loss/_DynamicLossAbstract.py,sha256=bqmPxyrcvZh_dL74DTpj-TGiFxchvG8qC6KhuGeyOoA,12006
|
|
9
|
-
jinns/loss/_LossODE.py,sha256=lTJ3b2EnlYk1eAcldiVyFRh_XQrk83eYRQBuqWSiILg,23418
|
|
10
|
-
jinns/loss/_LossPDE.py,sha256=ufECwKdQXFEy8h5fdDA2uFTjiO_sVdqi9nfP7NTyu60,48775
|
|
11
|
-
jinns/loss/__init__.py,sha256=PRiJV9fd2GSwaCBVCPyh6pFc6pdA40jfb_T1YvO8ERc,712
|
|
12
|
-
jinns/loss/_boundary_conditions.py,sha256=KL_UUajQWLtRFqmBO_lfEcxrUsY-398ySTgpfaHdeYk,12277
|
|
13
|
-
jinns/loss/_loss_utils.py,sha256=RXu20VYeAelpuRIAJ67k3h7jdLsu2VIdjzmLVdiMjns,14223
|
|
14
|
-
jinns/loss/_loss_weights.py,sha256=F0Fgji2XpVk3pr9oIryGuXcG1FGQo4Dv6WFgze2BtA0,2201
|
|
15
|
-
jinns/loss/_operators.py,sha256=Hf5zWUQCGRSm1SS8f6IZLqC8bL7suC-R4ckXNbXqoAI,21790
|
|
16
|
-
jinns/parameters/__init__.py,sha256=1gxNLoAXUjhUzBWuh86YjU5pYy8SOboCs8TrKcU1wZc,158
|
|
17
|
-
jinns/parameters/_derivative_keys.py,sha256=UyEcgfNF1vwPcGWD2ShAZkZiq4thzRDm_OUJzOfjjiY,21909
|
|
18
|
-
jinns/parameters/_params.py,sha256=wK9ZSqoL9KnjOWqc_ZhJ09ffbsgeUEcttc1Rhme0lLk,3550
|
|
19
|
-
jinns/plot/__init__.py,sha256=Q279h5veYWNLQyttsC8_tDOToqUHh8WaRON90CiWXqk,81
|
|
20
|
-
jinns/plot/_plot.py,sha256=jS0fdI9VHCd6HL4K8ZzDjmMr10EJFqU6a0S4iWpEzFk,11433
|
|
21
|
-
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
jinns/solver/_rar.py,sha256=BCvUdNLicdU1YizHKcSaJLmkUkCybyGmTRjHnjoiQKA,9730
|
|
23
|
-
jinns/solver/_solve.py,sha256=TR-RuH2tYIFq6YhbZ-Aa7r8OYy7CstjEYNm2RMI3xck,22641
|
|
24
|
-
jinns/solver/_utils.py,sha256=b2zYvwZY_fU0NMNWvUEMvHez9s7hwcxfpGzQlz5F6HA,5762
|
|
25
|
-
jinns/utils/__init__.py,sha256=nx7wi4RfgeHr5wLM_c1BAG8UQteTUTK6B0t9enmTn1E,243
|
|
26
|
-
jinns/utils/_containers.py,sha256=a7A-iUApnjc1YVc7bdt9tKUvHHPDOKMB9OfdrDZGWN8,1450
|
|
27
|
-
jinns/utils/_hyperpinn.py,sha256=hcdHRHXer03I3Twupok5moPR1vRGgPQUCjFHAZMCMdc,16872
|
|
28
|
-
jinns/utils/_pinn.py,sha256=8wDliUrDbBUKWllJgllapNX5BLLHMWF7Mz2CaEA-6vk,12667
|
|
29
|
-
jinns/utils/_ppinn.py,sha256=K4xuGaT-4XddIxsi1wOYO24G9l1RAY2m4bZqP7C3klI,8690
|
|
30
|
-
jinns/utils/_save_load.py,sha256=BngwUjkm5jibctnUCB8klMypYkfE6DKLaoWehwWnTNU,8585
|
|
31
|
-
jinns/utils/_spinn.py,sha256=tjly3qNZj9C2JntFKuwrLfFAbVfvy7Z5HPjj0ymfnLk,8337
|
|
32
|
-
jinns/utils/_types.py,sha256=EQXxUAipwy7lVk39hYw4NDf_a66OP01w_T3ziPITxpU,1901
|
|
33
|
-
jinns/utils/_utils.py,sha256=hoRcJqcTuQi_Ip40oI4EbxW46E1rp2C01_HfuCpwKRM,2932
|
|
34
|
-
jinns/validation/__init__.py,sha256=Jv58mzgC3F7cRfXA6caicL1t_U0UAhbwLrmMNVg6E7s,66
|
|
35
|
-
jinns/validation/_validation.py,sha256=bvqL2poTFJfn9lspWqMqXvQGcQIodKwKrC786QtEZ7A,4700
|
|
36
|
-
jinns-1.2.0.dist-info/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
37
|
-
jinns-1.2.0.dist-info/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
38
|
-
jinns-1.2.0.dist-info/METADATA,sha256=PxuzXslHG4YzOcuwDY3ulo9zJaF95ChjK9y584eMnN0,4662
|
|
39
|
-
jinns-1.2.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
|
40
|
-
jinns-1.2.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
41
|
-
jinns-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|