jinns 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +74 -69
  14. jinns/loss/_LossODE.py +132 -348
  15. jinns/loss/_LossPDE.py +262 -549
  16. jinns/loss/__init__.py +32 -6
  17. jinns/loss/_abstract_loss.py +128 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_components.py +43 -0
  20. jinns/loss/_loss_utils.py +85 -179
  21. jinns/loss/_loss_weight_updates.py +202 -0
  22. jinns/loss/_loss_weights.py +64 -40
  23. jinns/loss/_operators.py +84 -74
  24. jinns/nn/__init__.py +15 -0
  25. jinns/nn/_abstract_pinn.py +22 -0
  26. jinns/nn/_hyperpinn.py +94 -57
  27. jinns/nn/_mlp.py +50 -25
  28. jinns/nn/_pinn.py +33 -19
  29. jinns/nn/_ppinn.py +70 -34
  30. jinns/nn/_save_load.py +21 -51
  31. jinns/nn/_spinn.py +33 -16
  32. jinns/nn/_spinn_mlp.py +28 -22
  33. jinns/nn/_utils.py +38 -0
  34. jinns/parameters/__init__.py +8 -1
  35. jinns/parameters/_derivative_keys.py +116 -177
  36. jinns/parameters/_params.py +18 -46
  37. jinns/plot/__init__.py +2 -0
  38. jinns/plot/_plot.py +35 -34
  39. jinns/solver/_rar.py +80 -63
  40. jinns/solver/_solve.py +207 -92
  41. jinns/solver/_utils.py +4 -6
  42. jinns/utils/__init__.py +2 -0
  43. jinns/utils/_containers.py +16 -10
  44. jinns/utils/_types.py +20 -54
  45. jinns/utils/_utils.py +4 -11
  46. jinns/validation/__init__.py +2 -0
  47. jinns/validation/_validation.py +20 -19
  48. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/METADATA +8 -4
  49. jinns-1.5.0.dist-info/RECORD +55 -0
  50. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/WHEEL +1 -1
  51. jinns/data/_DataGenerators.py +0 -1634
  52. jinns-1.3.0.dist-info/RECORD +0 -44
  53. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/AUTHORS +0 -0
  54. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info/licenses}/LICENSE +0 -0
  55. {jinns-1.3.0.dist-info → jinns-1.5.0.dist-info}/top_level.txt +0 -0
