jinns 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jinns/loss/_LossPDE.py CHANGED
@@ -10,7 +10,13 @@ from jinns.loss._boundary_conditions import (
10
10
  )
11
11
  from jinns.loss._DynamicLoss import ODE, PDEStatio, PDENonStatio
12
12
  from jinns.data._DataGenerators import PDEStatioBatch, PDENonStatioBatch
13
- from jinns.utils._utils import _get_vmap_in_axes_params
13
+ from jinns.utils._utils import (
14
+ _get_vmap_in_axes_params,
15
+ _get_grid,
16
+ _check_user_func_return,
17
+ )
18
+ from jinns.utils._pinn import PINN
19
+ from jinns.utils._spinn import SPINN
14
20
  from jinns.loss._operators import _sobolev
15
21
 
16
22
  _IMPLEMENTED_BOUNDARY_CONDITIONS = [
@@ -260,7 +266,7 @@ class LossPDEStatio(LossPDEAbstract):
260
266
  sobolev_m
261
267
  An integer. Default is None.
262
268
  It corresponds to the Sobolev regularization order as proposed in
263
- _Convergence and error analysis of PINNs_,
269
+ *Convergence and error analysis of PINNs*,
264
270
  Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
265
271
 
266
272
 
@@ -417,38 +423,67 @@ class LossPDEStatio(LossPDEAbstract):
417
423
 
418
424
  # dynamic part
419
425
  if self.dynamic_loss is not None:
420
- v_dyn_loss = vmap(
421
- lambda x, params: self.dynamic_loss.evaluate(
422
- x,
423
- self.u,
424
- params,
425
- ),
426
- vmap_in_axes_x + vmap_in_axes_params,
427
- 0,
428
- )
429
- mse_dyn_loss = jnp.mean(
430
- self.loss_weights["dyn_loss"]
431
- * jnp.mean(v_dyn_loss(omega_batch, params) ** 2, axis=0)
432
- )
426
+ if isinstance(self.u, PINN):
427
+ v_dyn_loss = vmap(
428
+ lambda x, params: self.dynamic_loss.evaluate(
429
+ x,
430
+ self.u,
431
+ params,
432
+ ),
433
+ vmap_in_axes_x + vmap_in_axes_params,
434
+ 0,
435
+ )
436
+ mse_dyn_loss = jnp.mean(
437
+ self.loss_weights["dyn_loss"]
438
+ * jnp.sum(v_dyn_loss(omega_batch, params) ** 2, axis=-1)
439
+ )
440
+ elif isinstance(self.u, SPINN):
441
+ residuals = self.dynamic_loss.evaluate(omega_batch, self.u, params)
442
+ mse_dyn_loss = jnp.mean(
443
+ self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
444
+ )
433
445
 
434
446
  else:
435
447
  mse_dyn_loss = 0
436
448
 
437
449
  # normalization part
438
450
  if self.normalization_loss is not None:
439
- v_u = vmap(
440
- partial(
441
- self.u,
442
- u_params=params["nn_params"],
443
- eq_params=jax.lax.stop_gradient(params["eq_params"]),
444
- ),
445
- (0),
446
- 0,
447
- )
448
- mse_norm_loss = self.loss_weights["norm_loss"] * (
449
- jnp.abs(jnp.mean(v_u(self.get_norm_samples())) * self.int_length - 1)
450
- ** 2
451
- )
451
+ if isinstance(self.u, PINN):
452
+ v_u = vmap(
453
+ partial(
454
+ self.u,
455
+ u_params=params["nn_params"],
456
+ eq_params=jax.lax.stop_gradient(params["eq_params"]),
457
+ ),
458
+ (0),
459
+ 0,
460
+ )
461
+ mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
462
+ jnp.abs(
463
+ jnp.mean(v_u(self.get_norm_samples()), axis=-1)
464
+ * self.int_length
465
+ - 1
466
+ )
467
+ ** 2
468
+ )
469
+ elif isinstance(self.u, SPINN):
470
+ norm_samples = self.get_norm_samples()
471
+ res = self.u(
472
+ norm_samples,
473
+ params["nn_params"],
474
+ jax.lax.stop_gradient(params["eq_params"]),
475
+ )
476
+ mse_norm_loss = self.loss_weights["norm_loss"] * (
477
+ jnp.abs(
478
+ jnp.mean(
479
+ jnp.mean(res, axis=-1),
480
+ axis=tuple(range(res.ndim - 1)),
481
+ )
482
+ * self.int_length
483
+ - 1
484
+ )
485
+ ** 2
486
+ )
452
487
  else:
