jinns 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +25 -73
- jinns/loss/_LossODE.py +3 -3
- jinns/loss/_LossPDE.py +27 -36
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +78 -56
- jinns/loss/_operators.py +441 -184
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.1.0.dist-info/METADATA +0 -85
- jinns-1.1.0.dist-info/RECORD +0 -39
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/AUTHORS +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.1.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -21,53 +21,22 @@ else:
|
|
|
21
21
|
from equinox import AbstractClassVar
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def _decorator_heteregeneous_params(evaluate
|
|
24
|
+
def _decorator_heteregeneous_params(evaluate):
|
|
25
25
|
|
|
26
|
-
def
|
|
27
|
-
self,
|
|
26
|
+
def wrapper(*args):
|
|
27
|
+
self, inputs, u, params = args
|
|
28
28
|
_params = eqx.tree_at(
|
|
29
29
|
lambda p: p.eq_params,
|
|
30
30
|
params,
|
|
31
31
|
self._eval_heterogeneous_parameters(
|
|
32
|
-
|
|
32
|
+
inputs, u, params, self.eq_params_heterogeneity
|
|
33
33
|
),
|
|
34
34
|
)
|
|
35
35
|
new_args = args[:-1] + (_params,)
|
|
36
36
|
res = evaluate(*new_args)
|
|
37
37
|
return res
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
self, x, u, params = args
|
|
41
|
-
_params = eqx.tree_at(
|
|
42
|
-
lambda p: p.eq_params,
|
|
43
|
-
params,
|
|
44
|
-
self._eval_heterogeneous_parameters(
|
|
45
|
-
None, x, u, params, self.eq_params_heterogeneity
|
|
46
|
-
),
|
|
47
|
-
)
|
|
48
|
-
new_args = args[:-1] + (_params,)
|
|
49
|
-
res = evaluate(*new_args)
|
|
50
|
-
return res
|
|
51
|
-
|
|
52
|
-
def wrapper_pde_non_statio(*args):
|
|
53
|
-
self, t, x, u, params = args
|
|
54
|
-
_params = eqx.tree_at(
|
|
55
|
-
lambda p: p.eq_params,
|
|
56
|
-
params,
|
|
57
|
-
self._eval_heterogeneous_parameters(
|
|
58
|
-
t, x, u, params, self.eq_params_heterogeneity
|
|
59
|
-
),
|
|
60
|
-
)
|
|
61
|
-
new_args = args[:-1] + (_params,)
|
|
62
|
-
res = evaluate(*new_args)
|
|
63
|
-
return res
|
|
64
|
-
|
|
65
|
-
if eq_type == "ODE":
|
|
66
|
-
return wrapper_ode
|
|
67
|
-
elif eq_type == "Statio PDE":
|
|
68
|
-
return wrapper_pde_statio
|
|
69
|
-
elif eq_type == "Non-statio PDE":
|
|
70
|
-
return wrapper_pde_non_statio
|
|
39
|
+
return wrapper
|
|
71
40
|
|
|
72
41
|
|
|
73
42
|
class DynamicLoss(eqx.Module):
|
|
@@ -110,8 +79,7 @@ class DynamicLoss(eqx.Module):
|
|
|
110
79
|
|
|
111
80
|
def _eval_heterogeneous_parameters(
|
|
112
81
|
self,
|
|
113
|
-
|
|
114
|
-
x: Float[Array, "dim"],
|
|
82
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
115
83
|
u: eqx.Module,
|
|
116
84
|
params: Params | ParamsDict,
|
|
117
85
|
eq_params_heterogeneity: Dict[str, Callable | None] = None,
|
|
@@ -124,14 +92,7 @@ class DynamicLoss(eqx.Module):
|
|
|
124
92
|
if eq_params_heterogeneity[k] is None:
|
|
125
93
|
eq_params_[k] = p
|
|
126
94
|
else:
|
|
127
|
-
|
|
128
|
-
# signature will vary according to _eq_type
|
|
129
|
-
if self._eq_type == "ODE":
|
|
130
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, u, params)
|
|
131
|
-
elif self._eq_type == "Statio PDE":
|
|
132
|
-
eq_params_[k] = eq_params_heterogeneity[k](x, u, params)
|
|
133
|
-
elif self._eq_type == "Non-statio PDE":
|
|
134
|
-
eq_params_[k] = eq_params_heterogeneity[k](t, x, u, params)
|
|
95
|
+
eq_params_[k] = eq_params_heterogeneity[k](inputs, u, params)
|
|
135
96
|
except KeyError:
|
|
136
97
|
# we authorize missing eq_params_heterogeneity key
|
|
137
98
|
# if its heterogeneity is None anyway
|
|
@@ -140,22 +101,17 @@ class DynamicLoss(eqx.Module):
|
|
|
140
101
|
|
|
141
102
|
def _evaluate(
|
|
142
103
|
self,
|
|
143
|
-
|
|
144
|
-
x: Float[Array, "dim"],
|
|
104
|
+
inputs: Float[Array, "1"] | Float[Array, "dim"] | Float[Array, "1+dim"],
|
|
145
105
|
u: eqx.Module,
|
|
146
106
|
params: Params | ParamsDict,
|
|
147
107
|
) -> float:
|
|
148
|
-
|
|
149
|
-
if
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
else:
|
|
156
|
-
raise NotImplementedError("the equation type is not handled.")
|
|
157
|
-
|
|
158
|
-
return ans
|
|
108
|
+
evaluation = self.equation(inputs, u, params)
|
|
109
|
+
if len(evaluation.shape) == 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"The output of dynamic loss must be vectorial, "
|
|
112
|
+
"i.e. of shape (d,) with d >= 1"
|
|
113
|
+
)
|
|
114
|
+
return evaluation
|
|
159
115
|
|
|
160
116
|
@abc.abstractmethod
|
|
161
117
|
def equation(self, *args, **kwargs):
|
|
@@ -191,7 +147,7 @@ class ODE(DynamicLoss):
|
|
|
191
147
|
|
|
192
148
|
_eq_type: ClassVar[str] = "ODE"
|
|
193
149
|
|
|
194
|
-
@partial(_decorator_heteregeneous_params
|
|
150
|
+
@partial(_decorator_heteregeneous_params)
|
|
195
151
|
def evaluate(
|
|
196
152
|
self,
|
|
197
153
|
t: Float[Array, "1"],
|
|
@@ -199,7 +155,7 @@ class ODE(DynamicLoss):
|
|
|
199
155
|
params: Params | ParamsDict,
|
|
200
156
|
) -> float:
|
|
201
157
|
"""Here we call DynamicLoss._evaluate with x=None"""
|
|
202
|
-
return self._evaluate(t,
|
|
158
|
+
return self._evaluate(t, u, params)
|
|
203
159
|
|
|
204
160
|
@abc.abstractmethod
|
|
205
161
|
def equation(
|
|
@@ -260,12 +216,12 @@ class PDEStatio(DynamicLoss):
|
|
|
260
216
|
|
|
261
217
|
_eq_type: ClassVar[str] = "Statio PDE"
|
|
262
218
|
|
|
263
|
-
@partial(_decorator_heteregeneous_params
|
|
219
|
+
@partial(_decorator_heteregeneous_params)
|
|
264
220
|
def evaluate(
|
|
265
221
|
self, x: Float[Array, "dimension"], u: eqx.Module, params: Params | ParamsDict
|
|
266
222
|
) -> float:
|
|
267
223
|
"""Here we call the DynamicLoss._evaluate with t=None"""
|
|
268
|
-
return self._evaluate(
|
|
224
|
+
return self._evaluate(x, u, params)
|
|
269
225
|
|
|
270
226
|
@abc.abstractmethod
|
|
271
227
|
def equation(
|
|
@@ -325,23 +281,21 @@ class PDENonStatio(DynamicLoss):
|
|
|
325
281
|
|
|
326
282
|
_eq_type: ClassVar[str] = "Non-statio PDE"
|
|
327
283
|
|
|
328
|
-
@partial(_decorator_heteregeneous_params
|
|
284
|
+
@partial(_decorator_heteregeneous_params)
|
|
329
285
|
def evaluate(
|
|
330
286
|
self,
|
|
331
|
-
|
|
332
|
-
x: Float[Array, "dim"],
|
|
287
|
+
t_x: Float[Array, "1 + dim"],
|
|
333
288
|
u: eqx.Module,
|
|
334
289
|
params: Params | ParamsDict,
|
|
335
290
|
) -> float:
|
|
336
291
|
"""Here we call the DynamicLoss._evaluate with full arguments"""
|
|
337
|
-
ans = self._evaluate(
|
|
292
|
+
ans = self._evaluate(t_x, u, params)
|
|
338
293
|
return ans
|
|
339
294
|
|
|
340
295
|
@abc.abstractmethod
|
|
341
296
|
def equation(
|
|
342
297
|
self,
|
|
343
|
-
|
|
344
|
-
x: Float[Array, "dim"],
|
|
298
|
+
t_x: Float[Array, "1 + dim"],
|
|
345
299
|
u: eqx.Module,
|
|
346
300
|
params: Params | ParamsDict,
|
|
347
301
|
) -> float:
|
|
@@ -353,10 +307,8 @@ class PDENonStatio(DynamicLoss):
|
|
|
353
307
|
|
|
354
308
|
Parameters
|
|
355
309
|
----------
|
|
356
|
-
|
|
357
|
-
A
|
|
358
|
-
x : Float[Array, "d"]
|
|
359
|
-
A `d` dimensional jnp.array representing a point in the spatial domain $\Omega$.
|
|
310
|
+
t_x : Float[Array, "1 + dim"]
|
|
311
|
+
A jnp array containing the concatenation of a time point and a point in $\Omega$
|
|
360
312
|
u : eqx.Module
|
|
361
313
|
The neural network.
|
|
362
314
|
params : Params | ParamsDict
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -209,7 +209,7 @@ class LossODE(_LossODEAbstract):
|
|
|
209
209
|
mse_dyn_loss = dynamic_loss_apply(
|
|
210
210
|
self.dynamic_loss.evaluate,
|
|
211
211
|
self.u,
|
|
212
|
-
|
|
212
|
+
temporal_batch,
|
|
213
213
|
_set_derivatives(params, self.derivative_keys.dyn_loss),
|
|
214
214
|
self.vmap_in_axes + vmap_in_axes_params,
|
|
215
215
|
self.loss_weights.dyn_loss,
|
|
@@ -227,7 +227,7 @@ class LossODE(_LossODEAbstract):
|
|
|
227
227
|
else:
|
|
228
228
|
v_u = vmap(self.u, (None,) + vmap_in_axes_params)
|
|
229
229
|
t0, u0 = self.initial_condition # pylint: disable=unpacking-non-sequence
|
|
230
|
-
t0 = jnp.array(t0)
|
|
230
|
+
t0 = jnp.array([t0])
|
|
231
231
|
u0 = jnp.array(u0)
|
|
232
232
|
mse_initial_condition = jnp.mean(
|
|
233
233
|
self.loss_weights.initial_condition
|
|
@@ -522,7 +522,7 @@ class SystemLossODE(eqx.Module):
|
|
|
522
522
|
return dynamic_loss_apply(
|
|
523
523
|
dyn_loss.evaluate,
|
|
524
524
|
self.u_dict,
|
|
525
|
-
|
|
525
|
+
temporal_batch,
|
|
526
526
|
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
527
527
|
vmap_in_axes_t + vmap_in_axes_params,
|
|
528
528
|
loss_weight,
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
|
|
|
22
22
|
initial_condition_apply,
|
|
23
23
|
constraints_system_loss_apply,
|
|
24
24
|
)
|
|
25
|
-
from jinns.data._DataGenerators import
|
|
26
|
-
append_obs_batch,
|
|
27
|
-
)
|
|
25
|
+
from jinns.data._DataGenerators import append_obs_batch
|
|
28
26
|
from jinns.parameters._params import (
|
|
29
27
|
_get_vmap_in_axes_params,
|
|
30
28
|
_update_eq_params_dict,
|
|
@@ -370,8 +368,8 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
370
368
|
|
|
371
369
|
def _get_dynamic_loss_batch(
|
|
372
370
|
self, batch: PDEStatioBatch
|
|
373
|
-
) ->
|
|
374
|
-
return
|
|
371
|
+
) -> Float[Array, "batch_size dimension"]:
|
|
372
|
+
return batch.domain_batch
|
|
375
373
|
|
|
376
374
|
def _get_normalization_loss_batch(
|
|
377
375
|
self, _
|
|
@@ -432,7 +430,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
432
430
|
self.u,
|
|
433
431
|
self._get_normalization_loss_batch(batch),
|
|
434
432
|
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
435
|
-
|
|
433
|
+
vmap_in_axes_params,
|
|
436
434
|
self.norm_int_length,
|
|
437
435
|
self.loss_weights.norm_loss,
|
|
438
436
|
)
|
|
@@ -575,6 +573,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
575
573
|
kw_only=True, default=None, static=True
|
|
576
574
|
)
|
|
577
575
|
|
|
576
|
+
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
577
|
+
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
578
|
+
|
|
578
579
|
def __post_init__(self, params=None):
|
|
579
580
|
"""
|
|
580
581
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
@@ -585,7 +586,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
585
586
|
) # because __init__ or __post_init__ of Base
|
|
586
587
|
# class is not automatically called
|
|
587
588
|
|
|
588
|
-
self.vmap_in_axes = (0,
|
|
589
|
+
self.vmap_in_axes = (0,) # for t_x
|
|
589
590
|
|
|
590
591
|
if self.initial_condition_fun is None:
|
|
591
592
|
warnings.warn(
|
|
@@ -593,28 +594,28 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
593
594
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
594
595
|
)
|
|
595
596
|
|
|
597
|
+
# witht the variables below we avoid memory overflow since a cartesian
|
|
598
|
+
# product is taken
|
|
599
|
+
self._max_norm_time_slices = 100
|
|
600
|
+
self._max_norm_samples_omega = 1000
|
|
601
|
+
|
|
596
602
|
def _get_dynamic_loss_batch(
|
|
597
603
|
self, batch: PDENonStatioBatch
|
|
598
|
-
) ->
|
|
599
|
-
|
|
600
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
601
|
-
return (times_batch, omega_batch)
|
|
604
|
+
) -> Float[Array, "batch_size 1+dimension"]:
|
|
605
|
+
return batch.domain_batch
|
|
602
606
|
|
|
603
607
|
def _get_normalization_loss_batch(
|
|
604
608
|
self, batch: PDENonStatioBatch
|
|
605
|
-
) ->
|
|
609
|
+
) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
|
|
606
610
|
return (
|
|
607
|
-
batch.
|
|
608
|
-
self.norm_samples,
|
|
611
|
+
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
612
|
+
self.norm_samples[: self._max_norm_samples_omega],
|
|
609
613
|
)
|
|
610
614
|
|
|
611
615
|
def _get_observations_loss_batch(
|
|
612
616
|
self, batch: PDENonStatioBatch
|
|
613
617
|
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
614
|
-
return (
|
|
615
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
616
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
617
|
-
)
|
|
618
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
618
619
|
|
|
619
620
|
def __call__(self, *args, **kwargs):
|
|
620
621
|
return self.evaluate(*args, **kwargs)
|
|
@@ -637,7 +638,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
637
638
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
638
639
|
inputs/outputs/parameters
|
|
639
640
|
"""
|
|
640
|
-
omega_batch = batch.
|
|
641
|
+
omega_batch = batch.initial_batch
|
|
641
642
|
|
|
642
643
|
# Retrieve the optional eq_params_batch
|
|
643
644
|
# and update eq_params with the latter
|
|
@@ -660,7 +661,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
660
661
|
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
661
662
|
(0,) + vmap_in_axes_params,
|
|
662
663
|
self.initial_condition_fun,
|
|
663
|
-
omega_batch.shape[0],
|
|
664
664
|
self.loss_weights.initial_condition,
|
|
665
665
|
)
|
|
666
666
|
else:
|
|
@@ -1016,19 +1016,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
1016
1016
|
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1017
1017
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1018
1018
|
|
|
1019
|
-
|
|
1020
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
1021
|
-
vmap_in_axes_x_or_x_t = (0,)
|
|
1022
|
-
|
|
1023
|
-
batches = (omega_batch,)
|
|
1024
|
-
elif isinstance(batch, PDENonStatioBatch):
|
|
1025
|
-
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
1026
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
1027
|
-
|
|
1028
|
-
batches = (omega_batch, times_batch)
|
|
1029
|
-
vmap_in_axes_x_or_x_t = (0, 0)
|
|
1030
|
-
else:
|
|
1031
|
-
raise ValueError("Wrong type of batch")
|
|
1019
|
+
vmap_in_axes = (0,)
|
|
1032
1020
|
|
|
1033
1021
|
# Retrieve the optional eq_params_batch
|
|
1034
1022
|
# and update eq_params with the latter
|
|
@@ -1036,7 +1024,6 @@ class SystemLossPDE(eqx.Module):
|
|
|
1036
1024
|
if batch.param_batch_dict is not None:
|
|
1037
1025
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1038
1026
|
|
|
1039
|
-
# TODO
|
|
1040
1027
|
# feed the eq_params with the batch
|
|
1041
1028
|
for k in eq_params_batch_dict.keys():
|
|
1042
1029
|
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
@@ -1050,9 +1037,13 @@ class SystemLossPDE(eqx.Module):
|
|
|
1050
1037
|
return dynamic_loss_apply(
|
|
1051
1038
|
dyn_loss.evaluate,
|
|
1052
1039
|
self.u_dict,
|
|
1053
|
-
|
|
1040
|
+
(
|
|
1041
|
+
batch.domain_batch
|
|
1042
|
+
if isinstance(batch, PDEStatioBatch)
|
|
1043
|
+
else batch.domain_batch
|
|
1044
|
+
),
|
|
1054
1045
|
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1055
|
-
|
|
1046
|
+
vmap_in_axes + vmap_in_axes_params,
|
|
1056
1047
|
loss_weight,
|
|
1057
1048
|
u_type=type(list(self.u_dict.values())[0]),
|
|
1058
1049
|
)
|
jinns/loss/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
|
|
|
3
3
|
from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
|
-
|
|
6
|
+
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
@@ -19,9 +19,10 @@ from ._loss_weights import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
from ._operators import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
22
|
+
divergence_fwd,
|
|
23
|
+
divergence_rev,
|
|
24
|
+
laplacian_fwd,
|
|
25
|
+
laplacian_rev,
|
|
26
|
+
vectorial_laplacian_fwd,
|
|
27
|
+
vectorial_laplacian_rev,
|
|
27
28
|
)
|