jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py CHANGED
@@ -1,21 +1,18 @@
1
1
  from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
2
- from ._LossODE import LossODE, SystemLossODE
3
- from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
2
+ from ._LossODE import LossODE
3
+ from ._LossPDE import LossPDEStatio, LossPDENonStatio
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
6
  BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
10
- MassConservation2DStatio,
11
- NavierStokes2DStatio,
10
+ NavierStokesMassConservation2DStatio,
12
11
  )
13
12
  from ._loss_weights import (
14
13
  LossWeightsODE,
15
- LossWeightsODEDict,
16
14
  LossWeightsPDENonStatio,
17
15
  LossWeightsPDEStatio,
18
- LossWeightsPDEDict,
19
16
  )
20
17
 
21
18
  from ._operators import (
@@ -26,3 +23,28 @@ from ._operators import (
26
23
  vectorial_laplacian_fwd,
27
24
  vectorial_laplacian_rev,
28
25
  )
26
+
27
+ __all__ = [
28
+ "DynamicLoss",
29
+ "ODE",
30
+ "PDEStatio",
31
+ "PDENonStatio",
32
+ "LossODE",
33
+ "LossPDEStatio",
34
+ "LossPDENonStatio",
35
+ "GeneralizedLotkaVolterra",
36
+ "BurgersEquation",
37
+ "FPENonStatioLoss2D",
38
+ "OU_FPENonStatioLoss2D",
39
+ "FisherKPP",
40
+ "NavierStokesMassConservation2DStatio",
41
+ "LossWeightsODE",
42
+ "LossWeightsPDEStatio",
43
+ "LossWeightsPDENonStatio",
44
+ "divergence_fwd",
45
+ "divergence_rev",
46
+ "laplacian_fwd",
47
+ "laplacian_rev",
48
+ "vectorial_laplacian_fwd",
49
+ "vectorial_laplacian_rev",
50
+ ]
@@ -0,0 +1,15 @@
1
+ import abc
2
+ from jaxtyping import Array
3
+ import equinox as eqx
4
+
5
+
6
+ class AbstractLoss(eqx.Module):
7
+ """
8
+ Basically just a way to add a __call__ to an eqx.Module.
9
+ The way to go for correct type hints apparently
10
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def __call__(self, *_, **__) -> Array:
15
+ pass
@@ -7,31 +7,31 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from typing import TYPE_CHECKING, Callable
10
+ from jaxtyping import Array, Float
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap, grad
13
- import equinox as eqx
14
14
  from jinns.utils._utils import get_grid, _subtract_with_check
15
- from jinns.data._Batchs import *
16
- from jinns.utils._pinn import PINN
17
- from jinns.utils._spinn import SPINN
15
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
16
+ from jinns.nn._pinn import PINN
17
+ from jinns.nn._spinn import SPINN
18
18
 
19
19
  if TYPE_CHECKING:
20
- from jinns.utils._types import *
20
+ from jinns.parameters._params import Params
21
+ from jinns.utils._types import BoundaryConditionFun
22
+ from jinns.nn._abstract_pinn import AbstractPINN
21
23
 
22
24
 
23
25
  def _compute_boundary_loss(
24
26
  boundary_condition_type: str,
25
- f: Callable[
26
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
27
- ],
27
+ f: BoundaryConditionFun,
28
28
  batch: PDEStatioBatch | PDENonStatioBatch,
29
- u: eqx.Module,
30
- params: AnyParams,
29
+ u: AbstractPINN,
30
+ params: Params[Array],
31
31
  facet: int,
32
32
  dim_to_apply: slice,
33
33
  vmap_in_axes: tuple,
34
- ) -> float:
34
+ ) -> Float[Array, " "]:
35
35
  r"""A generic function that will compute the mini-batch MSE of a
36
36
  boundary condition in the stationary case, resp. non-stationary, given by:
37
37
 
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
67
67
  u
68
68
  a PINN
69
69
  params
70
- Params or ParamsDict
70
+ Params
71
71
  facet
72
72
  An integer which represents the id of the facet which is currently
73
73
  considered (in the order provided by the DataGenerator which is fixed)
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
96
96
 
97
97
  def boundary_dirichlet(
98
98
  f: Callable[
99
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
99
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
100
100
  ],
101
101
  batch: PDEStatioBatch | PDENonStatioBatch,
102
- u: eqx.Module,
103
- params: Params | ParamsDict,
102
+ u: AbstractPINN,
103
+ params: Params[Array],
104
104
  facet: int,
105
105
  dim_to_apply: slice,
106
106
  vmap_in_axes: tuple,
107
- ) -> float:
107
+ ) -> Float[Array, " "]:
108
108
  r"""
109
109
  This omega boundary condition enforces a solution that is equal to `f`
110
110
  at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
@@ -135,6 +135,7 @@ def boundary_dirichlet(
135
135
  vmap_in_axes
136
136
  A tuple object which specifies the in_axes of the vmapping
137
137
  """
