jinns 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_DataGenerators.py +2 -2
- jinns/loss/_DynamicLoss.py +2 -2
- jinns/loss/_LossODE.py +1 -1
- jinns/loss/_LossPDE.py +75 -38
- jinns/loss/_boundary_conditions.py +2 -2
- jinns/loss/_loss_utils.py +21 -15
- jinns/loss/_operators.py +0 -2
- jinns/nn/__init__.py +7 -0
- jinns/nn/_hyperpinn.py +397 -0
- jinns/nn/_mlp.py +192 -0
- jinns/nn/_pinn.py +190 -0
- jinns/nn/_ppinn.py +203 -0
- jinns/{utils → nn}/_save_load.py +39 -23
- jinns/nn/_spinn.py +106 -0
- jinns/nn/_spinn_mlp.py +196 -0
- jinns/plot/_plot.py +3 -3
- jinns/solver/_rar.py +3 -3
- jinns/solver/_solve.py +23 -9
- jinns/utils/__init__.py +0 -5
- jinns/utils/_types.py +4 -4
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/METADATA +9 -9
- jinns-1.3.0.dist-info/RECORD +44 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/WHEEL +1 -1
- jinns/utils/_hyperpinn.py +0 -420
- jinns/utils/_pinn.py +0 -324
- jinns/utils/_ppinn.py +0 -227
- jinns/utils/_spinn.py +0 -249
- jinns-1.2.0.dist-info/RECORD +0 -41
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/AUTHORS +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/LICENSE +0 -0
- {jinns-1.2.0.dist-info → jinns-1.3.0.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
CHANGED
|
@@ -54,7 +54,7 @@ def make_cartesian_product(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def _reset_batch_idx_and_permute(
|
|
57
|
-
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
57
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
58
58
|
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
59
59
|
key, domain, curr_idx, _, p = operands
|
|
60
60
|
# resetting counter
|
|
@@ -78,7 +78,7 @@ def _reset_batch_idx_and_permute(
|
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
def _increment_batch_idx(
|
|
81
|
-
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]]
|
|
81
|
+
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
82
82
|
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
83
83
|
key, domain, curr_idx, batch_size, _ = operands
|
|
84
84
|
# simply increases counter and get the batch
|
jinns/loss/_DynamicLoss.py
CHANGED
|
@@ -13,8 +13,8 @@ from jax import grad
|
|
|
13
13
|
import jax.numpy as jnp
|
|
14
14
|
import equinox as eqx
|
|
15
15
|
|
|
16
|
-
from jinns.
|
|
17
|
-
from jinns.
|
|
16
|
+
from jinns.nn._pinn import PINN
|
|
17
|
+
from jinns.nn._spinn_mlp import SPINN
|
|
18
18
|
|
|
19
19
|
from jinns.utils._utils import get_grid
|
|
20
20
|
from jinns.loss._DynamicLossAbstract import ODE, PDEStatio, PDENonStatio
|
jinns/loss/_LossODE.py
CHANGED
|
@@ -28,7 +28,7 @@ from jinns.parameters._params import (
|
|
|
28
28
|
from jinns.parameters._derivative_keys import _set_derivatives, DerivativeKeysODE
|
|
29
29
|
from jinns.loss._loss_weights import LossWeightsODE, LossWeightsODEDict
|
|
30
30
|
from jinns.loss._DynamicLossAbstract import ODE
|
|
31
|
-
from jinns.
|
|
31
|
+
from jinns.nn._pinn import PINN
|
|
32
32
|
|
|
33
33
|
if TYPE_CHECKING:
|
|
34
34
|
from jinns.utils._types import *
|
jinns/loss/_LossPDE.py
CHANGED
|
@@ -38,8 +38,8 @@ from jinns.loss._loss_weights import (
|
|
|
38
38
|
LossWeightsPDEDict,
|
|
39
39
|
)
|
|
40
40
|
from jinns.loss._DynamicLossAbstract import PDEStatio, PDENonStatio
|
|
41
|
-
from jinns.
|
|
42
|
-
from jinns.
|
|
41
|
+
from jinns.nn._pinn import PINN
|
|
42
|
+
from jinns.nn._spinn import SPINN
|
|
43
43
|
from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
44
44
|
|
|
45
45
|
|
|
@@ -92,12 +92,16 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
92
92
|
Note that it must be a slice and not an integer
|
|
93
93
|
(but a preprocessing of the user provided argument takes care of it)
|
|
94
94
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
95
|
-
|
|
95
|
+
Monte-Carlo sample points for computing the
|
|
96
96
|
normalization constant. Default is None.
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
97
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
98
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
99
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
100
|
+
`norm_weights` should have the same leading dimension as
|
|
101
|
+
`norm_samples`.
|
|
102
|
+
Alternatively, the user can pass a float or an integer.
|
|
103
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
104
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
101
105
|
obs_slice : slice, default=None
|
|
102
106
|
slice object specifying the begininning/ending of the PINN output
|
|
103
107
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -127,7 +131,9 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
127
131
|
norm_samples: Float[Array, "nb_norm_samples dimension"] | None = eqx.field(
|
|
128
132
|
kw_only=True, default=None
|
|
129
133
|
)
|
|
130
|
-
|
|
134
|
+
norm_weights: Float[Array, "nb_norm_samples"] | float | int | None = eqx.field(
|
|
135
|
+
kw_only=True, default=None
|
|
136
|
+
)
|
|
131
137
|
obs_slice: slice | None = eqx.field(kw_only=True, default=None, static=True)
|
|
132
138
|
|
|
133
139
|
params: InitVar[Params] = eqx.field(kw_only=True, default=None)
|
|
@@ -251,8 +257,25 @@ class _LossPDEAbstract(eqx.Module):
|
|
|
251
257
|
if not isinstance(self.omega_boundary_dim, slice):
|
|
252
258
|
raise ValueError("self.omega_boundary_dim must be a jnp.s_ object")
|
|
253
259
|
|
|
254
|
-
if self.norm_samples is not None
|
|
255
|
-
|
|
260
|
+
if self.norm_samples is not None:
|
|
261
|
+
if self.norm_weights is None:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"`norm_weights` must be provided when `norm_samples` is used!"
|
|
264
|
+
)
|
|
265
|
+
try:
|
|
266
|
+
assert self.norm_weights.shape[0] == self.norm_samples.shape[0]
|
|
267
|
+
except (AssertionError, AttributeError):
|
|
268
|
+
if isinstance(self.norm_weights, (int, float)):
|
|
269
|
+
self.norm_weights = jnp.array(
|
|
270
|
+
[self.norm_weights], dtype=jax.dtypes.canonicalize_dtype(float)
|
|
271
|
+
)
|
|
272
|
+
else:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
"`norm_weights` should have the same leading dimension"
|
|
275
|
+
" as `norm_samples`,"
|
|
276
|
+
f" got shape {self.norm_weights.shape} and"
|
|
277
|
+
f" shape {self.norm_samples.shape}."
|
|
278
|
+
)
|
|
256
279
|
|
|
257
280
|
@abc.abstractmethod
|
|
258
281
|
def evaluate(
|
|
@@ -324,12 +347,16 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
324
347
|
Note that it must be a slice and not an integer
|
|
325
348
|
(but a preprocessing of the user provided argument takes care of it)
|
|
326
349
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
327
|
-
|
|
350
|
+
Monte-Carlo sample points for computing the
|
|
328
351
|
normalization constant. Default is None.
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
352
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
353
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
354
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
355
|
+
`norm_weights` should have the same leading dimension as
|
|
356
|
+
`norm_samples`.
|
|
357
|
+
Alternatively, the user can pass a float or an integer.
|
|
358
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
359
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
333
360
|
obs_slice : slice, default=None
|
|
334
361
|
slice object specifying the begininning/ending of the PINN output
|
|
335
362
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -431,7 +458,7 @@ class LossPDEStatio(_LossPDEAbstract):
|
|
|
431
458
|
self._get_normalization_loss_batch(batch),
|
|
432
459
|
_set_derivatives(params, self.derivative_keys.norm_loss),
|
|
433
460
|
vmap_in_axes_params,
|
|
434
|
-
self.
|
|
461
|
+
self.norm_weights,
|
|
435
462
|
self.loss_weights.norm_loss,
|
|
436
463
|
)
|
|
437
464
|
else:
|
|
@@ -549,12 +576,16 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
549
576
|
Note that it must be a slice and not an integer
|
|
550
577
|
(but a preprocessing of the user provided argument takes care of it)
|
|
551
578
|
norm_samples : Float[Array, "nb_norm_samples dimension"], default=None
|
|
552
|
-
|
|
579
|
+
Monte-Carlo sample points for computing the
|
|
553
580
|
normalization constant. Default is None.
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
581
|
+
norm_weights : Float[Array, "nb_norm_samples"] | float | int, default=None
|
|
582
|
+
The importance sampling weights for Monte-Carlo integration of the
|
|
583
|
+
normalization constant. Must be provided if `norm_samples` is provided.
|
|
584
|
+
`norm_weights` should have the same leading dimension as
|
|
585
|
+
`norm_samples`.
|
|
586
|
+
Alternatively, the user can pass a float or an integer.
|
|
587
|
+
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
588
|
+
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
558
589
|
obs_slice : slice, default=None
|
|
559
590
|
slice object specifying the begininning/ending of the PINN output
|
|
560
591
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
@@ -730,15 +761,21 @@ class SystemLossPDE(eqx.Module):
|
|
|
730
761
|
(default) then no temporal boundary condition is applied
|
|
731
762
|
Must share the keys of `u_dict`
|
|
732
763
|
norm_samples_dict : Dict[str, Float[Array, "nb_norm_samples dimension"] | None, default=None
|
|
733
|
-
A dict of
|
|
734
|
-
normalization constant. Default is None
|
|
735
|
-
Must share the keys of `u_dict`
|
|
736
|
-
norm_int_length_dict : Dict[str, float | None] | None, default=None
|
|
737
|
-
A dict of Float. The domain area
|
|
738
|
-
(or interval length in 1D) upon which we perform the numerical
|
|
739
|
-
integration for each element of u_dict.
|
|
740
|
-
Default is None
|
|
764
|
+
A dict of Monte-Carlo sample points for computing the
|
|
765
|
+
normalization constant. Default is None.
|
|
741
766
|
Must share the keys of `u_dict`
|
|
767
|
+
norm_weights_dict : Dict[str, Array[Float, "nb_norm_samples"] | float | int | None] | None, default=None
|
|
768
|
+
A dict of jnp.array with the same keys as `u_dict`. The importance
|
|
769
|
+
sampling weights for Monte-Carlo integration of the
|
|
770
|
+
normalization constant for each element of u_dict. Must be provided if
|
|
771
|
+
`norm_samples_dict` is provided.
|
|
772
|
+
`norm_weights_dict[key]` should have the same leading dimension as
|
|
773
|
+
`norm_samples_dict[key]` for each `key`.
|
|
774
|
+
Alternatively, the user can pass a float or an integer.
|
|
775
|
+
For each key, an array of similar shape to `norm_samples_dict[key]`
|
|
776
|
+
or shape `(1,)` is expected. These corresponds to the weights $w_k =
|
|
777
|
+
\frac{1}{q(x_k)}$ where $q(\cdot)$ is the proposal p.d.f. and $x_k$ are
|
|
778
|
+
the Monte-Carlo samples. Default is None
|
|
742
779
|
obs_slice_dict : Dict[str, slice | None] | None, default=None
|
|
743
780
|
dict of obs_slice, with keys from `u_dict` to designate the
|
|
744
781
|
output(s) channels that are forced to observed values, for each
|
|
@@ -774,9 +811,9 @@ class SystemLossPDE(eqx.Module):
|
|
|
774
811
|
norm_samples_dict: Dict[str, Float[Array, "nb_norm_samples dimension"]] | None = (
|
|
775
812
|
eqx.field(kw_only=True, default=None)
|
|
776
813
|
)
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
)
|
|
814
|
+
norm_weights_dict: (
|
|
815
|
+
Dict[str, Float[Array, "nb_norm_samples dimension"] | float | int | None] | None
|
|
816
|
+
) = eqx.field(kw_only=True, default=None)
|
|
780
817
|
obs_slice_dict: Dict[str, slice | None] | None = eqx.field(
|
|
781
818
|
kw_only=True, default=None, static=True
|
|
782
819
|
)
|
|
@@ -819,8 +856,8 @@ class SystemLossPDE(eqx.Module):
|
|
|
819
856
|
self.initial_condition_fun_dict = self.u_dict_with_none
|
|
820
857
|
if self.norm_samples_dict is None:
|
|
821
858
|
self.norm_samples_dict = self.u_dict_with_none
|
|
822
|
-
if self.
|
|
823
|
-
self.
|
|
859
|
+
if self.norm_weights_dict is None:
|
|
860
|
+
self.norm_weights_dict = self.u_dict_with_none
|
|
824
861
|
if self.obs_slice_dict is None:
|
|
825
862
|
self.obs_slice_dict = {k: jnp.s_[...] for k in self.u_dict.keys()}
|
|
826
863
|
if self.u_dict.keys() != self.obs_slice_dict.keys():
|
|
@@ -861,7 +898,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
861
898
|
or self.u_dict.keys() != self.omega_boundary_dim_dict.keys()
|
|
862
899
|
or self.u_dict.keys() != self.initial_condition_fun_dict.keys()
|
|
863
900
|
or self.u_dict.keys() != self.norm_samples_dict.keys()
|
|
864
|
-
or self.u_dict.keys() != self.
|
|
901
|
+
or self.u_dict.keys() != self.norm_weights_dict.keys()
|
|
865
902
|
):
|
|
866
903
|
raise ValueError("All the dicts concerning the PINNs should have same keys")
|
|
867
904
|
|
|
@@ -890,7 +927,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
890
927
|
omega_boundary_condition=self.omega_boundary_condition_dict[i],
|
|
891
928
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
892
929
|
norm_samples=self.norm_samples_dict[i],
|
|
893
|
-
|
|
930
|
+
norm_weights=self.norm_weights_dict[i],
|
|
894
931
|
obs_slice=self.obs_slice_dict[i],
|
|
895
932
|
)
|
|
896
933
|
elif self.u_dict[i].eq_type == "nonstatio_PDE":
|
|
@@ -911,7 +948,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
911
948
|
omega_boundary_dim=self.omega_boundary_dim_dict[i],
|
|
912
949
|
initial_condition_fun=self.initial_condition_fun_dict[i],
|
|
913
950
|
norm_samples=self.norm_samples_dict[i],
|
|
914
|
-
|
|
951
|
+
norm_weights=self.norm_weights_dict[i],
|
|
915
952
|
obs_slice=self.obs_slice_dict[i],
|
|
916
953
|
)
|
|
917
954
|
else:
|
|
@@ -1045,7 +1082,7 @@ class SystemLossPDE(eqx.Module):
|
|
|
1045
1082
|
_set_derivatives(params_dict, self.derivative_keys_dyn_loss.dyn_loss),
|
|
1046
1083
|
vmap_in_axes + vmap_in_axes_params,
|
|
1047
1084
|
loss_weight,
|
|
1048
|
-
u_type=
|
|
1085
|
+
u_type=list(self.u_dict.values())[0].__class__.__base__,
|
|
1049
1086
|
)
|
|
1050
1087
|
|
|
1051
1088
|
dyn_loss_mse_dict = jax.tree_util.tree_map(
|
|
@@ -13,8 +13,8 @@ from jax import vmap, grad
|
|
|
13
13
|
import equinox as eqx
|
|
14
14
|
from jinns.utils._utils import get_grid, _subtract_with_check
|
|
15
15
|
from jinns.data._Batchs import *
|
|
16
|
-
from jinns.
|
|
17
|
-
from jinns.
|
|
16
|
+
from jinns.nn._pinn import PINN
|
|
17
|
+
from jinns.nn._spinn import SPINN
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from jinns.utils._types import *
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -19,9 +19,9 @@ from jinns.loss._boundary_conditions import (
|
|
|
19
19
|
from jinns.utils._utils import _subtract_with_check, get_grid
|
|
20
20
|
from jinns.data._DataGenerators import append_obs_batch, make_cartesian_product
|
|
21
21
|
from jinns.parameters._params import _get_vmap_in_axes_params
|
|
22
|
-
from jinns.
|
|
23
|
-
from jinns.
|
|
24
|
-
from jinns.
|
|
22
|
+
from jinns.nn._pinn import PINN
|
|
23
|
+
from jinns.nn._spinn import SPINN
|
|
24
|
+
from jinns.nn._hyperpinn import HyperPINN
|
|
25
25
|
from jinns.data._Batchs import *
|
|
26
26
|
from jinns.parameters._params import Params, ParamsDict
|
|
27
27
|
|
|
@@ -40,13 +40,13 @@ def dynamic_loss_apply(
|
|
|
40
40
|
params: Params | ParamsDict,
|
|
41
41
|
vmap_axes: tuple[int | None, ...],
|
|
42
42
|
loss_weight: float | Float[Array, "dyn_loss_dimension"],
|
|
43
|
-
u_type: PINN |
|
|
43
|
+
u_type: PINN | HyperPINN | None = None,
|
|
44
44
|
) -> float:
|
|
45
45
|
"""
|
|
46
46
|
Sometimes when u is a lambda function a or dict we do not have access to
|
|
47
47
|
its type here, hence the last argument
|
|
48
48
|
"""
|
|
49
|
-
if u_type == PINN or u_type ==
|
|
49
|
+
if u_type == PINN or u_type == HyperPINN or isinstance(u, (PINN, HyperPINN)):
|
|
50
50
|
v_dyn_loss = vmap(
|
|
51
51
|
lambda batch, params: dyn_loss(
|
|
52
52
|
batch, u, params # we must place the params at the end
|
|
@@ -75,23 +75,25 @@ def normalization_loss_apply(
|
|
|
75
75
|
),
|
|
76
76
|
params: Params | ParamsDict,
|
|
77
77
|
vmap_axes_params: tuple[int | None, ...],
|
|
78
|
-
|
|
78
|
+
norm_weights: Float[Array, "nb_norm_samples"],
|
|
79
79
|
loss_weight: float,
|
|
80
80
|
) -> float:
|
|
81
81
|
"""
|
|
82
82
|
Note the squeezing on each result. We expect unidimensional *PINN since
|
|
83
83
|
they represent probability distributions
|
|
84
84
|
"""
|
|
85
|
-
if isinstance(u, (PINN,
|
|
85
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
86
86
|
if len(batches) == 1:
|
|
87
87
|
v_u = vmap(
|
|
88
|
-
lambda b: u(b)[u.slice_solution],
|
|
88
|
+
lambda *b: u(*b)[u.slice_solution],
|
|
89
89
|
(0,) + vmap_axes_params,
|
|
90
90
|
0,
|
|
91
91
|
)
|
|
92
92
|
res = v_u(*batches, params)
|
|
93
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
94
|
+
# Monte-Carlo integration using importance sampling
|
|
93
95
|
mse_norm_loss = loss_weight * (
|
|
94
|
-
jnp.abs(jnp.mean(res.squeeze()
|
|
96
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights) - 1) ** 2
|
|
95
97
|
)
|
|
96
98
|
else:
|
|
97
99
|
# NOTE this cartesian product is costly
|
|
@@ -107,20 +109,23 @@ def normalization_loss_apply(
|
|
|
107
109
|
in_axes=(0,) + vmap_axes_params,
|
|
108
110
|
)
|
|
109
111
|
res = v_u(batches, params)
|
|
110
|
-
|
|
112
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *PINN"
|
|
113
|
+
# For all times t, we perform an integration. Then we average the
|
|
114
|
+
# losses over times.
|
|
111
115
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
112
|
-
jnp.abs(jnp.mean(res.squeeze(), axis=-1)
|
|
116
|
+
jnp.abs(jnp.mean(res.squeeze() * norm_weights, axis=-1) - 1) ** 2
|
|
113
117
|
)
|
|
114
118
|
elif isinstance(u, SPINN):
|
|
115
119
|
if len(batches) == 1:
|
|
116
120
|
res = u(*batches, params)
|
|
121
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
117
122
|
mse_norm_loss = (
|
|
118
123
|
loss_weight
|
|
119
124
|
* jnp.abs(
|
|
120
125
|
jnp.mean(
|
|
121
126
|
res.squeeze(),
|
|
122
127
|
)
|
|
123
|
-
*
|
|
128
|
+
* norm_weights
|
|
124
129
|
- 1
|
|
125
130
|
)
|
|
126
131
|
** 2
|
|
@@ -134,6 +139,7 @@ def normalization_loss_apply(
|
|
|
134
139
|
),
|
|
135
140
|
params,
|
|
136
141
|
)
|
|
142
|
+
assert res.shape[-1] == 1, "norm loss expects unidimensional *SPINN"
|
|
137
143
|
# the outer mean() below is for the times stamps
|
|
138
144
|
mse_norm_loss = loss_weight * jnp.mean(
|
|
139
145
|
jnp.abs(
|
|
@@ -141,7 +147,7 @@ def normalization_loss_apply(
|
|
|
141
147
|
res.squeeze(),
|
|
142
148
|
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
143
149
|
)
|
|
144
|
-
*
|
|
150
|
+
* norm_weights
|
|
145
151
|
- 1
|
|
146
152
|
)
|
|
147
153
|
** 2
|
|
@@ -230,7 +236,7 @@ def observations_loss_apply(
|
|
|
230
236
|
obs_slice: slice,
|
|
231
237
|
) -> float:
|
|
232
238
|
# TODO implement for SPINN
|
|
233
|
-
if isinstance(u, (PINN,
|
|
239
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
234
240
|
v_u = vmap(
|
|
235
241
|
lambda *args: u(*args)[u.slice_solution],
|
|
236
242
|
vmap_axes,
|
|
@@ -264,7 +270,7 @@ def initial_condition_apply(
|
|
|
264
270
|
) -> float:
|
|
265
271
|
n = omega_batch.shape[0]
|
|
266
272
|
t0_omega_batch = jnp.concatenate([jnp.zeros((n, 1)), omega_batch], axis=1)
|
|
267
|
-
if isinstance(u, (PINN,
|
|
273
|
+
if isinstance(u, (PINN, HyperPINN)):
|
|
268
274
|
v_u_t0 = vmap(
|
|
269
275
|
lambda t0_x, params: _subtract_with_check(
|
|
270
276
|
initial_condition_fun(t0_x[1:]),
|
jinns/loss/_operators.py
CHANGED
jinns/nn/__init__.py
ADDED