jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/nn/_spinn.py
CHANGED
|
@@ -21,12 +21,12 @@ class SPINN(AbstractPINN):
|
|
|
21
21
|
used for non-stationnary equations.
|
|
22
22
|
r : int
|
|
23
23
|
An integer. The dimension of the embedding.
|
|
24
|
-
eq_type : Literal["ODE", "
|
|
24
|
+
eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
|
|
25
25
|
A string with three possibilities.
|
|
26
26
|
"ODE": the PINN is called with one input `t`.
|
|
27
|
-
"
|
|
27
|
+
"PDEStatio": the PINN is called with one input `x`, `x`
|
|
28
28
|
can be high dimensional.
|
|
29
|
-
"
|
|
29
|
+
"PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
|
|
30
30
|
can be high dimensional.
|
|
31
31
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
32
32
|
of the dimension of `t` + the dimension of `x`.
|
|
@@ -49,7 +49,7 @@ class SPINN(AbstractPINN):
|
|
|
49
49
|
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
eq_type: Literal["ODE", "
|
|
52
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"] = eqx.field(
|
|
53
53
|
static=True, kw_only=True
|
|
54
54
|
)
|
|
55
55
|
d: int = eqx.field(static=True, kw_only=True)
|
jinns/nn/_spinn_mlp.py
CHANGED
|
@@ -78,7 +78,7 @@ class SPINN_MLP(SPINN):
|
|
|
78
78
|
d: int,
|
|
79
79
|
r: int,
|
|
80
80
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...],
|
|
81
|
-
eq_type: Literal["ODE", "
|
|
81
|
+
eq_type: Literal["ODE", "PDEStatio", "PDENonStatio"],
|
|
82
82
|
m: int = 1,
|
|
83
83
|
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = None,
|
|
84
84
|
) -> tuple[Self, SPINN]:
|
|
@@ -114,12 +114,12 @@ class SPINN_MLP(SPINN):
|
|
|
114
114
|
(jax.nn.tanh,),
|
|
115
115
|
(eqx.nn.Linear, 20, r * m)
|
|
116
116
|
)`.
|
|
117
|
-
eq_type : Literal["ODE", "
|
|
117
|
+
eq_type : Literal["ODE", "PDEStatio", "PDENonStatio"]
|
|
118
118
|
A string with three possibilities.
|
|
119
119
|
"ODE": the PINN is called with one input `t`.
|
|
120
|
-
"
|
|
120
|
+
"PDEStatio": the PINN is called with one input `x`, `x`
|
|
121
121
|
can be high dimensional.
|
|
122
|
-
"
|
|
122
|
+
"PDENonStatio": the PINN is called with two inputs `t` and `x`, `x`
|
|
123
123
|
can be high dimensional.
|
|
124
124
|
**Note**: the input dimension as given in eqx_list has to match the sum
|
|
125
125
|
of the dimension of `t` + the dimension of `x`.
|
|
@@ -150,11 +150,11 @@ class SPINN_MLP(SPINN):
|
|
|
150
150
|
Raises
|
|
151
151
|
------
|
|
152
152
|
RuntimeError
|
|
153
|
-
If the parameter value for eq_type is not in `["ODE", "
|
|
154
|
-
"
|
|
153
|
+
If the parameter value for eq_type is not in `["ODE", "PDEStatio",
|
|
154
|
+
"PDENonStatio"]` and for various failing checks
|
|
155
155
|
"""
|
|
156
156
|
|
|
157
|
-
if eq_type not in ["ODE", "
|
|
157
|
+
if eq_type not in ["ODE", "PDEStatio", "PDENonStatio"]:
|
|
158
158
|
raise RuntimeError("Wrong parameter value for eq_type")
|
|
159
159
|
|
|
160
160
|
def element_is_layer(element: tuple) -> TypeGuard[tuple[Callable, int, int]]:
|
|
@@ -47,9 +47,9 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
47
47
|
[`DynamicLoss`][jinns.loss.DynamicLoss] should be differentiated both with
|
|
48
48
|
respect to the neural network parameters *and* the equation parameters, or only some of them.
|
|
49
49
|
|
|
50
|
-
To do so, user can either use strings or a `Params` object
|
|
51
|
-
with PyTree structure matching the parameters of the problem at
|
|
52
|
-
hand, and booleans indicating if gradient is to be taken or not. Internally,
|
|
50
|
+
To do so, user can either use strings or a `Params[bool]` object
|
|
51
|
+
with PyTree structure matching the parameters of the problem (`Params[Array]`) at
|
|
52
|
+
hand, and leaves being booleans indicating if gradient is to be taken or not. Internally,
|
|
53
53
|
a `jax.lax.stop_gradient()` is appropriately set to each `True` node when
|
|
54
54
|
computing each loss term.
|
|
55
55
|
|
|
@@ -156,12 +156,12 @@ class DerivativeKeysODE(eqx.Module):
|
|
|
156
156
|
"""
|
|
157
157
|
Construct the DerivativeKeysODE from strings. For each term of the
|
|
158
158
|
loss, specify whether to differentiate wrt the neural network
|
|
159
|
-
parameters, the equation parameters or both. The `Params` object, which
|
|
159
|
+
parameters, the equation parameters or both. The `Params[Array]` object, which
|
|
160
160
|
contains the actual array of parameters must be passed to
|
|
161
161
|
construct the fields with the appropriate PyTree structure.
|
|
162
162
|
|
|
163
163
|
!!! note
|
|
164
|
-
You can mix strings and `Params` if you need granularity.
|
|
164
|
+
You can mix strings and `Params[bool]` if you need granularity.
|
|
165
165
|
|
|
166
166
|
Parameters
|
|
167
167
|
----------
|
|
@@ -498,7 +498,14 @@ def _set_derivatives(
|
|
|
498
498
|
`Params(nn_params=True | False, eq_params={"alpha":True | False,
|
|
499
499
|
"beta":True | False})`.
|
|
500
500
|
"""
|
|
501
|
-
|
|
501
|
+
assert jax.tree.structure(params_.eq_params) == jax.tree.structure(
|
|
502
|
+
derivative_mask.eq_params
|
|
503
|
+
), (
|
|
504
|
+
"The derivative "
|
|
505
|
+
"mask for eq_params does not have the same tree structure as "
|
|
506
|
+
"Params.eq_params. This is often due to a wrong Params[bool] "
|
|
507
|
+
"passed when initializing the derivative key object."
|
|
508
|
+
)
|
|
502
509
|
return Params(
|
|
503
510
|
nn_params=jax.lax.cond(
|
|
504
511
|
derivative_mask.nn_params,
|
jinns/parameters/_params.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Formalize the data structure for the parameters
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
5
6
|
from dataclasses import fields
|
|
6
7
|
from typing import Generic, TypeVar
|
|
7
8
|
import equinox as eqx
|
|
@@ -60,6 +61,15 @@ class Params(eqx.Module, Generic[T]):
|
|
|
60
61
|
else:
|
|
61
62
|
self.eq_params = eq_params
|
|
62
63
|
|
|
64
|
+
def partition(self, mask: Params[bool] | None):
|
|
65
|
+
"""
|
|
66
|
+
following the boolean mask, partition into two Params
|
|
67
|
+
"""
|
|
68
|
+
if mask is not None:
|
|
69
|
+
return eqx.partition(self, mask)
|
|
70
|
+
else:
|
|
71
|
+
return self, None
|
|
72
|
+
|
|
63
73
|
|
|
64
74
|
def update_eq_params(
|
|
65
75
|
params: Params[Array],
|
jinns/solver/_rar.py
CHANGED
|
@@ -10,6 +10,7 @@ from jax import vmap
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import equinox as eqx
|
|
12
12
|
|
|
13
|
+
from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
|
|
13
14
|
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
14
15
|
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
15
16
|
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
@@ -176,16 +177,25 @@ def _rar_step_init(
|
|
|
176
177
|
)
|
|
177
178
|
|
|
178
179
|
data = eqx.tree_at(lambda m: m.key, data, new_key)
|
|
179
|
-
|
|
180
|
-
v_dyn_loss = vmap(
|
|
181
|
-
lambda inputs: loss.dynamic_loss.evaluate(inputs, loss.u, params),
|
|
182
|
-
)
|
|
183
|
-
dyn_on_s = v_dyn_loss(new_samples)
|
|
184
|
-
|
|
185
|
-
if dyn_on_s.ndim > 1:
|
|
186
|
-
mse_on_s = (jnp.linalg.norm(dyn_on_s, axis=-1) ** 2).flatten()
|
|
187
180
|
else:
|
|
188
|
-
|
|
181
|
+
raise ValueError("Wrong DataGenerator type")
|
|
182
|
+
|
|
183
|
+
v_dyn_loss = jax.tree.map(
|
|
184
|
+
lambda d: vmap(
|
|
185
|
+
lambda inputs: d.evaluate(inputs, loss.u, params),
|
|
186
|
+
),
|
|
187
|
+
loss.dynamic_loss,
|
|
188
|
+
is_leaf=lambda x: isinstance(x, (ODE, PDEStatio, PDENonStatio)),
|
|
189
|
+
)
|
|
190
|
+
dyn_on_s = jax.tree.map(lambda d: d(new_samples), v_dyn_loss)
|
|
191
|
+
|
|
192
|
+
mse_on_s = jax.tree.reduce(
|
|
193
|
+
jnp.add,
|
|
194
|
+
jax.tree.map(
|
|
195
|
+
lambda v: (jnp.linalg.norm(v, axis=-1) ** 2).flatten(), dyn_on_s
|
|
196
|
+
),
|
|
197
|
+
0,
|
|
198
|
+
)
|
|
189
199
|
|
|
190
200
|
## Select the m points with higher dynamic loss
|
|
191
201
|
higher_residual_idx = jax.lax.dynamic_slice(
|