138
+ assert batch.border_batch is not None
138
139
  batch_array = batch.border_batch
139
140
  batch_array = batch_array[..., facet]
140
141
 
@@ -168,15 +169,15 @@ def boundary_dirichlet(
168
169
 
169
170
  def boundary_neumann(
170
171
  f: Callable[
171
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
172
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
172
173
  ],
173
174
  batch: PDEStatioBatch | PDENonStatioBatch,
174
- u: eqx.Module,
175
- params: Params | ParamsDict,
175
+ u: AbstractPINN,
176
+ params: Params[Array],
176
177
  facet: int,
177
178
  dim_to_apply: slice,
178
179
  vmap_in_axes: tuple,
179
- ) -> float:
180
+ ) -> Float[Array, " "]:
180
181
  r"""
181
182
  This omega boundary condition enforces a solution where $\nabla u\cdot
182
183
  n$ is equal to `f` at the cartesian product of `time_batch` x `omega
@@ -208,6 +209,7 @@ def boundary_neumann(
208
209
  vmap_in_axes
209
210
  A tuple object which specifies the in_axes of the vmapping
210
211
  """
212
+ assert batch.border_batch is not None
211
213
  batch_array = batch.border_batch
212
214
  batch_array = batch_array[..., facet]
213
215
 
@@ -223,7 +225,6 @@ def boundary_neumann(
223
225
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
224
226
 
225
227
  if isinstance(u, PINN):
226
-
227
228
  u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
228
229
 
229
230
  if u.eq_type == "statio_PDE":
jinns/loss/_loss_utils.py CHANGED
@@ -6,50 +6,53 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Callable, Dict
9
+ from typing import TYPE_CHECKING, Callable, TypeGuard
10
+ from types import EllipsisType
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap
13
- import equinox as eqx
14
- from jaxtyping import Float, Array, PyTree
14
+ from jaxtyping import Float, Array
15
15
 
16
16
  from jinns.loss._boundary_conditions import (
17
17
  _compute_boundary_loss,
18
18
  )
19
19
  from jinns.utils._utils import _subtract_with_check, get_grid
20
- from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
20
+ from jinns.data._utils import make_cartesian_product
21
21
  from jinns.parameters._params import _get_vmap_in_axes_params
22
- from jinns.utils._pinn import PINN
23
- from jinns.utils._spinn import SPINN
24
- from jinns.utils._hyperpinn import HYPERPINN
25
- from jinns.data._Batchs import *
26
- from jinns.parameters._params import Params, ParamsDict
22
+ from jinns.nn._pinn import PINN
23
+ from jinns.nn._spinn import SPINN
24
+ from jinns.nn._hyperpinn import HyperPINN
25
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
26
+ from jinns.parameters._params import Params
27
27
 
28
28
  if TYPE_CHECKING:
29
- from jinns.utils._types import *
29
+ from jinns.utils._types import BoundaryConditionFun
30
+ from jinns.nn._abstract_pinn import AbstractPINN
30
31
 
31
32
 
32
33
  def dynamic_loss_apply(
33
- dyn_loss: DynamicLoss,
34
- u: eqx.Module,
34
+ dyn_loss: Callable,
35
+ u: AbstractPINN,
35
36
  batch: (
36
- Float[Array, "batch_size 1"]
37
- | Float[Array, "batch_size dim"]
38
- | Float[Array, "batch_size 1+dim"]
37
+ Float[Array, " batch_size 1"]
38
+ | Float[Array, " batch_size dim"]
39
+ | Float[Array, " batch_size 1+dim"]
39
40
  ),
40
- params: Params | ParamsDict,
41
- vmap_axes: tuple[int | None, ...],
42
- loss_weight: float | Float[Array, "dyn_loss_dimension"],
43
- u_type: PINN | HYPERPINN | None = None,
44
- ) -> float:
41
+ params: Params[Array],
42
+ vmap_axes: tuple[int, Params[int | None] | None],
43
+ loss_weight: float | Float[Array, " dyn_loss_dimension"],
44
+ u_type: PINN | HyperPINN | None = None,
45
+ ) -> Float[Array, " "]:
45
46
  """
46
47
  Sometimes when u is a lambda function a or dict we do not have access to
47
48
  its type here, hence the last argument
48
49
  """
