jinns 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. jinns/__init__.py +2 -0
  2. jinns/data/_Batchs.py +27 -0
  3. jinns/data/_DataGenerators.py +904 -1203
  4. jinns/data/__init__.py +4 -8
  5. jinns/experimental/__init__.py +0 -2
  6. jinns/experimental/_diffrax_solver.py +5 -5
  7. jinns/loss/_DynamicLoss.py +282 -305
  8. jinns/loss/_DynamicLossAbstract.py +322 -167
  9. jinns/loss/_LossODE.py +324 -322
  10. jinns/loss/_LossPDE.py +652 -1027
  11. jinns/loss/__init__.py +21 -5
  12. jinns/loss/_boundary_conditions.py +87 -41
  13. jinns/loss/{_Losses.py → _loss_utils.py} +101 -45
  14. jinns/loss/_loss_weights.py +59 -0
  15. jinns/loss/_operators.py +78 -72
  16. jinns/parameters/__init__.py +6 -0
  17. jinns/parameters/_derivative_keys.py +521 -0
  18. jinns/parameters/_params.py +115 -0
  19. jinns/plot/__init__.py +5 -0
  20. jinns/{data/_display.py → plot/_plot.py} +98 -75
  21. jinns/solver/_rar.py +183 -39
  22. jinns/solver/_solve.py +151 -124
  23. jinns/utils/__init__.py +3 -9
  24. jinns/utils/_containers.py +37 -44
  25. jinns/utils/_hyperpinn.py +224 -119
  26. jinns/utils/_pinn.py +183 -111
  27. jinns/utils/_save_load.py +121 -56
  28. jinns/utils/_spinn.py +113 -86
  29. jinns/utils/_types.py +64 -0
  30. jinns/utils/_utils.py +6 -160
  31. jinns/validation/_validation.py +48 -140
  32. jinns-1.1.0.dist-info/AUTHORS +2 -0
  33. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/METADATA +5 -4
  34. jinns-1.1.0.dist-info/RECORD +39 -0
  35. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/WHEEL +1 -1
  36. jinns/experimental/_sinuspinn.py +0 -135
  37. jinns/experimental/_spectralpinn.py +0 -87
  38. jinns/solver/_seq2seq.py +0 -157
  39. jinns/utils/_optim.py +0 -147
  40. jinns/utils/_utils_uspinn.py +0 -727
  41. jinns-0.9.0.dist-info/RECORD +0 -36
  42. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/LICENSE +0 -0
  43. {jinns-0.9.0.dist-info → jinns-1.1.0.dist-info}/top_level.txt +0 -0
jinns/loss/__init__.py CHANGED
@@ -1,11 +1,27 @@
1
- from ._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
1
+ from ._DynamicLossAbstract import DynamicLoss, ODE, PDEStatio, PDENonStatio
2
+ from ._LossODE import LossODE, SystemLossODE
3
+ from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
2
4
  from ._DynamicLoss import (
3
- FisherKPP,
4
- BurgerEquation,
5
5
  GeneralizedLotkaVolterra,
6
+ BurgerEquation,
7
+ FPENonStatioLoss2D,
6
8
  OU_FPENonStatioLoss2D,
9
+ FisherKPP,
7
10
  MassConservation2DStatio,
8
11
  NavierStokes2DStatio,
9
12
  )
10
- from ._LossPDE import LossPDENonStatio, LossPDEStatio, SystemLossPDE
11
- from ._LossODE import LossODE, SystemLossODE
13
+ from ._loss_weights import (
14
+ LossWeightsODE,
15
+ LossWeightsODEDict,
16
+ LossWeightsPDENonStatio,
17
+ LossWeightsPDEStatio,
18
+ LossWeightsPDEDict,
19
+ )
20
+
21
+ from ._operators import (
22
+ _div_fwd,
23
+ _div_rev,
24
+ _laplacian_fwd,
25
+ _laplacian_rev,
26
+ _vectorial_laplacian,
27
+ )
@@ -2,38 +2,53 @@
2
2
  Implements the main boundary conditions for all kinds of losses in jinns
