jinns 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/loss/_DynamicLoss.py +26 -22
- jinns/loss/_DynamicLossAbstract.py +17 -10
- jinns/loss/_LossODE.py +8 -14
- jinns/loss/_LossPDE.py +64 -64
- jinns/loss/_boundary_conditions.py +17 -17
- jinns/loss/_operators.py +26 -29
- jinns/solver/_solve.py +10 -5
- jinns/utils/__init__.py +0 -1
- jinns/utils/_pinn.py +21 -31
- jinns/utils/_spinn.py +9 -8
- jinns/utils/_utils.py +31 -142
- {jinns-0.5.0.dist-info → jinns-0.5.1.dist-info}/METADATA +2 -4
- jinns-0.5.1.dist-info/RECORD +24 -0
- jinns-0.5.0.dist-info/RECORD +0 -24
- {jinns-0.5.0.dist-info → jinns-0.5.1.dist-info}/LICENSE +0 -0
- {jinns-0.5.0.dist-info → jinns-0.5.1.dist-info}/WHEEL +0 -0
- {jinns-0.5.0.dist-info → jinns-0.5.1.dist-info}/top_level.txt +0 -0
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements several dynamic losses
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import jax
|
|
2
|
-
from jax import
|
|
6
|
+
from jax import grad, jacrev
|
|
3
7
|
import jax.numpy as jnp
|
|
4
8
|
from jinns.utils._utils import _get_grid
|
|
5
9
|
from jinns.utils._pinn import PINN
|
|
@@ -51,7 +55,7 @@ class FisherKPP(PDENonStatio):
|
|
|
51
55
|
super().__init__(Tmax, derivatives, eq_params_heterogeneity)
|
|
52
56
|
|
|
53
57
|
def evaluate(self, t, x, u, params):
|
|
54
|
-
"""
|
|
58
|
+
r"""
|
|
55
59
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
56
60
|
|
|
57
61
|
Parameters
|
|
@@ -86,7 +90,7 @@ class FisherKPP(PDENonStatio):
|
|
|
86
90
|
- u(t, x, nn_params, eq_params)
|
|
87
91
|
* (eq_params["r"] - eq_params["g"] * u(t, x, nn_params, eq_params))
|
|
88
92
|
)
|
|
89
|
-
|
|
93
|
+
if isinstance(u, SPINN):
|
|
90
94
|
nn_params, eq_params = self.set_stop_gradient(params)
|
|
91
95
|
x_grid = _get_grid(x)
|
|
92
96
|
eq_params = self._eval_heterogeneous_parameters(
|
|
@@ -103,6 +107,7 @@ class FisherKPP(PDENonStatio):
|
|
|
103
107
|
-eq_params["D"] * lap
|
|
104
108
|
- u_tx * (eq_params["r"][..., None] - eq_params["g"] * u_tx)
|
|
105
109
|
)
|
|
110
|
+
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
106
111
|
|
|
107
112
|
|
|
108
113
|
class Malthus(ODE):
|
|
@@ -206,7 +211,7 @@ class BurgerEquation(PDENonStatio):
|
|
|
206
211
|
super().__init__(Tmax, derivatives, eq_params_heterogeneity)
|
|
207
212
|
|
|
208
213
|
def evaluate(self, t, x, u, params):
|
|
209
|
-
"""
|
|
214
|
+
r"""
|
|
210
215
|
Evaluate the dynamic loss at :math:`(t,x)`.
|
|
211
216
|
|
|
212
217
|
Parameters
|
|
@@ -243,7 +248,7 @@ class BurgerEquation(PDENonStatio):
|
|
|
243
248
|
- eq_params["nu"] * d2u_dx2(t, x)
|
|
244
249
|
)
|
|
245
250
|
|
|
246
|
-
|
|
251
|
+
if isinstance(u, SPINN):
|
|
247
252
|
nn_params, eq_params = self.set_stop_gradient(params)
|
|
248
253
|
x_grid = _get_grid(x)
|
|
249
254
|
eq_params = self._eval_heterogeneous_parameters(
|
|
@@ -264,6 +269,7 @@ class BurgerEquation(PDENonStatio):
|
|
|
264
269
|
du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
|
|
265
270
|
# Note that ones_like(x) works because x is Bx1 !
|
|
266
271
|
return du_dt + self.Tmax * (u_tx * du_dx - eq_params["nu"] * d2u_dx2)
|
|
272
|
+
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
267
273
|
|
|
268
274
|
|
|
269
275
|
class GeneralizedLotkaVolterra(ODE):
|
|
@@ -416,7 +422,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
416
422
|
super().__init__(Tmax, derivatives, eq_params_heterogeneity)
|
|
417
423
|
|
|
418
424
|
def evaluate(self, t, x, u, params):
|
|
419
|
-
"""
|
|
425
|
+
r"""
|
|
420
426
|
Evaluate the dynamic loss at :math:`(t,\mathbf{x})`.
|
|
421
427
|
|
|
422
428
|
Parameters
|
|
@@ -492,7 +498,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
492
498
|
|
|
493
499
|
return -du_dt + self.Tmax * (-order_1 + order_2)
|
|
494
500
|
|
|
495
|
-
|
|
501
|
+
if isinstance(u, SPINN):
|
|
496
502
|
nn_params, eq_params = self.set_stop_gradient(params)
|
|
497
503
|
x_grid = _get_grid(x)
|
|
498
504
|
eq_params = self._eval_heterogeneous_parameters(
|
|
@@ -556,6 +562,7 @@ class FPENonStatioLoss2D(PDENonStatio):
|
|
|
556
562
|
-(dau_dx1 + dau_dx2)
|
|
557
563
|
+ (d2su_dx12 + d2su_dx22 + d2su_dx1dx2 + d2su_dx2dx1)
|
|
558
564
|
)
|
|
565
|
+
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
559
566
|
|
|
560
567
|
|
|
561
568
|
class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
@@ -650,13 +657,12 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
|
|
|
650
657
|
jnp.transpose(self.sigma_mat(t, x, eq_params)),
|
|
651
658
|
)
|
|
652
659
|
)
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
)
|
|
660
|
+
return 0.5 * (
|
|
661
|
+
jnp.matmul(
|
|
662
|
+
self.sigma_mat(t, x, eq_params),
|
|
663
|
+
jnp.transpose(self.sigma_mat(t, x, eq_params)),
|
|
664
|
+
)[i, j]
|
|
665
|
+
)
|
|
660
666
|
|
|
661
667
|
|
|
662
668
|
class ConvectionDiffusionNonStatio(FPENonStatioLoss2D):
|
|
@@ -773,7 +779,7 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
773
779
|
super().__init__(derivatives, eq_params_heterogeneity)
|
|
774
780
|
|
|
775
781
|
def evaluate(self, x, u_dict, params_dict):
|
|
776
|
-
"""
|
|
782
|
+
r"""
|
|
777
783
|
Evaluate the dynamic loss at `\mathbf{x}`.
|
|
778
784
|
For stability we implement the dynamic loss in log space.
|
|
779
785
|
|
|
@@ -794,21 +800,20 @@ class MassConservation2DStatio(PDEStatio):
|
|
|
794
800
|
nn_params, eq_params = self.set_stop_gradient(params_dict)
|
|
795
801
|
|
|
796
802
|
nn_params = nn_params[self.nn_key]
|
|
797
|
-
eq_params = eq_params
|
|
798
803
|
|
|
799
804
|
u = u_dict[self.nn_key]
|
|
800
805
|
|
|
801
806
|
return _div_rev(u, nn_params, eq_params, x)[..., None]
|
|
802
807
|
|
|
803
|
-
|
|
808
|
+
if isinstance(u_dict[self.nn_key], SPINN):
|
|
804
809
|
nn_params, eq_params = self.set_stop_gradient(params_dict)
|
|
805
810
|
|
|
806
811
|
nn_params = nn_params[self.nn_key]
|
|
807
|
-
eq_params = eq_params
|
|
808
812
|
|
|
809
813
|
u = u_dict[self.nn_key]
|
|
810
814
|
|
|
811
815
|
return _div_fwd(u, nn_params, eq_params, x)[..., None]
|
|
816
|
+
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
812
817
|
|
|
813
818
|
|
|
814
819
|
class NavierStokes2DStatio(PDEStatio):
|
|
@@ -843,7 +848,7 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
843
848
|
def __init__(
|
|
844
849
|
self, u_key, p_key, derivatives="nn_params", eq_params_heterogeneity=None
|
|
845
850
|
):
|
|
846
|
-
"""
|
|
851
|
+
r"""
|
|
847
852
|
Parameters
|
|
848
853
|
----------
|
|
849
854
|
u_key
|
|
@@ -899,7 +904,6 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
899
904
|
|
|
900
905
|
u_nn_params = nn_params[self.u_key]
|
|
901
906
|
p_nn_params = nn_params[self.p_key]
|
|
902
|
-
eq_params = eq_params
|
|
903
907
|
|
|
904
908
|
u = u_dict[self.u_key]
|
|
905
909
|
|
|
@@ -929,12 +933,11 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
929
933
|
# output is 2D
|
|
930
934
|
return jnp.stack([result_x, result_y], axis=-1)
|
|
931
935
|
|
|
932
|
-
|
|
936
|
+
if isinstance(u_dict[self.u_key], SPINN):
|
|
933
937
|
nn_params, eq_params = self.set_stop_gradient(params_dict)
|
|
934
938
|
|
|
935
939
|
u_nn_params = nn_params[self.u_key]
|
|
936
940
|
p_nn_params = nn_params[self.p_key]
|
|
937
|
-
eq_params = eq_params
|
|
938
941
|
|
|
939
942
|
u = u_dict[self.u_key]
|
|
940
943
|
|
|
@@ -968,3 +971,4 @@ class NavierStokes2DStatio(PDEStatio):
|
|
|
968
971
|
|
|
969
972
|
# output is 2D
|
|
970
973
|
return jnp.stack([result_x, result_y], axis=-1)
|
|
974
|
+
raise ValueError("u is not among the recognized types (PINN or SPINN)")
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements abstract classes for dynamic losses
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import jax
|
|
2
|
-
from jax import jit, grad
|
|
3
|
-
import jax.numpy as jnp
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class DynamicLoss:
|
|
@@ -29,7 +31,13 @@ class DynamicLoss:
|
|
|
29
31
|
equation solution with as PINN.
|
|
30
32
|
eq_params_heterogeneity
|
|
31
33
|
Default None. A dict with the keys being the same as in eq_params
|
|
32
|
-
and the value being either None (no heterogeneity) or a function
|
|
34
|
+
and the value being either None (no heterogeneity) or a function
|
|
35
|
+
which encodes for the spatio-temporal heterogeneity of the parameter.
|
|
36
|
+
Such a function must be jittable and take three arguments `t`,
|
|
37
|
+
`x` and `params["eq_params"]` even if one is not used. Therefore,
|
|
38
|
+
one can introduce spatio-temporal covariates upon which a particular
|
|
39
|
+
parameter can depend, e.g. in a GLM fashion. The effect of these
|
|
40
|
+
covariables can themselves be estimated by being in `eq_params` too.
|
|
33
41
|
A value can be missing, in this case there is no heterogeneity (=None).
|
|
34
42
|
If eq_params_heterogeneity is None this means there is no
|
|
35
43
|
heterogeneity for no parameters.
|
|
@@ -79,15 +87,14 @@ class DynamicLoss:
|
|
|
79
87
|
|
|
80
88
|
if self.derivatives == "nn_params":
|
|
81
89
|
return (nn_params, jax.lax.stop_gradient(eq_params))
|
|
82
|
-
|
|
90
|
+
if self.derivatives == "eq_params":
|
|
83
91
|
return (jax.lax.stop_gradient(nn_params), eq_params)
|
|
84
|
-
|
|
92
|
+
if self.derivatives == "both":
|
|
85
93
|
return (nn_params, eq_params)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
)
|
|
94
|
+
return (
|
|
95
|
+
jax.lax.stop_gradient(nn_params),
|
|
96
|
+
jax.lax.stop_gradient(eq_params),
|
|
97
|
+
)
|
|
91
98
|
|
|
92
99
|
|
|
93
100
|
class ODE(DynamicLoss):
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main module to implement a ODE loss in jinns
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import jax
|
|
2
6
|
import jax.numpy as jnp
|
|
3
7
|
from jax import vmap
|
|
4
8
|
from jax.tree_util import register_pytree_node_class
|
|
5
|
-
from jinns.data._DataGenerators import ODEBatch
|
|
6
9
|
from jinns.utils._utils import _get_vmap_in_axes_params
|
|
7
10
|
|
|
8
11
|
|
|
@@ -29,7 +32,6 @@ class LossODE:
|
|
|
29
32
|
initial_condition=None,
|
|
30
33
|
obs_batch=None,
|
|
31
34
|
obs_slice=None,
|
|
32
|
-
nn_output_weights=None,
|
|
33
35
|
):
|
|
34
36
|
r"""
|
|
35
37
|
Parameters
|
|
@@ -62,9 +64,6 @@ class LossODE:
|
|
|
62
64
|
slice object specifying the begininning/ending
|
|
63
65
|
slice of u output(s) that is observed (this is then useful for
|
|
64
66
|
multidim PINN). Default is None.
|
|
65
|
-
nn_output_weights:
|
|
66
|
-
Give different weights to the output of the neural network
|
|
67
|
-
Default is None, ie all nn outputs are
|
|
68
67
|
|
|
69
68
|
Raises
|
|
70
69
|
------
|
|
@@ -105,11 +104,6 @@ class LossODE:
|
|
|
105
104
|
batch
|
|
106
105
|
A batch of time points at which to evaluate the loss
|
|
107
106
|
"""
|
|
108
|
-
if isinstance(params, tuple):
|
|
109
|
-
params_ = params[0]
|
|
110
|
-
else:
|
|
111
|
-
params_ = params
|
|
112
|
-
|
|
113
107
|
temporal_batch = batch.temporal_batch
|
|
114
108
|
|
|
115
109
|
vmap_in_axes_t = (0,)
|
|
@@ -138,7 +132,7 @@ class LossODE:
|
|
|
138
132
|
* jnp.mean(v_dyn_loss(temporal_batch, params) ** 2, axis=0)
|
|
139
133
|
)
|
|
140
134
|
else:
|
|
141
|
-
mse_dyn_loss = 0
|
|
135
|
+
mse_dyn_loss = jnp.array(0.0)
|
|
142
136
|
|
|
143
137
|
# initial condition
|
|
144
138
|
if self.initial_condition is not None:
|
|
@@ -158,7 +152,7 @@ class LossODE:
|
|
|
158
152
|
** 2
|
|
159
153
|
)
|
|
160
154
|
else:
|
|
161
|
-
mse_initial_condition = 0
|
|
155
|
+
mse_initial_condition = jnp.array(0.0)
|
|
162
156
|
|
|
163
157
|
# MSE loss wrt to an observed batch
|
|
164
158
|
if self.obs_batch is not None:
|
|
@@ -176,7 +170,7 @@ class LossODE:
|
|
|
176
170
|
)
|
|
177
171
|
)
|
|
178
172
|
else:
|
|
179
|
-
mse_observation_loss = 0
|
|
173
|
+
mse_observation_loss = jnp.array(0.0)
|
|
180
174
|
|
|
181
175
|
# total loss
|
|
182
176
|
total_loss = mse_dyn_loss + mse_initial_condition + mse_observation_loss
|
|
@@ -408,7 +402,7 @@ class SystemLossODE:
|
|
|
408
402
|
for i in self.dynamic_loss_dict.keys():
|
|
409
403
|
# dynamic part
|
|
410
404
|
v_dyn_loss = vmap(
|
|
411
|
-
lambda t, params_dict: self.dynamic_loss_dict[
|
|
405
|
+
lambda t, params_dict, key=i: self.dynamic_loss_dict[key].evaluate(
|
|
412
406
|
t, self.u_dict, params_dict
|
|
413
407
|
),
|
|
414
408
|
vmap_in_axes_t + vmap_in_axes_params,
|