49
- if u_type == PINN or u_type == HYPERPINN or isinstance(u, (PINN, HYPERPINN)):
50
+ if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
50
51
  v_dyn_loss = vmap(
51
52
  lambda batch, params: dyn_loss(
52
- batch, u, params # we must place the params at the end
53
+ batch,
54
+ u,
55
+ params, # we must place the params at the end
53
56
  ),
54
57
  vmap_axes,
55
58
  0,
@@ -66,36 +69,38 @@ def dynamic_loss_apply(
66
69
 
67
70
 
68
71
  def normalization_loss_apply(
69
- u: eqx.Module,
72
+ u: AbstractPINN,
70
73
  batches: (
71
- tuple[Float[Array, "nb_norm_samples dim"]]
74
+ tuple[Float[Array, " nb_norm_samples dim"]]
72
75
  | tuple[
73
- Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
76
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
74
77
  ]
75
78
  ),
76
- params: Params | ParamsDict,
77
- vmap_axes_params: tuple[int | None, ...],
78
- int_length: int,
79
+ params: Params[Array],
80
+ vmap_axes_params: tuple[Params[int | None] | None],
81
+ norm_weights: Float[Array, " nb_norm_samples"],
79
82
  loss_weight: float,
80
- ) -> float:
83
+ ) -> Float[Array, " "]:
81
84
  """
82
85
  Note the squeezing on each result. We expect unidimensional *PINN since
83
86
  they represent probability distributions
84
87
  """
85
- if isinstance(u, (PINN, HYPERPINN)):
88
+ if isinstance(u, (PINN, HyperPINN)):
86
89
  if len(batches) == 1:
87
90
  v_u = vmap(
88
- lambda b: u(b)[u.slice_solution],
91
+ lambda *b: u(*b)[u.slice_solution],
89
92
  (0,) + vmap_axes_params,
90
93
  0,
91
94
  )
92
95
  res = v_u(*batches, params)
96
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
97
+ # Monte-Carlo integration using importance sampling
93
98
  mse_norm_loss = loss_weight * (
94
- jnp.abs(jnp.mean(res.squeeze()) * int_length - 1) ** 2
99
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
95
100
  )
96
101
  else:
97
102
  # NOTE this cartesian product is costly
98
- batches = make_cartesian_product(
103
+ batch_cart_prod = make_cartesian_product(
99
104
  batches[0],
100
105
  batches[1],
101
106
  ).reshape(batches[0].shape[0], batches[1].shape[0], -1)
@@ -106,21 +111,24 @@ def normalization_loss_apply(
106
111
  ),
107
112
  in_axes=(0,) + vmap_axes_params,
108
113
  )
109
- res = v_u(batches, params)
110
- # Over all the times t, we perform a integration
114
+ res = v_u(batch_cart_prod, params)
115
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
116
+ # For all times t, we perform an integration. Then we average the
117
+ # losses over times.
111
118
  mse_norm_loss = loss_weight * jnp.mean(
112
- jnp.abs(jnp.mean(res.squeeze(), axis=-1) * int_length - 1) ** 2
119
+ jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
113
120
  )
114
121
  elif isinstance(u, SPINN):
115
122
  if len(batches) == 1:
116
123
  res = u(*batches, params)
124
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
117
125
  mse_norm_loss = (
118
126
  loss_weight
119
127
  * jnp.abs(
120
128
  jnp.mean(
121
129
  res.squeeze(),
122
130
  )
123
- * int_length
131
+ * norm_weights
124
132
  - 1
125
133
  )
126
134
  ** 2
@@ -134,14 +142,15 @@ def normalization_loss_apply(
134
142
  ),
135
143
  params,
136
144
  )
145
+ assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
137
146
  # the outer mean() below is for the times stamps
138
147
  mse_norm_loss = loss_weight * jnp.mean(
139
148
  jnp.abs(
140
149
  jnp.mean(
141
150
  res.squeeze(),
142
- axis=(d + 1 for d in range(res.ndim - 2)),
151
+ axis=list(d + 1 for d in range(res.ndim - 2)),
143
152
  )
144
- * int_length
153
+ * norm_weights
145
154
  - 1
146
155
  )
147
156
  ** 2
@@ -153,18 +162,34 @@ def normalization_loss_apply(
153
162
 
154
163
 
155
164
  def boundary_condition_apply(
156
- u: eqx.Module,
165
+ u: AbstractPINN,
157
166
  batch: PDEStatioBatch | PDENonStatioBatch,
158
- params: Params | ParamsDict,
159
- omega_boundary_fun: Callable,
160
- omega_boundary_condition: str,
161
- omega_boundary_dim: int,
162
- loss_weight: float | Float[Array, "boundary_cond_dim"],
163
- ) -> float:
164
-
167
+ params: Params[Array],
168
+ omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
169
+ omega_boundary_condition: str | dict[str, str],
170
+ omega_boundary_dim: slice | dict[str, slice],
171
+ loss_weight: float | Float[Array, " boundary_cond_dim"],
172
+ ) -> Float[Array, " "]:
173
+ assert batch.border_batch is not None
165
174
  vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
166
175
 
167
- if isinstance(omega_boundary_fun, dict):
176
+ def _check_tuple_of_dict(
177
+ val,
178
+ ) -> TypeGuard[
179
+ tuple[
180
+ dict[str, BoundaryConditionFun],
181
+ dict[str, BoundaryConditionFun],
182
+ dict[str, BoundaryConditionFun],
183
+ ]
184
+ ]:
185
+ return all(isinstance(x, dict) for x in val)
186
+
187
+ omega_boundary_dicts = (
188
+ omega_boundary_condition,
189
+ omega_boundary_fun,
190
+ omega_boundary_dim,
191
+ )
192
+ if _check_tuple_of_dict(omega_boundary_dicts):
168
193
  # We must create the facet tree dictionary as we do not have the
169
194
  # enumerate from the for loop to pass the id integer
170
195
  if batch.border_batch.shape[-1] == 2:
@@ -186,10 +211,10 @@ def boundary_condition_apply(
186
211
  )
187
212
  )
188
213
  ),