3
3
  """
4
4
 
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Callable
5
10
  import jax
6
11
  import jax.numpy as jnp
7
12
  from jax import vmap, grad
13
+ import equinox as eqx
8
14
  from jinns.utils._utils import (
9
15
  _get_grid,
10
16
  _check_user_func_return,
11
- _get_vmap_in_axes_params,
12
17
  )
13
- from jinns.data._DataGenerators import PDEStatioBatch, PDENonStatioBatch
18
+ from jinns.parameters._params import _get_vmap_in_axes_params
19
+ from jinns.data._Batchs import *
14
20
  from jinns.utils._pinn import PINN
15
21
  from jinns.utils._spinn import SPINN
16
22
 
23
+ if TYPE_CHECKING:
24
+ from jinns.utils._types import *
25
+
17
26
 
18
27
  def _compute_boundary_loss(
19
- boundary_condition_type, f, batch, u, params, facet, dim_to_apply
20
- ):
28
+ boundary_condition_type: str,
29
+ f: Callable,
30
+ batch: PDEStatioBatch | PDENonStatioBatch,
31
+ u: eqx.Module,
32
+ params: Params | ParamsDict,
33
+ facet: int,
34
+ dim_to_apply: slice,
35
+ ) -> float:
21
36
  r"""A generic function that will compute the mini-batch MSE of a
22
37
  boundary condition in the stationary case, resp. non-stationary, given by:
23
38
 
24
- .. math::
39
+ $$
25
40
  D[u](\partial x) = f(\partial x), \forall \partial x \in \partial \Omega
26
-
41
+ $$
27
42
  resp.,
28
43
 
29
- .. math::
44
+ $$
30
45
  D[u](t, \partial x) = f(\partial x), \forall t \in I, \forall \partial
31
46
  x \in \partial \Omega
47
+ $$
32
48
 
49
+ Where $D[\cdot]$ is a differential operator, possibly identity.
33
50
 
34
- Where :math:`D[\cdot]` is a differential operator, possibly identity.
35
-
36
- __Note__: if using a batch.param_batch_dict, we need to resolve the
51
+ **Note**: if using a batch.param_batch_dict, we need to resolve the
37
52
  vmapping axes in the boundary functions, however params["eq_params"]
38
53
  has already been fed with the batch in the `evaluate()` of `LossPDEStatio`,
39
54
  resp. `LossPDENonStatio`.