453
488
  mse_norm_loss = 0
454
489
  self.loss_weights["norm_loss"] = 0
@@ -493,33 +528,46 @@ class LossPDEStatio(LossPDEAbstract):
493
528
  # NOTE that it does not use jax.lax.stop_gradient on "eq_params" here
494
529
  # since we may wish to optimize on it.
495
530
  if self.obs_batch is not None:
496
- v_u = vmap(
497
- lambda x: self.u(x, params["nn_params"], params["eq_params"]),
498
- 0,
499
- 0,
500
- )
501
- mse_observation_loss = jnp.mean(
502
- self.loss_weights["observations"]
503
- * jnp.mean(
504
- (v_u(self.obs_batch[0][:, None]) - self.obs_batch[1]) ** 2, axis=0
531
+ # TODO implement for SPINN
532
+ if isinstance(self.u, PINN):
533
+ v_u = vmap(
534
+ lambda x: self.u(x, params["nn_params"], params["eq_params"]),
535
+ 0,
536
+ 0,
537
+ )
538
+ mse_observation_loss = jnp.mean(
539
+ self.loss_weights["observations"]
540
+ * jnp.sum(
541
+ (v_u(self.obs_batch[0][:, None]) - self.obs_batch[1]) ** 2,
542
+ axis=-1,
543
+ )
544
+ )
545
+ elif isinstance(self.u, SPINN):
546
+ raise RuntimeError(
547
+ "observation loss term not yet implemented for SPINNs"
505
548
  )
506
- )
507
549
  else:
508
550
  mse_observation_loss = 0
509
551
  self.loss_weights["observations"] = 0
510
552
 
511
553
  # Sobolev regularization
512
554
  if self.sobolev_reg is not None:
513
- v_sob_reg = vmap(
514
- lambda x: self.sobolev_reg(
515
- x, params["nn_params"], jax.lax.stop_gradient(params["eq_params"])
516
- ),
517
- (0, 0),
518
- 0,
519
- )
520
- mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
521
- v_sob_reg(omega_batch)
522
- )
555
+ # TODO implement for SPINN
556
+ if isinstance(self.u, PINN):
557
+ v_sob_reg = vmap(
558
+ lambda x: self.sobolev_reg(
559
+ x,
560
+ params["nn_params"],
561
+ jax.lax.stop_gradient(params["eq_params"]),
562
+ ),
563
+ (0, 0),
564
+ 0,
565
+ )
566
+ mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
567
+ v_sob_reg(omega_batch)
568
+ )
569
+ elif isinstance(self.u, SPINN):
570
+ raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
523
571
  else:
524
572
  mse_sobolev_loss = 0
525
573
  self.loss_weights["sobolev"] = 0
@@ -660,7 +708,7 @@ class LossPDENonStatio(LossPDEStatio):
660
708
  sobolev_m
661
709
  An integer. Default is None.
662
710
  It corresponds to the Sobolev regularization order as proposed in
663
- _Convergence and error analysis of PINNs_,
711
+ *Convergence and error analysis of PINNs*,
664
712
  Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
665
713
 
666
714
 
@@ -769,21 +817,25 @@ class LossPDENonStatio(LossPDEStatio):
769
817
 
770
818
  vmap_in_axes_params = _get_vmap_in_axes_params(batch.param_batch_dict, params)
771
819
 
772
- omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
773
- times_batch_ = rep_times(n) # it is repeated
820
+ if isinstance(self.u, PINN):
821
+ omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
822
+ times_batch_ = rep_times(n) # it is repeated
774
823
 
775
824
  # dynamic part
776
825
  if self.dynamic_loss is not None:
