jinns 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,9 @@
1
+ """
2
+ Implements several dynamic losses
3
+ """
4
+
1
5
  import jax
2
- from jax import jit, grad, jacrev, jacfwd
6
+ from jax import grad, jacrev
3
7
  import jax.numpy as jnp
4
8
  from jinns.utils._utils import _get_grid
5
9
  from jinns.utils._pinn import PINN
@@ -51,7 +55,7 @@ class FisherKPP(PDENonStatio):
51
55
  super().__init__(Tmax, derivatives, eq_params_heterogeneity)
52
56
 
53
57
  def evaluate(self, t, x, u, params):
54
- """
58
+ r"""
55
59
  Evaluate the dynamic loss at :math:`(t,x)`.
56
60
 
57
61
  Parameters
@@ -86,7 +90,7 @@ class FisherKPP(PDENonStatio):
86
90
  - u(t, x, nn_params, eq_params)
87
91
  * (eq_params["r"] - eq_params["g"] * u(t, x, nn_params, eq_params))
88
92
  )
89
- elif isinstance(u, SPINN):
93
+ if isinstance(u, SPINN):
90
94
  nn_params, eq_params = self.set_stop_gradient(params)
91
95
  x_grid = _get_grid(x)
92
96
  eq_params = self._eval_heterogeneous_parameters(
@@ -103,6 +107,7 @@ class FisherKPP(PDENonStatio):
103
107
  -eq_params["D"] * lap
104
108
  - u_tx * (eq_params["r"][..., None] - eq_params["g"] * u_tx)
105
109
  )
110
+ raise ValueError("u is not among the recognized types (PINN or SPINN)")
106
111
 
107
112
 
108
113
  class Malthus(ODE):
@@ -206,7 +211,7 @@ class BurgerEquation(PDENonStatio):
206
211
  super().__init__(Tmax, derivatives, eq_params_heterogeneity)
207
212
 
208
213
  def evaluate(self, t, x, u, params):
209
- """
214
+ r"""
210
215
  Evaluate the dynamic loss at :math:`(t,x)`.
211
216
 
212
217
  Parameters
@@ -243,7 +248,7 @@ class BurgerEquation(PDENonStatio):
243
248
  - eq_params["nu"] * d2u_dx2(t, x)
244
249
  )
245
250
 
246
- elif isinstance(u, SPINN):
251
+ if isinstance(u, SPINN):
247
252
  nn_params, eq_params = self.set_stop_gradient(params)
248
253
  x_grid = _get_grid(x)
249
254
  eq_params = self._eval_heterogeneous_parameters(
@@ -264,6 +269,7 @@ class BurgerEquation(PDENonStatio):
264
269
  du_dx, d2u_dx2 = jax.jvp(du_dx_fun, (x,), (jnp.ones_like(x),))
265
270
  # Note that ones_like(x) works because x is Bx1 !
266
271
  return du_dt + self.Tmax * (u_tx * du_dx - eq_params["nu"] * d2u_dx2)
272
+ raise ValueError("u is not among the recognized types (PINN or SPINN)")
267
273
 
268
274
 
269
275
  class GeneralizedLotkaVolterra(ODE):
@@ -416,7 +422,7 @@ class FPENonStatioLoss2D(PDENonStatio):
416
422
  super().__init__(Tmax, derivatives, eq_params_heterogeneity)
417
423
 
418
424
  def evaluate(self, t, x, u, params):
419
- """
425
+ r"""
420
426
  Evaluate the dynamic loss at :math:`(t,\mathbf{x})`.
421
427
 
422
428
  Parameters
@@ -492,7 +498,7 @@ class FPENonStatioLoss2D(PDENonStatio):
492
498
 
493
499
  return -du_dt + self.Tmax * (-order_1 + order_2)
494
500
 
495
- elif isinstance(u, SPINN):
501
+ if isinstance(u, SPINN):
496
502
  nn_params, eq_params = self.set_stop_gradient(params)
497
503
  x_grid = _get_grid(x)
498
504
  eq_params = self._eval_heterogeneous_parameters(
@@ -556,6 +562,7 @@ class FPENonStatioLoss2D(PDENonStatio):
556
562
  -(dau_dx1 + dau_dx2)
557
563
  + (d2su_dx12 + d2su_dx22 + d2su_dx1dx2 + d2su_dx2dx1)
558
564
  )
565
+ raise ValueError("u is not among the recognized types (PINN or SPINN)")
559
566
 
560
567
 
561
568
  class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
@@ -650,13 +657,12 @@ class OU_FPENonStatioLoss2D(FPENonStatioLoss2D):
650
657
  jnp.transpose(self.sigma_mat(t, x, eq_params)),
651
658
  )
652
659
  )