189
- omega_boundary_condition,
190
- omega_boundary_fun,
214
+ omega_boundary_dicts[0], # omega_boundary_condition,
215
+ omega_boundary_dicts[1], # omega_boundary_fun,
191
216
  facet_tree,
192
- omega_boundary_dim,
217
+ omega_boundary_dicts[2], # omega_boundary_dim,
193
218
  is_leaf=lambda x: x is None,
194
219
  ) # when exploring leaves with None value (no condition) the returned
195
220
  # mse is None and we get rid of the None leaves of b_losses_by_facet
@@ -202,13 +227,13 @@ def boundary_condition_apply(
202
227
  lambda fa: jnp.mean(
203
228
  loss_weight
204
229
  * _compute_boundary_loss(
205
- omega_boundary_condition,
206
- omega_boundary_fun,
230
+ omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
231
+ omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
207
232
  batch,
208
233
  u,
209
234
  params,
210
235
  fa,
211
- omega_boundary_dim,
236
+ omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
212
237
  vmap_in_axes,
213
238
  )
214
239
  ),
@@ -221,22 +246,21 @@ def boundary_condition_apply(
221
246
 
222
247
 
223
248
  def observations_loss_apply(
224
- u: eqx.Module,
225
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
226
- params: Params | ParamsDict,
227
- vmap_axes: tuple[int | None, ...],
228
- observed_values: Float[Array, "batch_size observation_dim"],
229
- loss_weight: float | Float[Array, "observation_dim"],
230
- obs_slice: slice,
231
- ) -> float:
232
- # TODO implement for SPINN
233
- if isinstance(u, (PINN, HYPERPINN)):
249
+ u: AbstractPINN,
250
+ batch: Float[Array, " obs_batch_size input_dim"],
251
+ params: Params[Array],
252
+ vmap_axes: tuple[int, Params[int | None] | None],
253
+ observed_values: Float[Array, " obs_batch_size observation_dim"],
254
+ loss_weight: float | Float[Array, " observation_dim"],
255
+ obs_slice: EllipsisType | slice | None,
256
+ ) -> Float[Array, " "]:
257
+ if isinstance(u, (PINN, HyperPINN)):
234
258
  v_u = vmap(
235
259
  lambda *args: u(*args)[u.slice_solution],
236
260
  vmap_axes,
237
261
  0,
238
262
  )
239
- val = v_u(*batches, params)[:, obs_slice]
263
+ val = v_u(batch, params)[:, obs_slice]
240
264
  mse_observation_loss = jnp.mean(
241
265
  jnp.sum(
242
266
  loss_weight
@@ -255,16 +279,17 @@ def observations_loss_apply(
255
279
 
256
280
 
257
281
  def initial_condition_apply(
258
- u: eqx.Module,
259
- omega_batch: Float[Array, "dimension"],
260
- params: Params | ParamsDict,
261
- vmap_axes: tuple[int | None, ...],
282
+ u: AbstractPINN,
283
+ omega_batch: Float[Array, " dimension"],
284
+ params: Params[Array],
285
+ vmap_axes: tuple[int, Params[int | None] | None],
262
286
  initial_condition_fun: Callable,
263
- loss_weight: float | Float[Array, "initial_condition_dimension"],
264
- ) -> float:
287
+ t0: Float[Array, " 1"],
288
+ loss_weight: float | Float[Array, " initial_condition_dimension"],
289
+ ) -> Float[Array, " "]:
265
290
  n = omega_batch.shape[0]
266
- t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
267
- if isinstance(u, (PINN, HYPERPINN)):
291
+ t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
292
+ if isinstance(u, (PINN, HyperPINN)):
268
293
  v_u_t0 = vmap(
269
294
  lambda t0_x, params: _subtract_with_check(
270
295
  initial_condition_fun(t0_x[1:]),
@@ -296,103 +321,3 @@ def initial_condition_apply(
296
321
  else:
297
322
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
298
323
  return mse_initial_condition
299
-
300
-
301
- def constraints_system_loss_apply(
302
- u_constraints_dict: Dict,
303
- batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
304
- params_dict: ParamsDict,
305
- loss_weights: Dict[str, float | Array],
306
- loss_weight_struct: PyTree,
307
- ):
308
- """
309
- Same function for systemlossODE and systemlossPDE!
310
- """
311
- # Transpose so we have each u_dict as outer structure and the
312
- # associated loss_weight as inner structure
313
- loss_weights_T = jax.tree_util.tree_transpose(
314
- jax.tree_util.tree_structure(loss_weight_struct),
315
- jax.tree_util.tree_structure(loss_weights["initial_condition"]),
316
- loss_weights,
317
- )
318
-
319
- if isinstance(params_dict.nn_params, dict):
320
-
321
- def apply_u_constraint(
322
- u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
323
- ):
324
- res_dict_for_u = u_constraint.evaluate(
325
- Params(
326
- nn_params=nn_params,
327
- eq_params=eq_params,
328
- ),
329
- append_obs_batch(batch, obs_batch_u),
330
- )[1]
331
- res_dict_ponderated = jax.tree_util.tree_map(
332
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
333
- )
334
- return res_dict_ponderated
335
-
336
- # Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
337
- # with keys corresponding to the PINN and value correspondinf to an
338
- # original obs_batch_dict. Hence the tree mapping also interates over
339
- # batch.obs_batch_dict
340
- res_dict = jax.tree_util.tree_map(
341
- apply_u_constraint,
342
- u_constraints_dict,
343
- params_dict.nn_params,
344
- (
345
- params_dict.eq_params
346
- if params_dict.eq_params.keys() == params_dict.nn_params.keys()
347
- else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
348
- ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
349
- loss_weights_T,
350
- batch.obs_batch_dict,
351
- is_leaf=lambda x: (
352
- not isinstance(x, dict) # to not traverse more than the first
353
- # outer dict of the pytrees passed to the function. This will
354
- # work because u_constraints_dict is a dict of Losses, and it
355
- # thus stops the traversing of other dict too
356
- ),
357
- )
358
- # TODO try to get rid of this condition?
359
- else:
360
-
361
- def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
362
- res_dict_for_u = u_constraint.evaluate(
363
- params_dict,
364
- append_obs_batch(batch, obs_batch_u),
365
- )[1]
366
- res_dict_ponderated = jax.tree_util.tree_map(
367
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
368
- )
369
- return res_dict_ponderated
370
-
371
- res_dict = jax.tree_util.tree_map(
372
- apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
373
- )
374
-
375
- # Transpose back so we have mses as outer structures and their values
376
- # for each u_dict as inner structures. The tree_leaves transforms the
377
- # inner structure into a list so we can catch is as leaf it the
378
- # tree_map below
379
- res_dict = jax.tree_util.tree_transpose(
380
- jax.tree_util.tree_structure(
381
- jax.tree_util.tree_leaves(loss_weights["initial_condition"])
382
- ),
383
- jax.tree_util.tree_structure(loss_weight_struct),
384
- res_dict,
385
- )
386
- # For each mse, sum their values on each u_dict
387
- res_dict = jax.tree_util.tree_map(
388
- lambda mse: jax.tree_util.tree_reduce(
389
- lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
390
- ),
391
- res_dict,
392
- is_leaf=lambda x: isinstance(x, list),
393
- )
394
- # Total loss
395
- total_loss = jax.tree_util.tree_reduce(
396
- lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
397
- )
398
- return total_loss, res_dict
@@ -2,58 +2,26 @@
2
2
  Formalize the loss weights data structure
3
3
  """
4
4
 
5
- from typing import Dict
6
5
  from jaxtyping import Array, Float
7
6
  import equinox as eqx
8
7
 
9
8
 
10
9
  class LossWeightsODE(eqx.Module):
11
-
12
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
13
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
14
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
15
-
16
-
17
- class LossWeightsODEDict(eqx.Module):
18
-
19
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
20
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
21
- kw_only=True, default=None
22
- )
23
- observations: Dict[str, Array | Float | None] = eqx.field(
24
- kw_only=True, default=None
25
- )
10
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
11
+ initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
12
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
26
13
 
27
14
 
28
15
  class LossWeightsPDEStatio(eqx.Module):
29
-
30
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
31
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
32
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
33
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
16
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
17
+ norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
18
+ boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
19
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
34
20
 
35
21
 
36
22
  class LossWeightsPDENonStatio(eqx.Module):
37
-
38
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
39
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
40
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
41
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
42
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
43
-
44
-
45
- class LossWeightsPDEDict(eqx.Module):
46
- """
47
- Only one type of LossWeights data structure for the SystemLossPDE:
48
- Include the initial condition always for the code to be more generic
49
- """
50
-
51
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
52
- norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
53
- boundary_loss: Dict[str, Array | Float | None] = eqx.field(
54
- kw_only=True, default=1.0
55
- )
56
- observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
57
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
58
- kw_only=True, default=1.0
59
- )
23
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
24
+ norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
25
+ boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
26
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
27
+ initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)