@@ -41,27 +56,24 @@ def _compute_boundary_loss(
41
56
  Parameters
42
57
  ----------
43
58
  boundary_condition_type
44
- a string defining the differential operator :math:`D[\cdot]`.
45
- Currently implements one of "Dirichlet" (:math:`D = Id`) and Von
46
- Neuman (:math:`D[\cdot] = \nabla \cdot n`) where :math:`n` is the
47
- unitary outgoing vector normal to :math:`\partial\Omega`
59
+ a string defining the differential operator $D[\cdot]$.
60
+ Currently implements one of "Dirichlet" ($D = Id$) and Von
61
+ Neuman ($D[u] = \nabla u \cdot n$) where $n$ is the
62
+ unitary outgoing vector normal to $\partial\Omega$
48
63
  f
49
64
  the function to be matched in the boundary condition. It should have
50
- one argument only (other are ignored).
65
+ one or two arguments only (other are ignored).
51
66
  batch
52
- a PDEStatioBatch object or PDENonStatioBatch
67
+ a PDEStatioBatch or PDENonStatioBatch
53
68
  u
54
69
  a PINN
55
70
  params
56
- The dictionary of parameters of the model.
57
- Typically, it is a dictionary of
58
- dictionaries: `eq_params` and `nn_params``, respectively the
59
- differential equation parameters and the neural network parameter
60
- facet:
71
+ Params or ParamsDict
72
+ facet
61
73
  An integer which represents the id of the facet which is currently
62
74
  considered (in the order provided by the DataGenerator which is fixed)
63
75
  dim_to_apply
64
- A jnp.s\_ object which indicates which dimension(s) of u will be forced
76
+ A `jnp.s_` object which indicates which dimension(s) of u will be forced
65
77
  to match the boundary condition
66
78
 
67
79
  Returns
@@ -91,7 +103,14 @@ def _compute_boundary_loss(
91
103
  return mse
92
104
 
93
105
 
94
- def boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply):
106
+ def boundary_dirichlet_statio(
107
+ f: Callable,
108
+ batch: PDEStatioBatch,
109
+ u: eqx.Module,
110
+ params: Params | ParamsDict,
111
+ facet: int,
112
+ dim_to_apply: slice,
113
+ ) -> float:
95
114
  r"""
96
115
  This omega boundary condition enforces a solution that is equal to f on
97
116
  border batch.
@@ -102,17 +121,14 @@ def boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply):
102
121
 
103
122
  Parameters
104
123
  ----------
105
- f:
124
+ f
106
125
  the constraint function
107
126
  batch
108
127
  A PDEStatioBatch object.
109
128
  u
110
129
  The PINN
111
130
  params
112
- The dictionary of parameters of the model.
113
- Typically, it is a dictionary of
114
- dictionaries: `eq_params` and `nn_params``, respectively the
115
- differential equation parameters and the neural network parameter
131
+ Params or ParamsDict
116
132
  dim_to_apply
117
133
  A jnp.s\_ object. The dimension of u on which to apply the boundary condition
118
134
  """
@@ -139,14 +155,23 @@ def boundary_dirichlet_statio(f, batch, u, params, facet, dim_to_apply):
139
155
  res**2,
140
156
  axis=-1,
141
157
  )
158
+ else:
159
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
142
160
  return mse_u_boundary
143
161
 
144
162
 
145
- def boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply):
163
+ def boundary_neumann_statio(
164
+ f: Callable,
165
+ batch: PDEStatioBatch,
166
+ u: eqx.Module,
167
+ params: Params | ParamsDict,
168
+ facet: int,
169
+ dim_to_apply: slice,
170
+ ) -> float:
146
171
  r"""
147
- This omega boundary condition enforces a solution where :math:`\nabla u\cdot
148
- n` is equal to `f` on omega borders. :math:`n` is the unitary
149
- outgoing vector normal at border :math:`\partial\Omega`.
172
+ This omega boundary condition enforces a solution where $\nabla u\cdot
173
+ n$ is equal to `f` on omega borders. $n$ is the unitary
174
+ outgoing vector normal at border $\partial\Omega$.
150
175
 
151
176
  __Note__: if using a batch.param_batch_dict, we need to resolve the
152
177
  vmapping axes here however params["eq_params"] has already been fed with
@@ -165,7 +190,7 @@ def boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply):
165
190
  Typically, it is a dictionary of
166
191
  dictionaries: `eq_params` and `nn_params``, respectively the
167
192
  differential equation parameters and the neural network parameter
168
- facet:
193
+ facet
169
194
  An integer which represents the id of the facet which is currently
170
195
  considered (in the order provided wy the DataGenerator which is fixed)
171
196
  dim_to_apply
@@ -248,13 +273,22 @@ def boundary_neumann_statio(f, batch, u, params, facet, dim_to_apply):
248
273
  boundaries = _check_user_func_return(f(x_grid), values.shape)
249
274
  res = values - boundaries
250
275
  mse_u_boundary = jnp.sum(res**2, axis=-1)
276
+ else:
277
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
251
278
  return mse_u_boundary
252
279
 
253
280
 
254
- def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
281
+ def boundary_dirichlet_nonstatio(
282
+ f: Callable,
283
+ batch: PDENonStatioBatch,
284
+ u: eqx.Module,
285
+ params: Params | ParamsDict,
286
+ facet: int,
287
+ dim_to_apply: slice,
288
+ ) -> float:
255
289
  r"""
256
- This omega boundary condition enforces a solution that is equal to f
257
- at times_batch x omega borders
290
+ This omega boundary condition enforces a solution that is equal to `f`
291
+ at `times_batch` x `omega borders`
258
292
 
259
293
  __Note__: if using a batch.param_batch_dict, we need to resolve the
260
294
  vmapping axes here however params["eq_params"] has already been fed with
@@ -271,7 +305,7 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
271
305
  params
272
306
  The dictionary of parameters of the model.
273
307
  Typically, it is a dictionary of
274
- dictionaries: `eq_params` and `nn_params``, respectively the
308
+ dictionaries: `eq_params` and `nn_params`, respectively the
275
309
  differential equation parameters and the neural network parameter
