jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -22,9 +22,7 @@ from jinns.loss._loss_utils import (
|
|
|
22
22
|
initial_condition_apply,
|
|
23
23
|
constraints_system_loss_apply,
|
|
24
24
|
)
|
|
25
|
-
from jinns.data._DataGenerators import
|
|
26
|
-
append_obs_batch,
|
|
27
|
-
)
|
|
25
|
+
from jinns.data._DataGenerators import append_obs_batch
|
|
28
26
|
from jinns.parameters._params import (
|
|
29
27
|
_get_vmap_in_axes_params,
|
|
30
28
|
_update_eq_params_dict,
|
|
@@ -103,6 +101,9 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
103
101
|
obs_slice : slice, default=None
|
|
104
102
|
slice object specifying the begininning/ending of the PINN output
|
|
105
103
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
104
|
+
params : InitVar[Params], default=None
|
|
105
|
+
The main Params object of the problem needed to instanciate the
|
|
106
|
+
DerivativeKeysODE if the latter is not specified.
|
|
106
107
|
"""
|
|
107
108
|
|
|
108
109
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
@@ -129,18 +130,26 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
129
130
|
norm_int_length: float | None = eqx.field(kw_only=True, default=None)
|
|
130
131
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
131
132
|
|
|
132
|
-
|
|
133
|
+
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
134
|
+
|
|
135
|
+
def __post_init__(self, params=None):
|
|
133
136
|
"""
|
|
134
137
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
135
138
|
Module with eqx.tree_at
|
|
136
139
|
"""
|
|
137
140
|
if self.derivative_keys is None:
|
|
138
141
|
# be default we only take gradient wrt nn_params
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
142
|
+
try:
|
|
143
|
+
self.derivative_keys = (
|
|
144
|
+
DerivativeKeysPDENonStatio(params=params)
|
|
145
|
+
if isinstance(self, LossPDENonStatio)
|
|
146
|
+
else DerivativeKeysPDEStatio(params=params)
|
|
147
|
+
)
|
|
148
|
+
except ValueError as exc:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"Problem at self.derivative_keys initialization "
|
|
151
|
+
f"received {self.derivative_keys=} and {params=}"
|
|
152
|
+
) from exc
|
|
144
153
|
|
|
145
154
|
if self.loss_weights is None:
|
|
146
155
|
self.loss_weights = (
|
|
@@ -324,6 +333,9 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
324
333
|
obs_slice : slice, default=None
|
|
325
334
|
slice object specifying the begininning/ending of the PINN output
|
|
326
335
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
336
|
+
params : InitVar[Params], default=None
|
|
337
|
+
The main Params object of the problem needed to instanciate the
|
|
338
|
+
DerivativeKeysODE if the latter is not specified.
|
|
327
339
|
|
|
328
340
|
|
|
329
341
|
Raises
|
|
@@ -342,20 +354,22 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
342
354
|
|
|
343
355
|
vmap_in_axes: tuple[Int] = eqx.field(init=False, static=True)
|
|
344
356
|
|
|
345
|
-
def __post_init__(self):
|
|
357
|
+
def __post_init__(self, params=None):
|
|
346
358
|
"""
|
|
347
359
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
348
360
|
Module with eqx.tree_at!
|
|
349
361
|
"""
|
|
350
|
-
super().__post_init__(
|
|
362
|
+
super().__post_init__(
|
|
363
|
+
params=params
|
|
364
|
+
) # because __init__ or __post_init__ of Base
|
|
351
365
|
# class is not automatically called
|
|
352
366
|
|
|
353
367
|
self.vmap_in_axes = (0,) # for x only here
|
|
354
368
|
|
|
355
369
|
def _get_dynamic_loss_batch(
|
|
356
370
|
self, batch: PDEStatioBatch
|
|
357
|
-
) ->
|
|
358
|
-
return
|
|
371
|
+
) -> Float[Array, "batch_size dimension"]:
|
|
372
|
+
return batch.domain_batch
|
|
359
373
|
|
|
360
374
|
def _get_normalization_loss_batch(
|
|
361
375
|
self, _
|
|
@@ -416,7 +430,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
416
430
|
self.u,
|
|
417
431
|
self._get_normalization_loss_batch(batch),
|
|
418
432
|
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
419
|
-
|
|
433
|
+
vmap_in_axes_params,
|
|
420
434
|
self.norm_int_length,
|
|
421
435
|
self.loss_weights.norm_loss,
|
|
422
436
|
)
|
|
@@ -547,6 +561,9 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
547
561
|
initial_condition_fun : Callable, default=None
|
|
548
562
|
A function representing the temporal initial condition. If None
|
|
549
563
|
(default) then no initial condition is applied
|
|
564
|
+
params : InitVar[Params], default=None
|
|
565
|
+
The main Params object of the problem needed to instanciate the
|
|
566
|
+
DerivativeKeysODE if the latter is not specified.
|
|
550
567
|
|
|
551
568
|
"""
|
|
552
569
|
|
|
@@ -556,15 +573,20 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
556
573
|
kw_only=True, default=None, static=True
|
|
557
574
|
)
|
|
558
575
|
|
|
559
|
-
|
|
576
|
+
_max_norm_samples_omega: Int = eqx.field(init=False, static=True)
|
|
577
|
+
_max_norm_time_slices: Int = eqx.field(init=False, static=True)
|
|
578
|
+
|
|
579
|
+
def __post_init__(self, params=None):
|
|
560
580
|
"""
|
|
561
581
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
562
582
|
Module with eqx.tree_at!
|
|
563
583
|
"""
|
|
564
|
-
super().__post_init__(
|
|
584
|
+
super().__post_init__(
|
|
585
|
+
params=params
|
|
586
|
+
) # because __init__ or __post_init__ of Base
|
|
565
587
|
# class is not automatically called
|
|
566
588
|
|
|
567
|
-
self.vmap_in_axes = (0,
|
|
589
|
+
self.vmap_in_axes = (0,) # for t_x
|
|
568
590
|
|
|
569
591
|
if self.initial_condition_fun is None:
|
|
570
592
|
warnings.warn(
|
|
@@ -572,28 +594,28 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
572
594
|
"case (e.g by. hardcoding it into the PINN output)."
|
|
573
595
|
)
|
|
574
596
|
|
|
597
|
+
# witht the variables below we avoid memory overflow since a cartesian
|
|
598
|
+
# product is taken
|
|
599
|
+
self._max_norm_time_slices = 100
|
|
600
|
+
self._max_norm_samples_omega = 1000
|
|
601
|
+
|
|
575
602
|
def _get_dynamic_loss_batch(
|
|
576
603
|
self, batch: PDENonStatioBatch
|
|
577
|
-
) ->
|
|
578
|
-
|
|
579
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
580
|
-
return (times_batch, omega_batch)
|
|
604
|
+
) -> Float[Array, "batch_size 1+dimension"]:
|
|
605
|
+
return batch.domain_batch
|
|
581
606
|
|
|
582
607
|
def _get_normalization_loss_batch(
|
|
583
608
|
self, batch: PDENonStatioBatch
|
|
584
|
-
) ->
|
|
609
|
+
) -> Float[Array, "nb_norm_time_slices nb_norm_samples dimension"]:
|
|
585
610
|
return (
|
|
586
|
-
batch.
|
|
587
|
-
self.norm_samples,
|
|
611
|
+
batch.domain_batch[: self._max_norm_time_slices, 0:1],
|
|
612
|
+
self.norm_samples[: self._max_norm_samples_omega],
|
|
588
613
|
)
|
|
589
614
|
|
|
590
615
|
def _get_observations_loss_batch(
|
|
591
616
|
self, batch: PDENonStatioBatch
|
|
592
617
|
) -> tuple[Float[Array, "batch_size 1"], Float[Array, "batch_size dimension"]]:
|
|
593
|
-
return (
|
|
594
|
-
batch.obs_batch_dict["pinn_in"][:, 0:1],
|
|
595
|
-
batch.obs_batch_dict["pinn_in"][:, 1:],
|
|
596
|
-
)
|
|
618
|
+
return (batch.obs_batch_dict["pinn_in"],)
|
|
597
619
|
|
|
598
620
|
def __call__(self, *args, **kwargs):
|
|
599
621
|
return self.evaluate(*args, **kwargs)
|
|
@@ -616,8 +638,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
616
638
|
of parameters (eg. for metamodeling) and an optional additional batch of observed
|
|
617
639
|
inputs/outputs/parameters
|
|
618
640
|
"""
|
|
619
|
-
|
|
620
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
641
|
+
omega_batch = batch.initial_batch
|
|
621
642
|
|
|
622
643
|
# Retrieve the optional eq_params_batch
|
|
623
644
|
# and update eq_params with the latter
|
|
@@ -640,7 +661,6 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
640
661
|
_set_derivatives(params, self.derivative_keys.initial_condition),
|
|
641
662
|
(0,) + vmap_in_axes_params,
|
|
642
663
|
self.initial_condition_fun,
|
|
643
|
-
omega_batch.shape[0],
|
|
644
664
|
self.loss_weights.initial_condition,
|
|
645
665
|
)
|
|
646
666
|
else:
|
|
@@ -725,6 +745,9 @@ class SystemLossPDE(eqx.Module):
|
|
|
725
745
|
PINNs. Default is None. But if a value is given, all the entries of
|
|
726
746
|
`u_dict` must be represented here with default value `jnp.s_[...]`
|
|
727
747
|
if no particular slice is to be given
|
|
748
|
+
params : InitVar[ParamsDict], default=None
|
|
749
|
+
The main Params object of the problem needed to instanciate the
|
|
750
|
+
DerivativeKeysODE if the latter is not specified.
|
|
728
751
|
|
|
729
752
|
"""
|
|
730
753
|
|
|
@@ -763,22 +786,20 @@ class SystemLossPDE(eqx.Module):
|
|
|
763
786
|
loss_weights: InitVar[LossWeightsPDEDict | None] = eqx.field(
|
|
764
787
|
kw_only=True, default=None
|
|
765
788
|
)
|
|
789
|
+
params_dict: InitVar[ParamsDict] = eqx.field(kw_only=True, default=None)
|
|
766
790
|
|
|
767
791
|
# following have init=False and are set in the __post_init__
|
|
768
792
|
u_constraints_dict: Dict[str, LossPDEStatio | LossPDENonStatio] = eqx.field(
|
|
769
793
|
init=False
|
|
770
794
|
)
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
derivative_keys_dyn_loss_dict: Dict[
|
|
775
|
-
str, DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio
|
|
776
|
-
] = eqx.field(init=False)
|
|
795
|
+
derivative_keys_dyn_loss: DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio = (
|
|
796
|
+
eqx.field(init=False)
|
|
797
|
+
)
|
|
777
798
|
u_dict_with_none: Dict[str, None] = eqx.field(init=False)
|
|
778
799
|
# internally the loss weights are handled with a dictionary
|
|
779
800
|
_loss_weights: Dict[str, dict] = eqx.field(init=False)
|
|
780
801
|
|
|
781
|
-
def __post_init__(self, loss_weights):
|
|
802
|
+
def __post_init__(self, loss_weights=None, params_dict=None):
|
|
782
803
|
# a dictionary that will be useful at different places
|
|
783
804
|
self.u_dict_with_none = {k: None for k in self.u_dict.keys()}
|
|
784
805
|
# First, for all the optional dict,
|
|
@@ -818,24 +839,19 @@ class SystemLossPDE(eqx.Module):
|
|
|
818
839
|
# iterating on dynamic_loss_dict. So each time we will require dome
|
|
819
840
|
# derivative_keys_dict
|
|
820
841
|
|
|
821
|
-
#
|
|
822
|
-
#
|
|
823
|
-
#
|
|
824
|
-
|
|
825
|
-
# default values
|
|
826
|
-
for k in self.dynamic_loss_dict.keys():
|
|
842
|
+
# derivative keys for the u_constraints. Note that we create missing
|
|
843
|
+
# DerivativeKeysODE around a Params object and not ParamsDict
|
|
844
|
+
# this works because u_dict.keys == params_dict.nn_params.keys()
|
|
845
|
+
for k in self.u_dict.keys():
|
|
827
846
|
if self.derivative_keys_dict[k] is None:
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio()
|
|
837
|
-
else:
|
|
838
|
-
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio()
|
|
847
|
+
if self.u_dict[k].eq_type == "statio_PDE":
|
|
848
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDEStatio(
|
|
849
|
+
params=params_dict.extract_params(k)
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
self.derivative_keys_dict[k] = DerivativeKeysPDENonStatio(
|
|
853
|
+
params=params_dict.extract_params(k)
|
|
854
|
+
)
|
|
839
855
|
|
|
840
856
|
# Second we make sure that all the dicts (except dynamic_loss_dict) have the same keys
|
|
841
857
|
if (
|
|
@@ -904,16 +920,11 @@ class SystemLossPDE(eqx.Module):
|
|
|
904
920
|
f"got {self.u_dict[i].eq_type[i]}"
|
|
905
921
|
)
|
|
906
922
|
|
|
907
|
-
#
|
|
908
|
-
#
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
}
|
|
913
|
-
self.derivative_keys_u_dict = {
|
|
914
|
-
k: self.derivative_keys_dict[k]
|
|
915
|
-
for k in self.u_dict.keys() & self.derivative_keys_dict.keys()
|
|
916
|
-
}
|
|
923
|
+
# derivative keys for the dynamic loss. Note that we create a
|
|
924
|
+
# DerivativeKeysODE around a ParamsDict object because a whole
|
|
925
|
+
# params_dict is feed to DynamicLoss.evaluate functions (extract_params
|
|
926
|
+
# happen inside it)
|
|
927
|
+
self.derivative_keys_dyn_loss = DerivativeKeysPDENonStatio(params=params_dict)
|
|
917
928
|
|
|
918
929
|
# also make sure we only have PINNs or SPINNs
|
|
919
930
|
if not (
|
|
@@ -1005,19 +1016,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
1005
1016
|
if self.u_dict.keys() != params_dict.nn_params.keys():
|
|
1006
1017
|
raise ValueError("u_dict and params_dict[nn_params] should have same keys ")
|
|
1007
1018
|
|
|
1008
|
-
|
|
1009
|
-
omega_batch, _ = batch.inside_batch, batch.border_batch
|
|
1010
|
-
vmap_in_axes_x_or_x_t = (0,)
|
|
1011
|
-
|
|
1012
|
-
batches = (omega_batch,)
|
|
1013
|
-
elif isinstance(batch, PDENonStatioBatch):
|
|
1014
|
-
times_batch = batch.times_x_inside_batch[:, 0:1]
|
|
1015
|
-
omega_batch = batch.times_x_inside_batch[:, 1:]
|
|
1016
|
-
|
|
1017
|
-
batches = (omega_batch, times_batch)
|
|
1018
|
-
vmap_in_axes_x_or_x_t = (0, 0)
|
|
1019
|
-
else:
|
|
1020
|
-
raise ValueError("Wrong type of batch")
|
|
1019
|
+
vmap_in_axes = (0,)
|
|
1021
1020
|
|
|
1022
1021
|
# Retrieve the optional eq_params_batch
|
|
1023
1022
|
# and update eq_params with the latter
|
|
@@ -1025,7 +1024,6 @@ class SystemLossPDE(eqx.Module):
|
|
|
1025
1024
|
if batch.param_batch_dict is not None:
|
|
1026
1025
|
eq_params_batch_dict = batch.param_batch_dict
|
|
1027
1026
|
|
|
1028
|
-
# TODO
|
|
1029
1027
|
# feed the eq_params with the batch
|
|
1030
1028
|
for k in eq_params_batch_dict.keys():
|
|
1031
1029
|
params_dict.eq_params[k] = eq_params_batch_dict[k]
|
|
@@ -1034,14 +1032,18 @@ class SystemLossPDE(eqx.Module):
|
|
|
1034
1032
|
batch.param_batch_dict, params_dict
|
|
1035
1033
|
)
|
|
1036
1034
|
|
|
1037
|
-
def dyn_loss_for_one_key(dyn_loss,
|
|
1035
|
+
def dyn_loss_for_one_key(dyn_loss, loss_weight):
|
|
1038
1036
|
"""The function used in tree_map"""
|
|
1039
1037
|
return dynamic_loss_apply(
|
|
1040
1038
|
dyn_loss.evaluate,
|
|
1041
1039
|
self.u_dict,
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1040
|
+
(
|
|
1041
|
+
batch.domain_batch
|
|
1042
|
+
if isinstance(batch, PDEStatioBatch)
|
|
1043
|
+
else batch.domain_batch
|
|
1044
|
+
),
|
|
1045
|
+
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1046
|
+
vmap_in_axes + vmap_in_axes_params,
|
|
1045
1047
|
loss_weight,
|
|
1046
1048
|
u_type=type(list(self.u_dict.values())[0]),
|
|
1047
1049
|
)
|
|
@@ -1049,7 +1051,6 @@ class SystemLossPDE(eqx.Module):
|
|
|
1049
1051
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
1050
1052
|
dyn_loss_for_one_key,
|
|
1051
1053
|
self.dynamic_loss_dict,
|
|
1052
|
-
self.derivative_keys_dyn_loss_dict,
|
|
1053
1054
|
self._loss_weights["dyn_loss"],
|
|
1054
1055
|
is_leaf=lambda x: isinstance(
|
|
1055
1056
|
x, (PDEStatio, PDENonStatio)
|
jinns/loss/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from ._LossODE import LossODE, SystemLossODE
|
|
|
3
3
|
from ._LossPDE import LossPDEStatio, LossPDENonStatio, SystemLossPDE
|
|
4
4
|
from ._DynamicLoss import (
|
|
5
5
|
GeneralizedLotkaVolterra,
|
|
6
|
-
|
|
6
|
+
BurgersEquation,
|
|
7
7
|
FPENonStatioLoss2D,
|
|
8
8
|
OU_FPENonStatioLoss2D,
|
|
9
9
|
FisherKPP,
|
|
@@ -19,9 +19,10 @@ from ._loss_weights import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
from ._operators import (
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
22
|
+
divergence_fwd,
|
|
23
|
+
divergence_rev,
|
|
24
|
+
laplacian_fwd,
|
|
25
|
+
laplacian_rev,
|
|
26
|
+
vectorial_laplacian_fwd,
|
|
27
|
+
vectorial_laplacian_rev,
|
|
27
28
|
)
|