653
- else:
654
- return 0.5 * (
655
- jnp.matmul(
656
- self.sigma_mat(t, x, eq_params),
657
- jnp.transpose(self.sigma_mat(t, x, eq_params)),
658
- )[i, j]
659
- )
660
+ return 0.5 * (
661
+ jnp.matmul(
662
+ self.sigma_mat(t, x, eq_params),
663
+ jnp.transpose(self.sigma_mat(t, x, eq_params)),
664
+ )[i, j]
665
+ )
660
666
 
661
667
 
662
668
  class ConvectionDiffusionNonStatio(FPENonStatioLoss2D):
@@ -773,7 +779,7 @@ class MassConservation2DStatio(PDEStatio):
773
779
  super().__init__(derivatives, eq_params_heterogeneity)
774
780
 
775
781
  def evaluate(self, x, u_dict, params_dict):
776
- """
782
+ r"""
777
783
  Evaluate the dynamic loss at `\mathbf{x}`.
778
784
  For stability we implement the dynamic loss in log space.
779
785
 
@@ -794,21 +800,20 @@ class MassConservation2DStatio(PDEStatio):
794
800
  nn_params, eq_params = self.set_stop_gradient(params_dict)
795
801
 
796
802
  nn_params = nn_params[self.nn_key]
797
- eq_params = eq_params
798
803
 
799
804
  u = u_dict[self.nn_key]
800
805
 
801
806
  return _div_rev(u, nn_params, eq_params, x)[..., None]
802
807
 
803
- elif isinstance(u_dict[self.nn_key], SPINN):
808
+ if isinstance(u_dict[self.nn_key], SPINN):
804
809
  nn_params, eq_params = self.set_stop_gradient(params_dict)
805
810
 
806
811
  nn_params = nn_params[self.nn_key]
807
- eq_params = eq_params
808
812
 
809
813
  u = u_dict[self.nn_key]
810
814
 
811
815
  return _div_fwd(u, nn_params, eq_params, x)[..., None]
816
+ raise ValueError("u is not among the recognized types (PINN or SPINN)")
812
817
 
813
818
 
814
819
  class NavierStokes2DStatio(PDEStatio):
@@ -843,7 +848,7 @@ class NavierStokes2DStatio(PDEStatio):
843
848
  def __init__(
844
849
  self, u_key, p_key, derivatives="nn_params", eq_params_heterogeneity=None
845
850
  ):
846
- """
851
+ r"""
847
852
  Parameters
848
853
  ----------
849
854
  u_key
@@ -899,7 +904,6 @@ class NavierStokes2DStatio(PDEStatio):
899
904
 
900
905
  u_nn_params = nn_params[self.u_key]
901
906
  p_nn_params = nn_params[self.p_key]
902
- eq_params = eq_params
903
907
 
904
908
  u = u_dict[self.u_key]
905
909
 
@@ -929,12 +933,11 @@ class NavierStokes2DStatio(PDEStatio):
929
933
  # output is 2D
930
934
  return jnp.stack([result_x, result_y], axis=-1)
931
935
 
932
- elif isinstance(u_dict[self.u_key], SPINN):
936
+ if isinstance(u_dict[self.u_key], SPINN):
933
937
  nn_params, eq_params = self.set_stop_gradient(params_dict)
934
938
 
935
939
  u_nn_params = nn_params[self.u_key]
936
940
  p_nn_params = nn_params[self.p_key]
937
- eq_params = eq_params
938
941
 
939
942
  u = u_dict[self.u_key]
940
943
 
@@ -968,3 +971,4 @@ class NavierStokes2DStatio(PDEStatio):
968
971
 
969
972
  # output is 2D
970
973
  return jnp.stack([result_x, result_y], axis=-1)
974
+ raise ValueError("u is not among the recognized types (PINN or SPINN)")
@@ -1,6 +1,8 @@
1
+ """
2
+ Implements abstract classes for dynamic losses
3
+ """
4
+
1
5
  import jax
2
- from jax import jit, grad
3
- import jax.numpy as jnp
4
6
 
5
7
 
6
8
  class DynamicLoss:
@@ -29,7 +31,13 @@ class DynamicLoss:
29
31
  equation solution with as PINN.
30
32
  eq_params_heterogeneity
31
33
  Default None. A dict with the keys being the same as in eq_params