276
310
  facet:
277
311
  An integer which represents the id of the facet which is currently
@@ -309,14 +343,24 @@ def boundary_dirichlet_nonstatio(f, batch, u, params, facet, dim_to_apply):
309
343
  )
310
344
  res = values - boundaries
311
345
  mse_u_boundary = jnp.sum(res**2, axis=-1)
346
+ else:
347
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
312
348
  return mse_u_boundary
313
349
 
314
350
 
315
- def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
351
+ def boundary_neumann_nonstatio(
352
+ f: Callable,
353
+ batch: PDENonStatioBatch,
354
+ u: eqx.Module,
355
+ params: Params | ParamsDict,
356
+ facet: int,
357
+ dim_to_apply: slice,
358
+ ) -> float:
316
359
  r"""
317
- This omega boundary condition enforces a solution where :math:`\nabla u\cdot
318
- n` is equal to `f` at time_batch x omega borders. :math:`n` is the unitary
319
- outgoing vector normal at border :math:`\partial\Omega`.
360
+ This omega boundary condition enforces a solution where $\nabla u\cdot
361
+ n$ is equal to `f` at the cartesian product of `time_batch` x `omega
362
+ borders`. $n$ is the unitary outgoing vector normal at border
363
+ $\partial\Omega$.
320
364
 
321
365
  __Note__: if using a batch.param_batch_dict, we need to resolve the
322
366
  vmapping axes here however params["eq_params"] has already been fed with
@@ -424,4 +468,6 @@ def boundary_neumann_nonstatio(f, batch, u, params, facet, dim_to_apply):
424
468
  res**2,
425
469
  axis=-1,
426
470
  )
471
+ else:
472
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
427
473
  return mse_u_boundary
@@ -2,22 +2,43 @@
2
2
  Interface for diverse loss functions to factorize code
3
3
  """
4
4
 
5
+ from __future__ import (
6
+ annotations,
7
+ ) # https://docs.python.org/3/library/typing.html#constant
8
+
9
+ from typing import TYPE_CHECKING, Callable, Dict
5
10
  import jax
6
11
  import jax.numpy as jnp
7
12
  from jax import vmap
13
+ import equinox as eqx
14
+ from jaxtyping import Float, Array, PyTree
8
15
 
9
- from jinns.utils._pinn import PINN
10
- from jinns.utils._spinn import SPINN
11
- from jinns.utils._hyperpinn import HYPERPINN
12
16
  from jinns.loss._boundary_conditions import (
13
17
  _compute_boundary_loss,
14
18
  )
15
19
  from jinns.utils._utils import _check_user_func_return, _get_grid
20
+ from jinns.data._DataGenerators import (
21
+ append_obs_batch,
22
+ )
23
+ from jinns.utils._pinn import PINN
24
+ from jinns.utils._spinn import SPINN
25
+ from jinns.utils._hyperpinn import HYPERPINN
26
+ from jinns.data._Batchs import *
27
+ from jinns.parameters._params import Params, ParamsDict
28
+
29
+ if TYPE_CHECKING:
30
+ from jinns.utils._types import *
16
31
 
17
32
 
