jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
|
@@ -7,8 +7,9 @@ from jinns import parameters as parameters
|
|
|
7
7
|
from jinns import plot as plot
|
|
8
8
|
from jinns import nn as nn
|
|
9
9
|
from jinns.solver._solve import solve
|
|
10
|
+
from jinns.solver._solve_alternate import solve_alternate
|
|
10
11
|
|
|
11
|
-
__all__ = ["nn", "solve"]
|
|
12
|
+
__all__ = ["nn", "solve", "solve_alternate"]
|
|
12
13
|
|
|
13
14
|
import warnings
|
|
14
15
|
|
jinns/data/_Batchs.py
CHANGED
|
@@ -26,14 +26,14 @@ class ObsBatchDict(TypedDict):
|
|
|
26
26
|
class ODEBatch(eqx.Module):
|
|
27
27
|
temporal_batch: Float[Array, " batch_size"]
|
|
28
28
|
param_batch_dict: eqx.Module | None = eqx.field(default=None)
|
|
29
|
-
obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
|
|
29
|
+
obs_batch_dict: tuple[ObsBatchDict, ...] | None = eqx.field(default=None)
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class PDEStatioBatch(eqx.Module):
|
|
33
33
|
domain_batch: Float[Array, " batch_size dimension"]
|
|
34
34
|
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
35
35
|
param_batch_dict: eqx.Module | None
|
|
36
|
-
obs_batch_dict: ObsBatchDict | None
|
|
36
|
+
obs_batch_dict: tuple[ObsBatchDict, ...] | None
|
|
37
37
|
|
|
38
38
|
# rewrite __init__ to be able to use inheritance for the NonStatio case
|
|
39
39
|
# below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
|
|
@@ -44,7 +44,7 @@ class PDEStatioBatch(eqx.Module):
|
|
|
44
44
|
domain_batch: Float[Array, " batch_size dimension"],
|
|
45
45
|
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
46
46
|
param_batch_dict: eqx.Module | None = None,
|
|
47
|
-
obs_batch_dict: ObsBatchDict | None = None,
|
|
47
|
+
obs_batch_dict: tuple[ObsBatchDict, ...] | None = None,
|
|
48
48
|
):
|
|
49
49
|
# TODO: document this ?
|
|
50
50
|
self.domain_batch = domain_batch
|
|
@@ -67,7 +67,7 @@ class PDENonStatioBatch(PDEStatioBatch):
|
|
|
67
67
|
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
68
68
|
initial_batch: Float[Array, " batch_size dimension"] | None,
|
|
69
69
|
param_batch_dict: eqx.Module | None = None,
|
|
70
|
-
obs_batch_dict: ObsBatchDict | None = None,
|
|
70
|
+
obs_batch_dict: tuple[ObsBatchDict, ...] | None = None,
|
|
71
71
|
):
|
|
72
72
|
self.domain_batch = domain_batch
|
|
73
73
|
self.border_batch = border_batch
|
jinns/data/_DataGeneratorODE.py
CHANGED
|
@@ -77,7 +77,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
77
77
|
nt: int,
|
|
78
78
|
tmin: float,
|
|
79
79
|
tmax: float,
|
|
80
|
-
temporal_batch_size: int | None,
|
|
80
|
+
temporal_batch_size: int | None = None,
|
|
81
81
|
method: str = "uniform",
|
|
82
82
|
rar_parameters: None | dict[str, int] = None,
|
|
83
83
|
n_start: None | int = None,
|