777
- v_dyn_loss = vmap(
778
- lambda t, x, params: self.dynamic_loss.evaluate(t, x, self.u, params),
779
- vmap_in_axes_x_t + vmap_in_axes_params,
780
- 0,
781
- )
782
- mse_dyn_loss = jnp.mean(
783
- self.loss_weights["dyn_loss"]
784
- * jnp.mean(v_dyn_loss(times_batch_, omega_batch_, params) ** 2, axis=0)
785
- )
786
- # OR for Causality is all you need (not yet implemented)
826
+ if isinstance(self.u, PINN):
827
+ v_dyn_loss = vmap(
828
+ lambda t, x, params: self.dynamic_loss.evaluate(
829
+ t, x, self.u, params
830
+ ),
831
+ vmap_in_axes_x_t + vmap_in_axes_params,
832
+ 0,
833
+ )
834
+ residuals = v_dyn_loss(times_batch_, omega_batch_, params)
835
+ mse_dyn_loss = jnp.mean(
836
+ self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
837
+ )
838
+ # TODO implement Causality is all you need (not yet implemented)
787
839
  # epsilon = 0.01
788
840
  # times_batch_ = jnp.sort(times_batch_)
789
841
  # val_dyn_loss = v_dyn_loss(times_batch_, omega_batch_, params)
@@ -791,33 +843,65 @@ class LossPDENonStatio(LossPDEStatio):
791
843
  # jnp.cumsum(val_dyn_loss)), shift=1,
792
844
  # axis=0)) * val_dyn_loss
793
845
  # mse_dyn_loss = jnp.mean(causality_is_all_you_need ** 2)
846
+ elif isinstance(self.u, SPINN):
847
+ residuals = self.dynamic_loss.evaluate(
848
+ times_batch, omega_batch, self.u, params
849
+ )
850
+ mse_dyn_loss = jnp.mean(
851
+ self.loss_weights["dyn_loss"] * jnp.sum(residuals**2, axis=-1)
852
+ )
853
+ # TODO implement Causality is all you need (not yet implemented)
854
+ # epsilon = 0.01
855
+ # times_batch = jnp.sort(times_batch)
856
+ # residuals = jax.lax.stop_gradient(jnp.roll(jnp.exp(-epsilon *
857
+ # jnp.cumsum(jnp.cumsum(residuals, axis=-2), axis=-1)
858
+ # ),
859
+ # shift=1, axis=0)) * residuals
794
860
  else:
795
861
  mse_dyn_loss = 0
796
862
 
797
863
  # normalization part
798
864
  if self.normalization_loss is not None:
