jinns 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +103 -65
- jinns/loss/_LossPDE.py +145 -77
- jinns/loss/_abstract_loss.py +64 -6
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +4 -1
- jinns/solver/_utils.py +17 -11
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/METADATA +14 -4
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/RECORD +28 -28
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.7.0.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -122,7 +122,7 @@ class _LossPDEAbstract(
|
|
|
122
122
|
made broadcastable to `norm_samples`.
|
|
123
123
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
124
124
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
125
|
-
obs_slice : EllipsisType | slice, default=None
|
|
125
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
126
126
|
slice object specifying the begininning/ending of the PINN output
|
|
127
127
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
128
128
|
key : Key | None
|
|
@@ -136,7 +136,7 @@ class _LossPDEAbstract(
|
|
|
136
136
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
137
137
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
138
138
|
u: eqx.AbstractVar[AbstractPINN]
|
|
139
|
-
dynamic_loss: eqx.AbstractVar[Y]
|
|
139
|
+
dynamic_loss: tuple[eqx.AbstractVar[Y] | None, ...]
|
|
140
140
|
omega_boundary_fun: (
|
|
141
141
|
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
142
142
|
) = eqx.field(static=True)
|
|
@@ -146,7 +146,7 @@ class _LossPDEAbstract(
|
|
|
146
146
|
omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
|
|
147
147
|
norm_samples: Float[Array, " nb_norm_samples dimension"] | None
|
|
148
148
|
norm_weights: Float[Array, " nb_norm_samples"] | None
|
|
149
|
-
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
149
|
+
obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
|
|
150
150
|
key: PRNGKeyArray | None
|
|
151
151
|
|
|
152
152
|
def __init__(
|
|
@@ -159,7 +159,10 @@ class _LossPDEAbstract(
|
|
|
159
159
|
omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
|
|
160
160
|
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
|
|
161
161
|
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
|
|
162
|
-
obs_slice: EllipsisType | slice
|
|
162
|
+
obs_slice: tuple[EllipsisType | slice, ...]
|
|
163
|
+
| EllipsisType
|
|
164
|
+
| slice
|
|
165
|
+
| None = None,
|
|
163
166
|
key: PRNGKeyArray | None = None,
|
|
164
167
|
derivative_keys: DKPDE,
|
|
165
168
|
**kwargs: Any, # for arguments for super()
|
|
@@ -181,7 +184,9 @@ class _LossPDEAbstract(
|
|
|
181
184
|
)
|
|
182
185
|
|
|
183
186
|
if obs_slice is None:
|
|
184
|
-
self.obs_slice = jnp.s_[...]
|
|
187
|
+
self.obs_slice = (jnp.s_[...],)
|
|
188
|
+
elif not isinstance(obs_slice, tuple):
|
|
189
|
+
self.obs_slice = (obs_slice,)
|
|
185
190
|
else:
|
|
186
191
|
self.obs_slice = obs_slice
|
|
187
192
|
|
|
@@ -321,16 +326,23 @@ class _LossPDEAbstract(
|
|
|
321
326
|
|
|
322
327
|
def _get_dyn_loss_fun(
|
|
323
328
|
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
324
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
329
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
330
|
+
# Note, for the record, multiple dynamic losses
|
|
331
|
+
# have been introduced in MR 92
|
|
332
|
+
if self.dynamic_loss != (None,):
|
|
333
|
+
dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
334
|
+
jax.tree.map(
|
|
335
|
+
lambda d: lambda p: dynamic_loss_apply(
|
|
336
|
+
d.evaluate,
|
|
337
|
+
self.u,
|
|
338
|
+
self._get_dynamic_loss_batch(batch),
|
|
339
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
340
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
341
|
+
),
|
|
342
|
+
self.dynamic_loss,
|
|
343
|
+
is_leaf=lambda x: isinstance(
|
|
344
|
+
x, (PDEStatio, PDENonStatio)
|
|
345
|
+
), # do not traverse further than first level
|
|
334
346
|
)
|
|
335
347
|
)
|
|
336
348
|
else:
|
|
@@ -340,9 +352,13 @@ class _LossPDEAbstract(
|
|
|
340
352
|
|
|
341
353
|
def _get_norm_loss_fun(
|
|
342
354
|
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
343
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
355
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
356
|
+
# Note that since MR 92
|
|
357
|
+
# norm_loss_fun is formed as a tuple for
|
|
358
|
+
# consistency with dynamic and observation
|
|
359
|
+
# losses and more modularity for later
|
|
344
360
|
if self.norm_samples is not None:
|
|
345
|
-
norm_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
361
|
+
norm_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
346
362
|
lambda p: normalization_loss_apply(
|
|
347
363
|
self.u,
|
|
348
364
|
cast(
|
|
@@ -351,7 +367,7 @@ class _LossPDEAbstract(
|
|
|
351
367
|
_set_derivatives(p, self.derivative_keys.norm_loss),
|
|
352
368
|
vmap_in_axes_params,
|
|
353
369
|
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
354
|
-
)
|
|
370
|
+
),
|
|
355
371
|
)
|
|
356
372
|
else:
|
|
357
373
|
norm_loss_fun = None
|
|
@@ -359,20 +375,24 @@ class _LossPDEAbstract(
|
|
|
359
375
|
|
|
360
376
|
def _get_boundary_loss_fun(
|
|
361
377
|
self, batch: B
|
|
362
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
378
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
379
|
+
# Note that since MR 92
|
|
380
|
+
# boundary_loss_fun is formed as a tuple for
|
|
381
|
+
# consistency with dynamic and observation
|
|
382
|
+
# losses and more modularity for later
|
|
363
383
|
if (
|
|
364
384
|
self.omega_boundary_condition is not None
|
|
365
385
|
and self.omega_boundary_fun is not None
|
|
366
386
|
):
|
|
367
|
-
boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
387
|
+
boundary_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
368
388
|
lambda p: boundary_condition_apply(
|
|
369
389
|
self.u,
|
|
370
390
|
batch,
|
|
371
391
|
_set_derivatives(p, self.derivative_keys.boundary_loss),
|
|
372
392
|
self.omega_boundary_fun, # type: ignore (we are in lambda)
|
|
373
393
|
self.omega_boundary_condition, # type: ignore
|
|
374
|
-
self.omega_boundary_dim,
|
|
375
|
-
)
|
|
394
|
+
self.omega_boundary_dim,
|
|
395
|
+
),
|
|
376
396
|
)
|
|
377
397
|
else:
|
|
378
398
|
boundary_loss_fun = None
|
|
@@ -384,24 +404,37 @@ class _LossPDEAbstract(
|
|
|
384
404
|
batch: B,
|
|
385
405
|
vmap_in_axes_params: tuple[Params[int | None] | None],
|
|
386
406
|
params: Params[Array],
|
|
387
|
-
) -> tuple[
|
|
407
|
+
) -> tuple[
|
|
408
|
+
Params[Array] | None, tuple[Callable[[Params[Array]], Array], ...] | None
|
|
409
|
+
]:
|
|
410
|
+
# Note, for the record, multiple DGObs
|
|
411
|
+
# (leading to batch.obs_batch_dict being tuple | None)
|
|
412
|
+
# have been introduced in MR 92
|
|
388
413
|
if batch.obs_batch_dict is not None:
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
414
|
+
if len(batch.obs_batch_dict) != len(self.obs_slice):
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"There must be the same number of "
|
|
417
|
+
"observation datasets as the number of "
|
|
418
|
+
"obs_slice"
|
|
419
|
+
)
|
|
420
|
+
params_obs = jax.tree.map(
|
|
421
|
+
lambda d: update_eq_params(params, d["eq_params"]),
|
|
422
|
+
batch.obs_batch_dict,
|
|
423
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
395
424
|
)
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
425
|
+
obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
426
|
+
jax.tree.map(
|
|
427
|
+
lambda d, slice_: lambda po: observations_loss_apply(
|
|
428
|
+
self.u,
|
|
429
|
+
d["pinn_in"],
|
|
430
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
431
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
432
|
+
d["val"],
|
|
433
|
+
slice_,
|
|
434
|
+
),
|
|
435
|
+
batch.obs_batch_dict,
|
|
404
436
|
self.obs_slice,
|
|
437
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
405
438
|
)
|
|
406
439
|
)
|
|
407
440
|
else:
|
|
@@ -434,7 +467,7 @@ class LossPDEStatio(
|
|
|
434
467
|
----------
|
|
435
468
|
u : AbstractPINN
|
|
436
469
|
the PINN
|
|
437
|
-
dynamic_loss : PDEStatio | None
|
|
470
|
+
dynamic_loss : tuple[PDEStatio, ...] | PDEStatio | None
|
|
438
471
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
439
472
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
440
473
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
@@ -497,7 +530,7 @@ class LossPDEStatio(
|
|
|
497
530
|
Alternatively, the user can pass a float or an integer.
|
|
498
531
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
499
532
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
500
|
-
obs_slice : slice, default=None
|
|
533
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
501
534
|
slice object specifying the begininning/ending of the PINN output
|
|
502
535
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
503
536
|
|
|
@@ -512,7 +545,7 @@ class LossPDEStatio(
|
|
|
512
545
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
513
546
|
|
|
514
547
|
u: AbstractPINN
|
|
515
|
-
dynamic_loss: PDEStatio | None
|
|
548
|
+
dynamic_loss: tuple[PDEStatio | None, ...]
|
|
516
549
|
loss_weights: LossWeightsPDEStatio
|
|
517
550
|
derivative_keys: DerivativeKeysPDEStatio
|
|
518
551
|
|
|
@@ -522,7 +555,7 @@ class LossPDEStatio(
|
|
|
522
555
|
self,
|
|
523
556
|
*,
|
|
524
557
|
u: AbstractPINN,
|
|
525
|
-
dynamic_loss: PDEStatio | None,
|
|
558
|
+
dynamic_loss: tuple[PDEStatio, ...] | PDEStatio | None,
|
|
526
559
|
loss_weights: LossWeightsPDEStatio | None = None,
|
|
527
560
|
derivative_keys: DerivativeKeysPDEStatio | None = None,
|
|
528
561
|
params: Params[Array] | None = None,
|
|
@@ -543,15 +576,16 @@ class LossPDEStatio(
|
|
|
543
576
|
"Problem at derivative_keys initialization "
|
|
544
577
|
f"received {derivative_keys=} and {params=}"
|
|
545
578
|
) from exc
|
|
546
|
-
else:
|
|
547
|
-
derivative_keys = derivative_keys
|
|
548
579
|
|
|
549
580
|
super().__init__(
|
|
550
581
|
derivative_keys=derivative_keys,
|
|
551
582
|
vmap_in_axes=(0,),
|
|
552
583
|
**kwargs,
|
|
553
584
|
)
|
|
554
|
-
|
|
585
|
+
if not isinstance(dynamic_loss, tuple):
|
|
586
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
587
|
+
else:
|
|
588
|
+
self.dynamic_loss = dynamic_loss
|
|
555
589
|
|
|
556
590
|
def _get_dynamic_loss_batch(
|
|
557
591
|
self, batch: PDEStatioBatch
|
|
@@ -619,29 +653,44 @@ class LossPDEStatio(
|
|
|
619
653
|
)
|
|
620
654
|
|
|
621
655
|
# get the unweighted mses for each loss term as well as the gradients
|
|
622
|
-
all_funs: PDEStatioComponents[
|
|
656
|
+
all_funs: PDEStatioComponents[
|
|
657
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
658
|
+
] = PDEStatioComponents(
|
|
659
|
+
dyn_loss_fun,
|
|
660
|
+
norm_loss_fun,
|
|
661
|
+
boundary_loss_fun,
|
|
662
|
+
obs_loss_fun,
|
|
663
|
+
)
|
|
664
|
+
all_params: PDEStatioComponents[tuple[Params[Array], ...] | None] = (
|
|
623
665
|
PDEStatioComponents(
|
|
624
|
-
|
|
666
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
667
|
+
jax.tree.map(lambda l: params, norm_loss_fun),
|
|
668
|
+
jax.tree.map(lambda l: params, boundary_loss_fun),
|
|
669
|
+
params_obs,
|
|
625
670
|
)
|
|
626
671
|
)
|
|
627
|
-
all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
|
|
628
|
-
params, params, params, params_obs
|
|
629
|
-
)
|
|
630
672
|
mses_grads = jax.tree.map(
|
|
631
673
|
self.get_gradients,
|
|
632
674
|
all_funs,
|
|
633
675
|
all_params,
|
|
634
676
|
is_leaf=lambda x: x is None,
|
|
635
677
|
)
|
|
678
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
679
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
680
|
+
# the tuple of dyn_loss is of length 2)
|
|
636
681
|
mses = jax.tree.map(
|
|
637
|
-
lambda leaf: leaf[0],
|
|
682
|
+
lambda leaf: leaf[0],
|
|
638
683
|
mses_grads,
|
|
639
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
684
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
685
|
+
and len(x) == 2
|
|
686
|
+
and isinstance(x[1], Params),
|
|
640
687
|
)
|
|
641
688
|
grads = jax.tree.map(
|
|
642
|
-
lambda leaf: leaf[1],
|
|
689
|
+
lambda leaf: leaf[1],
|
|
643
690
|
mses_grads,
|
|
644
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
691
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
692
|
+
and len(x) == 2
|
|
693
|
+
and isinstance(x[1], Params),
|
|
645
694
|
)
|
|
646
695
|
|
|
647
696
|
return mses, grads
|
|
@@ -673,7 +722,7 @@ class LossPDENonStatio(
|
|
|
673
722
|
----------
|
|
674
723
|
u : AbstractPINN
|
|
675
724
|
the PINN
|
|
676
|
-
dynamic_loss : PDENonStatio
|
|
725
|
+
dynamic_loss : tuple[PDENonStatio, ...] | PDENonStatio | None
|
|
677
726
|
the non stationary PDE dynamic part of the loss, basically the differential
|
|
678
727
|
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
679
728
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
@@ -750,14 +799,13 @@ class LossPDENonStatio(
|
|
|
750
799
|
Alternatively, the user can pass a float or an integer.
|
|
751
800
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
752
801
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
753
|
-
obs_slice : slice, default=None
|
|
802
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
754
803
|
slice object specifying the begininning/ending of the PINN output
|
|
755
804
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
756
|
-
|
|
757
805
|
"""
|
|
758
806
|
|
|
759
807
|
u: AbstractPINN
|
|
760
|
-
dynamic_loss: PDENonStatio | None
|
|
808
|
+
dynamic_loss: tuple[PDENonStatio | None, ...]
|
|
761
809
|
loss_weights: LossWeightsPDENonStatio
|
|
762
810
|
derivative_keys: DerivativeKeysPDENonStatio
|
|
763
811
|
params: InitVar[Params[Array] | None]
|
|
@@ -774,7 +822,7 @@ class LossPDENonStatio(
|
|
|
774
822
|
self,
|
|
775
823
|
*,
|
|
776
824
|
u: AbstractPINN,
|
|
777
|
-
dynamic_loss: PDENonStatio | None,
|
|
825
|
+
dynamic_loss: tuple[PDENonStatio, ...] | PDENonStatio | None,
|
|
778
826
|
loss_weights: LossWeightsPDENonStatio | None = None,
|
|
779
827
|
derivative_keys: DerivativeKeysPDENonStatio | None = None,
|
|
780
828
|
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
|
|
@@ -800,8 +848,6 @@ class LossPDENonStatio(
|
|
|
800
848
|
"Problem at derivative_keys initialization "
|
|
801
849
|
f"received {derivative_keys=} and {params=}"
|
|
802
850
|
) from exc
|
|
803
|
-
else:
|
|
804
|
-
derivative_keys = derivative_keys
|
|
805
851
|
|
|
806
852
|
super().__init__(
|
|
807
853
|
derivative_keys=derivative_keys,
|
|
@@ -809,7 +855,10 @@ class LossPDENonStatio(
|
|
|
809
855
|
**kwargs,
|
|
810
856
|
)
|
|
811
857
|
|
|
812
|
-
|
|
858
|
+
if not isinstance(dynamic_loss, tuple):
|
|
859
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
860
|
+
else:
|
|
861
|
+
self.dynamic_loss = dynamic_loss
|
|
813
862
|
|
|
814
863
|
if initial_condition_fun is None:
|
|
815
864
|
warnings.warn(
|
|
@@ -904,7 +953,13 @@ class LossPDENonStatio(
|
|
|
904
953
|
|
|
905
954
|
# initial condition
|
|
906
955
|
if self.initial_condition_fun is not None:
|
|
907
|
-
|
|
956
|
+
# Note that since MR 92
|
|
957
|
+
# initial_condition_fun is formed as a tuple for
|
|
958
|
+
# consistency with dynamic and observation
|
|
959
|
+
# losses and more modularity for later
|
|
960
|
+
initial_condition_fun: (
|
|
961
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
962
|
+
) = (
|
|
908
963
|
lambda p: initial_condition_apply(
|
|
909
964
|
self.u,
|
|
910
965
|
omega_initial_batch,
|
|
@@ -912,39 +967,52 @@ class LossPDENonStatio(
|
|
|
912
967
|
(0,) + vmap_in_axes_params,
|
|
913
968
|
self.initial_condition_fun, # type: ignore
|
|
914
969
|
self.t0,
|
|
915
|
-
)
|
|
970
|
+
),
|
|
916
971
|
)
|
|
917
972
|
else:
|
|
918
|
-
|
|
973
|
+
initial_condition_fun = None
|
|
919
974
|
|
|
920
975
|
# get the unweighted mses for each loss term as well as the gradients
|
|
921
|
-
all_funs: PDENonStatioComponents[
|
|
976
|
+
all_funs: PDENonStatioComponents[
|
|
977
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
978
|
+
] = PDENonStatioComponents(
|
|
979
|
+
dyn_loss_fun,
|
|
980
|
+
norm_loss_fun,
|
|
981
|
+
boundary_loss_fun,
|
|
982
|
+
obs_loss_fun,
|
|
983
|
+
initial_condition_fun,
|
|
984
|
+
)
|
|
985
|
+
all_params: PDENonStatioComponents[tuple[Params[Array], ...] | None] = (
|
|
922
986
|
PDENonStatioComponents(
|
|
923
|
-
dyn_loss_fun,
|
|
924
|
-
norm_loss_fun,
|
|
925
|
-
boundary_loss_fun,
|
|
926
|
-
|
|
927
|
-
|
|
987
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
988
|
+
jax.tree.map(lambda l: params, norm_loss_fun),
|
|
989
|
+
jax.tree.map(lambda l: params, boundary_loss_fun),
|
|
990
|
+
params_obs,
|
|
991
|
+
jax.tree.map(lambda l: params, initial_condition_fun),
|
|
928
992
|
)
|
|
929
993
|
)
|
|
930
|
-
all_params: PDENonStatioComponents[Params[Array] | None] = (
|
|
931
|
-
PDENonStatioComponents(params, params, params, params_obs, params)
|
|
932
|
-
)
|
|
933
994
|
mses_grads = jax.tree.map(
|
|
934
995
|
self.get_gradients,
|
|
935
996
|
all_funs,
|
|
936
997
|
all_params,
|
|
937
998
|
is_leaf=lambda x: x is None,
|
|
938
999
|
)
|
|
1000
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
1001
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
1002
|
+
# the tuple of dyn_loss is of length 2)
|
|
939
1003
|
mses = jax.tree.map(
|
|
940
|
-
lambda leaf: leaf[0],
|
|
1004
|
+
lambda leaf: leaf[0],
|
|
941
1005
|
mses_grads,
|
|
942
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
1006
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
1007
|
+
and len(x) == 2
|
|
1008
|
+
and isinstance(x[1], Params),
|
|
943
1009
|
)
|
|
944
1010
|
grads = jax.tree.map(
|
|
945
|
-
lambda leaf: leaf[1],
|
|
1011
|
+
lambda leaf: leaf[1],
|
|
946
1012
|
mses_grads,
|
|
947
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
1013
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
1014
|
+
and len(x) == 2
|
|
1015
|
+
and isinstance(x[1], Params),
|
|
948
1016
|
)
|
|
949
1017
|
|
|
950
1018
|
return mses, grads
|
jinns/loss/_abstract_loss.py
CHANGED
|
@@ -1,14 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
|
-
|
|
5
|
-
from
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Self, Literal, Callable, TypeVar, Generic, Any, get_args
|
|
6
|
+
from dataclasses import InitVar
|
|
7
|
+
from jaxtyping import Array, PyTree, Float, PRNGKeyArray
|
|
6
8
|
import equinox as eqx
|
|
7
9
|
import jax
|
|
8
10
|
import jax.numpy as jnp
|
|
9
11
|
import optax
|
|
10
12
|
from jinns.parameters._params import Params
|
|
11
|
-
from jinns.loss._loss_weight_updates import
|
|
13
|
+
from jinns.loss._loss_weight_updates import (
|
|
14
|
+
soft_adapt,
|
|
15
|
+
lr_annealing,
|
|
16
|
+
ReLoBRaLo,
|
|
17
|
+
prior_loss,
|
|
18
|
+
)
|
|
12
19
|
from jinns.utils._types import (
|
|
13
20
|
AnyLossComponents,
|
|
14
21
|
AnyBatch,
|
|
@@ -38,6 +45,11 @@ DK = TypeVar("DK", bound=AnyDerivativeKeys)
|
|
|
38
45
|
# the return types of evaluate_by_terms for example!
|
|
39
46
|
|
|
40
47
|
|
|
48
|
+
AvailableUpdateWeightMethods = Literal[
|
|
49
|
+
"softadapt", "soft_adapt", "prior_loss", "lr_annealing", "ReLoBRaLo"
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
41
53
|
class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
42
54
|
"""
|
|
43
55
|
About the call:
|
|
@@ -46,10 +58,43 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
|
46
58
|
|
|
47
59
|
derivative_keys: eqx.AbstractVar[DK]
|
|
48
60
|
loss_weights: eqx.AbstractVar[L]
|
|
49
|
-
|
|
50
|
-
|
|
61
|
+
loss_weight_scales: L = eqx.field(init=False)
|
|
62
|
+
update_weight_method: AvailableUpdateWeightMethods | None = eqx.field(
|
|
63
|
+
kw_only=True, default=None, static=True
|
|
51
64
|
)
|
|
52
65
|
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
66
|
+
keep_initial_loss_weight_scales: InitVar[bool] = eqx.field(
|
|
67
|
+
default=True, kw_only=True
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
*,
|
|
73
|
+
loss_weights,
|
|
74
|
+
derivative_keys,
|
|
75
|
+
vmap_in_axes,
|
|
76
|
+
update_weight_method=None,
|
|
77
|
+
keep_initial_loss_weight_scales: bool = False,
|
|
78
|
+
):
|
|
79
|
+
if update_weight_method is not None and update_weight_method not in get_args(
|
|
80
|
+
AvailableUpdateWeightMethods
|
|
81
|
+
):
|
|
82
|
+
raise ValueError(f"{update_weight_method=} is not a valid method")
|
|
83
|
+
self.update_weight_method = update_weight_method
|
|
84
|
+
self.loss_weights = loss_weights
|
|
85
|
+
self.derivative_keys = derivative_keys
|
|
86
|
+
self.vmap_in_axes = vmap_in_axes
|
|
87
|
+
if keep_initial_loss_weight_scales:
|
|
88
|
+
self.loss_weight_scales = self.loss_weights
|
|
89
|
+
if self.update_weight_method is not None:
|
|
90
|
+
warnings.warn(
|
|
91
|
+
"Loss weights out from update_weight_method will still be"
|
|
92
|
+
" multiplied by the initial input loss_weights"
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
self.loss_weight_scales = optax.tree_utils.tree_ones_like(self.loss_weights)
|
|
96
|
+
# self.loss_weight_scales will contain None where self.loss_weights
|
|
97
|
+
# has None
|
|
53
98
|
|
|
54
99
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
55
100
|
return self.evaluate(*args, **kwargs)
|
|
@@ -127,7 +172,11 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
|
127
172
|
raise ValueError(
|
|
128
173
|
"The numbers of declared loss weights and "
|
|
129
174
|
"declared loss terms do not concord "
|
|
130
|
-
f" got {len(weights)} and {len(terms_list)}"
|
|
175
|
+
f" got {len(weights)} and {len(terms_list)}. "
|
|
176
|
+
"If you passed tuple of dyn_loss, make sure to pass "
|
|
177
|
+
"tuple of loss weights at LossWeights.dyn_loss."
|
|
178
|
+
"If you passed tuple of obs datasets, make sure to pass "
|
|
179
|
+
"tuple of loss weights at LossWeights.observations."
|
|
131
180
|
)
|
|
132
181
|
|
|
133
182
|
def ponderate_and_sum_gradient(self, terms: C) -> Params[Array | None]:
|
|
@@ -171,6 +220,8 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
|
171
220
|
new_weights = soft_adapt(
|
|
172
221
|
self.loss_weights, iteration_nb, loss_terms, stored_loss_terms
|
|
173
222
|
)
|
|
223
|
+
elif self.update_weight_method == "prior_loss":
|
|
224
|
+
new_weights = prior_loss(self.loss_weights, iteration_nb, stored_loss_terms)
|
|
174
225
|
elif self.update_weight_method == "lr_annealing":
|
|
175
226
|
new_weights = lr_annealing(self.loss_weights, grad_terms)
|
|
176
227
|
elif self.update_weight_method == "ReLoBRaLo":
|
|
@@ -183,6 +234,13 @@ class AbstractLoss(eqx.Module, Generic[L, B, C, DK]):
|
|
|
183
234
|
# Below we update the non None entry in the PyTree self.loss_weights
|
|
184
235
|
# we directly get the non None entries because None is not treated as a
|
|
185
236
|
# leaf
|
|
237
|
+
|
|
238
|
+
new_weights = jax.lax.cond(
|
|
239
|
+
iteration_nb == 0,
|
|
240
|
+
lambda nw: nw,
|
|
241
|
+
lambda nw: jnp.array(jax.tree.leaves(self.loss_weight_scales)) * nw,
|
|
242
|
+
new_weights,
|
|
243
|
+
)
|
|
186
244
|
return eqx.tree_at(
|
|
187
245
|
lambda pt: jax.tree.leaves(pt.loss_weights), self, new_weights
|
|
188
246
|
)
|
|
@@ -227,7 +227,7 @@ def boundary_neumann(
|
|
|
227
227
|
if isinstance(u, PINN):
|
|
228
228
|
u_ = lambda inputs, params: jnp.squeeze(u(inputs, params)[dim_to_apply])
|
|
229
229
|
|
|
230
|
-
if u.eq_type == "
|
|
230
|
+
if u.eq_type == "PDEStatio":
|
|
231
231
|
v_neumann = vmap(
|
|
232
232
|
lambda inputs, params: _subtract_with_check(
|
|
233
233
|
f(inputs),
|
|
@@ -240,7 +240,7 @@ def boundary_neumann(
|
|
|
240
240
|
vmap_in_axes,
|
|
241
241
|
0,
|
|
242
242
|
)
|
|
243
|
-
elif u.eq_type == "
|
|
243
|
+
elif u.eq_type == "PDENonStatio":
|
|
244
244
|
v_neumann = vmap(
|
|
245
245
|
lambda inputs, params: _subtract_with_check(
|
|
246
246
|
f(inputs),
|
|
@@ -274,14 +274,14 @@ def boundary_neumann(
|
|
|
274
274
|
if (batch_array.shape[0] == 1 and isinstance(batch, PDEStatioBatch)) or (
|
|
275
275
|
batch_array.shape[-1] == 2 and isinstance(batch, PDENonStatioBatch)
|
|
276
276
|
):
|
|
277
|
-
if u.eq_type == "
|
|
277
|
+
if u.eq_type == "PDEStatio":
|
|
278
278
|
_, du_dx = jax.jvp(
|
|
279
279
|
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
280
280
|
(batch_array,),
|
|
281
281
|
(jnp.ones_like(batch_array),),
|
|
282
282
|
)
|
|
283
283
|
values = du_dx * n[facet]
|
|
284
|
-
if u.eq_type == "
|
|
284
|
+
if u.eq_type == "PDENonStatio":
|
|
285
285
|
_, du_dx = jax.jvp(
|
|
286
286
|
lambda inputs: u(inputs, params)[..., dim_to_apply],
|
|
287
287
|
(batch_array,),
|
|
@@ -291,7 +291,7 @@ def boundary_neumann(
|
|
|
291
291
|
elif (batch_array.shape[-1] == 2 and isinstance(batch, PDEStatioBatch)) or (
|
|
292
292
|
batch_array.shape[-1] == 3 and isinstance(batch, PDENonStatioBatch)
|
|
293
293
|
):
|
|
294
|
-
if u.eq_type == "
|
|
294
|
+
if u.eq_type == "PDEStatio":
|
|
295
295
|
tangent_vec_0 = jnp.repeat(
|
|
296
296
|
jnp.array([1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
297
297
|
)
|
|
@@ -309,7 +309,7 @@ def boundary_neumann(
|
|
|
309
309
|
(tangent_vec_1,),
|
|
310
310
|
)
|
|
311
311
|
values = du_dx1 * n[0, facet] + du_dx2 * n[1, facet] # dot product
|
|
312
|
-
if u.eq_type == "
|
|
312
|
+
if u.eq_type == "PDENonStatio":
|
|
313
313
|
tangent_vec_0 = jnp.repeat(
|
|
314
314
|
jnp.array([0.0, 1.0, 0.0])[None], batch_array.shape[0], axis=0
|
|
315
315
|
)
|
jinns/loss/_loss_utils.py
CHANGED
|
@@ -26,12 +26,12 @@ from jinns.data._Batchs import PDEStatioBatch, PDENonStatioBatch
|
|
|
26
26
|
from jinns.parameters._params import Params
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from jinns.utils._types import BoundaryConditionFun
|
|
29
|
+
from jinns.utils._types import BoundaryConditionFun, AnyBatch
|
|
30
30
|
from jinns.nn._abstract_pinn import AbstractPINN
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def dynamic_loss_apply(
|
|
34
|
-
dyn_loss: Callable,
|
|
34
|
+
dyn_loss: Callable[[AnyBatch, AbstractPINN, Params[Array]], Array],
|
|
35
35
|
u: AbstractPINN,
|
|
36
36
|
batch: (
|
|
37
37
|
Float[Array, " batch_size 1"]
|
|
@@ -13,6 +13,36 @@ if TYPE_CHECKING:
|
|
|
13
13
|
from jinns.utils._types import AnyLossComponents, AnyLossWeights
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
def prior_loss(
|
|
17
|
+
loss_weights: AnyLossWeights,
|
|
18
|
+
iteration_nb: int,
|
|
19
|
+
stored_loss_terms: AnyLossComponents,
|
|
20
|
+
) -> Array:
|
|
21
|
+
"""
|
|
22
|
+
Simple adaptative weights according to the prior loss idea:
|
|
23
|
+
the ponderation in front of a loss term is given by the inverse of the
|
|
24
|
+
value of that loss term at the previous iteration
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def do_nothing(loss_weights, _):
|
|
28
|
+
return jnp.array(
|
|
29
|
+
jax.tree.leaves(loss_weights, is_leaf=eqx.is_inexact_array), dtype=float
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def _prior_loss(_, stored_loss_terms):
|
|
33
|
+
new_weights = jax.tree.map(
|
|
34
|
+
lambda slt: 1 / (slt[iteration_nb - 1] + 1e-6), stored_loss_terms
|
|
35
|
+
)
|
|
36
|
+
return jnp.array(jax.tree.leaves(new_weights), dtype=float)
|
|
37
|
+
|
|
38
|
+
return jax.lax.cond(
|
|
39
|
+
iteration_nb == 0,
|
|
40
|
+
lambda op: do_nothing(*op),
|
|
41
|
+
lambda op: _prior_loss(*op),
|
|
42
|
+
(loss_weights, stored_loss_terms),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
16
46
|
def soft_adapt(
|
|
17
47
|
loss_weights: AnyLossWeights,
|
|
18
48
|
iteration_nb: int,
|
jinns/loss/_loss_weights.py
CHANGED
|
@@ -18,6 +18,10 @@ from jinns.loss._loss_components import (
|
|
|
18
18
|
def lw_converter(x: Array | None) -> Array | None:
|
|
19
19
|
if x is None:
|
|
20
20
|
return x
|
|
21
|
+
elif isinstance(x, tuple):
|
|
22
|
+
# user might input tuple of scalar loss weights to account for cases
|
|
23
|
+
# when dyn loss is also a tuple of (possibly 1D) dyn_loss
|
|
24
|
+
return tuple(jnp.asarray(x_) for x_ in x)
|
|
21
25
|
else:
|
|
22
26
|
return jnp.asarray(x)
|
|
23
27
|
|