jinns 1.6.1__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +2 -1
- jinns/data/_Batchs.py +4 -4
- jinns/data/_DataGeneratorODE.py +1 -1
- jinns/data/_DataGeneratorObservations.py +498 -90
- jinns/loss/_DynamicLossAbstract.py +3 -1
- jinns/loss/_LossODE.py +138 -73
- jinns/loss/_LossPDE.py +208 -104
- jinns/loss/_abstract_loss.py +97 -14
- jinns/loss/_boundary_conditions.py +6 -6
- jinns/loss/_loss_utils.py +2 -2
- jinns/loss/_loss_weight_updates.py +30 -0
- jinns/loss/_loss_weights.py +4 -0
- jinns/loss/_operators.py +27 -27
- jinns/nn/_abstract_pinn.py +1 -1
- jinns/nn/_hyperpinn.py +6 -6
- jinns/nn/_mlp.py +3 -3
- jinns/nn/_pinn.py +7 -7
- jinns/nn/_ppinn.py +6 -6
- jinns/nn/_spinn.py +4 -4
- jinns/nn/_spinn_mlp.py +7 -7
- jinns/parameters/_derivative_keys.py +13 -6
- jinns/parameters/_params.py +10 -0
- jinns/solver/_rar.py +19 -9
- jinns/solver/_solve.py +102 -367
- jinns/solver/_solve_alternate.py +885 -0
- jinns/solver/_utils.py +520 -11
- jinns/utils/_DictToModuleMeta.py +3 -1
- jinns/utils/_containers.py +8 -4
- jinns/utils/_types.py +42 -1
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/METADATA +26 -14
- jinns-1.7.1.dist-info/RECORD +58 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/WHEEL +1 -1
- jinns-1.6.1.dist-info/RECORD +0 -57
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.6.1.dist-info → jinns-1.7.1.dist-info}/top_level.txt +0 -0
jinns/loss/_LossPDE.py
CHANGED
|
@@ -68,11 +68,14 @@ B = TypeVar("B", bound=PDEStatioBatch | PDENonStatioBatch)
|
|
|
68
68
|
C = TypeVar(
|
|
69
69
|
"C", bound=PDEStatioComponents[Array | None] | PDENonStatioComponents[Array | None]
|
|
70
70
|
)
|
|
71
|
-
|
|
71
|
+
DKPDE = TypeVar("DKPDE", bound=DerivativeKeysPDEStatio | DerivativeKeysPDENonStatio)
|
|
72
72
|
Y = TypeVar("Y", bound=PDEStatio | PDENonStatio | None)
|
|
73
73
|
|
|
74
74
|
|
|
75
|
-
class _LossPDEAbstract(
|
|
75
|
+
class _LossPDEAbstract(
|
|
76
|
+
AbstractLoss[L, B, C, DKPDE],
|
|
77
|
+
Generic[L, B, C, DKPDE, Y],
|
|
78
|
+
):
|
|
76
79
|
r"""
|
|
77
80
|
Parameters
|
|
78
81
|
----------
|
|
@@ -119,7 +122,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
119
122
|
made broadcastable to `norm_samples`.
|
|
120
123
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
121
124
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
122
|
-
obs_slice : EllipsisType | slice, default=None
|
|
125
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
123
126
|
slice object specifying the begininning/ending of the PINN output
|
|
124
127
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
125
128
|
key : Key | None
|
|
@@ -133,7 +136,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
133
136
|
# NOTE static=True only for leaf attributes that are not valid JAX types
|
|
134
137
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
135
138
|
u: eqx.AbstractVar[AbstractPINN]
|
|
136
|
-
dynamic_loss: eqx.AbstractVar[Y]
|
|
139
|
+
dynamic_loss: tuple[eqx.AbstractVar[Y] | None, ...]
|
|
137
140
|
omega_boundary_fun: (
|
|
138
141
|
BoundaryConditionFun | dict[str, BoundaryConditionFun] | None
|
|
139
142
|
) = eqx.field(static=True)
|
|
@@ -143,7 +146,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
143
146
|
omega_boundary_dim: slice | dict[str, slice] = eqx.field(static=True)
|
|
144
147
|
norm_samples: Float[Array, " nb_norm_samples dimension"] | None
|
|
145
148
|
norm_weights: Float[Array, " nb_norm_samples"] | None
|
|
146
|
-
obs_slice: EllipsisType | slice = eqx.field(static=True)
|
|
149
|
+
obs_slice: tuple[EllipsisType | slice, ...] = eqx.field(static=True)
|
|
147
150
|
key: PRNGKeyArray | None
|
|
148
151
|
|
|
149
152
|
def __init__(
|
|
@@ -156,14 +159,34 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
156
159
|
omega_boundary_dim: int | slice | dict[str, int | slice] | None = None,
|
|
157
160
|
norm_samples: Float[Array, " nb_norm_samples dimension"] | None = None,
|
|
158
161
|
norm_weights: Float[Array, " nb_norm_samples"] | float | int | None = None,
|
|
159
|
-
obs_slice: EllipsisType | slice
|
|
162
|
+
obs_slice: tuple[EllipsisType | slice, ...]
|
|
163
|
+
| EllipsisType
|
|
164
|
+
| slice
|
|
165
|
+
| None = None,
|
|
160
166
|
key: PRNGKeyArray | None = None,
|
|
167
|
+
derivative_keys: DKPDE,
|
|
161
168
|
**kwargs: Any, # for arguments for super()
|
|
162
169
|
):
|
|
163
|
-
super().__init__(
|
|
170
|
+
super().__init__(
|
|
171
|
+
loss_weights=self.loss_weights,
|
|
172
|
+
derivative_keys=derivative_keys,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if self.update_weight_method is not None and jnp.any(
|
|
177
|
+
jnp.array(jax.tree.leaves(self.loss_weights)) == 0
|
|
178
|
+
):
|
|
179
|
+
warnings.warn(
|
|
180
|
+
"self.update_weight_method is activated while some loss "
|
|
181
|
+
"weights are zero. The update weight method will likely "
|
|
182
|
+
"update the zero weight to some non-zero value. Check that "
|
|
183
|
+
"this is the desired behaviour."
|
|
184
|
+
)
|
|
164
185
|
|
|
165
186
|
if obs_slice is None:
|
|
166
|
-
self.obs_slice = jnp.s_[...]
|
|
187
|
+
self.obs_slice = (jnp.s_[...],)
|
|
188
|
+
elif not isinstance(obs_slice, tuple):
|
|
189
|
+
self.obs_slice = (obs_slice,)
|
|
167
190
|
else:
|
|
168
191
|
self.obs_slice = obs_slice
|
|
169
192
|
|
|
@@ -303,16 +326,23 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
303
326
|
|
|
304
327
|
def _get_dyn_loss_fun(
|
|
305
328
|
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
306
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
329
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
330
|
+
# Note, for the record, multiple dynamic losses
|
|
331
|
+
# have been introduced in MR 92
|
|
332
|
+
if self.dynamic_loss != (None,):
|
|
333
|
+
dyn_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
334
|
+
jax.tree.map(
|
|
335
|
+
lambda d: lambda p: dynamic_loss_apply(
|
|
336
|
+
d.evaluate,
|
|
337
|
+
self.u,
|
|
338
|
+
self._get_dynamic_loss_batch(batch),
|
|
339
|
+
_set_derivatives(p, self.derivative_keys.dyn_loss),
|
|
340
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
341
|
+
),
|
|
342
|
+
self.dynamic_loss,
|
|
343
|
+
is_leaf=lambda x: isinstance(
|
|
344
|
+
x, (PDEStatio, PDENonStatio)
|
|
345
|
+
), # do not traverse further than first level
|
|
316
346
|
)
|
|
317
347
|
)
|
|
318
348
|
else:
|
|
@@ -322,9 +352,13 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
322
352
|
|
|
323
353
|
def _get_norm_loss_fun(
|
|
324
354
|
self, batch: B, vmap_in_axes_params: tuple[Params[int | None] | None]
|
|
325
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
355
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
356
|
+
# Note that since MR 92
|
|
357
|
+
# norm_loss_fun is formed as a tuple for
|
|
358
|
+
# consistency with dynamic and observation
|
|
359
|
+
# losses and more modularity for later
|
|
326
360
|
if self.norm_samples is not None:
|
|
327
|
-
norm_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
361
|
+
norm_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
328
362
|
lambda p: normalization_loss_apply(
|
|
329
363
|
self.u,
|
|
330
364
|
cast(
|
|
@@ -333,7 +367,7 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
333
367
|
_set_derivatives(p, self.derivative_keys.norm_loss),
|
|
334
368
|
vmap_in_axes_params,
|
|
335
369
|
self.norm_weights, # type: ignore -> can't get the __post_init__ narrowing here
|
|
336
|
-
)
|
|
370
|
+
),
|
|
337
371
|
)
|
|
338
372
|
else:
|
|
339
373
|
norm_loss_fun = None
|
|
@@ -341,20 +375,24 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
341
375
|
|
|
342
376
|
def _get_boundary_loss_fun(
|
|
343
377
|
self, batch: B
|
|
344
|
-
) -> Callable[[Params[Array]], Array] | None:
|
|
378
|
+
) -> tuple[Callable[[Params[Array]], Array], ...] | None:
|
|
379
|
+
# Note that since MR 92
|
|
380
|
+
# boundary_loss_fun is formed as a tuple for
|
|
381
|
+
# consistency with dynamic and observation
|
|
382
|
+
# losses and more modularity for later
|
|
345
383
|
if (
|
|
346
384
|
self.omega_boundary_condition is not None
|
|
347
385
|
and self.omega_boundary_fun is not None
|
|
348
386
|
):
|
|
349
|
-
boundary_loss_fun: Callable[[Params[Array]], Array] | None = (
|
|
387
|
+
boundary_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
350
388
|
lambda p: boundary_condition_apply(
|
|
351
389
|
self.u,
|
|
352
390
|
batch,
|
|
353
391
|
_set_derivatives(p, self.derivative_keys.boundary_loss),
|
|
354
392
|
self.omega_boundary_fun, # type: ignore (we are in lambda)
|
|
355
393
|
self.omega_boundary_condition, # type: ignore
|
|
356
|
-
self.omega_boundary_dim,
|
|
357
|
-
)
|
|
394
|
+
self.omega_boundary_dim,
|
|
395
|
+
),
|
|
358
396
|
)
|
|
359
397
|
else:
|
|
360
398
|
boundary_loss_fun = None
|
|
@@ -366,24 +404,37 @@ class _LossPDEAbstract(AbstractLoss[L, B, C], Generic[L, B, C, D, Y]):
|
|
|
366
404
|
batch: B,
|
|
367
405
|
vmap_in_axes_params: tuple[Params[int | None] | None],
|
|
368
406
|
params: Params[Array],
|
|
369
|
-
) -> tuple[
|
|
407
|
+
) -> tuple[
|
|
408
|
+
Params[Array] | None, tuple[Callable[[Params[Array]], Array], ...] | None
|
|
409
|
+
]:
|
|
410
|
+
# Note, for the record, multiple DGObs
|
|
411
|
+
# (leading to batch.obs_batch_dict being tuple | None)
|
|
412
|
+
# have been introduced in MR 92
|
|
370
413
|
if batch.obs_batch_dict is not None:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
414
|
+
if len(batch.obs_batch_dict) != len(self.obs_slice):
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"There must be the same number of "
|
|
417
|
+
"observation datasets as the number of "
|
|
418
|
+
"obs_slice"
|
|
419
|
+
)
|
|
420
|
+
params_obs = jax.tree.map(
|
|
421
|
+
lambda d: update_eq_params(params, d["eq_params"]),
|
|
422
|
+
batch.obs_batch_dict,
|
|
423
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
377
424
|
)
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
425
|
+
obs_loss_fun: tuple[Callable[[Params[Array]], Array], ...] | None = (
|
|
426
|
+
jax.tree.map(
|
|
427
|
+
lambda d, slice_: lambda po: observations_loss_apply(
|
|
428
|
+
self.u,
|
|
429
|
+
d["pinn_in"],
|
|
430
|
+
_set_derivatives(po, self.derivative_keys.observations),
|
|
431
|
+
self.vmap_in_axes + vmap_in_axes_params,
|
|
432
|
+
d["val"],
|
|
433
|
+
slice_,
|
|
434
|
+
),
|
|
435
|
+
batch.obs_batch_dict,
|
|
386
436
|
self.obs_slice,
|
|
437
|
+
is_leaf=lambda x: isinstance(x, dict),
|
|
387
438
|
)
|
|
388
439
|
)
|
|
389
440
|
else:
|
|
@@ -416,7 +467,7 @@ class LossPDEStatio(
|
|
|
416
467
|
----------
|
|
417
468
|
u : AbstractPINN
|
|
418
469
|
the PINN
|
|
419
|
-
dynamic_loss : PDEStatio | None
|
|
470
|
+
dynamic_loss : tuple[PDEStatio, ...] | PDEStatio | None
|
|
420
471
|
the stationary PDE dynamic part of the loss, basically the differential
|
|
421
472
|
operator $\mathcal{N}[u](x)$. Should implement a method
|
|
422
473
|
`dynamic_loss.evaluate(x, u, params)`.
|
|
@@ -479,7 +530,7 @@ class LossPDEStatio(
|
|
|
479
530
|
Alternatively, the user can pass a float or an integer.
|
|
480
531
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
481
532
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
482
|
-
obs_slice : slice, default=None
|
|
533
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
483
534
|
slice object specifying the begininning/ending of the PINN output
|
|
484
535
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
485
536
|
|
|
@@ -494,10 +545,9 @@ class LossPDEStatio(
|
|
|
494
545
|
# (ie. jax.Array cannot be static) and that we do not expect to change
|
|
495
546
|
|
|
496
547
|
u: AbstractPINN
|
|
497
|
-
dynamic_loss: PDEStatio | None
|
|
548
|
+
dynamic_loss: tuple[PDEStatio | None, ...]
|
|
498
549
|
loss_weights: LossWeightsPDEStatio
|
|
499
550
|
derivative_keys: DerivativeKeysPDEStatio
|
|
500
|
-
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
501
551
|
|
|
502
552
|
params: InitVar[Params[Array] | None]
|
|
503
553
|
|
|
@@ -505,7 +555,7 @@ class LossPDEStatio(
|
|
|
505
555
|
self,
|
|
506
556
|
*,
|
|
507
557
|
u: AbstractPINN,
|
|
508
|
-
dynamic_loss: PDEStatio | None,
|
|
558
|
+
dynamic_loss: tuple[PDEStatio, ...] | PDEStatio | None,
|
|
509
559
|
loss_weights: LossWeightsPDEStatio | None = None,
|
|
510
560
|
derivative_keys: DerivativeKeysPDEStatio | None = None,
|
|
511
561
|
params: Params[Array] | None = None,
|
|
@@ -516,25 +566,26 @@ class LossPDEStatio(
|
|
|
516
566
|
self.loss_weights = LossWeightsPDEStatio()
|
|
517
567
|
else:
|
|
518
568
|
self.loss_weights = loss_weights
|
|
519
|
-
self.dynamic_loss = dynamic_loss
|
|
520
|
-
|
|
521
|
-
super().__init__(
|
|
522
|
-
**kwargs,
|
|
523
|
-
)
|
|
524
569
|
|
|
525
570
|
if derivative_keys is None:
|
|
526
571
|
# be default we only take gradient wrt nn_params
|
|
527
572
|
try:
|
|
528
|
-
|
|
573
|
+
derivative_keys = DerivativeKeysPDEStatio(params=params)
|
|
529
574
|
except ValueError as exc:
|
|
530
575
|
raise ValueError(
|
|
531
576
|
"Problem at derivative_keys initialization "
|
|
532
577
|
f"received {derivative_keys=} and {params=}"
|
|
533
578
|
) from exc
|
|
534
|
-
else:
|
|
535
|
-
self.derivative_keys = derivative_keys
|
|
536
579
|
|
|
537
|
-
|
|
580
|
+
super().__init__(
|
|
581
|
+
derivative_keys=derivative_keys,
|
|
582
|
+
vmap_in_axes=(0,),
|
|
583
|
+
**kwargs,
|
|
584
|
+
)
|
|
585
|
+
if not isinstance(dynamic_loss, tuple):
|
|
586
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
587
|
+
else:
|
|
588
|
+
self.dynamic_loss = dynamic_loss
|
|
538
589
|
|
|
539
590
|
def _get_dynamic_loss_batch(
|
|
540
591
|
self, batch: PDEStatioBatch
|
|
@@ -549,11 +600,12 @@ class LossPDEStatio(
|
|
|
549
600
|
# we could have used typing.cast though
|
|
550
601
|
|
|
551
602
|
def evaluate_by_terms(
|
|
552
|
-
self,
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
603
|
+
self,
|
|
604
|
+
opt_params: Params[Array],
|
|
605
|
+
batch: PDEStatioBatch,
|
|
606
|
+
*,
|
|
607
|
+
non_opt_params: Params[Array] | None = None,
|
|
608
|
+
) -> tuple[PDEStatioComponents[Array | None], PDEStatioComponents[Array | None]]:
|
|
557
609
|
"""
|
|
558
610
|
Evaluate the loss function at a batch of points for given parameters.
|
|
559
611
|
|
|
@@ -561,15 +613,22 @@ class LossPDEStatio(
|
|
|
561
613
|
|
|
562
614
|
Parameters
|
|
563
615
|
---------
|
|
564
|
-
|
|
565
|
-
Parameters at which the loss is evaluated
|
|
616
|
+
opt_params
|
|
617
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
566
618
|
batch
|
|
567
619
|
Composed of a batch of points in the
|
|
568
620
|
domain, a batch of points in the domain
|
|
569
621
|
border and an optional additional batch of parameters (eg. for
|
|
570
622
|
metamodeling) and an optional additional batch of observed
|
|
571
623
|
inputs/outputs/parameters
|
|
624
|
+
non_opt_params
|
|
625
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
572
626
|
"""
|
|
627
|
+
if non_opt_params is not None:
|
|
628
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
629
|
+
else:
|
|
630
|
+
params = opt_params
|
|
631
|
+
|
|
573
632
|
# Retrieve the optional eq_params_batch
|
|
574
633
|
# and update eq_params with the latter
|
|
575
634
|
# and update vmap_in_axes
|
|
@@ -594,29 +653,44 @@ class LossPDEStatio(
|
|
|
594
653
|
)
|
|
595
654
|
|
|
596
655
|
# get the unweighted mses for each loss term as well as the gradients
|
|
597
|
-
all_funs: PDEStatioComponents[
|
|
656
|
+
all_funs: PDEStatioComponents[
|
|
657
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
658
|
+
] = PDEStatioComponents(
|
|
659
|
+
dyn_loss_fun,
|
|
660
|
+
norm_loss_fun,
|
|
661
|
+
boundary_loss_fun,
|
|
662
|
+
obs_loss_fun,
|
|
663
|
+
)
|
|
664
|
+
all_params: PDEStatioComponents[tuple[Params[Array], ...] | None] = (
|
|
598
665
|
PDEStatioComponents(
|
|
599
|
-
|
|
666
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
667
|
+
jax.tree.map(lambda l: params, norm_loss_fun),
|
|
668
|
+
jax.tree.map(lambda l: params, boundary_loss_fun),
|
|
669
|
+
params_obs,
|
|
600
670
|
)
|
|
601
671
|
)
|
|
602
|
-
all_params: PDEStatioComponents[Params[Array] | None] = PDEStatioComponents(
|
|
603
|
-
params, params, params, params_obs
|
|
604
|
-
)
|
|
605
672
|
mses_grads = jax.tree.map(
|
|
606
673
|
self.get_gradients,
|
|
607
674
|
all_funs,
|
|
608
675
|
all_params,
|
|
609
676
|
is_leaf=lambda x: x is None,
|
|
610
677
|
)
|
|
678
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
679
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
680
|
+
# the tuple of dyn_loss is of length 2)
|
|
611
681
|
mses = jax.tree.map(
|
|
612
|
-
lambda leaf: leaf[0],
|
|
682
|
+
lambda leaf: leaf[0],
|
|
613
683
|
mses_grads,
|
|
614
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
684
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
685
|
+
and len(x) == 2
|
|
686
|
+
and isinstance(x[1], Params),
|
|
615
687
|
)
|
|
616
688
|
grads = jax.tree.map(
|
|
617
|
-
lambda leaf: leaf[1],
|
|
689
|
+
lambda leaf: leaf[1],
|
|
618
690
|
mses_grads,
|
|
619
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
691
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
692
|
+
and len(x) == 2
|
|
693
|
+
and isinstance(x[1], Params),
|
|
620
694
|
)
|
|
621
695
|
|
|
622
696
|
return mses, grads
|
|
@@ -648,7 +722,7 @@ class LossPDENonStatio(
|
|
|
648
722
|
----------
|
|
649
723
|
u : AbstractPINN
|
|
650
724
|
the PINN
|
|
651
|
-
dynamic_loss : PDENonStatio
|
|
725
|
+
dynamic_loss : tuple[PDENonStatio, ...] | PDENonStatio | None
|
|
652
726
|
the non stationary PDE dynamic part of the loss, basically the differential
|
|
653
727
|
operator $\mathcal{N}[u](t, x)$. Should implement a method
|
|
654
728
|
`dynamic_loss.evaluate(t, x, u, params)`.
|
|
@@ -725,14 +799,13 @@ class LossPDENonStatio(
|
|
|
725
799
|
Alternatively, the user can pass a float or an integer.
|
|
726
800
|
These corresponds to the weights $w_k = \frac{1}{q(x_k)}$ where
|
|
727
801
|
$q(\cdot)$ is the proposal p.d.f. and $x_k$ are the Monte-Carlo samples.
|
|
728
|
-
obs_slice : slice, default=None
|
|
802
|
+
obs_slice : tuple[EllipsisType | slice, ...] | EllipsisType | slice | None, default=None
|
|
729
803
|
slice object specifying the begininning/ending of the PINN output
|
|
730
804
|
that is observed (this is then useful for multidim PINN). Default is None.
|
|
731
|
-
|
|
732
805
|
"""
|
|
733
806
|
|
|
734
807
|
u: AbstractPINN
|
|
735
|
-
dynamic_loss: PDENonStatio | None
|
|
808
|
+
dynamic_loss: tuple[PDENonStatio | None, ...]
|
|
736
809
|
loss_weights: LossWeightsPDENonStatio
|
|
737
810
|
derivative_keys: DerivativeKeysPDENonStatio
|
|
738
811
|
params: InitVar[Params[Array] | None]
|
|
@@ -740,7 +813,6 @@ class LossPDENonStatio(
|
|
|
740
813
|
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array] | None = (
|
|
741
814
|
eqx.field(static=True)
|
|
742
815
|
)
|
|
743
|
-
vmap_in_axes: tuple[int] = eqx.field(static=True)
|
|
744
816
|
max_norm_samples_omega: int = eqx.field(static=True)
|
|
745
817
|
max_norm_time_slices: int = eqx.field(static=True)
|
|
746
818
|
|
|
@@ -750,7 +822,7 @@ class LossPDENonStatio(
|
|
|
750
822
|
self,
|
|
751
823
|
*,
|
|
752
824
|
u: AbstractPINN,
|
|
753
|
-
dynamic_loss: PDENonStatio | None,
|
|
825
|
+
dynamic_loss: tuple[PDENonStatio, ...] | PDENonStatio | None,
|
|
754
826
|
loss_weights: LossWeightsPDENonStatio | None = None,
|
|
755
827
|
derivative_keys: DerivativeKeysPDENonStatio | None = None,
|
|
756
828
|
initial_condition_fun: Callable[[Float[Array, " dimension"]], Array]
|
|
@@ -766,25 +838,27 @@ class LossPDENonStatio(
|
|
|
766
838
|
self.loss_weights = LossWeightsPDENonStatio()
|
|
767
839
|
else:
|
|
768
840
|
self.loss_weights = loss_weights
|
|
769
|
-
self.dynamic_loss = dynamic_loss
|
|
770
|
-
|
|
771
|
-
super().__init__(
|
|
772
|
-
**kwargs,
|
|
773
|
-
)
|
|
774
841
|
|
|
775
842
|
if derivative_keys is None:
|
|
776
843
|
# be default we only take gradient wrt nn_params
|
|
777
844
|
try:
|
|
778
|
-
|
|
845
|
+
derivative_keys = DerivativeKeysPDENonStatio(params=params)
|
|
779
846
|
except ValueError as exc:
|
|
780
847
|
raise ValueError(
|
|
781
848
|
"Problem at derivative_keys initialization "
|
|
782
849
|
f"received {derivative_keys=} and {params=}"
|
|
783
850
|
) from exc
|
|
784
|
-
else:
|
|
785
|
-
self.derivative_keys = derivative_keys
|
|
786
851
|
|
|
787
|
-
|
|
852
|
+
super().__init__(
|
|
853
|
+
derivative_keys=derivative_keys,
|
|
854
|
+
vmap_in_axes=(0,), # for t_x
|
|
855
|
+
**kwargs,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
if not isinstance(dynamic_loss, tuple):
|
|
859
|
+
self.dynamic_loss = (dynamic_loss,)
|
|
860
|
+
else:
|
|
861
|
+
self.dynamic_loss = dynamic_loss
|
|
788
862
|
|
|
789
863
|
if initial_condition_fun is None:
|
|
790
864
|
warnings.warn(
|
|
@@ -820,7 +894,11 @@ class LossPDENonStatio(
|
|
|
820
894
|
)
|
|
821
895
|
|
|
822
896
|
def evaluate_by_terms(
|
|
823
|
-
self,
|
|
897
|
+
self,
|
|
898
|
+
opt_params: Params[Array],
|
|
899
|
+
batch: PDENonStatioBatch,
|
|
900
|
+
*,
|
|
901
|
+
non_opt_params: Params[Array] | None = None,
|
|
824
902
|
) -> tuple[
|
|
825
903
|
PDENonStatioComponents[Array | None], PDENonStatioComponents[Array | None]
|
|
826
904
|
]:
|
|
@@ -831,15 +909,22 @@ class LossPDENonStatio(
|
|
|
831
909
|
|
|
832
910
|
Parameters
|
|
833
911
|
---------
|
|
834
|
-
|
|
835
|
-
Parameters at which the loss is evaluated
|
|
912
|
+
opt_params
|
|
913
|
+
Parameters, which are optimized, at which the loss is evaluated
|
|
836
914
|
batch
|
|
837
915
|
Composed of a batch of points in the
|
|
838
916
|
domain, a batch of points in the domain
|
|
839
917
|
border and an optional additional batch of parameters (eg. for
|
|
840
918
|
metamodeling) and an optional additional batch of observed
|
|
841
919
|
inputs/outputs/parameters
|
|
920
|
+
non_opt_params
|
|
921
|
+
Parameters, which are non optimized, at which the loss is evaluated
|
|
842
922
|
"""
|
|
923
|
+
if non_opt_params is not None:
|
|
924
|
+
params = eqx.combine(opt_params, non_opt_params)
|
|
925
|
+
else:
|
|
926
|
+
params = opt_params
|
|
927
|
+
|
|
843
928
|
omega_initial_batch = batch.initial_batch
|
|
844
929
|
assert omega_initial_batch is not None
|
|
845
930
|
|
|
@@ -868,7 +953,13 @@ class LossPDENonStatio(
|
|
|
868
953
|
|
|
869
954
|
# initial condition
|
|
870
955
|
if self.initial_condition_fun is not None:
|
|
871
|
-
|
|
956
|
+
# Note that since MR 92
|
|
957
|
+
# initial_condition_fun is formed as a tuple for
|
|
958
|
+
# consistency with dynamic and observation
|
|
959
|
+
# losses and more modularity for later
|
|
960
|
+
initial_condition_fun: (
|
|
961
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
962
|
+
) = (
|
|
872
963
|
lambda p: initial_condition_apply(
|
|
873
964
|
self.u,
|
|
874
965
|
omega_initial_batch,
|
|
@@ -876,39 +967,52 @@ class LossPDENonStatio(
|
|
|
876
967
|
(0,) + vmap_in_axes_params,
|
|
877
968
|
self.initial_condition_fun, # type: ignore
|
|
878
969
|
self.t0,
|
|
879
|
-
)
|
|
970
|
+
),
|
|
880
971
|
)
|
|
881
972
|
else:
|
|
882
|
-
|
|
973
|
+
initial_condition_fun = None
|
|
883
974
|
|
|
884
975
|
# get the unweighted mses for each loss term as well as the gradients
|
|
885
|
-
all_funs: PDENonStatioComponents[
|
|
976
|
+
all_funs: PDENonStatioComponents[
|
|
977
|
+
tuple[Callable[[Params[Array]], Array], ...] | None
|
|
978
|
+
] = PDENonStatioComponents(
|
|
979
|
+
dyn_loss_fun,
|
|
980
|
+
norm_loss_fun,
|
|
981
|
+
boundary_loss_fun,
|
|
982
|
+
obs_loss_fun,
|
|
983
|
+
initial_condition_fun,
|
|
984
|
+
)
|
|
985
|
+
all_params: PDENonStatioComponents[tuple[Params[Array], ...] | None] = (
|
|
886
986
|
PDENonStatioComponents(
|
|
887
|
-
dyn_loss_fun,
|
|
888
|
-
norm_loss_fun,
|
|
889
|
-
boundary_loss_fun,
|
|
890
|
-
|
|
891
|
-
|
|
987
|
+
jax.tree.map(lambda l: params, dyn_loss_fun),
|
|
988
|
+
jax.tree.map(lambda l: params, norm_loss_fun),
|
|
989
|
+
jax.tree.map(lambda l: params, boundary_loss_fun),
|
|
990
|
+
params_obs,
|
|
991
|
+
jax.tree.map(lambda l: params, initial_condition_fun),
|
|
892
992
|
)
|
|
893
993
|
)
|
|
894
|
-
all_params: PDENonStatioComponents[Params[Array] | None] = (
|
|
895
|
-
PDENonStatioComponents(params, params, params, params_obs, params)
|
|
896
|
-
)
|
|
897
994
|
mses_grads = jax.tree.map(
|
|
898
995
|
self.get_gradients,
|
|
899
996
|
all_funs,
|
|
900
997
|
all_params,
|
|
901
998
|
is_leaf=lambda x: x is None,
|
|
902
999
|
)
|
|
1000
|
+
# NOTE the is_leaf below is more complex since it must pass possible the tuple
|
|
1001
|
+
# of dyn_loss and then stops (but also account it should not stop when
|
|
1002
|
+
# the tuple of dyn_loss is of length 2)
|
|
903
1003
|
mses = jax.tree.map(
|
|
904
|
-
lambda leaf: leaf[0],
|
|
1004
|
+
lambda leaf: leaf[0],
|
|
905
1005
|
mses_grads,
|
|
906
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
1006
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
1007
|
+
and len(x) == 2
|
|
1008
|
+
and isinstance(x[1], Params),
|
|
907
1009
|
)
|
|
908
1010
|
grads = jax.tree.map(
|
|
909
|
-
lambda leaf: leaf[1],
|
|
1011
|
+
lambda leaf: leaf[1],
|
|
910
1012
|
mses_grads,
|
|
911
|
-
is_leaf=lambda x: isinstance(x, tuple)
|
|
1013
|
+
is_leaf=lambda x: isinstance(x, tuple)
|
|
1014
|
+
and len(x) == 2
|
|
1015
|
+
and isinstance(x[1], Params),
|
|
912
1016
|
)
|
|
913
1017
|
|
|
914
1018
|
return mses, grads
|