799
- v_u = vmap(
800
- vmap(
801
- lambda t, x: self.u(
802
- t,
803
- x,
804
- params["nn_params"],
805
- jax.lax.stop_gradient(params["eq_params"]),
865
+ if isinstance(self.u, PINN):
866
+ v_u = vmap(
867
+ vmap(
868
+ lambda t, x: self.u(
869
+ t,
870
+ x,
871
+ params["nn_params"],
872
+ jax.lax.stop_gradient(params["eq_params"]),
873
+ ),
874
+ in_axes=(None, 0),
806
875
  ),
807
- in_axes=(None, 0),
808
- ),
809
- in_axes=(0, None),
810
- ) # Note that it is not faster to have it as a static
811
- # attribute
812
- mse_norm_loss = self.loss_weights["norm_loss"] * jnp.sum(
813
- (1 / nt)
814
- * jnp.abs(
815
- jnp.mean(v_u(times_batch, self.get_norm_samples()), axis=-1)
816
- * self.int_length
817
- - 1
876
+ in_axes=(0, None),
877
+ )
878
+ res = v_u(times_batch, self.get_norm_samples())
879
+ # the outer mean() below is for the times stamps
880
+ mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
881
+ jnp.abs(jnp.mean(res, axis=(-2, -1)) * self.int_length - 1) ** 2
882
+ )
883
+ elif isinstance(self.u, SPINN):
884
+ norm_samples = self.get_norm_samples()
885
+ assert norm_samples.shape[0] % times_batch.shape[0] == 0
886
+ rep_t = norm_samples.shape[0] // times_batch.shape[0]
887
+ res = self.u(
888
+ jnp.repeat(times_batch, rep_t, axis=0),
889
+ norm_samples,
890
+ params["nn_params"],
891
+ jax.lax.stop_gradient(params["eq_params"]),
892
+ )
893
+ # the outer mean() below is for the times stamps
894
+ mse_norm_loss = self.loss_weights["norm_loss"] * jnp.mean(
895
+ jnp.abs(
896
+ jnp.mean(
897
+ jnp.mean(res, axis=-1),
898
+ axis=(d + 1 for d in range(res.ndim - 2)),
899
+ )
900
+ * self.int_length
901
+ - 1
902
+ )
903
+ ** 2
818
904
  )
819
- ** 2
820
- )
821
905
 
822
906
  else:
823
907
  mse_norm_loss = 0
@@ -860,23 +944,40 @@ class LossPDENonStatio(LossPDEStatio):
860
944
  else:
861
945
  mse_boundary_loss = 0
862
946
 
863
- # temporal part
947
+ # initial condition
864
948
  if self.initial_condition_fun is not None:
865
- v_u_t0 = vmap(
866
- lambda x: self.initial_condition_fun(x)
867
- - self.u(
868
- t=jnp.zeros((1,)),
869
- x=x,
870
- u_params=params["nn_params"],
871
- eq_params=jax.lax.stop_gradient(params["eq_params"]),
872
- ),
873
- (0),
874
- 0,
875
- )
876
- mse_initial_condition = jnp.mean(
877
- self.loss_weights["initial_condition"]
878
- * jnp.mean(v_u_t0(omega_batch) ** 2, axis=0)
879
- )
949
+ if isinstance(self.u, PINN):
950
+ v_u_t0 = vmap(
951
+ lambda x: self.initial_condition_fun(x)
952
+ - self.u(
953
+ t=jnp.zeros((1,)),
954
+ x=x,
955
+ u_params=params["nn_params"],
956
+ eq_params=jax.lax.stop_gradient(params["eq_params"]),
957
+ ),
958
+ (0),
959
+ 0,
960
+ )
961
+ res = v_u_t0(omega_batch)
962
+ mse_initial_condition = jnp.mean(
963
+ self.loss_weights["initial_condition"] * jnp.sum(res**2, axis=-1)
964
+ )
965
+ elif isinstance(self.u, SPINN):
966
+ values = lambda x: self.u(
967
+ jnp.repeat(jnp.zeros((1, 1)), omega_batch.shape[0], axis=0),
968
+ x,
969
+ params["nn_params"],
970
+ jax.lax.stop_gradient(params["eq_params"]),
971
+ )[0]
972
+ omega_batch_grid = _get_grid(omega_batch)
973
+ v_ini = values(omega_batch)
974
+ ini = _check_user_func_return(
975
+ self.initial_condition_fun(omega_batch_grid), v_ini.shape
976
+ )
977
+ res = ini - v_ini
978
+ mse_initial_condition = jnp.mean(
979
+ self.loss_weights["initial_condition"] * jnp.sum(res**2, axis=-1)
980
+ )
880
981
  else:
881
982
  mse_initial_condition = 0
882
983
 
@@ -884,41 +985,51 @@ class LossPDENonStatio(LossPDEStatio):
884
985
  # NOTE that it does not use jax.lax.stop_gradient on "eq_params" here
885
986
  # since we may wish to optimize on it.
886
987
  if self.obs_batch is not None:
887
- v_u = vmap(
888
- lambda t, x: self.u(t, x, params["nn_params"], params["eq_params"]),
889
- (0, 0),
890
- 0,
891
- )
892
- mse_observation_loss = jnp.mean(
893
- self.loss_weights["observations"]
894
- * jnp.mean(
895
- (
896
- v_u(self.obs_batch[0][:, None], self.obs_batch[1])
897
- - self.obs_batch[2]
988
+ # TODO implement for SPINN
989
+ if isinstance(self.u, PINN):
990
+ v_u = vmap(
991
+ lambda t, x: self.u(t, x, params["nn_params"], params["eq_params"]),
992
+ (0, 0),
993
+ 0,
994
+ )
995
+ mse_observation_loss = jnp.mean(
996
+ self.loss_weights["observations"]
997
+ * jnp.mean(
998
+ (
999
+ v_u(self.obs_batch[0][:, None], self.obs_batch[1])
1000
+ - self.obs_batch[2]
1001
+ )
1002
+ ** 2,
1003
+ axis=-1,
898
1004
  )
899
- ** 2,
900
- axis=0,
901
1005
  )
902
- )
1006
+ elif isinstance(self.u, SPINN):
1007
+ raise RuntimeError(
1008
+ "observation loss term not yet implemented for SPINNs"
1009
+ )
903
1010
  else:
904
1011
  mse_observation_loss = 0
905
1012
  self.loss_weights["observations"] = 0
906
1013
 
907
1014
  # Sobolev regularization
908
1015
  if self.sobolev_reg is not None:
909
- v_sob_reg = vmap(
910
- lambda t, x: self.sobolev_reg(
911
- t,
912
- x,
913
- params["nn_params"],
914
- jax.lax.stop_gradient(params["eq_params"]),
915
- ),
916
- (0, 0),
917
- 0,
918
- )
919
- mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
920
- v_sob_reg(omega_batch_, times_batch_)
921
- )
1016
+ # TODO implement for SPINN
1017
+ if isinstance(self.u, PINN):
1018
+ v_sob_reg = vmap(
1019
+ lambda t, x: self.sobolev_reg(
1020
+ t,
1021
+ x,
1022
+ params["nn_params"],
1023
+ jax.lax.stop_gradient(params["eq_params"]),
1024
+ ),
1025
+ (0, 0),
1026
+ 0,
1027
+ )
1028
+ mse_sobolev_loss = self.loss_weights["sobolev"] * jnp.mean(
1029
+ v_sob_reg(omega_batch_, times_batch_)
1030
+ )
1031
+ elif isinstance(self.u, SPINN):
1032
+ raise RuntimeError("Sobolev loss term not yet implemented for SPINNs")
922
1033
  else:
923
1034
  mse_sobolev_loss = 0
924
1035
  self.loss_weights["sobolev"] = 0
@@ -1029,7 +1140,7 @@ class SystemLossPDE:
1029
1140
  nn_type_dict
1030
1141
  A dict whose keys are that of u_dict whose value is either
1031
1142
  `nn_statio` or `nn_nonstatio` which signifies either the PINN has a
1032
- time component in input or not
1143
+ time component in input or not.
1033
1144
  omega_boundary_fun_dict
1034
1145
  A dict of functions to be matched in the border condition, or a
1035
1146
  dict of dict of functions (see doc for `omega_boundary_fun` in
@@ -1069,7 +1180,7 @@ class SystemLossPDE:
1069
1180
  Default is None. A dictionary of integers, one per key which must
1070
1181
  match `u_dict`.
1071
1182
  It corresponds to the Sobolev regularization order as proposed in
1072
- _Convergence and error analysis of PINNs_,
1183
+ *Convergence and error analysis of PINNs*,
1073
1184
  Doumeche et al., 2023, https://arxiv.org/pdf/2305.01240.pdf
1074
1185
 
1075
1186
 
@@ -1134,6 +1245,8 @@ class SystemLossPDE:
1134
1245
 
1135
1246
  self.dynamic_loss_dict = dynamic_loss_dict
1136
1247
  self.u_dict = u_dict
1248
+ # TODO nn_type should become a class attribute now that we have PINN
1249
+ # class and SPINNs class
1137
1250
  self.nn_type_dict = nn_type_dict
1138
1251
 
1139
1252
  self.loss_weights = loss_weights # This calls the setter
@@ -1189,6 +1302,15 @@ class SystemLossPDE:
1189
1302
  f"Wrong value for nn_type_dict[i], got " "{nn_type_dict[i]}"
1190
1303
  )
1191
1304
 
1305
+ # also make sure we only have PINNs or SPINNs
1306
+ if not (
1307
+ all(type(value) == PINN for value in u_dict.values())
1308
+ or all(type(value) == SPINN for value in u_dict.values())
1309
+ ):
1310
+ raise ValueError(
1311
+ "We only accept dictionary of PINNs or dictionary" " of SPINNs"
1312
+ )
1313
+
1192
1314
  @property
1193
1315
  def loss_weights(self):
1194
1316
  return self._loss_weights
@@ -1310,39 +1432,66 @@ class SystemLossPDE:
1310
1432
  for i in self.dynamic_loss_dict.keys():
1311
1433
  # dynamic part
1312
1434
  if isinstance(self.dynamic_loss_dict[i], PDEStatio):
1313
- v_dyn_loss = vmap(
1314
- lambda x, params_dict: self.dynamic_loss_dict[i].evaluate(
1315
- x,
1316
- self.u_dict,
1317
- params_dict,
1318
- ),
1319
- vmap_in_axes_x + vmap_in_axes_params,
1320
- 0,
1321
- )
1322
- mse_dyn_loss += jnp.mean(
1323
- self._loss_weights["dyn_loss"][i]
1324
- * jnp.mean(v_dyn_loss(omega_batch, params_dict) ** 2, axis=0)
1325
- )
1435
+ # Below we just look at the first element because we suppose we
1436
+ # must only have SPINNs or only PINNs
1437
+ if isinstance(list(self.u_dict.values())[0], PINN):
1438
+ v_dyn_loss = vmap(
1439
+ lambda x, params_dict: self.dynamic_loss_dict[i].evaluate(
1440
+ x,
1441
+ self.u_dict,
1442
+ params_dict,
1443
+ ),
1444
+ vmap_in_axes_x + vmap_in_axes_params,
1445
+ 0,
1446
+ )
1447
+ mse_dyn_loss += jnp.mean(
1448
+ self._loss_weights["dyn_loss"][i]
1449
+ * jnp.sum(v_dyn_loss(omega_batch, params_dict) ** 2, axis=-1)
1450
+ )
1451
+ elif isinstance(list(self.u_dict.values())[0], SPINN):
1452
+ residuals = self.dynamic_loss_dict[i].evaluate(
1453
+ omega_batch, self.u_dict, params_dict
1454
+ )
1455
+ mse_dyn_loss += jnp.mean(
1456
+ self._loss_weights["dyn_loss"][i]
1457
+ * jnp.sum(
1458
+ residuals**2,
1459
+ axis=-1,
1460
+ )
1461
+ )
1326
1462
  else:
1327
- v_dyn_loss = vmap(
1328
- lambda t, x, params_dict: self.dynamic_loss_dict[i].evaluate(
1329
- t, x, self.u_dict, params_dict
1330
- ),
1331
- vmap_in_axes_x_t + vmap_in_axes_params,
1332
- 0,
1333
- )
1463
+ if isinstance(list(self.u_dict.values())[0], PINN):
1464
+ v_dyn_loss = vmap(
1465
+ lambda t, x, params_dict: self.dynamic_loss_dict[i].evaluate(
1466
+ t, x, self.u_dict, params_dict
1467
+ ),
1468
+ vmap_in_axes_x_t + vmap_in_axes_params,
1469
+ 0,
1470
+ )
1334
1471
 
1335
- tile_omega_batch = jnp.tile(omega_batch, reps=(nt, 1))
1472
+ tile_omega_batch = jnp.tile(omega_batch, reps=(nt, 1))
1336
1473
 
1337
- omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
1338
- times_batch_ = rep_times(n) # it is repeated
1474
+ omega_batch_ = jnp.tile(omega_batch, reps=(nt, 1)) # it is tiled
1475
+ times_batch_ = rep_times(n) # it is repeated
1339
1476
 
1340
- mse_dyn_loss += jnp.mean(
1341
- self._loss_weights["dyn_loss"][i]
1342
- * jnp.mean(
1343
- v_dyn_loss(times_batch_, omega_batch_, params_dict) ** 2, axis=0
1477
+ mse_dyn_loss += jnp.mean(
1478
+ self._loss_weights["dyn_loss"][i]
1479
+ * jnp.sum(
1480
+ v_dyn_loss(times_batch_, omega_batch_, params_dict) ** 2,
1481
+ axis=-1,
1482
+ )
1483
+ )
1484
+ elif isinstance(list(self.u_dict.values())[0], SPINN):
1485
+ residuals = self.dynamic_loss_dict[i].evaluate(
1486
+ times_batch, omega_batch, self.u_dict, params_dict
1487
+ )
1488
+ mse_dyn_loss += jnp.mean(
1489
+ self._loss_weights["dyn_loss"][i]
1490
+ * jnp.sum(
1491
+ residuals**2,
1492
+ axis=-1,
1493
+ )
1344
1494
  )
1345
- )
1346
1495
 
1347
1496
  # boundary conditions, normalization conditions, observation_loss,
1348
1497
  # initial condition... loss this is done via the internal
jinns/loss/__init__.py CHANGED
@@ -4,12 +4,6 @@ from ._DynamicLoss import (
4
4
  Malthus,
5
5
  BurgerEquation,
6
6
  GeneralizedLotkaVolterra,
7
- OU_FPEStatioLoss1D,
8
- CIR_FPEStatioLoss1D,
9
- OU_FPEStatioLoss2D,
10
- OU_FPENonStatioLoss1D,
11
- CIR_FPENonStatioLoss1D,
12
- Sinus_FPENonStatioLoss1D,
13
7
  OU_FPENonStatioLoss2D,
14
8
  ConvectionDiffusionNonStatio,
15
9
  MassConservation2DStatio,