jinns 1.1.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +534 -343
- jinns/loss/_DynamicLoss.py +152 -175
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +4 -4
- jinns/loss/_LossPDE.py +102 -74
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +150 -281
- jinns/loss/_loss_utils.py +95 -67
- jinns/loss/_operators.py +441 -186
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +47 -31
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +113 -100
- jinns/solver/_rar.py +104 -409
- jinns/solver/_solve.py +87 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +1 -4
- jinns/utils/_containers.py +3 -1
- jinns/utils/_types.py +5 -4
- jinns/utils/_utils.py +40 -12
- jinns-1.3.0.dist-info/METADATA +127 -0
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -410
- jinns/utils/_pinn.py +0 -334
- jinns/utils/_spinn.py +0 -268
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -21,53 +21,22 @@ else:
|
|
|
21
21
|
from equinox import AbstractClassVar
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def _decorator_heteregeneous_params(evaluate
|
|
24
|
+
def _decorator_heteregeneous_params(evaluate):
|
|
25
25
|
|
|
26
|
-
def
|
|
27
|
-
self,
|
|
26
|
+
def wrapper(*args):
|
|
27
|
+
self, inputs, u, params = args
|
|
28
28
|
_params = eqx.tree_at(
|
|
29
29
|
lambda p: p.eq_params,
|
|
30
30
|
params,
|
|
31
31
|
self._eval_heterogeneous_parameters(
|
|
32
|
-
|
|
32
|
+
inputs, u, params, self.eq_params_heterogeneity
|
|
33
33
|
),
|
|
34
34
|
)
|
|
35
35
|
new_args = args[:-1] + (_params,)
|
|
36
36
|
res = evaluate(*new_args)
|
|
37
37
|
return res
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
self, x, u, params = args
|
|
41
|
-
_params = eqx.tree_at(
|
|
42
|
-
lambda p: p.eq_params,
|
|
43
|
-
params,
|
|
44
|
-
self._eval_heterogeneous_parameters(
|
|
45
|
-
None, x, u, params, self.eq_params_heterogeneity
|
|
46
|
-
),
|
|
47
|
-
)
|
|
48
|
-
new_args = args[:-1] + (_params,)
|
|
49
|
-
res = evaluate(*new_args)
|
|
50
|
-
return res
|
|
51
|
-
|
|
52
|
-
def wrapper_pde_non_statio(*args):
|
|
53
|
-
self, t, x, u, params = args
|
|
54
|
-
_params = eqx.tree_at(
|
|
55
|
-
lambda p: p.eq_params,
|
|
56
|
-
params,
|
|
57
|
-
self._eval_heterogeneous_parameters(
|
|
58
|
-
t, x, u, params, self.eq_params_heterogeneity
|
|
59
|
-
),
|
|
60
|
-
)
|
|
61
|
-
new_args = args[:-1] + (_params,)
|
|
62
|
-
res = evaluate(*new_args)
|
|
63
|
-
return res
|
|
64
|
-
|
|
65
|
-
if eq_type == "ODE":
|
|
66
|
-
return wrapper_ode
|
|
67
|
-
elif eq_type == "Statio PDE":
|
|
68
|
-
return wrapper_pde_statio
|
|
69
|
-
elif eq_type == "Non-statio PDE":
|
|
70
|
-
return wrapper_pde_non_statio
|
|
39
|
+
return wrapper
|
|
71
40
|
|
|
72
41
|
|
|
73
42
|
class DynamicLoss(eqx.Module):
|
|
@@ -110,8 +79,7 @@ class DynamicLoss(eqx.Module):
|
|
|
110
79
|
|
|
111
80
|
def _eval_heterogeneous_parameters(
|
|
112
81
|
self,
|
|
113
|
-
|
|
114
|
-
x: Float[Array, "dim"],
|
|
82
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
115
83
|
u: eqx.Module,
|
|
116
84
|
params: Params | ParamsDict,
|
|
117
85
|
eq_params_heterogeneity: Dict[str, Callable | None] = None,
|
|
@@ -124,14 +92,7 @@ class DynamicLoss(eqx.Module):
|
|
|
124
92
|
if eq_params_heterogeneity[k] is None:
|
|
125
93
|
eq_params_[k] = p
|
|
126
94
|
else:
|
|
127
|
-
|
|
128
|
-
# signature will vary according to _eq_type
|
|
129
|
-
if self._eq_type == "ODE":
|
|
130
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
|
|
131
|
-
elif self._eq_type == "Statio PDE":
|
|
132
|
-
eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
|
|
133
|
-
elif self._eq_type == "Non-statio PDE":
|
|
134
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
|
|
95
|
+
eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
|
|
135
96
|
except KeyError:
|
|
136
97
|
# we authorize missing eq_params_heterogeneity key
|
|
137
98
|
# if its heterogeneity is None anyway
|
|
@@ -140,22 +101,17 @@ class DynamicLoss(eqx.Module):
|
|
|
140
101
|
|
|
141
102
|
def _evaluate(
|
|
142
103
|
self,
|
|
143
|
-
|
|
144
|
-
x: Float[Array, "dim"],
|
|
104
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
145
105
|
u: eqx.Module,
|
|
146
106
|
params: Params | ParamsDict,
|
|
147
107
|
) -> float:
|
|
148
|
-
|
|
149
|
-
if
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
else:
|
|
156
|
-
raise NotImplementedError("the equation type is not handled.")
|
|
157
|
-
|
|
158
|
-
return ans
|
|
108
|
+
evaluation = self.equation(inputs, u, params)
|
|
109
|
+
if len(evaluation.shape) == 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"The output of dynamic loss must be vectorial, "
|
|
112
|
+
"i.e. of shape (d,) with d >= 1"
|
|
113
|
+
)
|
|
114
|
+
return evaluation
|
|
159
115
|
|
|
160
116
|
@abc.abstractmethod
|
|
161
117
|
def equation(self, *args, **kwargs):
|
|
@@ -191,7 +147,7 @@ class ODE(DynamicLoss):
|
|
|
191
147
|
|
|
192
148
|
_eq_type: ClassVar[str] = "ODE"
|
|
193
149
|
|
|
194
|
-
@partial(_decorator_heteregeneous_params
|
|
150
|
+
@partial(_decorator_heteregeneous_params)
|
|
195
151
|
def evaluate(
|
|
196
152
|
self,
|
|
197
153
|
t: Float[Array, "1"],
|
|
@@ -199,7 +155,7 @@ class ODE(DynamicLoss):
|
|
|
199
155
|
params: Params | ParamsDict,
|
|
200
156
|
) -> float:
|
|
201
157
|
"""Here we call DynamicLoss._evaluate with x=None"""
|
|
202
|
-
return self._evaluate(t,
|
|
158
|
+
return self._evaluate(t, u, params)
|
|
203
159
|
|
|
204
160
|
@abc.abstractmethod
|
|
205
161
|
def equation(
|
|
@@ -260,12 +216,12 @@ class PDEStatio(DynamicLoss):
|
|
|
260
216
|
|
|
261
217
|
_eq_type: ClassVar[str] = "Statio PDE"
|
|
262
218
|
|
|
263
|
-
@partial(_decorator_heteregeneous_params
|
|
219
|
+
@partial(_decorator_heteregeneous_params)
|
|
264
220
|
def evaluate(
|
|
265
221
|
self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
|
|
266
222
|
) -> float:
|
|
267
223
|
"""Here we call the DynamicLoss._evaluate with t=None"""
|
|
268
|
-
return self._evaluate(
|
|
224
|
+
return self._evaluate(x, u, params)
|
|
269
225
|
|
|
270
226
|
@abc.abstractmethod
|
|
271
227
|
def equation(
|
|
@@ -325,23 +281,21 @@ class PDENonStatio(DynamicLoss):
|
|
|
325
281
|
|
|
326
282
|
_eq_type: ClassVar[str] = "Non-statio PDE"
|
|
327
283
|
|
|
328
|
-
@partial(_decorator_heteregeneous_params
|
|
284
|
+
@partial(_decorator_heteregeneous_params)
|
|
329
285
|
def evaluate(
|
|
330
286
|
self,
|
|
331
|
-
|
|
332
|
-
x: Float[Array, "dim"],
|
|
287
|
+
t_x: Float[Array, "1 + dim"],
|
|
333
288
|
u: eqx.Module,
|
|
334
289
|
params: Params | ParamsDict,
|
|
335
290
|
) -> float:
|
|
336
291
|
"""Here we call the DynamicLoss._evaluate with full arguments"""
|
|
337
|
-
ans = self._evaluate(
|
|
292
|
+
ans = self._evaluate(t_x, u, params)
|
|
338
293
|
return ans
|
|
339
294
|
|
|
340
295
|
@abc.abstractmethod
|
|
341
296
|
def equation(
|
|
342
297
|
self,
|
|
343
|
-
|
|
344
|
-
x: Float[Array, "dim"],
|
|
298
|
+
t_x: Float[Array, "1 + dim"],
|
|
345
299
|
u: eqx.Module,
|
|
346
300
|
params: Params | ParamsDict,
|
|
347
301
|
) -> float:
|
|
@@ -353,10 +307,8 @@ class PDENonStatio(DynamicLoss):
|
|
|
353
307
|
|
|
354
308
|
Parameters
|
|
355
309
|
----------
|
|
356
|
-
|
|
357
|
-
A
|
|
358
|
-
x : Float[Array, "d"]
|
|
359
|
-
A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
|
|
310
|
+
t_x : Float[Array, "1 + dim"]
|
|
311
|
+
A jnp array containing the concatenation of a time point and a point in $\Omega$
|
|
360
312
|
u : eqx.Module
|
|
361
313
|
The neural network.
|
|
362
314
|
params : Params | ParamsDict
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -28,7 +28,7 @@ from jinns.parameters._params import (
|
|
|
28
28
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
29
|
from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
|
|
30
30
|
from jinns.loss._DynamicLossAbstract import ODE
|
|
31
|
-
from jinns.
|
|
31
|
+
from jinns.nn._pinn import PINN
|
|
32
32
|
|
|
33
33
|
if TYPE_CHECKING:
|
|
34
34
|
from jinns.utils._types import *
|
|
@@ -209,7 +209,7 @@ class LossODE(_LossODEAbstract):
|
|
|
209
209
|
mse_dyn_loss = dynamic_loss_apply(
|
|
210
210
|
self.dynamic_loss.evaluate,
|
|
211
211
|
self.u,
|
|
212
|
-
|
|
212
|
+
temporal_batch,
|
|
213
213
|
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
214
214
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
215
215
|
self.loss_weights.dyn_loss,
|
|
@@ -227,7 +227,7 @@ class LossODE(_LossODEAbstract):
|
|
|
227
227
|
else:
|
|
228
228
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
229
229
|
t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
|
|
230
|
-
t0 = jnp.array(t0)
|
|
230
|
+
t0 = jnp.array([t0])
|
|
231
231
|
u0 = jnp.array(u0)
|
|
232
232
|
mse_initial_condition = jnp.mean(
|
|
233
233
|
self.loss_weights.initial_condition
|
|
@@ -522,7 +522,7 @@ class SystemLossODE(eqx.Module):
|
|
|
522
522
|
return dynamic_loss_apply(
|
|
523
523
|
dyn_loss.evaluate,
|
|
524
524
|
self.u_dict,
|
|
525
|
-
|
|
525
|
+
temporal_batch,
|
|
526
526
|
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
527
527
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
528
528
|
loss_weight,
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
|
|
|
22
22
|
initial_condition_apply,
|
|
23
23
|
constraints_system_loss_apply,
|
|
24
24
|
)
|
|
25
|
-
from jinns.data._DataGenerators import
|
|
26
|
-
append_obs_batch,
|
|
27
|
-
)
|
|
25
|
+
from jinns.data._DataGenerators import append_obs_batch
|
|
28
26
|
from jinns.parameters._params import (
|
|
29
27
|
_get_vmap_in_axes_params,
|
|
30
28
|
_update_eq_params_dict,
|
|
@@ -40,8 +38,8 @@ from jinns.loss._loss_weights import (
|
|
|
40
38
|
LossWeightsPDEDict,
|
|
41
39
|
)
|
|
42
40
|
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
43
|
-
from jinns.
|
|
44
|
-
from jinns.
|
|
41
|
+
from jinns.nn._pinn import PINN
|
|
42
|
+
from jinns.nn._spinn import SPINN
|
|
45
43
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
46
44
|
|
|
47
45
|
|
|
@@ -94,12 +92,16 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
94
92
|
Note that it must be a slice and not an integer
|
|
95
93
|
(but a preprocessing of the user provided argument takes care of it)
|
|
96
94
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
97
|
-
|
|
95
|
+
Monte-Carlo sample points for computing the
|
|
98
96
|
normalization constant. Default is None.
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
97
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
98
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
99
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
100
|
+
`norm_weights` should have the same leading dimension as
|
|
101
|
+
`norm_samples`.
|
|
102
|
+
Alternatively, the user can pass a float or an integer.
|
|
103
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
104
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
103
105
|
obs_slice : slice, default=None
|
|
104
106
|
slice object specifying the begininning/ending of the PINN output
|
|
105
107
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -129,7 +131,9 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
129
131
|
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
130
132
|
kw_only=True, default=None
|
|
131
133
|
)
|
|
132
|
-
|
|
134
|
+
norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
|
|
135
|
+
kw_only=True, default=None
|
|
136
|
+
)
|
|
133
137
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
134
138
|
|
|
135
139
|
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
@@ -253,8 +257,25 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
253
257
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
254
258
|
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
255
259
|
|
|
256
|
-
if self.norm_samples is not None
|
|
257
|
-
|
|
260
|
+
if self.norm_samples is not None:
|
|
261
|
+
if self.norm_weights is None:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
264
|
+
)
|
|
265
|
+
try:
|
|
266
|
+
assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
|
|
267
|
+
except (AssertionError, AttributeError):
|
|
268
|
+
if isinstance(self.norm_weights, (int, float)):
|
|
269
|
+
self.norm_weights = jnp.array(
|
|
270
|
+
[self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
"`norm_weights` should have the same leading dimension"
|
|
275
|
+
" as `norm_samples`,"
|
|
276
|
+
f" got shape {self.norm_weights.shape} and"
|
|
277
|
+
f" shape {self.norm_samples.shape}."
|
|
278
|
+
)
|
|
258
279
|
|
|
259
280
|
@abc.abstractmethod
|
|
260
281
|
def evaluate(
|
|
@@ -326,12 +347,16 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
326
347
|
Note that it must be a slice and not an integer
|
|
327
348
|
(but a preprocessing of the user provided argument takes care of it)
|
|
328
349
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
329
|
-
|
|
350
|
+
Monte-Carlo sample points for computing the
|
|
330
351
|
normalization constant. Default is None.
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
352
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
353
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
354
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
355
|
+
`norm_weights` should have the same leading dimension as
|
|
356
|
+
`norm_samples`.
|
|
357
|
+
Alternatively, the user can pass a float or an integer.
|
|
358
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
359
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
335
360
|
obs_slice : slice, default=None
|
|
336
361
|
slice object specifying the begininning/ending of the PINN output
|
|
337
362
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -370,8 +395,8 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
370
395
|
|
|
371
396
|
def _get_dynamic_loss_batch(
|
|
372
397
|
self, batch: PDEStatioBatch
|
|
373
|
-
) ->
|
|
374
|
-
return
|
|
398
|
+
) -> Float[Array, "batch_size dimension"]:
|
|
399
|
+
return batch.domain_batch
|
|
375
400
|
|
|
376
401
|
def _get_normalization_loss_batch(
|
|
377
402
|
self, _
|
|
@@ -432,8 +457,8 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
432
457
|
self.u,
|
|
433
458
|
self._get_normalization_loss_batch(batch),
|
|
434
459
|
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
435
|
-
|
|
436
|
-
self.
|
|
460
|
+
vmap_in_axes_params,
|
|
461
|
+
self.norm_weights,
|
|
437
462
|
self.loss_weights.norm_loss,
|
|
438
463
|
)
|
|
439
464
|
else:
|
|
@@ -551,12 +576,16 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
551
576
|
Note that it must be a slice and not an integer
|
|
552
577
|
(but a preprocessing of the user provided argument takes care of it)
|
|
553
578
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
554
|
-
|
|
579
|
+
Monte-Carlo sample points for computing the
|
|
555
580
|
normalization constant. Default is None.
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
581
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
582
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
583
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
584
|
+
`norm_weights` should have the same leading dimension as
|
|
585
|
+
`norm_samples`.
|
|
586
|
+
Alternatively, the user can pass a float or an integer.
|
|
587
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
588
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
560
589
|
obs_slice : slice, default=None
|
|
561
590
|
slice object specifying the begininning/ending of the PINN output
|
|
562
591
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -575,6 +604,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
575
604
|
kw_only=True, default=None, static=True
|
|
576
605
|
)
|
|
577
606
|
|
|
607
|
+
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
608
|
+
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
609
|
+
|
|
578
610
|
def __post_init__(self, params=None):
|
|
579
611
|
"""
|
|
580
612
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
@@ -585,7 +617,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
585
617
|
) # because __init__ or __post_init__ of Base
|
|
586
618
|
# class is not automatically called
|
|
587
619
|
|
|
588
|
-
self.vmap_in_axes = (0,
|
|
620
|
+
self.vmap_in_axes = (0,) # for t_x
|
|
589
621
|
|
|
590
622
|
if self.initial_condition_fun is None:
|
|
591
623
|
warnings.warn(
|
|
@@ -593,28 +625,28 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
593
625
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
594
626
|
)
|
|
595
627
|
|
|
628
|
+
# witht the variables below we avoid memory overflow since a cartesian
|
|
629
|
+
# product is taken
|
|
630
|
+
self._max_norm_time_slices = 100
|
|
631
|
+
self._max_norm_samples_omega = 1000
|
|
632
|
+
|
|
596
633
|
def _get_dynamic_loss_batch(
|
|
597
634
|
self, batch: PDENonStatioBatch
|
|
598
|
-
) ->
|
|
599
|
-
|
|
600
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
601
|
-
return (times_batch, omega_batch)
|
|
635
|
+
) -> Float[Array, "batch_size 1+dimension"]:
|
|
636
|
+
return batch.domain_batch
|
|
602
637
|
|
|
603
638
|
def _get_normalization_loss_batch(
|
|
604
639
|
self, batch: PDENonStatioBatch
|
|
605
|
-
) ->
|
|
640
|
+
) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
|
|
606
641
|
return (
|
|
607
|
-
batch.
|
|
608
|
-
self.norm_samples,
|
|
642
|
+
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
643
|
+
self.norm_samples[: self._max_norm_samples_omega],
|
|
609
644
|
)
|
|
610
645
|
|
|
611
646
|
def _get_observations_loss_batch(
|
|
612
647
|
self, batch: PDENonStatioBatch
|
|
613
648
|
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
614
|
-
return (
|
|
615
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
616
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
617
|
-
)
|
|
649
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
618
650
|
|
|
619
651
|
def __call__(self, *args, **kwargs):
|
|
620
652
|
return self.evaluate(*args, **kwargs)
|
|
@@ -637,7 +669,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
637
669
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
638
670
|
inputs/outputs/parameters
|
|
639
671
|
"""
|
|
640
|
-
omega_batch = batch.
|
|
672
|
+
omega_batch = batch.initial_batch
|
|
641
673
|
|
|
642
674
|
# Retrieve the optional eq_params_batch
|
|
643
675
|
# and update eq_params with the latter
|
|
@@ -660,7 +692,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
660
692
|
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
661
693
|
(0,) + vmap_in_axes_params,
|
|
662
694
|
self.initial_condition_fun,
|
|
663
|
-
omega_batch.shape[0],
|
|
664
695
|
self.loss_weights.initial_condition,
|
|
665
696
|
)
|
|
666
697
|
else:
|
|
@@ -730,15 +761,21 @@ class SystemLossPDE(eqx.Module):
|
|
|
730
761
|
(default) then no temporal boundary condition is applied
|
|
731
762
|
Must share the keys of `u_dict`
|
|
732
763
|
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
733
|
-
A dict of
|
|
734
|
-
normalization constant. Default is None
|
|
735
|
-
Must share the keys of `u_dict`
|
|
736
|
-
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
737
|
-
A dict of Float. The domain area
|
|
738
|
-
(or interval length in 1D) upon which we perform the numerical
|
|
739
|
-
integration for each element of u_dict.
|
|
740
|
-
Default is None
|
|
764
|
+
A dict of Monte-Carlo sample points for computing the
|
|
765
|
+
normalization constant. Default is None.
|
|
741
766
|
Must share the keys of `u_dict`
|
|
767
|
+
norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
|
|
768
|
+
A dict of jnp.array with the same keys as `u_dict`. The importance
|
|
769
|
+
sampling weights for Monte-Carlo integration of the
|
|
770
|
+
normalization constant for each element of u_dict. Must be provided if
|
|
771
|
+
`norm_samples_dict` is provided.
|
|
772
|
+
`norm_weights_dict[key]` should have the same leading dimension as
|
|
773
|
+
`norm_samples_dict[key]` for each `key`.
|
|
774
|
+
Alternatively, the user can pass a float or an integer.
|
|
775
|
+
For each key, an array of similar shape to `norm_samples_dict[key]`
|
|
776
|
+
or shape `(1,)` is expected. These corresponds to the weights $w_k =
|
|
777
|
+
\frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
|
|
778
|
+
the Monte-Carlo samples. Default is None
|
|
742
779
|
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
743
780
|
dict of obs_slice, with keys from `u_dict` to designate the
|
|
744
781
|
output(s) channels that are forced to observed values, for each
|
|
@@ -774,9 +811,9 @@ class SystemLossPDE(eqx.Module):
|
|
|
774
811
|
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
775
812
|
eqx.field(kw_only=True, default=None)
|
|
776
813
|
)
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
)
|
|
814
|
+
norm_weights_dict: (
|
|
815
|
+
Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
|
|
816
|
+
) = eqx.field(kw_only=True, default=None)
|
|
780
817
|
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
781
818
|
kw_only=True, default=None, static=True
|
|
782
819
|
)
|
|
@@ -819,8 +856,8 @@ class SystemLossPDE(eqx.Module):
|
|
|
819
856
|
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
820
857
|
if self.norm_samples_dict is None:
|
|
821
858
|
self.norm_samples_dict = self.u_dict_with_none
|
|
822
|
-
if self.
|
|
823
|
-
self.
|
|
859
|
+
if self.norm_weights_dict is None:
|
|
860
|
+
self.norm_weights_dict = self.u_dict_with_none
|
|
824
861
|
if self.obs_slice_dict is None:
|
|
825
862
|
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
826
863
|
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
@@ -861,7 +898,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
861
898
|
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
862
899
|
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
863
900
|
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
864
|
-
or self.u_dict.keys() != self.
|
|
901
|
+
or self.u_dict.keys() != self.norm_weights_dict.keys()
|
|
865
902
|
):
|
|
866
903
|
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
867
904
|
|
|
@@ -890,7 +927,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
890
927
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
891
928
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
892
929
|
norm_samples=self.norm_samples_dict[i],
|
|
893
|
-
|
|
930
|
+
norm_weights=self.norm_weights_dict[i],
|
|
894
931
|
obs_slice=self.obs_slice_dict[i],
|
|
895
932
|
)
|
|
896
933
|
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
@@ -911,7 +948,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
911
948
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
912
949
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
913
950
|
norm_samples=self.norm_samples_dict[i],
|
|
914
|
-
|
|
951
|
+
norm_weights=self.norm_weights_dict[i],
|
|
915
952
|
obs_slice=self.obs_slice_dict[i],
|
|
916
953
|
)
|
|
917
954
|
else:
|
|
@@ -1016,19 +1053,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
1016
1053
|
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1017
1054
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1018
1055
|
|
|
1019
|
-
|
|
1020
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
1021
|
-
vmap_in_axes_x_or_x_t = (0,)
|
|
1022
|
-
|
|
1023
|
-
batches = (omega_batch,)
|
|
1024
|
-
elif isinstance(batch, PDENonStatioBatch):
|
|
1025
|
-
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
1026
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
1027
|
-
|
|
1028
|
-
batches = (omega_batch, times_batch)
|
|
1029
|
-
vmap_in_axes_x_or_x_t = (0, 0)
|
|
1030
|
-
else:
|
|
1031
|
-
raise ValueError("Wrong type of batch")
|
|
1056
|
+
vmap_in_axes = (0,)
|
|
1032
1057
|
|
|
1033
1058
|
# Retrieve the optional eq_params_batch
|
|
1034
1059
|
# and update eq_params with the latter
|
|
@@ -1036,7 +1061,6 @@ class SystemLossPDE(eqx.Module):
|
|
|
1036
1061
|
if batch.param_batch_dict is not None:
|
|
1037
1062
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1038
1063
|
|
|
1039
|
-
# TODO
|
|
1040
1064
|
# feed the eq_params with the batch
|
|
1041
1065
|
for k in eq_params_batch_dict.keys():
|
|
1042
1066
|
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
@@ -1050,11 +1074,15 @@ class SystemLossPDE(eqx.Module):
|
|
|
1050
1074
|
return dynamic_loss_apply(
|
|
1051
1075
|
dyn_loss.evaluate,
|
|
1052
1076
|
self.u_dict,
|
|
1053
|
-
|
|
1077
|
+
(
|
|
1078
|
+
batch.domain_batch
|
|
1079
|
+
if isinstance(batch, PDEStatioBatch)
|
|
1080
|
+
else batch.domain_batch
|
|
1081
|
+
),
|
|
1054
1082
|
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1055
|
-
|
|
1083
|
+
vmap_in_axes + vmap_in_axes_params,
|
|
1056
1084
|
loss_weight,
|
|
1057
|
-
u_type=
|
|
1085
|
+
u_type=list(self.u_dict.values())[0].__class__.__base__,
|
|
1058
1086
|
)
|
|
1059
1087
|
|
|
1060
1088
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
jinns/loss/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
|
|
|
3
3
|
from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
|
-
|
|
6
|
+
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
@@ -19,9 +19,10 @@ from ._loss_weights import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
from ._operators import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
22
|
+
divergence_fwd,
|
|
23
|
+
divergence_rev,
|
|
24
|
+
laplacian_fwd,
|
|
25
|
+
laplacian_rev,
|
|
26
|
+
vectorial_laplacian_fwd,
|
|
27
|
+
vectorial_laplacian_rev,
|
|
27
28
|
)
|