jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py CHANGED
@@ -1,21 +1,18 @@
1
1
  from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
2
- from ._LossODE import LossODE, SystemLossODE
3
- from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
2
+ from ._LossODE import LossODE
3
+ from ._LossPDE import LossPDEStatio, LossPDENonStatio
4
4
  from ._DynamicLoss import (
5
5
  GeneralizedLotkaVolterra,
6
6
  BurgersEquation,
7
7
  FPENonStatioLoss2D,
8
8
  OU_FPENonStatioLoss2D,
9
9
  FisherKPP,
10
- MassConservation2DStatio,
11
- NavierStokes2DStatio,
10
+ NavierStokesMassConservation2DStatio,
12
11
  )
13
12
  from ._loss_weights import (
14
13
  LossWeightsODE,
15
- LossWeightsODEDict,
16
14
  LossWeightsPDENonStatio,
17
15
  LossWeightsPDEStatio,
18
- LossWeightsPDEDict,
19
16
  )
20
17
 
21
18
  from ._operators import (
@@ -26,3 +23,28 @@ from ._operators import (
26
23
  vectorial_laplacian_fwd,
27
24
  vectorial_laplacian_rev,
28
25
  )
26
+
27
+ __all__ = [
28
+ "DynamicLoss",
29
+ "ODE",
30
+ "PDEStatio",
31
+ "PDENonStatio",
32
+ "LossODE",
33
+ "LossPDEStatio",
34
+ "LossPDENonStatio",
35
+ "GeneralizedLotkaVolterra",
36
+ "BurgersEquation",
37
+ "FPENonStatioLoss2D",
38
+ "OU_FPENonStatioLoss2D",
39
+ "FisherKPP",
40
+ "NavierStokesMassConservation2DStatio",
41
+ "LossWeightsODE",
42
+ "LossWeightsPDEStatio",
43
+ "LossWeightsPDENonStatio",
44
+ "divergence_fwd",
45
+ "divergence_rev",
46
+ "laplacian_fwd",
47
+ "laplacian_rev",
48
+ "vectorial_laplacian_fwd",
49
+ "vectorial_laplacian_rev",
50
+ ]
@@ -0,0 +1,15 @@
1
+ import abc
2
+ from jaxtyping import Array
3
+ import equinox as eqx
4
+
5
+
6
+ class AbstractLoss(eqx.Module):
7
+ """
8
+ Basically just a way to add a __call__ to an eqx.Module.
9
+ The way to go for correct type hints apparently
10
+ https://github.com/patrick-kidger/equinox/issues/1002 + https://docs.kidger.site/equinox/pattern/
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def __call__(self, *_, **__) -> Array:
15
+ pass
@@ -7,31 +7,31 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  from typing import TYPE_CHECKING, Callable
10
+ from jaxtyping import Array, Float
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap, grad
13
- import equinox as eqx
14
14
  from jinns.utils._utils import get_grid, _subtract_with_check
15
- from jinns.data._Batchs import *
15
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
16
16
  from jinns.nn._pinn import PINN
17
17
  from jinns.nn._spinn import SPINN
18
18
 
19
19
  if TYPE_CHECKING:
20
- from jinns.utils._types import *
20
+ from jinns.parameters._params import Params
21
+ from jinns.utils._types import BoundaryConditionFun
22
+ from jinns.nn._abstract_pinn import AbstractPINN
21
23
 
22
24
 
23
25
  def _compute_boundary_loss(
24
26
  boundary_condition_type: str,
25
- f: Callable[
26
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
27
- ],
27
+ f: BoundaryConditionFun,
28
28
  batch: PDEStatioBatch | PDENonStatioBatch,
29
- u: eqx.Module,
30
- params: AnyParams,
29
+ u: AbstractPINN,
30
+ params: Params[Array],
31
31
  facet: int,
32
32
  dim_to_apply: slice,
33
33
  vmap_in_axes: tuple,
34
- ) -> float:
34
+ ) -> Float[Array, " "]:
35
35
  r"""A generic function that will compute the mini-batch MSE of a
36
36
  boundary condition in the stationary case, resp. non-stationary, given by:
37
37
 
@@ -67,7 +67,7 @@ def _compute_boundary_loss(
67
67
  u
68
68
  a PINN
69
69
  params
70
- Params or ParamsDict
70
+ Params
71
71
  facet
72
72
  An integer which represents the id of the facet which is currently
73
73
  considered (in the order provided by the DataGenerator which is fixed)
@@ -96,15 +96,15 @@ def _compute_boundary_loss(
96
96
 
97
97
  def boundary_dirichlet(
98
98
  f: Callable[
99
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
99
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
100
100
  ],
101
101
  batch: PDEStatioBatch | PDENonStatioBatch,
102
- u: eqx.Module,
103
- params: Params | ParamsDict,
102
+ u: AbstractPINN,
103
+ params: Params[Array],
104
104
  facet: int,
105
105
  dim_to_apply: slice,
106
106
  vmap_in_axes: tuple,
107
- ) -> float:
107
+ ) -> Float[Array, " "]:
108
108
  r"""
109
109
  This omega boundary condition enforces a solution that is equal to `f`
110
110
  at `times_batch` x `omega_border` (non stationary case) or at `omega_border`
@@ -135,6 +135,7 @@ def boundary_dirichlet(
135
135
  vmap_in_axes
136
136
  A tuple object which specifies the in_axes of the vmapping
137
137
  """
138
+ assert batch.border_batch is not None
138
139
  batch_array = batch.border_batch
139
140
  batch_array = batch_array[..., facet]
140
141
 
@@ -168,15 +169,15 @@ def boundary_dirichlet(
168
169
 
169
170
  def boundary_neumann(
170
171
  f: Callable[
171
- [Float[Array, "dim"] | Float[Array, "dim + 1"]], Float[Array, "dim_solution"]
172
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
172
173
  ],
173
174
  batch: PDEStatioBatch | PDENonStatioBatch,
174
- u: eqx.Module,
175
- params: Params | ParamsDict,
175
+ u: AbstractPINN,
176
+ params: Params[Array],
176
177
  facet: int,
177
178
  dim_to_apply: slice,
178
179
  vmap_in_axes: tuple,
179
- ) -> float:
180
+ ) -> Float[Array, " "]:
180
181
  r"""
181
182
  This omega boundary condition enforces a solution where $\nabla u\cdot
182
183
  n$ is equal to `f` at the cartesian product of `time_batch` x `omega
@@ -208,6 +209,7 @@ def boundary_neumann(
208
209
  vmap_in_axes
209
210
  A tuple object which specifies the in_axes of the vmapping
210
211
  """
212
+ assert batch.border_batch is not None
211
213
  batch_array = batch.border_batch
212
214
  batch_array = batch_array[..., facet]
213
215
 
@@ -223,7 +225,6 @@ def boundary_neumann(
223
225
  n = jnp.array([[-1, 1, 0, 0], [0, 0, -1, 1]])
224
226
 
225
227
  if isinstance(u, PINN):
226
-
227
228
  u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
228
229
 
229
230
  if u.eq_type == "statio_PDE":
jinns/loss/_loss_utils.py CHANGED
@@ -6,42 +6,43 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
- from typing import TYPE_CHECKING, Callable, Dict
9
+ from typing import TYPE_CHECKING, Callable, TypeGuard
10
+ from types import EllipsisType
10
11
  import jax
11
12
  import jax.numpy as jnp
12
13
  from jax import vmap
13
- import equinox as eqx
14
- from jaxtyping import Float, Array, PyTree
14
+ from jaxtyping import Float, Array
15
15
 
16
16
  from jinns.loss._boundary_conditions import (
17
17
  _compute_boundary_loss,
18
18
  )
19
19
  from jinns.utils._utils import _subtract_with_check, get_grid
20
- from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
20
+ from jinns.data._utils import make_cartesian_product
21
21
  from jinns.parameters._params import _get_vmap_in_axes_params
22
22
  from jinns.nn._pinn import PINN
23
23
  from jinns.nn._spinn import SPINN
24
24
  from jinns.nn._hyperpinn import HyperPINN
25
- from jinns.data._Batchs import *
26
- from jinns.parameters._params import Params, ParamsDict
25
+ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
26
+ from jinns.parameters._params import Params
27
27
 
28
28
  if TYPE_CHECKING:
29
- from jinns.utils._types import *
29
+ from jinns.utils._types import BoundaryConditionFun
30
+ from jinns.nn._abstract_pinn import AbstractPINN
30
31
 
31
32
 
32
33
  def dynamic_loss_apply(
33
- dyn_loss: DynamicLoss,
34
- u: eqx.Module,
34
+ dyn_loss: Callable,
35
+ u: AbstractPINN,
35
36
  batch: (
36
- Float[Array, "batch_size 1"]
37
- | Float[Array, "batch_size dim"]
38
- | Float[Array, "batch_size 1+dim"]
37
+ Float[Array, " batch_size 1"]
38
+ | Float[Array, " batch_size dim"]
39
+ | Float[Array, " batch_size 1+dim"]
39
40
  ),
40
- params: Params | ParamsDict,
41
- vmap_axes: tuple[int | None, ...],
42
- loss_weight: float | Float[Array, "dyn_loss_dimension"],
41
+ params: Params[Array],
42
+ vmap_axes: tuple[int, Params[int | None] | None],
43
+ loss_weight: float | Float[Array, " dyn_loss_dimension"],
43
44
  u_type: PINN | HyperPINN | None = None,
44
- ) -> float:
45
+ ) -> Float[Array, " "]:
45
46
  """
46
47
  Sometimes when u is a lambda function a or dict we do not have access to
47
48
  its type here, hence the last argument
@@ -49,7 +50,9 @@ def dynamic_loss_apply(
49
50
  if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
50
51
  v_dyn_loss = vmap(
51
52
  lambda batch, params: dyn_loss(
52
- batch, u, params # we must place the params at the end
53
+ batch,
54
+ u,
55
+ params, # we must place the params at the end
53
56
  ),
54
57
  vmap_axes,
55
58
  0,
@@ -66,18 +69,18 @@ def dynamic_loss_apply(
66
69
 
67
70
 
68
71
  def normalization_loss_apply(
69
- u: eqx.Module,
72
+ u: AbstractPINN,
70
73
  batches: (
71
- tuple[Float[Array, "nb_norm_samples dim"]]
74
+ tuple[Float[Array, " nb_norm_samples dim"]]
72
75
  | tuple[
73
- Float[Array, "nb_norm_time_slices 1"], Float[Array, "nb_norm_samples dim"]
76
+ Float[Array, " nb_norm_time_slices 1"], Float[Array, " nb_norm_samples dim"]
74
77
  ]
75
78
  ),
76
- params: Params | ParamsDict,
77
- vmap_axes_params: tuple[int | None, ...],
78
- norm_weights: Float[Array, "nb_norm_samples"],
79
+ params: Params[Array],
80
+ vmap_axes_params: tuple[Params[int | None] | None],
81
+ norm_weights: Float[Array, " nb_norm_samples"],
79
82
  loss_weight: float,
80
- ) -> float:
83
+ ) -> Float[Array, " "]:
81
84
  """
82
85
  Note the squeezing on each result. We expect unidimensional *PINN since
83
86
  they represent probability distributions
@@ -97,7 +100,7 @@ def normalization_loss_apply(
97
100
  )
98
101
  else:
99
102
  # NOTE this cartesian product is costly
100
- batches = make_cartesian_product(
103
+ batch_cart_prod = make_cartesian_product(
101
104
  batches[0],
102
105
  batches[1],
103
106
  ).reshape(batches[0].shape[0], batches[1].shape[0], -1)
@@ -108,7 +111,7 @@ def normalization_loss_apply(
108
111
  ),
109
112
  in_axes=(0,) + vmap_axes_params,
110
113
  )
111
- res = v_u(batches, params)
114
+ res = v_u(batch_cart_prod, params)
112
115
  assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
113
116
  # For all times t, we perform an integration. Then we average the
114
117
  # losses over times.
@@ -145,7 +148,7 @@ def normalization_loss_apply(
145
148
  jnp.abs(
146
149
  jnp.mean(
147
150
  res.squeeze(),
148
- axis=(d + 1 for d in range(res.ndim - 2)),
151
+ axis=list(d + 1 for d in range(res.ndim - 2)),
149
152
  )
150
153
  * norm_weights
151
154
  - 1
@@ -159,18 +162,34 @@ def normalization_loss_apply(
159
162
 
160
163
 
161
164
  def boundary_condition_apply(
162
- u: eqx.Module,
165
+ u: AbstractPINN,
163
166
  batch: PDEStatioBatch | PDENonStatioBatch,
164
- params: Params | ParamsDict,
165
- omega_boundary_fun: Callable,
166
- omega_boundary_condition: str,
167
- omega_boundary_dim: int,
168
- loss_weight: float | Float[Array, "boundary_cond_dim"],
169
- ) -> float:
170
-
167
+ params: Params[Array],
168
+ omega_boundary_fun: BoundaryConditionFun | dict[str, BoundaryConditionFun],
169
+ omega_boundary_condition: str | dict[str, str],
170
+ omega_boundary_dim: slice | dict[str, slice],
171
+ loss_weight: float | Float[Array, " boundary_cond_dim"],
172
+ ) -> Float[Array, " "]:
173
+ assert batch.border_batch is not None
171
174
  vmap_in_axes = (0,) + _get_vmap_in_axes_params(batch.param_batch_dict, params)
172
175
 
173
- if isinstance(omega_boundary_fun, dict):
176
+ def _check_tuple_of_dict(
177
+ val,
178
+ ) -> TypeGuard[
179
+ tuple[
180
+ dict[str, BoundaryConditionFun],
181
+ dict[str, BoundaryConditionFun],
182
+ dict[str, BoundaryConditionFun],
183
+ ]
184
+ ]:
185
+ return all(isinstance(x, dict) for x in val)
186
+
187
+ omega_boundary_dicts = (
188
+ omega_boundary_condition,
189
+ omega_boundary_fun,
190
+ omega_boundary_dim,
191
+ )
192
+ if _check_tuple_of_dict(omega_boundary_dicts):
174
193
  # We must create the facet tree dictionary as we do not have the
175
194
  # enumerate from the for loop to pass the id integer
176
195
  if batch.border_batch.shape[-1] == 2:
@@ -192,10 +211,10 @@ def boundary_condition_apply(
192
211
  )
193
212
  )
194
213
  ),
195
- omega_boundary_condition,
196
- omega_boundary_fun,
214
+ omega_boundary_dicts[0], # omega_boundary_condition,
215
+ omega_boundary_dicts[1], # omega_boundary_fun,
197
216
  facet_tree,
198
- omega_boundary_dim,
217
+ omega_boundary_dicts[2], # omega_boundary_dim,
199
218
  is_leaf=lambda x: x is None,
200
219
  ) # when exploring leaves with None value (no condition) the returned
201
220
  # mse is None and we get rid of the None leaves of b_losses_by_facet
@@ -208,13 +227,13 @@ def boundary_condition_apply(
208
227
  lambda fa: jnp.mean(
209
228
  loss_weight
210
229
  * _compute_boundary_loss(
211
- omega_boundary_condition,
212
- omega_boundary_fun,
230
+ omega_boundary_dicts[0], # type: ignore -> need TypeIs from 3.13
231
+ omega_boundary_dicts[1], # type: ignore -> need TypeIs from 3.13
213
232
  batch,
214
233
  u,
215
234
  params,
216
235
  fa,
217
- omega_boundary_dim,
236
+ omega_boundary_dicts[2], # type: ignore -> need TypeIs from 3.13
218
237
  vmap_in_axes,
219
238
  )
220
239
  ),
@@ -227,22 +246,21 @@ def boundary_condition_apply(
227
246
 
228
247
 
229
248
  def observations_loss_apply(
230
- u: eqx.Module,
231
- batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
232
- params: Params | ParamsDict,
233
- vmap_axes: tuple[int | None, ...],
234
- observed_values: Float[Array, "batch_size observation_dim"],
235
- loss_weight: float | Float[Array, "observation_dim"],
236
- obs_slice: slice,
237
- ) -> float:
238
- # TODO implement for SPINN
249
+ u: AbstractPINN,
250
+ batch: Float[Array, " obs_batch_size input_dim"],
251
+ params: Params[Array],
252
+ vmap_axes: tuple[int, Params[int | None] | None],
253
+ observed_values: Float[Array, " obs_batch_size observation_dim"],
254
+ loss_weight: float | Float[Array, " observation_dim"],
255
+ obs_slice: EllipsisType | slice | None,
256
+ ) -> Float[Array, " "]:
239
257
  if isinstance(u, (PINN, HyperPINN)):
240
258
  v_u = vmap(
241
259
  lambda *args: u(*args)[u.slice_solution],
242
260
  vmap_axes,
243
261
  0,
244
262
  )
245
- val = v_u(*batches, params)[:, obs_slice]
263
+ val = v_u(batch, params)[:, obs_slice]
246
264
  mse_observation_loss = jnp.mean(
247
265
  jnp.sum(
248
266
  loss_weight
@@ -261,15 +279,16 @@ def observations_loss_apply(
261
279
 
262
280
 
263
281
  def initial_condition_apply(
264
- u: eqx.Module,
265
- omega_batch: Float[Array, "dimension"],
266
- params: Params | ParamsDict,
267
- vmap_axes: tuple[int | None, ...],
282
+ u: AbstractPINN,
283
+ omega_batch: Float[Array, " dimension"],
284
+ params: Params[Array],
285
+ vmap_axes: tuple[int, Params[int | None] | None],
268
286
  initial_condition_fun: Callable,
269
- loss_weight: float | Float[Array, "initial_condition_dimension"],
270
- ) -> float:
287
+ t0: Float[Array, " 1"],
288
+ loss_weight: float | Float[Array, " initial_condition_dimension"],
289
+ ) -> Float[Array, " "]:
271
290
  n = omega_batch.shape[0]
272
- t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
291
+ t0_omega_batch = jnp.concatenate([t0 * jnp.ones((n, 1)), omega_batch], axis=1)
273
292
  if isinstance(u, (PINN, HyperPINN)):
274
293
  v_u_t0 = vmap(
275
294
  lambda t0_x, params: _subtract_with_check(
@@ -302,103 +321,3 @@ def initial_condition_apply(
302
321
  else:
303
322
  raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
304
323
  return mse_initial_condition
305
-
306
-
307
- def constraints_system_loss_apply(
308
- u_constraints_dict: Dict,
309
- batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
310
- params_dict: ParamsDict,
311
- loss_weights: Dict[str, float | Array],
312
- loss_weight_struct: PyTree,
313
- ):
314
- """
315
- Same function for systemlossODE and systemlossPDE!
316
- """
317
- # Transpose so we have each u_dict as outer structure and the
318
- # associated loss_weight as inner structure
319
- loss_weights_T = jax.tree_util.tree_transpose(
320
- jax.tree_util.tree_structure(loss_weight_struct),
321
- jax.tree_util.tree_structure(loss_weights["initial_condition"]),
322
- loss_weights,
323
- )
324
-
325
- if isinstance(params_dict.nn_params, dict):
326
-
327
- def apply_u_constraint(
328
- u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
329
- ):
330
- res_dict_for_u = u_constraint.evaluate(
331
- Params(
332
- nn_params=nn_params,
333
- eq_params=eq_params,
334
- ),
335
- append_obs_batch(batch, obs_batch_u),
336
- )[1]
337
- res_dict_ponderated = jax.tree_util.tree_map(
338
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
339
- )
340
- return res_dict_ponderated
341
-
342
- # Note in the case of multiple PINNs, batch.obs_batch_dict is a dict
343
- # with keys corresponding to the PINN and value correspondinf to an
344
- # original obs_batch_dict. Hence the tree mapping also interates over
345
- # batch.obs_batch_dict
346
- res_dict = jax.tree_util.tree_map(
347
- apply_u_constraint,
348
- u_constraints_dict,
349
- params_dict.nn_params,
350
- (
351
- params_dict.eq_params
352
- if params_dict.eq_params.keys() == params_dict.nn_params.keys()
353
- else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
354
- ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
355
- loss_weights_T,
356
- batch.obs_batch_dict,
357
- is_leaf=lambda x: (
358
- not isinstance(x, dict) # to not traverse more than the first
359
- # outer dict of the pytrees passed to the function. This will
360
- # work because u_constraints_dict is a dict of Losses, and it
361
- # thus stops the traversing of other dict too
362
- ),
363
- )
364
- # TODO try to get rid of this condition?
365
- else:
366
-
367
- def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
368
- res_dict_for_u = u_constraint.evaluate(
369
- params_dict,
370
- append_obs_batch(batch, obs_batch_u),
371
- )[1]
372
- res_dict_ponderated = jax.tree_util.tree_map(
373
- lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
374
- )
375
- return res_dict_ponderated
376
-
377
- res_dict = jax.tree_util.tree_map(
378
- apply_u_constraint, u_constraints_dict, loss_weights_T, batch.obs_batch_dict
379
- )
380
-
381
- # Transpose back so we have mses as outer structures and their values
382
- # for each u_dict as inner structures. The tree_leaves transforms the
383
- # inner structure into a list so we can catch is as leaf it the
384
- # tree_map below
385
- res_dict = jax.tree_util.tree_transpose(
386
- jax.tree_util.tree_structure(
387
- jax.tree_util.tree_leaves(loss_weights["initial_condition"])
388
- ),
389
- jax.tree_util.tree_structure(loss_weight_struct),
390
- res_dict,
391
- )
392
- # For each mse, sum their values on each u_dict
393
- res_dict = jax.tree_util.tree_map(
394
- lambda mse: jax.tree_util.tree_reduce(
395
- lambda x, y: x + y, jax.tree_util.tree_leaves(mse)
396
- ),
397
- res_dict,
398
- is_leaf=lambda x: isinstance(x, list),
399
- )
400
- # Total loss
401
- total_loss = jax.tree_util.tree_reduce(
402
- lambda x, y: x + y, jax.tree_util.tree_leaves(res_dict)
403
- )
404
- return total_loss, res_dict
@@ -2,58 +2,26 @@
2
2
  Formalize the loss weights data structure
3
3
  """
4
4
 
5
- from typing import Dict
6
5
  from jaxtyping import Array, Float
7
6
  import equinox as eqx
8
7
 
9
8
 
10
9
  class LossWeightsODE(eqx.Module):
11
-
12
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
13
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
14
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
15
-
16
-
17
- class LossWeightsODEDict(eqx.Module):
18
-
19
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
20
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
21
- kw_only=True, default=None
22
- )
23
- observations: Dict[str, Array | Float | None] = eqx.field(
24
- kw_only=True, default=None
25
- )
10
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
11
+ initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)
12
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
26
13
 
27
14
 
28
15
  class LossWeightsPDEStatio(eqx.Module):
29
-
30
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
31
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
32
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
33
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
16
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
17
+ norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
18
+ boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
19
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
34
20
 
35
21
 
36
22
  class LossWeightsPDENonStatio(eqx.Module):
37
-
38
- dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
39
- norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
40
- boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
41
- observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
42
- initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
43
-
44
-
45
- class LossWeightsPDEDict(eqx.Module):
46
- """
47
- Only one type of LossWeights data structure for the SystemLossPDE:
48
- Include the initial condition always for the code to be more generic
49
- """
50
-
51
- dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
52
- norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
53
- boundary_loss: Dict[str, Array | Float | None] = eqx.field(
54
- kw_only=True, default=1.0
55
- )
56
- observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
57
- initial_condition: Dict[str, Array | Float | None] = eqx.field(
58
- kw_only=True, default=1.0
59
- )
23
+ dyn_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
24
+ norm_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
25
+ boundary_loss: Array | Float = eqx.field(kw_only=True, default=0.0)
26
+ observations: Array | Float = eqx.field(kw_only=True, default=0.0)
27
+ initial_condition: Array | Float = eqx.field(kw_only=True, default=0.0)