32
- and the value being either None (no heterogeneity) or a function which encodes for the spatio-temporal heterogeneity of the parameter. Such a function must be jittable and take three arguments `t`, `x` and `params["eq_params"]` even if one is not used. Therefore, one can introduce spatio-temporal covariates upon which a particular parameter can depend, e.g. in a GLM fashion. The effect of these covariables can themselves be estimated by being in `eq_params` too.
34
+ and the value being either None (no heterogeneity) or a function
35
+ which encodes for the spatio-temporal heterogeneity of the parameter.
36
+ Such a function must be jittable and take three arguments `t`,
37
+ `x` and `params["eq_params"]` even if one is not used. Therefore,
38
+ one can introduce spatio-temporal covariates upon which a particular
39
+ parameter can depend, e.g. in a GLM fashion. The effect of these
40
+ covariables can themselves be estimated by being in `eq_params` too.
33
41
  A value can be missing, in this case there is no heterogeneity (=None).
34
42
  If eq_params_heterogeneity is None this means there is no
35
43
  heterogeneity for no parameters.
@@ -79,15 +87,14 @@ class DynamicLoss:
79
87
 
80
88
  if self.derivatives == "nn_params":
81
89
  return (nn_params, jax.lax.stop_gradient(eq_params))
82
- elif self.derivatives == "eq_params":
90
+ if self.derivatives == "eq_params":
83
91
  return (jax.lax.stop_gradient(nn_params), eq_params)
84
- elif self.derivatives == "both":
92
+ if self.derivatives == "both":
85
93
  return (nn_params, eq_params)
86
- else:
87
- return (
88
- jax.lax.stop_gradient(nn_params),
89
- jax.lax.stop_gradient(eq_params),
90
- )
94
+ return (
95
+ jax.lax.stop_gradient(nn_params),
96
+ jax.lax.stop_gradient(eq_params),
97
+ )
91
98
 
92
99
 
93
100
  class ODE(DynamicLoss):
jinns/loss/_LossODE.py CHANGED
@@ -1,8 +1,11 @@
1
+ """
2
+ Main module to implement a ODE loss in jinns
3
+ """
4
+
1
5
  import jax
2
6
  import jax.numpy as jnp
3
7
  from jax import vmap
4
8
  from jax.tree_util import register_pytree_node_class
5
- from jinns.data._DataGenerators import ODEBatch
6
9
  from jinns.utils._utils import _get_vmap_in_axes_params
7
10
 
8
11
 
@@ -29,7 +32,6 @@ class LossODE:
29
32
  initial_condition=None,
30
33
  obs_batch=None,
31
34
  obs_slice=None,
32
- nn_output_weights=None,
33
35
  ):
34
36
  r"""
35
37
  Parameters
@@ -62,9 +64,6 @@ class LossODE:
62
64
  slice object specifying the begininning/ending
63
65
  slice of u output(s) that is observed (this is then useful for
64
66
  multidim PINN). Default is None.
65
- nn_output_weights:
66
- Give different weights to the output of the neural network
67
- Default is None, ie all nn outputs are
68
67
 
69
68
  Raises
70
69
  ------
@@ -105,11 +104,6 @@ class LossODE:
105
104
  batch
106
105
  A batch of time points at which to evaluate the loss
107
106
  """
108
- if isinstance(params, tuple):
109
- params_ = params[0]
110
- else:
111
- params_ = params
112
-
113
107
  temporal_batch = batch.temporal_batch
114
108
 
115
109
  vmap_in_axes_t = (0,)
@@ -138,7 +132,7 @@ class LossODE:
138
132
  * jnp.mean(v_dyn_loss(temporal_batch, params) ** 2, axis=0)
139
133
  )
140
134
  else:
141
- mse_dyn_loss = 0
135
+ mse_dyn_loss = jnp.array(0.0)
142
136
 
143
137
  # initial condition
144
138
  if self.initial_condition is not None:
@@ -158,7 +152,7 @@ class LossODE:
158
152
  ** 2
159
153
  )
160
154
  else:
161
- mse_initial_condition = 0
155
+ mse_initial_condition = jnp.array(0.0)
162
156
 
163
157
  # MSE loss wrt to an observed batch
164
158
  if self.obs_batch is not None:
@@ -176,7 +170,7 @@ class LossODE:
176
170
  )
177
171
  )
178
172
  else:
179
- mse_observation_loss = 0
173
+ mse_observation_loss = jnp.array(0.0)
180
174
 
181
175
  # total loss
182
176
  total_loss = mse_dyn_loss + mse_initial_condition + mse_observation_loss
@@ -408,7 +402,7 @@ class SystemLossODE:
408
402
  for i in self.dynamic_loss_dict.keys():
409
403
  # dynamic part
410
404
  v_dyn_loss = vmap(
411
- lambda t, params_dict: self.dynamic_loss_dict[i].evaluate(
405
+ lambda t, params_dict, key=i: self.dynamic_loss_dict[key].evaluate(
412
406
  t, self.u_dict, params_dict
413
407
  ),
414
408
  vmap_in_axes_t + vmap_in_axes_params,