18
33
  def dynamic_loss_apply(
19
- dyn_loss, u, batches, params, vmap_axes, loss_weight, u_type=None
20
- ):
34
+ dyn_loss: DynamicLoss,
35
+ u: eqx.Module,
36
+ batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
37
+ params: Params | ParamsDict,
38
+ vmap_axes: tuple[int | None, ...],
39
+ loss_weight: float | Float[Array, "dyn_loss_dimension"],
40
+ u_type: PINN | HYPERPINN | None = None,
41
+ ) -> float:
21
42
  """
22
43
  Sometimes when u is a lambda function a or dict we do not have access to
23
44
  its type here, hence the last argument
@@ -35,11 +56,20 @@ def dynamic_loss_apply(
35
56
  elif u_type == SPINN or isinstance(u, SPINN):
36
57
  residuals = dyn_loss(*batches, u, params)
37
58
  mse_dyn_loss = jnp.mean(jnp.sum(loss_weight * residuals**2, axis=-1))
59
+ else:
60
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
38
61
 
39
62
  return mse_dyn_loss
40
63
 
41
64
 
42
- def normalization_loss_apply(u, batches, params, vmap_axes, int_length, loss_weight):
65
+ def normalization_loss_apply(
66
+ u: eqx.Module,
67
+ batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
68
+ params: Params | ParamsDict,
69
+ vmap_axes: tuple[int | None, ...],
70
+ int_length: int,
71
+ loss_weight: float,
72
+ ) -> float:
43
73
  # TODO merge stationary and non stationary cases
44
74
  if isinstance(u, (PINN, HYPERPINN)):
45
75
  if len(batches) == 1:
@@ -95,26 +125,38 @@ def normalization_loss_apply(u, batches, params, vmap_axes, int_length, loss_wei
95
125
  )
96
126
  ** 2
97
127
  )
128
+ else:
129
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
98
130
 
99
131
  return mse_norm_loss
100
132
 
101
133
 
102
134
  def boundary_condition_apply(
103
- u,
104
- batch,
105
- params,
106
- omega_boundary_fun,
107
- omega_boundary_condition,
108
- omega_boundary_dim,
109
- loss_weight,
110
- ):
135
+ u: eqx.Module,
136
+ batch: PDEStatioBatch | PDENonStatioBatch,
137
+ params: Params | ParamsDict,
138
+ omega_boundary_fun: Callable,
139
+ omega_boundary_condition: str,
140
+ omega_boundary_dim: int,
141
+ loss_weight: float | Float[Array, "boundary_cond_dim"],
142
+ ) -> float:
111
143
  if isinstance(omega_boundary_fun, dict):
112
144
  # We must create the facet tree dictionary as we do not have the
113
145
  # enumerate from the for loop to pass the id integer
114
- if batch[1].shape[-1] == 2:
146
+ if (
147
+ isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 2
148
+ ) or (
149
+ isinstance(batch, PDENonStatioBatch)
150
+ and batch.times_x_border_batch.shape[-1] == 2
151
+ ):
115
152
  # 1D
116
153
  facet_tree = {"xmin": 0, "xmax": 1}
117
- elif batch[1].shape[-1] == 4:
154
+ elif (
155
+ isinstance(batch, PDEStatioBatch) and batch.border_batch.shape[-1] == 4
156
+ ) or (
157
+ isinstance(batch, PDENonStatioBatch)
158
+ and batch.times_x_border_batch.shape[-1] == 4
159
+ ):
118
160
  # 2D
119
161
  facet_tree = {"xmin": 0, "xmax": 1, "ymin": 2, "ymax": 3}
120
162
  else:
@@ -138,7 +180,10 @@ def boundary_condition_apply(
138
180
  # Note that to keep the behaviour given in the comment above we neede
139
181
  # to specify is_leaf according to the note in the release of 0.4.29
140
182
  else:
141
- facet_tuple = tuple(f for f in range(batch[1].shape[-1]))
183
+ if isinstance(batch, PDEStatioBatch):
184
+ facet_tuple = tuple(f for f in range(batch.border_batch.shape[-1]))
185
+ else:
186
+ facet_tuple = tuple(f for f in range(batch.times_x_border_batch.shape[-1]))
142
187
  b_losses_by_facet = jax.tree_util.tree_map(
143
188
  lambda fa: jnp.mean(
144
189
  loss_weight
@@ -161,8 +206,14 @@ def boundary_condition_apply(
161
206
 
162
207
 
163
208
  def observations_loss_apply(
164
- u, batches, params, vmap_axes, observed_values, loss_weight, obs_slice
165
- ):
209
+ u: eqx.Module,
210
+ batches: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
211
+ params: Params | ParamsDict,
212
+ vmap_axes: tuple[int | None, ...],
213
+ observed_values: Float[Array, "batch_size observation_dim"],
214
+ loss_weight: float | Float[Array, "observation_dim"],
215
+ obs_slice: slice,
216
+ ) -> float:
166
217
  # TODO implement for SPINN
167
218
  if isinstance(u, (PINN, HYPERPINN)):
168
219
  v_u = vmap(
@@ -181,12 +232,20 @@ def observations_loss_apply(
181
232
  )
182
233
  elif isinstance(u, SPINN):
183
234
  raise RuntimeError("observation loss term not yet implemented for SPINNs")
235
+ else:
236
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
184
237
  return mse_observation_loss
185
238
 
186
239
 
187
240
  def initial_condition_apply(
188
- u, omega_batch, params, vmap_axes, initial_condition_fun, n, loss_weight
189
- ):
241
+ u: eqx.Module,
242
+ omega_batch: Float[Array, "dimension"],
243
+ params: Params | ParamsDict,
244
+ vmap_axes: tuple[int | None, ...],
245
+ initial_condition_fun: Callable,
246
+ n: int,
247
+ loss_weight: float | Float[Array, "initial_condition_dimension"],
248
+ ) -> float:
190
249
  if isinstance(u, (PINN, HYPERPINN)):
191
250
  v_u_t0 = vmap(
192
251
  lambda x, params: initial_condition_fun(x) - u(jnp.zeros((1,)), x, params),
@@ -212,25 +271,17 @@ def initial_condition_apply(
212
271
  )
213
272
  res = ini - v_ini
214
273
  mse_initial_condition = jnp.mean(jnp.sum(loss_weight * res**2, axis=-1))
274
+ else:
275
+ raise ValueError(f"Bad type for u. Got {type(u)}, expected PINN or SPINN")
215
276
  return mse_initial_condition
216
277
 
217
278
 
218
- def sobolev_reg_apply(u, batches, params, vmap_axes, sobolev_reg, loss_weight):
219
- # TODO implement for SPINN
220
- if isinstance(u, (PINN, HYPERPINN)):
221
- v_sob_reg = vmap(
222
- lambda *args: sobolev_reg(*args), # pylint: disable=E1121
223
- vmap_axes,
224
- 0,
225
- )
226
- mse_sobolev_loss = loss_weight * jnp.mean(v_sob_reg(*batches, params))
227
- elif isinstance(u, SPINN):
228
- raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
229
- return mse_sobolev_loss
230
-
231
-
232
279
  def constraints_system_loss_apply(
233
- u_constraints_dict, batch, params_dict, loss_weights, loss_weight_struct
280
+ u_constraints_dict: Dict,
281
+ batch: ODEBatch | PDEStatioBatch | PDENonStatioBatch,
282
+ params_dict: ParamsDict,
283
+ loss_weights: Dict[str, float | Array],
284
+ loss_weight_struct: PyTree,
234
285
  ):
235
286
  """
