jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_display.py +78 -21
- jinns/loss/_DynamicLoss.py +405 -907
- jinns/loss/_LossPDE.py +303 -154
- jinns/loss/__init__.py +0 -6
- jinns/loss/_boundary_conditions.py +231 -65
- jinns/loss/_operators.py +201 -45
- jinns/utils/__init__.py +2 -1
- jinns/utils/_pinn.py +308 -0
- jinns/utils/_spinn.py +237 -0
- jinns/utils/_utils.py +32 -306
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/METADATA +15 -2
- jinns-0.5.0.dist-info/RECORD +24 -0
- jinns-0.4.2.dist-info/RECORD +0 -22
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/LICENSE +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/WHEEL +0 -0
- {jinns-0.4.2.dist-info → jinns-0.5.0.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -10,7 +10,13 @@ from jinns.loss._boundary_conditions import (
|
|
|
10
10
|
)
|
|
11
11
|
from jinns.loss._DynamicLoss import ODE, PDEStatio, PDENonStatio
|
|
12
12
|
from jinns.data._DataGenerators import PDEStatioBatch, PDENonStatioBatch
|
|
13
|
-
from jinns.utils._utils import
|
|
13
|
+
from jinns.utils._utils import (
|
|
14
|
+
_get_vmap_in_axes_params,
|
|
15
|
+
_get_grid,
|
|
16
|
+
_check_user_func_return,
|
|
17
|
+
)
|
|
18
|
+
from jinns.utils._pinn import PINN
|
|
19
|
+
from jinns.utils._spinn import SPINN
|
|
14
20
|
from jinns.loss._operators import _sobolev
|
|
15
21
|
|
|
16
22
|
_IMPLEMENTED_BOUNDARY_CONDITIONS = [
|
|
@@ -260,7 +266,7 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
260
266
|
sobolev_m
|
|
261
267
|
An integer. Default is None.
|
|
262
268
|
It corresponds to the Sobolev regularization order as proposed in
|
|
263
|
-
|
|
269
|
+
*Convergence and error analysis of PINNs*,
|
|
264
270
|
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
265
271
|
|
|
266
272
|
|
|
@@ -417,38 +423,67 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
417
423
|
|
|
418
424
|
# dynamic part
|
|
419
425
|
if self.dynamic_loss is not None:
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
x,
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
426
|
+
if isinstance(self.u, PINN):
|
|
427
|
+
v_dyn_loss = vmap(
|
|
428
|
+
lambda x, params: self.dynamic_loss.evaluate(
|
|
429
|
+
x,
|
|
430
|
+
self.u,
|
|
431
|
+
params,
|
|
432
|
+
),
|
|
433
|
+
vmap_in_axes_x + vmap_in_axes_params,
|
|
434
|
+
0,
|
|
435
|
+
)
|
|
436
|
+
mse_dyn_loss = jnp.mean(
|
|
437
|
+
self.loss_weights["dyn_loss"]
|
|
438
|
+
* jnp.sum(v_dyn_loss(omega_batch, params) ** 2, axis=-1)
|
|
439
|
+
)
|
|
440
|
+
elif isinstance(self.u, SPINN):
|
|
441
|
+
residuals = self.dynamic_loss.evaluate(omega_batch, self.u, params)
|
|
442
|
+
mse_dyn_loss = jnp.mean(
|
|
443
|
+
self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
|
|
444
|
+
)
|
|
433
445
|
|
|
434
446
|
else:
|
|
435
447
|
mse_dyn_loss = 0
|
|
436
448
|
|
|
437
449
|
# normalization part
|
|
438
450
|
if self.normalization_loss is not None:
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
451
|
+
if isinstance(self.u, PINN):
|
|
452
|
+
v_u = vmap(
|
|
453
|
+
partial(
|
|
454
|
+
self.u,
|
|
455
|
+
u_params=params["nn_params"],
|
|
456
|
+
eq_params=jax.lax.stop_gradient(params["eq_params"]),
|
|
457
|
+
),
|
|
458
|
+
(0),
|
|
459
|
+
0,
|
|
460
|
+
)
|
|
461
|
+
mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
|
|
462
|
+
jnp.abs(
|
|
463
|
+
jnp.mean(v_u(self.get_norm_samples()), axis=-1)
|
|
464
|
+
* self.int_length
|
|
465
|
+
- 1
|
|
466
|
+
)
|
|
467
|
+
** 2
|
|
468
|
+
)
|
|
469
|
+
elif isinstance(self.u, SPINN):
|
|
470
|
+
norm_samples = self.get_norm_samples()
|
|
471
|
+
res = self.u(
|
|
472
|
+
norm_samples,
|
|
473
|
+
params["nn_params"],
|
|
474
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
475
|
+
)
|
|
476
|
+
mse_norm_loss = self.loss_weights["norm_loss"] * (
|
|
477
|
+
jnp.abs(
|
|
478
|
+
jnp.mean(
|
|
479
|
+
jnp.mean(res, axis=-1),
|
|
480
|
+
axis=tuple(range(res.ndim - 1)),
|
|
481
|
+
)
|
|
482
|
+
* self.int_length
|
|
483
|
+
- 1
|
|
484
|
+
)
|
|
485
|
+
** 2
|
|
486
|
+
)
|
|
452
487
|
else:
|
|
453
488
|
mse_norm_loss = 0
|
|
454
489
|
self.loss_weights["norm_loss"] = 0
|
|
@@ -493,33 +528,46 @@ class LossPDEStatio(LossPDEAbstract):
|
|
|
493
528
|
# NOTE that it does not use jax.lax.stop_gradient on "eq_params" here
|
|
494
529
|
# since we may wish to optimize on it.
|
|
495
530
|
if self.obs_batch is not None:
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
531
|
+
# TODO implement for SPINN
|
|
532
|
+
if isinstance(self.u, PINN):
|
|
533
|
+
v_u = vmap(
|
|
534
|
+
lambda x: self.u(x, params["nn_params"], params["eq_params"]),
|
|
535
|
+
0,
|
|
536
|
+
0,
|
|
537
|
+
)
|
|
538
|
+
mse_observation_loss = jnp.mean(
|
|
539
|
+
self.loss_weights["observations"]
|
|
540
|
+
* jnp.sum(
|
|
541
|
+
(v_u(self.obs_batch[0][:, None]) - self.obs_batch[1]) ** 2,
|
|
542
|
+
axis=-1,
|
|
543
|
+
)
|
|
544
|
+
)
|
|
545
|
+
elif isinstance(self.u, SPINN):
|
|
546
|
+
raise RuntimeError(
|
|
547
|
+
"observation loss term not yet implemented for SPINNs"
|
|
505
548
|
)
|
|
506
|
-
)
|
|
507
549
|
else:
|
|
508
550
|
mse_observation_loss = 0
|
|
509
551
|
self.loss_weights["observations"] = 0
|
|
510
552
|
|
|
511
553
|
# Sobolev regularization
|
|
512
554
|
if self.sobolev_reg is not None:
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
555
|
+
# TODO implement for SPINN
|
|
556
|
+
if isinstance(self.u, PINN):
|
|
557
|
+
v_sob_reg = vmap(
|
|
558
|
+
lambda x: self.sobolev_reg(
|
|
559
|
+
x,
|
|
560
|
+
params["nn_params"],
|
|
561
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
562
|
+
),
|
|
563
|
+
(0, 0),
|
|
564
|
+
0,
|
|
565
|
+
)
|
|
566
|
+
mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
|
|
567
|
+
v_sob_reg(omega_batch)
|
|
568
|
+
)
|
|
569
|
+
elif isinstance(self.u, SPINN):
|
|
570
|
+
raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
|
|
523
571
|
else:
|
|
524
572
|
mse_sobolev_loss = 0
|
|
525
573
|
self.loss_weights["sobolev"] = 0
|
|
@@ -660,7 +708,7 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
660
708
|
sobolev_m
|
|
661
709
|
An integer. Default is None.
|
|
662
710
|
It corresponds to the Sobolev regularization order as proposed in
|
|
663
|
-
|
|
711
|
+
*Convergence and error analysis of PINNs*,
|
|
664
712
|
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
665
713
|
|
|
666
714
|
|
|
@@ -769,21 +817,25 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
769
817
|
|
|
770
818
|
vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
|
|
771
819
|
|
|
772
|
-
|
|
773
|
-
|
|
820
|
+
if isinstance(self.u, PINN):
|
|
821
|
+
omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
|
|
822
|
+
times_batch_ = rep_times(n) # it is repeated
|
|
774
823
|
|
|
775
824
|
# dynamic part
|
|
776
825
|
if self.dynamic_loss is not None:
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
826
|
+
if isinstance(self.u, PINN):
|
|
827
|
+
v_dyn_loss = vmap(
|
|
828
|
+
lambda t, x, params: self.dynamic_loss.evaluate(
|
|
829
|
+
t, x, self.u, params
|
|
830
|
+
),
|
|
831
|
+
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
832
|
+
0,
|
|
833
|
+
)
|
|
834
|
+
residuals = v_dyn_loss(times_batch_, omega_batch_, params)
|
|
835
|
+
mse_dyn_loss = jnp.mean(
|
|
836
|
+
self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
|
|
837
|
+
)
|
|
838
|
+
# TODO implement Causality is all you need (not yet implemented)
|
|
787
839
|
# epsilon = 0.01
|
|
788
840
|
# times_batch_ = jnp.sort(times_batch_)
|
|
789
841
|
# val_dyn_loss = v_dyn_loss(times_batch_, omega_batch_, params)
|
|
@@ -791,33 +843,65 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
791
843
|
# jnp.cumsum(val_dyn_loss)), shift=1,
|
|
792
844
|
# axis=0)) * val_dyn_loss
|
|
793
845
|
# mse_dyn_loss = jnp.mean(causality_is_all_you_need ** 2)
|
|
846
|
+
elif isinstance(self.u, SPINN):
|
|
847
|
+
residuals = self.dynamic_loss.evaluate(
|
|
848
|
+
times_batch, omega_batch, self.u, params
|
|
849
|
+
)
|
|
850
|
+
mse_dyn_loss = jnp.mean(
|
|
851
|
+
self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
|
|
852
|
+
)
|
|
853
|
+
# TODO implement Causality is all you need (not yet implemented)
|
|
854
|
+
# epsilon = 0.01
|
|
855
|
+
# times_batch = jnp.sort(times_batch)
|
|
856
|
+
# residuals = jax.lax.stop_gradient(jnp.roll(jnp.exp(-epsilon *
|
|
857
|
+
# jnp.cumsum(jnp.cumsum(residuals, axis=-2), axis=-1)
|
|
858
|
+
# ),
|
|
859
|
+
# shift=1, axis=0)) * residuals
|
|
794
860
|
else:
|
|
795
861
|
mse_dyn_loss = 0
|
|
796
862
|
|
|
797
863
|
# normalization part
|
|
798
864
|
if self.normalization_loss is not None:
|
|
799
|
-
|
|
800
|
-
vmap(
|
|
801
|
-
|
|
802
|
-
t,
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
865
|
+
if isinstance(self.u, PINN):
|
|
866
|
+
v_u = vmap(
|
|
867
|
+
vmap(
|
|
868
|
+
lambda t, x: self.u(
|
|
869
|
+
t,
|
|
870
|
+
x,
|
|
871
|
+
params["nn_params"],
|
|
872
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
873
|
+
),
|
|
874
|
+
in_axes=(None, 0),
|
|
806
875
|
),
|
|
807
|
-
in_axes=(
|
|
808
|
-
)
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
876
|
+
in_axes=(0, None),
|
|
877
|
+
)
|
|
878
|
+
res = v_u(times_batch, self.get_norm_samples())
|
|
879
|
+
# the outer mean() below is for the times stamps
|
|
880
|
+
mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
|
|
881
|
+
jnp.abs(jnp.mean(res, axis=(-2, -1)) * self.int_length - 1) ** 2
|
|
882
|
+
)
|
|
883
|
+
elif isinstance(self.u, SPINN):
|
|
884
|
+
norm_samples = self.get_norm_samples()
|
|
885
|
+
assert norm_samples.shape[0] % times_batch.shape[0] == 0
|
|
886
|
+
rep_t = norm_samples.shape[0] // times_batch.shape[0]
|
|
887
|
+
res = self.u(
|
|
888
|
+
jnp.repeat(times_batch, rep_t, axis=0),
|
|
889
|
+
norm_samples,
|
|
890
|
+
params["nn_params"],
|
|
891
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
892
|
+
)
|
|
893
|
+
# the outer mean() below is for the times stamps
|
|
894
|
+
mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
|
|
895
|
+
jnp.abs(
|
|
896
|
+
jnp.mean(
|
|
897
|
+
jnp.mean(res, axis=-1),
|
|
898
|
+
axis=(d + 1 for d in range(res.ndim - 2)),
|
|
899
|
+
)
|
|
900
|
+
* self.int_length
|
|
901
|
+
- 1
|
|
902
|
+
)
|
|
903
|
+
** 2
|
|
818
904
|
)
|
|
819
|
-
** 2
|
|
820
|
-
)
|
|
821
905
|
|
|
822
906
|
else:
|
|
823
907
|
mse_norm_loss = 0
|
|
@@ -860,23 +944,40 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
860
944
|
else:
|
|
861
945
|
mse_boundary_loss = 0
|
|
862
946
|
|
|
863
|
-
#
|
|
947
|
+
# initial condition
|
|
864
948
|
if self.initial_condition_fun is not None:
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
949
|
+
if isinstance(self.u, PINN):
|
|
950
|
+
v_u_t0 = vmap(
|
|
951
|
+
lambda x: self.initial_condition_fun(x)
|
|
952
|
+
- self.u(
|
|
953
|
+
t=jnp.zeros((1,)),
|
|
954
|
+
x=x,
|
|
955
|
+
u_params=params["nn_params"],
|
|
956
|
+
eq_params=jax.lax.stop_gradient(params["eq_params"]),
|
|
957
|
+
),
|
|
958
|
+
(0),
|
|
959
|
+
0,
|
|
960
|
+
)
|
|
961
|
+
res = v_u_t0(omega_batch)
|
|
962
|
+
mse_initial_condition = jnp.mean(
|
|
963
|
+
self.loss_weights["initial_condition"] * jnp.sum(res**2, axis=-1)
|
|
964
|
+
)
|
|
965
|
+
elif isinstance(self.u, SPINN):
|
|
966
|
+
values = lambda x: self.u(
|
|
967
|
+
jnp.repeat(jnp.zeros((1, 1)), omega_batch.shape[0], axis=0),
|
|
968
|
+
x,
|
|
969
|
+
params["nn_params"],
|
|
970
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
971
|
+
)[0]
|
|
972
|
+
omega_batch_grid = _get_grid(omega_batch)
|
|
973
|
+
v_ini = values(omega_batch)
|
|
974
|
+
ini = _check_user_func_return(
|
|
975
|
+
self.initial_condition_fun(omega_batch_grid), v_ini.shape
|
|
976
|
+
)
|
|
977
|
+
res = ini - v_ini
|
|
978
|
+
mse_initial_condition = jnp.mean(
|
|
979
|
+
self.loss_weights["initial_condition"] * jnp.sum(res**2, axis=-1)
|
|
980
|
+
)
|
|
880
981
|
else:
|
|
881
982
|
mse_initial_condition = 0
|
|
882
983
|
|
|
@@ -884,41 +985,51 @@ class LossPDENonStatio(LossPDEStatio):
|
|
|
884
985
|
# NOTE that it does not use jax.lax.stop_gradient on "eq_params" here
|
|
885
986
|
# since we may wish to optimize on it.
|
|
886
987
|
if self.obs_batch is not None:
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
(
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
988
|
+
# TODO implement for SPINN
|
|
989
|
+
if isinstance(self.u, PINN):
|
|
990
|
+
v_u = vmap(
|
|
991
|
+
lambda t, x: self.u(t, x, params["nn_params"], params["eq_params"]),
|
|
992
|
+
(0, 0),
|
|
993
|
+
0,
|
|
994
|
+
)
|
|
995
|
+
mse_observation_loss = jnp.mean(
|
|
996
|
+
self.loss_weights["observations"]
|
|
997
|
+
* jnp.mean(
|
|
998
|
+
(
|
|
999
|
+
v_u(self.obs_batch[0][:, None], self.obs_batch[1])
|
|
1000
|
+
- self.obs_batch[2]
|
|
1001
|
+
)
|
|
1002
|
+
** 2,
|
|
1003
|
+
axis=-1,
|
|
898
1004
|
)
|
|
899
|
-
** 2,
|
|
900
|
-
axis=0,
|
|
901
1005
|
)
|
|
902
|
-
)
|
|
1006
|
+
elif isinstance(self.u, SPINN):
|
|
1007
|
+
raise RuntimeError(
|
|
1008
|
+
"observation loss term not yet implemented for SPINNs"
|
|
1009
|
+
)
|
|
903
1010
|
else:
|
|
904
1011
|
mse_observation_loss = 0
|
|
905
1012
|
self.loss_weights["observations"] = 0
|
|
906
1013
|
|
|
907
1014
|
# Sobolev regularization
|
|
908
1015
|
if self.sobolev_reg is not None:
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
x
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
1016
|
+
# TODO implement for SPINN
|
|
1017
|
+
if isinstance(self.u, PINN):
|
|
1018
|
+
v_sob_reg = vmap(
|
|
1019
|
+
lambda t, x: self.sobolev_reg(
|
|
1020
|
+
t,
|
|
1021
|
+
x,
|
|
1022
|
+
params["nn_params"],
|
|
1023
|
+
jax.lax.stop_gradient(params["eq_params"]),
|
|
1024
|
+
),
|
|
1025
|
+
(0, 0),
|
|
1026
|
+
0,
|
|
1027
|
+
)
|
|
1028
|
+
mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
|
|
1029
|
+
v_sob_reg(omega_batch_, times_batch_)
|
|
1030
|
+
)
|
|
1031
|
+
elif isinstance(self.u, SPINN):
|
|
1032
|
+
raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
|
|
922
1033
|
else:
|
|
923
1034
|
mse_sobolev_loss = 0
|
|
924
1035
|
self.loss_weights["sobolev"] = 0
|
|
@@ -1029,7 +1140,7 @@ class SystemLossPDE:
|
|
|
1029
1140
|
nn_type_dict
|
|
1030
1141
|
A dict whose keys are that of u_dict whose value is either
|
|
1031
1142
|
`nn_statio` or `nn_nonstatio` which signifies either the PINN has a
|
|
1032
|
-
time component in input or not
|
|
1143
|
+
time component in input or not.
|
|
1033
1144
|
omega_boundary_fun_dict
|
|
1034
1145
|
A dict of functions to be matched in the border condition, or a
|
|
1035
1146
|
dict of dict of functions (see doc for `omega_boundary_fun` in
|
|
@@ -1069,7 +1180,7 @@ class SystemLossPDE:
|
|
|
1069
1180
|
Default is None. A dictionary of integers, one per key which must
|
|
1070
1181
|
match `u_dict`.
|
|
1071
1182
|
It corresponds to the Sobolev regularization order as proposed in
|
|
1072
|
-
|
|
1183
|
+
*Convergence and error analysis of PINNs*,
|
|
1073
1184
|
Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
|
|
1074
1185
|
|
|
1075
1186
|
|
|
@@ -1134,6 +1245,8 @@ class SystemLossPDE:
|
|
|
1134
1245
|
|
|
1135
1246
|
self.dynamic_loss_dict = dynamic_loss_dict
|
|
1136
1247
|
self.u_dict = u_dict
|
|
1248
|
+
# TODO nn_type should become a class attribute now that we have PINN
|
|
1249
|
+
# class and SPINNs class
|
|
1137
1250
|
self.nn_type_dict = nn_type_dict
|
|
1138
1251
|
|
|
1139
1252
|
self.loss_weights = loss_weights # This calls the setter
|
|
@@ -1189,6 +1302,15 @@ class SystemLossPDE:
|
|
|
1189
1302
|
f"Wrong value for nn_type_dict[i], got " "{nn_type_dict[i]}"
|
|
1190
1303
|
)
|
|
1191
1304
|
|
|
1305
|
+
# also make sure we only have PINNs or SPINNs
|
|
1306
|
+
if not (
|
|
1307
|
+
all(type(value) == PINN for value in u_dict.values())
|
|
1308
|
+
or all(type(value) == SPINN for value in u_dict.values())
|
|
1309
|
+
):
|
|
1310
|
+
raise ValueError(
|
|
1311
|
+
"We only accept dictionary of PINNs or dictionary" " of SPINNs"
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1192
1314
|
@property
|
|
1193
1315
|
def loss_weights(self):
|
|
1194
1316
|
return self._loss_weights
|
|
@@ -1310,39 +1432,66 @@ class SystemLossPDE:
|
|
|
1310
1432
|
for i in self.dynamic_loss_dict.keys():
|
|
1311
1433
|
# dynamic part
|
|
1312
1434
|
if isinstance(self.dynamic_loss_dict[i], PDEStatio):
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
params_dict
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1435
|
+
# Below we just look at the first element because we suppose we
|
|
1436
|
+
# must only have SPINNs or only PINNs
|
|
1437
|
+
if isinstance(list(self.u_dict.values())[0], PINN):
|
|
1438
|
+
v_dyn_loss = vmap(
|
|
1439
|
+
lambda x, params_dict: self.dynamic_loss_dict[i].evaluate(
|
|
1440
|
+
x,
|
|
1441
|
+
self.u_dict,
|
|
1442
|
+
params_dict,
|
|
1443
|
+
),
|
|
1444
|
+
vmap_in_axes_x + vmap_in_axes_params,
|
|
1445
|
+
0,
|
|
1446
|
+
)
|
|
1447
|
+
mse_dyn_loss += jnp.mean(
|
|
1448
|
+
self._loss_weights["dyn_loss"][i]
|
|
1449
|
+
* jnp.sum(v_dyn_loss(omega_batch, params_dict) ** 2, axis=-1)
|
|
1450
|
+
)
|
|
1451
|
+
elif isinstance(list(self.u_dict.values())[0], SPINN):
|
|
1452
|
+
residuals = self.dynamic_loss_dict[i].evaluate(
|
|
1453
|
+
omega_batch, self.u_dict, params_dict
|
|
1454
|
+
)
|
|
1455
|
+
mse_dyn_loss += jnp.mean(
|
|
1456
|
+
self._loss_weights["dyn_loss"][i]
|
|
1457
|
+
* jnp.sum(
|
|
1458
|
+
residuals**2,
|
|
1459
|
+
axis=-1,
|
|
1460
|
+
)
|
|
1461
|
+
)
|
|
1326
1462
|
else:
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
t, x, self.
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1463
|
+
if isinstance(list(self.u_dict.values())[0], PINN):
|
|
1464
|
+
v_dyn_loss = vmap(
|
|
1465
|
+
lambda t, x, params_dict: self.dynamic_loss_dict[i].evaluate(
|
|
1466
|
+
t, x, self.u_dict, params_dict
|
|
1467
|
+
),
|
|
1468
|
+
vmap_in_axes_x_t + vmap_in_axes_params,
|
|
1469
|
+
0,
|
|
1470
|
+
)
|
|
1334
1471
|
|
|
1335
|
-
|
|
1472
|
+
tile_omega_batch = jnp.tile(omega_batch, reps=(nt, 1))
|
|
1336
1473
|
|
|
1337
|
-
|
|
1338
|
-
|
|
1474
|
+
omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
|
|
1475
|
+
times_batch_ = rep_times(n) # it is repeated
|
|
1339
1476
|
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1477
|
+
mse_dyn_loss += jnp.mean(
|
|
1478
|
+
self._loss_weights["dyn_loss"][i]
|
|
1479
|
+
* jnp.sum(
|
|
1480
|
+
v_dyn_loss(times_batch_, omega_batch_, params_dict) ** 2,
|
|
1481
|
+
axis=-1,
|
|
1482
|
+
)
|
|
1483
|
+
)
|
|
1484
|
+
elif isinstance(list(self.u_dict.values())[0], SPINN):
|
|
1485
|
+
residuals = self.dynamic_loss_dict[i].evaluate(
|
|
1486
|
+
times_batch, omega_batch, self.u_dict, params_dict
|
|
1487
|
+
)
|
|
1488
|
+
mse_dyn_loss += jnp.mean(
|
|
1489
|
+
self._loss_weights["dyn_loss"][i]
|
|
1490
|
+
* jnp.sum(
|
|
1491
|
+
residuals**2,
|
|
1492
|
+
axis=-1,
|
|
1493
|
+
)
|
|
1344
1494
|
)
|
|
1345
|
-
)
|
|
1346
1495
|
|
|
1347
1496
|
# boundary conditions, normalization conditions, observation_loss,
|
|
1348
1497
|
# initial condition... loss this is done via the internal
|
jinns/loss/__init__.py
CHANGED
|
@@ -4,12 +4,6 @@ from ._DynamicLoss import (
|
|
|
4
4
|
Malthus,
|
|
5
5
|
BurgerEquation,
|
|
6
6
|
GeneralizedLotkaVolterra,
|
|
7
|
-
OU_FPEStatioLoss1D,
|
|
8
|
-
CIR_FPEStatioLoss1D,
|
|
9
|
-
OU_FPEStatioLoss2D,
|
|
10
|
-
OU_FPENonStatioLoss1D,
|
|
11
|
-
CIR_FPENonStatioLoss1D,
|
|
12
|
-
Sinus_FPENonStatioLoss1D,
|
|
13
7
|
OU_FPENonStatioLoss2D,
|
|
14
8
|
ConvectionDiffusionNonStatio,
|
|
15
9
|
MassConservation2DStatio,
|