@@ -6,23 +6,27 @@ from __future__ import (
6
6
  annotations,
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
+ import warnings
10
+ import abc
11
+ from functools import partial
12
+ from typing import Callable, TYPE_CHECKING, ClassVar, Generic, TypeVar
9
13
  import equinox as eqx
10
- from typing import Callable, Dict, TYPE_CHECKING, ClassVar
11
14
  from jaxtyping import Float, Array
12
- from functools import partial
13
- import abc
15
+ import jax.numpy as jnp
14
16
 
15
17
 
16
18
  # See : https://docs.kidger.site/equinox/api/module/advanced_fields/#equinox.AbstractClassVar--known-issues
17
19
  if TYPE_CHECKING:
18
20
  from typing import ClassVar as AbstractClassVar
19
- from jinns.parameters import Params, ParamsDict
21
+ from jinns.parameters import Params
22
+ from jinns.nn._abstract_pinn import AbstractPINN
20
23
  else:
21
24
  from equinox import AbstractClassVar
22
25
 
26
+ InputDim = TypeVar("InputDim")
23
27
 
24
- def _decorator_heteregeneous_params(evaluate):
25
28
 
29
+ def _decorator_heteregeneous_params(evaluate):
26
30
  def wrapper(*args):
27
31
  self, inputs, u, params = args
28
32
  _params = eqx.tree_at(
@@ -39,7 +43,7 @@ def _decorator_heteregeneous_params(evaluate):
39
43
  return wrapper
40
44
 
41
45
 
42
- class DynamicLoss(eqx.Module):
46
+ class DynamicLoss(eqx.Module, Generic[InputDim]):
43
47
  r"""
44
48
  Abstract base class for dynamic losses. Implements the physical term:
45
49
 
@@ -55,7 +59,7 @@ class DynamicLoss(eqx.Module):
55
59
  Tmax needs to be given when the PINN time input is normalized in
56
60
  [0, 1], ie. we have performed renormalization of the differential
57
61
  equation
58
- eq_params_heterogeneity : Dict[str, Callable | None], default=None
62
+ eq_params_heterogeneity : dict[str, Callable | None], default=None
59
63
  A dict with the same keys as eq_params and the value being either None
60
64
  (no heterogeneity) or a function which encodes for the spatio-temporal
61
65
  heterogeneity of the parameter.
@@ -68,49 +72,78 @@ class DynamicLoss(eqx.Module):
68
72
  A value can be missing, in this case there is no heterogeneity (=None).
69
73
  Default None, meaning there is no heterogeneity in the equation
70
74
  parameters.
75
+ vectorial_dyn_loss_ponderation : Float[Array, " dim"], default=None
76
+ Add a different ponderation weight to each of the dimension to the
77
+ dynamic loss. This array must have the same dimension as the output of
78
+ the dynamic loss equation or an error is raised. Default is None which
79
+ means that a ponderation of 1 is applied on each dimension.
80
+ `vectorial_dyn_loss_ponderation`
81
+ is different from loss weights, which are attributes of Loss
82
+ classes and which implement scalar (and possibly dynamic)
83
+ ponderations for each term of the total loss.
84
+ `vectorial_dyn_loss_ponderation` can be used with loss weights.
71
85
  """
72
86
 
73
87
  _eq_type = AbstractClassVar[str] # class variable denoting the type of
74
88
  # differential equation
75
89
  Tmax: Float = eqx.field(kw_only=True, default=1)
76
- eq_params_heterogeneity: Dict[str, Callable | None] = eqx.field(
90
+ eq_params_heterogeneity: dict[str, Callable | None] | None = eqx.field(
77
91
  kw_only=True, default=None, static=True
78
92
  )
93
+ vectorial_dyn_loss_ponderation: Float[Array, " dim"] | None = eqx.field(
94
+ kw_only=True, default=None
95
+ )
96
+
97
+ def __post_init__(self):
98
+ if self.vectorial_dyn_loss_ponderation is None:
99
+ self.vectorial_dyn_loss_ponderation = jnp.array(1.0)
79
100
 
80
101
  def _eval_heterogeneous_parameters(
81
102
  self,
82
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
83
- u: eqx.Module,
84
- params: Params | ParamsDict,
85
- eq_params_heterogeneity: Dict[str, Callable | None] = None,
86
- ) -> Dict[str, float | Float[Array, "parameter_dimension"]]:
103
+ inputs: InputDim,
104
+ u: AbstractPINN,
105
+ params: Params[Array],
106
+ eq_params_heterogeneity: dict[str, Callable | None] | None = None,
107
+ ) -> dict[str, Array]:
87
108
  eq_params_ = {}
88
109
  if eq_params_heterogeneity is None:
89
110
  return params.eq_params
111
+
90
112
  for k, p in params.eq_params.items():
91
113
  try:
92
- if eq_params_heterogeneity[k] is None:
93
- eq_params_[k] = p
114
+ if eq_params_heterogeneity[k] is not None:
115
+ eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params) # type: ignore don't know why pyright says
116
+ # eq_params_heterogeneity[k] can be None here
94
117
  else:
95
- eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
118
+ eq_params_[k] = p
96
119
  except KeyError:
97
120
  # we authorize missing eq_params_heterogeneity key
98
121
  # if its heterogeneity is None anyway
99
122
  eq_params_[k] = p
100
123
  return eq_params_
101
124
 
102
- def _evaluate(
125
+ @partial(_decorator_heteregeneous_params)
126
+ def evaluate(
103
127
  self,
104
- inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
105
- u: eqx.Module,
106
- params: Params | ParamsDict,
128
+ inputs: InputDim,
129
+ u: AbstractPINN,
130
+ params: Params[Array],
107
131
  ) -> float:
108
- evaluation = self.equation(inputs, u, params)
132
+ evaluation = self.vectorial_dyn_loss_ponderation * self.equation(
133
+ inputs, u, params
134
+ )
109
135
  if len(evaluation.shape) == 0:
110
136
  raise ValueError(
111
137
  "The output of dynamic loss must be vectorial, "
112
138
  "i.e. of shape (d,) with d >= 1"
113
139
  )
140
+ if len(evaluation.shape) > 1:
141
+ warnings.warn(
142
+ "Return value from DynamicLoss' equation has more "
143
+ "than one dimension. This is in general a mistake (probably from "
144
+ "an unfortunate broadcast in jnp.array computations) resulting in "
145
+ "bad reduction operations in losses."
146
+ )
114
147
  return evaluation
115
148
 
116
149
  @abc.abstractmethod
@@ -120,7 +153,7 @@ class DynamicLoss(eqx.Module):
120
153
  raise NotImplementedError("You should implement your equation.")
121
154
 
122
155
 
123
- class ODE(DynamicLoss):
156
+ class ODE(DynamicLoss[Float[Array, " 1"]]):
124
157
  r"""
125
158
  Abstract base class for ODE dynamic losses. All dynamic loss must subclass
126
159
  this class and override the abstract method `equation`.
@@ -131,7 +164,7 @@ class ODE(DynamicLoss):
131
164
  Tmax needs to be given when the PINN time input is normalized in
132
165
  [0, 1], ie. we have performed renormalization of the differential
133
166
  equation
134
- eq_params_heterogeneity : Dict[str, Callable | None], default=None
167
+ eq_params_heterogeneity : dict[str, Callable | None], default=None
135
168
  Default None. A dict with the keys being the same as in eq_params
136
169
  and the value being either None (no heterogeneity) or a function
137
170
  which encodes for the spatio-temporal heterogeneity of the parameter.
@@ -147,19 +180,9 @@ class ODE(DynamicLoss):
147
180
 
148
181
  _eq_type: ClassVar[str] = "ODE"
149
182
 
150
- @partial(_decorator_heteregeneous_params)
151
- def evaluate(
152
- self,
153
- t: Float[Array, "1"],
154
- u: eqx.Module | Dict[str, eqx.Module],
155
- params: Params | ParamsDict,
156
- ) -> float:
157
- """Here we call DynamicLoss._evaluate with x=None"""
158
- return self._evaluate(t, u, params)
159
-
160
183
  @abc.abstractmethod
161
184
  def equation(
162
- self, t: Float[Array, "1"], u: eqx.Module, params: Params | ParamsDict
185
+ self, t: Float[Array, " 1"], u: AbstractPINN, params: Params[Array]
163
186
  ) -> float:
164
187
  r"""
165
188
  The differential operator defining the ODE.
@@ -170,11 +193,11 @@ class ODE(DynamicLoss):
170
193
 
171
194
  Parameters
172
195
  ----------
173
- t : Float[Array, "1"]
196
+ t : Float[Array, " 1"]
174
197
  A 1-dimensional jnp.array representing the time point.
175
- u : eqx.Module
198
+ u : AbstractPINN
176
199
  The network with a call signature `u(t, params)`.
177
- params : Params | ParamsDict
200
+ params : Params[Array]
178
201
  The equation and neural network parameters $\theta$ and $\nu$.
179
202
 
180
203
  Returns
@@ -190,7 +213,7 @@ class ODE(DynamicLoss):
190
213
  raise NotImplementedError
191
214
 
192
215
 
193
- class PDEStatio(DynamicLoss):
216
+ class PDEStatio(DynamicLoss[Float[Array, " dim"]]):
194
217
  r"""
195
218
  Abstract base class for stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
196
219
 
@@ -200,7 +223,7 @@ class PDEStatio(DynamicLoss):
200
223
  Tmax needs to be given when the PINN time input is normalized in
201
224
  [0, 1], ie. we have performed renormalization of the differential
202
225
  equation
203
- eq_params_heterogeneity : Dict[str, Callable | None], default=None
226
+ eq_params_heterogeneity : dict[str, Callable | None], default=None
204
227
  Default None. A dict with the keys being the same as in eq_params
205
228
  and the value being either None (no heterogeneity) or a function
206
229
  which encodes for the spatio-temporal heterogeneity of the parameter.
@@ -216,16 +239,9 @@ class PDEStatio(DynamicLoss):
216
239
 
217
240
  _eq_type: ClassVar[str] = "Statio PDE"
218
241
 
219
- @partial(_decorator_heteregeneous_params)
220
- def evaluate(
221
- self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
222
- ) -> float:
223
- """Here we call the DynamicLoss._evaluate with t=None"""
224
- return self._evaluate(x, u, params)
225
-
226
242
  @abc.abstractmethod
227
243
  def equation(
228
- self, x: Float[Array, "d"], u: eqx.Module, params: Params | ParamsDict
244
+ self, x: Float[Array, " dim"], u: AbstractPINN, params: Params[Array]
229
245
  ) -> float:
230
246
  r"""The differential operator defining the stationnary PDE.
231
247
 
@@ -235,11 +251,11 @@ class PDEStatio(DynamicLoss):
235
251
 
236
252
  Parameters
237
253
  ----------
238
- x : Float[Array, "d"]
254
+ x : Float[Array, " dim"]
239
255
  A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
240
- u : eqx.Module
256
+ u : AbstractPINN
241
257
  The neural network.
242
- params : Params | ParamsDict
258
+ params : Params[Array]
243
259
  The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
244
260
 
245
261
  Returns
@@ -255,7 +271,7 @@ class PDEStatio(DynamicLoss):
255
271
  raise NotImplementedError
256
272
 
257
273
 
258
- class PDENonStatio(DynamicLoss):
274
+ class PDENonStatio(DynamicLoss[Float[Array, " 1 + dim"]]):
259
275
  """
260
276
  Abstract base class for non-stationnary PDE dynamic losses. All dynamic loss must subclass this class and override the abstract method `equation`.
261
277
 
@@ -265,7 +281,7 @@ class PDENonStatio(DynamicLoss):
265
281
  Tmax needs to be given when the PINN time input is normalized in
266
282
  [0, 1], ie. we have performed renormalization of the differential
267
283
  equation
268
- eq_params_heterogeneity : Dict[str, Callable | None], default=None
284
+ eq_params_heterogeneity : dict[str, Callable | None], default=None
269
285
  Default None. A dict with the keys being the same as in eq_params
270
286
  and the value being either None (no heterogeneity) or a function
271
287
  which encodes for the spatio-temporal heterogeneity of the parameter.
@@ -281,23 +297,12 @@ class PDENonStatio(DynamicLoss):
281
297
 
282
298
  _eq_type: ClassVar[str] = "Non-statio PDE"
283
299
 
284
- @partial(_decorator_heteregeneous_params)
285
- def evaluate(
286
- self,
287
- t_x: Float[Array, "1 + dim"],
288
- u: eqx.Module,
289
- params: Params | ParamsDict,
290
- ) -> float:
291
- """Here we call the DynamicLoss._evaluate with full arguments"""
292
- ans = self._evaluate(t_x, u, params)
293
- return ans
294
-
295
300
  @abc.abstractmethod
296
301
  def equation(
297
302
  self,
298
- t_x: Float[Array, "1 + dim"],
299
- u: eqx.Module,
300
- params: Params | ParamsDict,
303
+ t_x: Float[Array, " 1 + dim"],
304
+ u: AbstractPINN,
305
+ params: Params[Array],
301
306
  ) -> float:
302
307
  r"""The differential operator defining the non-stationnary PDE.
303
308
 
@@ -307,11 +312,11 @@ class PDENonStatio(DynamicLoss):
307
312
 
308
313
  Parameters
309
314
  ----------
310
- t_x : Float[Array, "1 + dim"]
315
+ t_x : Float[Array, " 1 + dim"]
311
316
  A jnp array containing the concatenation of a time point and a point in $\Omega$
312
- u : eqx.Module
317
+ u : AbstractPINN
313
318
  The neural network.
314
- params : Params | ParamsDict
319
+ params : Params[Array]
315
320
  The parameters of the equation and the networks, $\theta$ and $\nu$ respectively.
316
321
  Returns
317
322
  -------