236
287
  Same function for systemlossODE and systemlossPDE!
@@ -243,17 +294,17 @@ def constraints_system_loss_apply(
243
294
  loss_weights,
244
295
  )
245
296
 
246
- if isinstance(params_dict["nn_params"], dict):
297
+ if isinstance(params_dict.nn_params, dict):
247
298
 
248
299
  def apply_u_constraint(
249
- u_constraint, nn_params, loss_weights_for_u, obs_batch_u
300
+ u_constraint, nn_params, eq_params, loss_weights_for_u, obs_batch_u
250
301
  ):
251
302
  res_dict_for_u = u_constraint.evaluate(
252
- {
253
- "nn_params": nn_params,
254
- "eq_params": params_dict["eq_params"],
255
- },
256
- batch._replace(obs_batch_dict=obs_batch_u),
303
+ Params(
304
+ nn_params=nn_params,
305
+ eq_params=eq_params,
306
+ ),
307
+ append_obs_batch(batch, obs_batch_u),
257
308
  )[1]
258
309
  res_dict_ponderated = jax.tree_util.tree_map(
259
310
  lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
@@ -267,7 +318,12 @@ def constraints_system_loss_apply(
267
318
  res_dict = jax.tree_util.tree_map(
268
319
  apply_u_constraint,
269
320
  u_constraints_dict,
270
- params_dict["nn_params"],
321
+ params_dict.nn_params,
322
+ (
323
+ params_dict.eq_params
324
+ if params_dict.eq_params.keys() == params_dict.nn_params.keys()
325
+ else {k: params_dict.eq_params for k in params_dict.nn_params.keys()}
326
+ ), # this manipulation is needed since we authorize eq_params not to have the same structure as nn_params in ParamsDict
271
327
  loss_weights_T,
272
328
  batch.obs_batch_dict,
273
329
  is_leaf=lambda x: (
@@ -283,7 +339,7 @@ def constraints_system_loss_apply(
283
339
  def apply_u_constraint(u_constraint, loss_weights_for_u, obs_batch_u):
284
340
  res_dict_for_u = u_constraint.evaluate(
285
341
  params_dict,
286
- batch._replace(obs_batch_dict=obs_batch_u),
342
+ append_obs_batch(batch, obs_batch_u),
287
343
  )[1]
288
344
  res_dict_ponderated = jax.tree_util.tree_map(
289
345
  lambda w, l: w * l, res_dict_for_u, loss_weights_for_u
@@ -0,0 +1,59 @@
1
+ """
2
+ Formalize the loss weights data structure
3
+ """
4
+
5
+ from typing import Dict
6
+ from jaxtyping import Array, Float
7
+ import equinox as eqx
8
+
9
+
10
+ class LossWeightsODE(eqx.Module):
11
+
12
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
13
+ initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
14
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
15
+
16
+
17
+ class LossWeightsODEDict(eqx.Module):
18
+
19
+ dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=None)
20
+ initial_condition: Dict[str, Array | Float | None] = eqx.field(
21
+ kw_only=True, default=None
22
+ )
23
+ observations: Dict[str, Array | Float | None] = eqx.field(
24
+ kw_only=True, default=None
25
+ )
26
+
27
+
28
+ class LossWeightsPDEStatio(eqx.Module):
29
+
30
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
31
+ norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
32
+ boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
33
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
34
+
35
+
36
+ class LossWeightsPDENonStatio(eqx.Module):
37
+
38
+ dyn_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
39
+ norm_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
40
+ boundary_loss: Array | Float | None = eqx.field(kw_only=True, default=1.0)
41
+ observations: Array | Float | None = eqx.field(kw_only=True, default=1.0)
42
+ initial_condition: Array | Float | None = eqx.field(kw_only=True, default=1.0)
43
+
44
+
45
+ class LossWeightsPDEDict(eqx.Module):
46
+ """
47
+ Only one type of LossWeights data structure for the SystemLossPDE:
48
+ Include the initial condition always for the code to be more generic
49
+ """
50
+
51
+ dyn_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
52
+ norm_loss: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
53
+ boundary_loss: Dict[str, Array | Float | None] = eqx.field(
54
+ kw_only=True, default=1.0
55
+ )
56
+ observations: Dict[str, Array | Float | None] = eqx.field(kw_only=True, default=1.0)
57
+ initial_condition: Dict[str, Array | Float | None] = eqx.field(
58
+ kw_only=True, default=1.0
59
+ )