jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/loss/_abstract_loss.py
CHANGED
|
@@ -1,41 +1,77 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
from typing import
|
|
5
|
-
from jaxtyping import Array, PyTree,
|
|
4
|
+
from typing import Self, Literal, Callable, TypeVar, Generic, Any
|
|
5
|
+
from jaxtyping import PRNGKeyArray, Array, PyTree, Float
|
|
6
6
|
import equinox as eqx
|
|
7
7
|
import jax
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
import optax
|
|
10
|
-
from jinns.loss._loss_weights import AbstractLossWeights
|
|
11
10
|
from jinns.parameters._params import Params
|
|
12
11
|
from jinns.loss._loss_weight_updates import soft_adapt, lr_annealing, ReLoBRaLo
|
|
12
|
+
from jinns.utils._types import AnyLossComponents, AnyBatch, AnyLossWeights
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
14
|
+
L = TypeVar(
|
|
15
|
+
"L", bound=AnyLossWeights
|
|
16
|
+
) # we want to be able to use one of the element of AnyLossWeights
|
|
17
|
+
# that is https://stackoverflow.com/a/79534258 via `bound`
|
|
16
18
|
|
|
19
|
+
B = TypeVar(
|
|
20
|
+
"B", bound=AnyBatch
|
|
21
|
+
) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
|
|
22
|
+
# We then do the same TypeVar to be able to use one of the element of AnyBatch
|
|
23
|
+
# in the evaluate_by_terms methods of child classes.
|
|
24
|
+
C = TypeVar(
|
|
25
|
+
"C", bound=AnyLossComponents[Array | None]
|
|
26
|
+
) # The above comment also works with Unions (https://docs.python.org/3/library/typing.html#typing.TypeVar)
|
|
17
27
|
|
|
18
|
-
|
|
28
|
+
# In the cases above, without the bound, we could not have covariance on
|
|
29
|
+
# the type because it would break LSP. Note that covariance on the return type
|
|
30
|
+
# is authorized in LSP hence we do not need the same TypeVar instruction for
|
|
31
|
+
# the return types of evaluate_by_terms for example!
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AbstractLoss(eqx.Module, Generic[L, B, C]):
|
|
19
35
|
"""
|
|
20
36
|
About the call:
|
|
21
37
|
https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
|
|
22
38
|
"""
|
|
23
39
|
|
|
24
|
-
loss_weights:
|
|
40
|
+
loss_weights: eqx.AbstractVar[L]
|
|
25
41
|
update_weight_method: Literal["soft_adapt", "lr_annealing", "ReLoBRaLo"] | None = (
|
|
26
42
|
eqx.field(kw_only=True, default=None, static=True)
|
|
27
43
|
)
|
|
28
44
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
pass
|
|
45
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
46
|
+
return self.evaluate(*args, **kwargs)
|
|
32
47
|
|
|
33
48
|
@abc.abstractmethod
|
|
34
|
-
def evaluate_by_terms(
|
|
35
|
-
self, params: Params[Array], batch: AnyBatch
|
|
36
|
-
) -> tuple[AnyLossComponents, AnyLossComponents]:
|
|
49
|
+
def evaluate_by_terms(self, params: Params[Array], batch: B) -> tuple[C, C]:
|
|
37
50
|
pass
|
|
38
51
|
|
|
52
|
+
def evaluate(self, params: Params[Array], batch: B) -> tuple[Float[Array, " "], C]:
|
|
53
|
+
"""
|
|
54
|
+
Evaluate the loss function at a batch of points for given parameters.
|
|
55
|
+
|
|
56
|
+
We retrieve the total value itself and a PyTree with loss values for each term
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
---------
|
|
60
|
+
params
|
|
61
|
+
Parameters at which the loss is evaluated
|
|
62
|
+
batch
|
|
63
|
+
Composed of a batch of points in the
|
|
64
|
+
domain, a batch of points in the domain
|
|
65
|
+
border and an optional additional batch of parameters (eg. for
|
|
66
|
+
metamodeling) and an optional additional batch of observed
|
|
67
|
+
inputs/outputs/parameters
|
|
68
|
+
"""
|
|
69
|
+
loss_terms, _ = self.evaluate_by_terms(params, batch)
|
|
70
|
+
|
|
71
|
+
loss_val = self.ponderate_and_sum_loss(loss_terms)
|
|
72
|
+
|
|
73
|
+
return loss_val, loss_terms
|
|
74
|
+
|
|
39
75
|
def get_gradients(
|
|
40
76
|
self, fun: Callable[[Params[Array]], Array], params: Params[Array]
|
|
41
77
|
) -> tuple[Array, Array]:
|
|
@@ -48,7 +84,7 @@ class AbstractLoss(eqx.Module):
|
|
|
48
84
|
loss_val, grads = value_grad_loss(params)
|
|
49
85
|
return loss_val, grads
|
|
50
86
|
|
|
51
|
-
def ponderate_and_sum_loss(self, terms):
|
|
87
|
+
def ponderate_and_sum_loss(self, terms: C) -> Array:
|
|
52
88
|
"""
|
|
53
89
|
Get total loss from individual loss terms and weights
|
|
54
90
|
|
|
@@ -58,19 +94,18 @@ class AbstractLoss(eqx.Module):
|
|
|
58
94
|
self.loss_weights,
|
|
59
95
|
is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None,
|
|
60
96
|
)
|
|
61
|
-
|
|
97
|
+
terms_list = jax.tree.leaves(
|
|
62
98
|
terms, is_leaf=lambda x: eqx.is_inexact_array(x) and x is not None
|
|
63
99
|
)
|
|
64
|
-
if len(weights) == len(
|
|
65
|
-
return jnp.sum(jnp.array(weights) * jnp.array(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
)
|
|
100
|
+
if len(weights) == len(terms_list):
|
|
101
|
+
return jnp.sum(jnp.array(weights) * jnp.array(terms_list))
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"The numbers of declared loss weights and "
|
|
104
|
+
"declared loss terms do not concord "
|
|
105
|
+
f" got {len(weights)} and {len(terms_list)}"
|
|
106
|
+
)
|
|
72
107
|
|
|
73
|
-
def ponderate_and_sum_gradient(self, terms):
|
|
108
|
+
def ponderate_and_sum_gradient(self, terms: C) -> C:
|
|
74
109
|
"""
|
|
75
110
|
Get total gradients from individual loss gradients and weights
|
|
76
111
|
for each parameter
|
|
@@ -102,7 +137,7 @@ class AbstractLoss(eqx.Module):
|
|
|
102
137
|
loss_terms: PyTree,
|
|
103
138
|
stored_loss_terms: PyTree,
|
|
104
139
|
grad_terms: PyTree,
|
|
105
|
-
key:
|
|
140
|
+
key: PRNGKeyArray,
|
|
106
141
|
) -> Self:
|
|
107
142
|
"""
|
|
108
143
|
Update the loss weights according to a predefined scheme
|
jinns/loss/_loss_components.py
CHANGED
|
@@ -1,38 +1,17 @@
|
|
|
1
1
|
from typing import TypeVar, Generic
|
|
2
|
-
|
|
3
|
-
|
|
2
|
+
|
|
3
|
+
from jinns.utils._ItemizableModule import ItemizableModule
|
|
4
4
|
|
|
5
5
|
T = TypeVar("T")
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
9
|
-
"""
|
|
10
|
-
Provides a template for ODE components with generic types.
|
|
11
|
-
One can inherit to specialize and add methods and attributes
|
|
12
|
-
We do not enforce keyword only to avoid being to verbose (this then can
|
|
13
|
-
work like a tuple)
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
def items(self):
|
|
17
|
-
"""
|
|
18
|
-
For the dataclass to be iterated like a dictionary.
|
|
19
|
-
Practical and retrocompatible with old code when loss components were
|
|
20
|
-
dictionaries
|
|
21
|
-
"""
|
|
22
|
-
return {
|
|
23
|
-
field.name: getattr(self, field.name)
|
|
24
|
-
for field in fields(self)
|
|
25
|
-
if getattr(self, field.name) is not None
|
|
26
|
-
}.items()
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ODEComponents(XDEComponentsAbstract[T]):
|
|
8
|
+
class ODEComponents(ItemizableModule, Generic[T]):
|
|
30
9
|
dyn_loss: T
|
|
31
10
|
initial_condition: T
|
|
32
11
|
observations: T
|
|
33
12
|
|
|
34
13
|
|
|
35
|
-
class PDEStatioComponents(
|
|
14
|
+
class PDEStatioComponents(ItemizableModule, Generic[T]):
|
|
36
15
|
dyn_loss: T
|
|
37
16
|
norm_loss: T
|
|
38
17
|
boundary_loss: T
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -308,3 +308,26 @@ def initial_condition_apply(
|
|
|
308
308
|
else:
|
|
309
309
|
raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
|
|
310
310
|
return mse_initial_condition
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def initial_condition_check(x, dim_size=None):
|
|
314
|
+
"""
|
|
315
|
+
Make a (dim_size,) jnp array from an int, a float or a 0D jnp array
|
|
316
|
+
|
|
317
|
+
"""
|
|
318
|
+
if isinstance(x, Array):
|
|
319
|
+
if not x.shape: # e.g. user input: jnp.array(0.)
|
|
320
|
+
x = jnp.array([x])
|
|
321
|
+
if dim_size is not None: # we check for the required dims_ize
|
|
322
|
+
if x.shape != (dim_size,):
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Wrong dim_size. It should be({dim_size},). Got shape: {x.shape}"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
elif isinstance(x, float): # e.g. user input: 0.
|
|
328
|
+
x = jnp.array([x])
|
|
329
|
+
elif isinstance(x, int): # e.g. user input: 0
|
|
330
|
+
x = jnp.array([float(x)])
|
|
331
|
+
else:
|
|
332
|
+
raise ValueError(f"Wrong value, expected Array, float or int, got {type(x)}")
|
|
333
|
+
return x
|
|
@@ -4,18 +4,17 @@ A collection of specific weight update schemes in jinns
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
|
-
from jaxtyping import Array,
|
|
7
|
+
from jaxtyping import Array, PRNGKeyArray
|
|
8
8
|
import jax.numpy as jnp
|
|
9
9
|
import jax
|
|
10
10
|
import equinox as eqx
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
-
from jinns.
|
|
14
|
-
from jinns.utils._types import AnyLossComponents
|
|
13
|
+
from jinns.utils._types import AnyLossComponents, AnyLossWeights
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
def soft_adapt(
|
|
18
|
-
loss_weights:
|
|
17
|
+
loss_weights: AnyLossWeights,
|
|
19
18
|
iteration_nb: int,
|
|
20
19
|
loss_terms: AnyLossComponents,
|
|
21
20
|
stored_loss_terms: AnyLossComponents,
|
|
@@ -58,11 +57,11 @@ def soft_adapt(
|
|
|
58
57
|
|
|
59
58
|
|
|
60
59
|
def ReLoBRaLo(
|
|
61
|
-
loss_weights:
|
|
60
|
+
loss_weights: AnyLossWeights,
|
|
62
61
|
iteration_nb: int,
|
|
63
62
|
loss_terms: AnyLossComponents,
|
|
64
63
|
stored_loss_terms: AnyLossComponents,
|
|
65
|
-
key:
|
|
64
|
+
key: PRNGKeyArray,
|
|
66
65
|
decay_factor: float = 0.9,
|
|
67
66
|
tau: float = 1, ## referred to as temperature in the article
|
|
68
67
|
p: float = 0.9,
|
|
@@ -146,7 +145,7 @@ def ReLoBRaLo(
|
|
|
146
145
|
|
|
147
146
|
|
|
148
147
|
def lr_annealing(
|
|
149
|
-
loss_weights:
|
|
148
|
+
loss_weights: AnyLossWeights,
|
|
150
149
|
grad_terms: AnyLossComponents,
|
|
151
150
|
decay_factor: float = 0.9, # 0.9 is the recommended value from the article
|
|
152
151
|
eps: float = 1e-6,
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -3,81 +3,80 @@ Formalize the loss weights data structure
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
|
-
from dataclasses import fields
|
|
7
6
|
|
|
8
7
|
from jaxtyping import Array
|
|
9
8
|
import jax.numpy as jnp
|
|
10
9
|
import equinox as eqx
|
|
11
10
|
|
|
11
|
+
from jinns.loss._loss_components import (
|
|
12
|
+
ODEComponents,
|
|
13
|
+
PDEStatioComponents,
|
|
14
|
+
PDENonStatioComponents,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
|
|
18
|
+
def lw_converter(x: Array | None) -> Array | None:
|
|
14
19
|
if x is None:
|
|
15
20
|
return x
|
|
16
21
|
else:
|
|
17
22
|
return jnp.asarray(x)
|
|
18
23
|
|
|
19
24
|
|
|
20
|
-
class
|
|
25
|
+
class LossWeightsODE(ODEComponents[Array | None]):
|
|
21
26
|
"""
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
TODO in the future maybe loss weights could be subclasses of
|
|
25
|
-
XDEComponentsAbstract?
|
|
27
|
+
Value given at initialization is converted to a jnp.array orunmodified if None.
|
|
28
|
+
This means that at initialization, the user can pass a float or int
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
|
-
|
|
29
|
-
"""
|
|
30
|
-
For the dataclass to be iterated like a dictionary.
|
|
31
|
-
Practical and retrocompatible with old code when loss components were
|
|
32
|
-
dictionaries
|
|
33
|
-
"""
|
|
34
|
-
return {
|
|
35
|
-
field.name: getattr(self, field.name)
|
|
36
|
-
for field in fields(self)
|
|
37
|
-
if getattr(self, field.name) is not None
|
|
38
|
-
}.items()
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class LossWeightsODE(AbstractLossWeights):
|
|
42
|
-
dyn_loss: Array | float | None = eqx.field(
|
|
31
|
+
dyn_loss: Array | None = eqx.field(
|
|
43
32
|
kw_only=True, default=None, converter=lw_converter
|
|
44
33
|
)
|
|
45
|
-
initial_condition: Array |
|
|
34
|
+
initial_condition: Array | None = eqx.field(
|
|
46
35
|
kw_only=True, default=None, converter=lw_converter
|
|
47
36
|
)
|
|
48
|
-
observations: Array |
|
|
37
|
+
observations: Array | None = eqx.field(
|
|
49
38
|
kw_only=True, default=None, converter=lw_converter
|
|
50
39
|
)
|
|
51
40
|
|
|
52
41
|
|
|
53
|
-
class LossWeightsPDEStatio(
|
|
54
|
-
|
|
42
|
+
class LossWeightsPDEStatio(PDEStatioComponents[Array | None]):
|
|
43
|
+
"""
|
|
44
|
+
Value given at initialization is converted to a jnp.array orunmodified if None.
|
|
45
|
+
This means that at initialization, the user can pass a float or int
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
dyn_loss: Array | None = eqx.field(
|
|
55
49
|
kw_only=True, default=None, converter=lw_converter
|
|
56
50
|
)
|
|
57
|
-
norm_loss: Array |
|
|
51
|
+
norm_loss: Array | None = eqx.field(
|
|
58
52
|
kw_only=True, default=None, converter=lw_converter
|
|
59
53
|
)
|
|
60
|
-
boundary_loss: Array |
|
|
54
|
+
boundary_loss: Array | None = eqx.field(
|
|
61
55
|
kw_only=True, default=None, converter=lw_converter
|
|
62
56
|
)
|
|
63
|
-
observations: Array |
|
|
57
|
+
observations: Array | None = eqx.field(
|
|
64
58
|
kw_only=True, default=None, converter=lw_converter
|
|
65
59
|
)
|
|
66
60
|
|
|
67
61
|
|
|
68
|
-
class LossWeightsPDENonStatio(
|
|
69
|
-
|
|
62
|
+
class LossWeightsPDENonStatio(PDENonStatioComponents[Array | None]):
|
|
63
|
+
"""
|
|
64
|
+
Value given at initialization is converted to a jnp.array orunmodified if None.
|
|
65
|
+
This means that at initialization, the user can pass a float or int
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
dyn_loss: Array | None = eqx.field(
|
|
70
69
|
kw_only=True, default=None, converter=lw_converter
|
|
71
70
|
)
|
|
72
|
-
norm_loss: Array |
|
|
71
|
+
norm_loss: Array | None = eqx.field(
|
|
73
72
|
kw_only=True, default=None, converter=lw_converter
|
|
74
73
|
)
|
|
75
|
-
boundary_loss: Array |
|
|
74
|
+
boundary_loss: Array | None = eqx.field(
|
|
76
75
|
kw_only=True, default=None, converter=lw_converter
|
|
77
76
|
)
|
|
78
|
-
observations: Array |
|
|
77
|
+
observations: Array | None = eqx.field(
|
|
79
78
|
kw_only=True, default=None, converter=lw_converter
|
|
80
79
|
)
|
|
81
|
-
initial_condition: Array |
|
|
80
|
+
initial_condition: Array | None = eqx.field(
|
|
82
81
|
kw_only=True, default=None, converter=lw_converter
|
|
83
82
|
)
|
jinns/nn/_abstract_pinn.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import Literal, Any
|
|
|
3
3
|
from jaxtyping import Array
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
|
|
6
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
7
6
|
from jinns.parameters._params import Params
|
|
8
7
|
|
|
9
8
|
|
|
@@ -17,6 +16,5 @@ class AbstractPINN(eqx.Module):
|
|
|
17
16
|
eq_type: eqx.AbstractVar[Literal["ODE", "statio_PDE", "nonstatio_PDE"]]
|
|
18
17
|
|
|
19
18
|
@abc.abstractmethod
|
|
20
|
-
@_PyTree_to_Params
|
|
21
19
|
def __call__(self, inputs: Any, params: Params[Array], *args, **kwargs) -> Any:
|
|
22
20
|
pass
|
jinns/nn/_hyperpinn.py
CHANGED
|
@@ -7,18 +7,17 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import warnings
|
|
9
9
|
from dataclasses import InitVar
|
|
10
|
-
from typing import Callable, Literal, Self, Union, Any, cast
|
|
10
|
+
from typing import Callable, Literal, Self, Union, Any, cast
|
|
11
11
|
from math import prod
|
|
12
12
|
import jax
|
|
13
13
|
import jax.numpy as jnp
|
|
14
|
-
from jaxtyping import Array, Float, PyTree
|
|
14
|
+
from jaxtyping import PRNGKeyArray, Array, Float, PyTree
|
|
15
15
|
import equinox as eqx
|
|
16
16
|
import numpy as onp
|
|
17
17
|
|
|
18
18
|
from jinns.nn._pinn import PINN
|
|
19
19
|
from jinns.nn._mlp import MLP
|
|
20
20
|
from jinns.parameters._params import Params
|
|
21
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
22
21
|
|
|
23
22
|
|
|
24
23
|
def _get_param_nb(
|
|
@@ -138,6 +137,32 @@ class HyperPINN(PINN):
|
|
|
138
137
|
jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
|
|
139
138
|
)
|
|
140
139
|
|
|
140
|
+
# For the record. We exhibited that the jnp.split was a serious time
|
|
141
|
+
# bottleneck. However none of the approaches below improved the speed.
|
|
142
|
+
# Moreover, this operation is not well implemented by a triton kernel
|
|
143
|
+
# apparently so such an optim is not an option.
|
|
144
|
+
# 1)
|
|
145
|
+
# pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
|
|
146
|
+
# jnp.split(hyper_output, self.pinn_params_cumsum[:-1]),
|
|
147
|
+
# )
|
|
148
|
+
# 2)
|
|
149
|
+
# pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
|
|
150
|
+
# [jax.lax.slice(hyper_output, (s,), (e,)).reshape(r) for s, e, r in
|
|
151
|
+
# zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum,
|
|
152
|
+
# self.pinn_params_shapes)]
|
|
153
|
+
# )
|
|
154
|
+
# 3)
|
|
155
|
+
# pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
|
|
156
|
+
# [hyper_output[s:e].reshape(r) for s, e, r in
|
|
157
|
+
# zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum,
|
|
158
|
+
# self.pinn_params_shapes)]
|
|
159
|
+
# )
|
|
160
|
+
# 4)
|
|
161
|
+
# pinn_params_flat = jax.tree.unflatten(self.pinn_params_struct,
|
|
162
|
+
# [jax.lax.dynamic_slice(hyper_output, (s,), (size,)) for s, size in
|
|
163
|
+
# zip(self.pinn_params_cumsum_start, self.pinn_params_cumsum_size)]
|
|
164
|
+
# )
|
|
165
|
+
|
|
141
166
|
return jax.tree.map(
|
|
142
167
|
lambda a, b: a.reshape(b.shape),
|
|
143
168
|
pinn_params_flat,
|
|
@@ -145,17 +170,6 @@ class HyperPINN(PINN):
|
|
|
145
170
|
is_leaf=lambda x: isinstance(x, jnp.ndarray),
|
|
146
171
|
)
|
|
147
172
|
|
|
148
|
-
@overload
|
|
149
|
-
@_PyTree_to_Params
|
|
150
|
-
def __call__(
|
|
151
|
-
self,
|
|
152
|
-
inputs: Float[Array, " input_dim"],
|
|
153
|
-
params: PyTree,
|
|
154
|
-
*args,
|
|
155
|
-
**kwargs,
|
|
156
|
-
) -> Float[Array, " output_dim"]: ...
|
|
157
|
-
|
|
158
|
-
@_PyTree_to_Params
|
|
159
173
|
def __call__(
|
|
160
174
|
self,
|
|
161
175
|
inputs: Float[Array, " input_dim"],
|
|
@@ -175,13 +189,10 @@ class HyperPINN(PINN):
|
|
|
175
189
|
# DataGenerators)
|
|
176
190
|
inputs = inputs[None]
|
|
177
191
|
|
|
178
|
-
# try:
|
|
179
192
|
hyper = eqx.combine(params.nn_params, self.static_hyper)
|
|
180
|
-
# except (KeyError, AttributeError, TypeError) as e: # give more flexibility
|
|
181
|
-
# hyper = eqx.combine(params, self.static_hyper)
|
|
182
193
|
|
|
183
194
|
eq_params_batch = jnp.concatenate(
|
|
184
|
-
[params.eq_params
|
|
195
|
+
[getattr(params.eq_params, k).flatten() for k in self.hyperparams],
|
|
185
196
|
axis=0,
|
|
186
197
|
)
|
|
187
198
|
|
|
@@ -202,12 +213,13 @@ class HyperPINN(PINN):
|
|
|
202
213
|
@classmethod
|
|
203
214
|
def create(
|
|
204
215
|
cls,
|
|
216
|
+
*,
|
|
205
217
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
206
218
|
hyperparams: list[str],
|
|
207
219
|
hypernet_input_size: int,
|
|
220
|
+
key: PRNGKeyArray | None = None,
|
|
208
221
|
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
209
222
|
eqx_hyper_network: eqx.nn.MLP | MLP | None = None,
|
|
210
|
-
key: Key = None,
|
|
211
223
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
212
224
|
eqx_list_hyper: (
|
|
213
225
|
tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None
|
|
@@ -359,10 +371,10 @@ class HyperPINN(PINN):
|
|
|
359
371
|
|
|
360
372
|
### Now we finetune the hypernetwork architecture
|
|
361
373
|
|
|
362
|
-
|
|
374
|
+
subkey1, subkey2 = jax.random.split(key, 2)
|
|
363
375
|
# with warnings.catch_warnings():
|
|
364
376
|
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
365
|
-
eqx_network = MLP(key=
|
|
377
|
+
eqx_network = MLP(key=subkey1, eqx_list=eqx_list)
|
|
366
378
|
# quick partitioning to get the params to get the correct number of neurons
|
|
367
379
|
# for the last layer of hyper network
|
|
368
380
|
params_mlp, _ = eqx.partition(eqx_network, eqx.is_inexact_array)
|
|
@@ -405,10 +417,9 @@ class HyperPINN(PINN):
|
|
|
405
417
|
+ eqx_list_hyper[2:]
|
|
406
418
|
),
|
|
407
419
|
)
|
|
408
|
-
key, subkey = jax.random.split(key, 2)
|
|
409
420
|
# with warnings.catch_warnings():
|
|
410
421
|
# warnings.filterwarnings("ignore", message="A JAX array is being set as static!")
|
|
411
|
-
eqx_hyper_network = cast(MLP, MLP(key=
|
|
422
|
+
eqx_hyper_network = cast(MLP, MLP(key=subkey2, eqx_list=eqx_list_hyper))
|
|
412
423
|
|
|
413
424
|
### End of finetuning the hypernetwork architecture
|
|
414
425
|
|
jinns/nn/_mlp.py
CHANGED
|
@@ -9,7 +9,7 @@ from dataclasses import InitVar
|
|
|
9
9
|
import jax
|
|
10
10
|
import equinox as eqx
|
|
11
11
|
from typing import Protocol
|
|
12
|
-
from jaxtyping import Array,
|
|
12
|
+
from jaxtyping import Array, PRNGKeyArray, PyTree, Float
|
|
13
13
|
|
|
14
14
|
from jinns.parameters._params import Params
|
|
15
15
|
from jinns.nn._pinn import PINN
|
|
@@ -33,7 +33,7 @@ class MLP(eqx.Module):
|
|
|
33
33
|
|
|
34
34
|
Parameters
|
|
35
35
|
----------
|
|
36
|
-
key : InitVar[
|
|
36
|
+
key : InitVar[PRNGKeyArray]
|
|
37
37
|
A jax random key for the layer initializations.
|
|
38
38
|
eqx_list : InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]]
|
|
39
39
|
A tuple of tuples of successive equinox modules and activation functions to
|
|
@@ -52,7 +52,7 @@ class MLP(eqx.Module):
|
|
|
52
52
|
)`.
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
|
-
key: InitVar[
|
|
55
|
+
key: InitVar[PRNGKeyArray] = eqx.field(kw_only=True)
|
|
56
56
|
eqx_list: InitVar[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] = (
|
|
57
57
|
eqx.field(kw_only=True)
|
|
58
58
|
)
|
|
@@ -94,9 +94,10 @@ class PINN_MLP(PINN):
|
|
|
94
94
|
@classmethod
|
|
95
95
|
def create(
|
|
96
96
|
cls,
|
|
97
|
+
*,
|
|
97
98
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
99
|
+
key: PRNGKeyArray | None = None,
|
|
98
100
|
eqx_network: eqx.nn.MLP | MLP | None = None,
|
|
99
|
-
key: Key = None,
|
|
100
101
|
eqx_list: tuple[tuple[Callable, int, int] | tuple[Callable], ...] | None = None,
|
|
101
102
|
input_transform: (
|
|
102
103
|
Callable[
|
jinns/nn/_pinn.py
CHANGED
|
@@ -4,14 +4,13 @@ Implement abstract class for PINN architectures
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
from typing import Callable, Union, Any, Literal
|
|
7
|
+
from typing import Callable, Union, Any, Literal
|
|
8
8
|
from dataclasses import InitVar
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
from jaxtyping import Float, Array, PyTree
|
|
11
11
|
import jax.numpy as jnp
|
|
12
12
|
from jinns.parameters._params import Params
|
|
13
13
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
14
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class PINN(AbstractPINN):
|
|
@@ -157,17 +156,6 @@ class PINN(AbstractPINN):
|
|
|
157
156
|
|
|
158
157
|
return network(inputs)
|
|
159
158
|
|
|
160
|
-
@overload
|
|
161
|
-
@_PyTree_to_Params
|
|
162
|
-
def __call__(
|
|
163
|
-
self,
|
|
164
|
-
inputs: Float[Array, " input_dim"],
|
|
165
|
-
params: PyTree,
|
|
166
|
-
*args,
|
|
167
|
-
**kwargs,
|
|
168
|
-
) -> Float[Array, " output_dim"]: ...
|
|
169
|
-
|
|
170
|
-
@_PyTree_to_Params
|
|
171
159
|
def __call__(
|
|
172
160
|
self,
|
|
173
161
|
inputs: Float[Array, " input_dim"],
|
|
@@ -180,9 +168,6 @@ class PINN(AbstractPINN):
|
|
|
180
168
|
`params` and `self.static` to recreate the callable eqx.Module
|
|
181
169
|
architecture. The rest of the content of this function is dependent on
|
|
182
170
|
the network.
|
|
183
|
-
|
|
184
|
-
Note that that thanks to the decorator, params can also directly be the
|
|
185
|
-
PyTree (SPINN, PINN_MLP, ...) that we get out of eqx.combine
|
|
186
171
|
"""
|
|
187
172
|
|
|
188
173
|
if len(inputs.shape) == 0:
|
jinns/nn/_ppinn.py
CHANGED
|
@@ -4,18 +4,17 @@ Implements utility function to create PINNs
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
from typing import Callable, Literal, Self, cast
|
|
7
|
+
from typing import Callable, Literal, Self, cast
|
|
8
8
|
from dataclasses import InitVar
|
|
9
9
|
import jax
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import equinox as eqx
|
|
12
12
|
|
|
13
|
-
from jaxtyping import Array,
|
|
13
|
+
from jaxtyping import Array, Float, PRNGKeyArray
|
|
14
14
|
|
|
15
15
|
from jinns.parameters._params import Params
|
|
16
16
|
from jinns.nn._pinn import PINN
|
|
17
17
|
from jinns.nn._mlp import MLP
|
|
18
|
-
from jinns.nn._utils import _PyTree_to_Params
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
class PPINN_MLP(PINN):
|
|
@@ -85,17 +84,6 @@ class PPINN_MLP(PINN):
|
|
|
85
84
|
self.init_params = self.init_params + (params,)
|
|
86
85
|
self.static = self.static + (static,)
|
|
87
86
|
|
|
88
|
-
@overload
|
|
89
|
-
@_PyTree_to_Params
|
|
90
|
-
def __call__(
|
|
91
|
-
self,
|
|
92
|
-
inputs: Float[Array, " input_dim"],
|
|
93
|
-
params: PyTree,
|
|
94
|
-
*args,
|
|
95
|
-
**kwargs,
|
|
96
|
-
) -> Float[Array, " output_dim"]: ...
|
|
97
|
-
|
|
98
|
-
@_PyTree_to_Params
|
|
99
87
|
def __call__(
|
|
100
88
|
self,
|
|
101
89
|
inputs: Float[Array, " 1"] | Float[Array, " dim"] | Float[Array, " 1+dim"],
|
|
@@ -135,9 +123,10 @@ class PPINN_MLP(PINN):
|
|
|
135
123
|
@classmethod
|
|
136
124
|
def create(
|
|
137
125
|
cls,
|
|
126
|
+
*,
|
|
127
|
+
key: PRNGKeyArray | None = None,
|
|
138
128
|
eq_type: Literal["ODE", "statio_PDE", "nonstatio_PDE"],
|
|
139
129
|
eqx_network_list: list[eqx.nn.MLP | MLP] | None = None,
|
|
140
|
-
key: Key = None,
|
|
141
130
|
eqx_list_list: (
|
|
142
131
|
list[tuple[tuple[Callable, int, int] | tuple[Callable], ...]] | None
|
|
143
132
|
) = None,
|
|
@@ -225,7 +214,7 @@ class PPINN_MLP(PINN):
|
|
|
225
214
|
|
|
226
215
|
eqx_network_list = []
|
|
227
216
|
for eqx_list in eqx_list_list:
|
|
228
|
-
key, subkey = jax.random.split(key, 2)
|
|
217
|
+
key, subkey = jax.random.split(key, 2) # type: ignore
|
|
229
218
|
eqx_network_list.append(MLP(key=subkey, eqx_list=eqx_list))
|
|
230
219
|
|
|
231
220
|
ppinn = cls(
|