jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py
CHANGED
|
@@ -8,56 +8,71 @@ from __future__ import (
|
|
|
8
8
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
9
9
|
|
|
10
10
|
import time
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeAlias, Callable
|
|
12
12
|
from functools import partial
|
|
13
13
|
import optax
|
|
14
14
|
import jax
|
|
15
15
|
from jax import jit
|
|
16
16
|
import jax.numpy as jnp
|
|
17
|
-
from jaxtyping import
|
|
17
|
+
from jaxtyping import Float, Array
|
|
18
18
|
from jinns.solver._rar import init_rar, trigger_rar
|
|
19
19
|
from jinns.utils._utils import _check_nan_in_pytree
|
|
20
20
|
from jinns.solver._utils import _check_batch_size
|
|
21
|
-
from jinns.utils._containers import
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
append_param_batch,
|
|
21
|
+
from jinns.utils._containers import (
|
|
22
|
+
DataGeneratorContainer,
|
|
23
|
+
OptimizationContainer,
|
|
24
|
+
OptimizationExtraContainer,
|
|
25
|
+
LossContainer,
|
|
26
|
+
StoredObjectContainer,
|
|
28
27
|
)
|
|
28
|
+
from jinns.data._utils import append_param_batch, append_obs_batch
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
|
-
from jinns.
|
|
31
|
+
from jinns.parameters._params import Params
|
|
32
|
+
from jinns.utils._types import AnyLoss, AnyBatch
|
|
33
|
+
from jinns.validation._validation import AbstractValidationModule
|
|
34
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
35
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
36
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
37
|
+
|
|
38
|
+
main_carry: TypeAlias = tuple[
|
|
39
|
+
int,
|
|
40
|
+
AnyLoss,
|
|
41
|
+
OptimizationContainer,
|
|
42
|
+
OptimizationExtraContainer,
|
|
43
|
+
DataGeneratorContainer,
|
|
44
|
+
AbstractValidationModule | None,
|
|
45
|
+
LossContainer,
|
|
46
|
+
StoredObjectContainer,
|
|
47
|
+
Float[Array, " n_iter"] | None,
|
|
48
|
+
]
|
|
32
49
|
|
|
33
50
|
|
|
34
51
|
def solve(
|
|
35
|
-
n_iter:
|
|
36
|
-
init_params:
|
|
37
|
-
data:
|
|
52
|
+
n_iter: int,
|
|
53
|
+
init_params: Params[Array],
|
|
54
|
+
data: AbstractDataGenerator,
|
|
38
55
|
loss: AnyLoss,
|
|
39
56
|
optimizer: optax.GradientTransformation,
|
|
40
|
-
print_loss_every:
|
|
41
|
-
opt_state:
|
|
42
|
-
tracked_params: Params |
|
|
57
|
+
print_loss_every: int = 1000,
|
|
58
|
+
opt_state: optax.OptState | None = None,
|
|
59
|
+
tracked_params: Params[Any | None] | None = None,
|
|
43
60
|
param_data: DataGeneratorParameter | None = None,
|
|
44
|
-
obs_data:
|
|
45
|
-
DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
|
|
46
|
-
) = None,
|
|
61
|
+
obs_data: DataGeneratorObservations | None = None,
|
|
47
62
|
validation: AbstractValidationModule | None = None,
|
|
48
63
|
obs_batch_sharding: jax.sharding.Sharding | None = None,
|
|
49
|
-
verbose:
|
|
50
|
-
ahead_of_time:
|
|
64
|
+
verbose: bool = True,
|
|
65
|
+
ahead_of_time: bool = True,
|
|
51
66
|
) -> tuple[
|
|
52
|
-
Params
|
|
53
|
-
Float[Array, "n_iter"],
|
|
54
|
-
|
|
55
|
-
|
|
67
|
+
Params[Array],
|
|
68
|
+
Float[Array, " n_iter"],
|
|
69
|
+
dict[str, Float[Array, " n_iter"]],
|
|
70
|
+
AbstractDataGenerator,
|
|
56
71
|
AnyLoss,
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Float[Array, "n_iter"],
|
|
60
|
-
|
|
72
|
+
optax.OptState,
|
|
73
|
+
Params[Array | None],
|
|
74
|
+
Float[Array, " n_iter"] | None,
|
|
75
|
+
Params[Array],
|
|
61
76
|
]:
|
|
62
77
|
"""
|
|
63
78
|
Performs the optimization process via stochastic gradient descent
|
|
@@ -94,8 +109,7 @@ def solve(
|
|
|
94
109
|
Default None. A DataGeneratorParameter object which can be used to
|
|
95
110
|
sample equation parameters.
|
|
96
111
|
obs_data
|
|
97
|
-
Default None. A DataGeneratorObservations
|
|
98
|
-
DataGeneratorObservationsMultiPINNs
|
|
112
|
+
Default None. A DataGeneratorObservations
|
|
99
113
|
object which can be used to sample minibatches of observations.
|
|
100
114
|
validation
|
|
101
115
|
Default None. Otherwise, a callable ``eqx.Module`` which implements a
|
|
@@ -171,11 +185,21 @@ def solve(
|
|
|
171
185
|
_check_batch_size(obs_data, param_data, "n")
|
|
172
186
|
|
|
173
187
|
if opt_state is None:
|
|
174
|
-
opt_state = optimizer.init(init_params)
|
|
188
|
+
opt_state = optimizer.init(init_params) # type: ignore
|
|
189
|
+
# our Params are eqx.Module (dataclass + PyTree), PyTree is
|
|
190
|
+
# compatible with optax transform but not dataclass, this leads to a
|
|
191
|
+
# type hint error: we could prevent this by ensuring with the eqx.filter that
|
|
192
|
+
# we have only floating points optimizable params given to optax
|
|
193
|
+
# see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
|
|
194
|
+
# opt_state = optimizer.init(
|
|
195
|
+
# eqx.filter(init_params, eqx.is_inexact_array)
|
|
196
|
+
# )
|
|
197
|
+
# but this seems like a hack and there is no better way
|
|
198
|
+
# https://github.com/google-deepmind/optax/issues/384
|
|
175
199
|
|
|
176
200
|
# RAR sampling init (ouside scanned function to avoid dynamic slice error)
|
|
177
201
|
# If RAR is not used the _rar_step_*() are juste None and data is unchanged
|
|
178
|
-
data, _rar_step_true, _rar_step_false = init_rar(data)
|
|
202
|
+
data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
|
|
179
203
|
|
|
180
204
|
# Seq2seq
|
|
181
205
|
curr_seq = 0
|
|
@@ -292,7 +316,7 @@ def solve(
|
|
|
292
316
|
if verbose:
|
|
293
317
|
_print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
|
|
294
318
|
|
|
295
|
-
if validation is not None:
|
|
319
|
+
if validation is not None and validation_crit_values is not None:
|
|
296
320
|
# there is a jax.lax.cond because we do not necesarily call the
|
|
297
321
|
# validation step every iteration
|
|
298
322
|
(
|
|
@@ -306,7 +330,7 @@ def solve(
|
|
|
306
330
|
lambda operands: (
|
|
307
331
|
operands[0],
|
|
308
332
|
False,
|
|
309
|
-
validation_crit_values[i - 1],
|
|
333
|
+
validation_crit_values[i - 1], # type: ignore don't know why it can still be None
|
|
310
334
|
False,
|
|
311
335
|
),
|
|
312
336
|
(
|
|
@@ -431,7 +455,7 @@ def solve(
|
|
|
431
455
|
# get ready to return the parameters at last iteration...
|
|
432
456
|
# (by default arbitrary choice, this could be None)
|
|
433
457
|
validation_parameters = optimization.last_non_nan_params
|
|
434
|
-
if validation is not None:
|
|
458
|
+
if validation is not None and validation_crit_values is not None:
|
|
435
459
|
jax.debug.print(
|
|
436
460
|
"validation loss value = {validation_loss_val}",
|
|
437
461
|
validation_loss_val=validation_crit_values[i - 1],
|
|
@@ -466,24 +490,28 @@ def _gradient_step(
|
|
|
466
490
|
loss: AnyLoss,
|
|
467
491
|
optimizer: optax.GradientTransformation,
|
|
468
492
|
batch: AnyBatch,
|
|
469
|
-
params:
|
|
470
|
-
opt_state:
|
|
471
|
-
last_non_nan_params:
|
|
493
|
+
params: Params[Array],
|
|
494
|
+
opt_state: optax.OptState,
|
|
495
|
+
last_non_nan_params: Params[Array],
|
|
472
496
|
) -> tuple[
|
|
473
497
|
AnyLoss,
|
|
474
498
|
float,
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
499
|
+
dict[str, float],
|
|
500
|
+
Params[Array],
|
|
501
|
+
optax.OptState,
|
|
502
|
+
Params[Array],
|
|
479
503
|
]:
|
|
480
504
|
"""
|
|
481
505
|
optimizer cannot be jit-ted.
|
|
482
506
|
"""
|
|
483
507
|
value_grad_loss = jax.value_and_grad(loss, has_aux=True)
|
|
484
508
|
(loss_val, loss_terms), grads = value_grad_loss(params, batch)
|
|
485
|
-
updates, opt_state = optimizer.update(
|
|
486
|
-
|
|
509
|
+
updates, opt_state = optimizer.update(
|
|
510
|
+
grads,
|
|
511
|
+
opt_state,
|
|
512
|
+
params, # type: ignore
|
|
513
|
+
) # see optimizer.init for explaination
|
|
514
|
+
params = optax.apply_updates(params, updates) # type: ignore
|
|
487
515
|
|
|
488
516
|
# check if any of the parameters is NaN
|
|
489
517
|
last_non_nan_params = jax.lax.cond(
|
|
@@ -504,7 +532,7 @@ def _gradient_step(
|
|
|
504
532
|
|
|
505
533
|
|
|
506
534
|
@partial(jit, static_argnames=["prefix"])
|
|
507
|
-
def _print_fn(i:
|
|
535
|
+
def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
|
|
508
536
|
# note that if the following is not jitted in the main lor loop, it is
|
|
509
537
|
# super slow
|
|
510
538
|
_ = jax.lax.cond(
|
|
@@ -521,17 +549,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
|
|
|
521
549
|
|
|
522
550
|
@jit
|
|
523
551
|
def _store_loss_and_params(
|
|
524
|
-
i:
|
|
525
|
-
params:
|
|
526
|
-
stored_params:
|
|
527
|
-
stored_loss_terms:
|
|
528
|
-
train_loss_values: Float[Array, "n_iter"],
|
|
552
|
+
i: int,
|
|
553
|
+
params: Params[Array],
|
|
554
|
+
stored_params: Params[Array],
|
|
555
|
+
stored_loss_terms: dict[str, Float[Array, " n_iter"]],
|
|
556
|
+
train_loss_values: Float[Array, " n_iter"],
|
|
529
557
|
train_loss_val: float,
|
|
530
|
-
loss_terms:
|
|
531
|
-
tracked_params:
|
|
532
|
-
) -> tuple[
|
|
533
|
-
Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
|
|
534
|
-
]:
|
|
558
|
+
loss_terms: dict[str, float],
|
|
559
|
+
tracked_params: Params,
|
|
560
|
+
) -> tuple[Params, dict[str, Float[Array, " n_iter"]], Float[Array, " n_iter"]]:
|
|
535
561
|
stored_params = jax.tree_util.tree_map(
|
|
536
562
|
lambda stored_value, param, tracked_param: (
|
|
537
563
|
None
|
|
@@ -558,7 +584,7 @@ def _store_loss_and_params(
|
|
|
558
584
|
return (stored_params, stored_loss_terms, train_loss_values)
|
|
559
585
|
|
|
560
586
|
|
|
561
|
-
def _get_break_fun(n_iter:
|
|
587
|
+
def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
|
|
562
588
|
"""
|
|
563
589
|
Wrapper to get the break_fun with appropriate `n_iter`.
|
|
564
590
|
The verbose argument is here to control printing (or not) when exiting
|
|
@@ -599,7 +625,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
599
625
|
bool_nan_in_params = jax.lax.cond(
|
|
600
626
|
_check_nan_in_pytree(optimization.params),
|
|
601
627
|
lambda _: stop_while_loop(
|
|
602
|
-
"NaN values in parameters
|
|
628
|
+
"NaN values in parameters (returning last non NaN values)"
|
|
603
629
|
),
|
|
604
630
|
continue_while_loop,
|
|
605
631
|
None,
|
|
@@ -622,18 +648,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
|
|
|
622
648
|
|
|
623
649
|
|
|
624
650
|
def _get_get_batch(
|
|
625
|
-
obs_batch_sharding: jax.sharding.Sharding,
|
|
651
|
+
obs_batch_sharding: jax.sharding.Sharding | None,
|
|
626
652
|
) -> Callable[
|
|
627
653
|
[
|
|
628
|
-
|
|
654
|
+
AbstractDataGenerator,
|
|
629
655
|
DataGeneratorParameter | None,
|
|
630
|
-
DataGeneratorObservations |
|
|
656
|
+
DataGeneratorObservations | None,
|
|
631
657
|
],
|
|
632
658
|
tuple[
|
|
633
659
|
AnyBatch,
|
|
634
|
-
|
|
660
|
+
AbstractDataGenerator,
|
|
635
661
|
DataGeneratorParameter | None,
|
|
636
|
-
DataGeneratorObservations |
|
|
662
|
+
DataGeneratorObservations | None,
|
|
637
663
|
],
|
|
638
664
|
]:
|
|
639
665
|
"""
|
jinns/solver/_utils.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
from jinns.data.
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
DataGeneratorParameter,
|
|
6
|
-
)
|
|
1
|
+
from jinns.data._DataGeneratorODE import DataGeneratorODE
|
|
2
|
+
from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
|
|
3
|
+
from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
|
|
4
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
def _check_batch_size(other_data, main_data, attr_name):
|
jinns/utils/__init__.py
CHANGED
jinns/utils/_containers.py
CHANGED
|
@@ -11,23 +11,26 @@ from jaxtyping import PyTree, Array, Float, Bool
|
|
|
11
11
|
from optax import OptState
|
|
12
12
|
import equinox as eqx
|
|
13
13
|
|
|
14
|
+
from jinns.parameters._params import Params
|
|
15
|
+
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
|
-
from jinns.
|
|
17
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
18
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
19
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
20
|
+
from jinns.utils._types import AnyLoss
|
|
16
21
|
|
|
17
22
|
|
|
18
23
|
class DataGeneratorContainer(eqx.Module):
|
|
19
|
-
data:
|
|
24
|
+
data: AbstractDataGenerator
|
|
20
25
|
param_data: DataGeneratorParameter | None = None
|
|
21
|
-
obs_data: DataGeneratorObservations |
|
|
22
|
-
None
|
|
23
|
-
)
|
|
26
|
+
obs_data: DataGeneratorObservations | None = None
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
class ValidationContainer(eqx.Module):
|
|
27
30
|
loss: AnyLoss | None
|
|
28
31
|
data: DataGeneratorContainer
|
|
29
32
|
hyperparams: PyTree = None
|
|
30
|
-
loss_values: Float[Array, "n_iter"] | None = None
|
|
33
|
+
loss_values: Float[Array, " n_iter"] | None = None
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class OptimizationContainer(eqx.Module):
|
|
@@ -45,9 +48,9 @@ class OptimizationExtraContainer(eqx.Module):
|
|
|
45
48
|
|
|
46
49
|
|
|
47
50
|
class LossContainer(eqx.Module):
|
|
48
|
-
stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
|
|
49
|
-
train_loss_values: Float[Array, "n_iter"]
|
|
51
|
+
stored_loss_terms: Dict[str, Float[Array, " n_iter"]]
|
|
52
|
+
train_loss_values: Float[Array, " n_iter"]
|
|
50
53
|
|
|
51
54
|
|
|
52
55
|
class StoredObjectContainer(eqx.Module):
|
|
53
|
-
stored_params:
|
|
56
|
+
stored_params: Params[Array | None]
|
jinns/utils/_types.py
CHANGED
|
@@ -1,65 +1,19 @@
|
|
|
1
|
-
# pragma: exclude file
|
|
2
1
|
from __future__ import (
|
|
3
2
|
annotations,
|
|
4
3
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
5
4
|
|
|
6
|
-
from typing import TypeAlias, TYPE_CHECKING,
|
|
7
|
-
from jaxtyping import
|
|
5
|
+
from typing import TypeAlias, TYPE_CHECKING, Callable
|
|
6
|
+
from jaxtyping import Float, Array
|
|
8
7
|
|
|
9
8
|
if TYPE_CHECKING:
|
|
10
|
-
from jinns.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
SystemLossPDE,
|
|
14
|
-
)
|
|
9
|
+
from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
|
|
10
|
+
from jinns.loss._LossODE import LossODE
|
|
11
|
+
from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
|
|
15
12
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
CubicMeshPDEStatio,
|
|
21
|
-
CubicMeshPDENonStatio,
|
|
22
|
-
DataGeneratorObservations,
|
|
23
|
-
DataGeneratorParameter,
|
|
24
|
-
DataGeneratorObservationsMultiPINNs,
|
|
25
|
-
)
|
|
13
|
+
# Here we define types available for the whole package
|
|
14
|
+
BoundaryConditionFun: TypeAlias = Callable[
|
|
15
|
+
[Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
|
|
16
|
+
]
|
|
26
17
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
from jinns.nn._pinn import PINN
|
|
30
|
-
from jinns.nn._hyperpinn import HyperPINN
|
|
31
|
-
from jinns.nn._spinn_mlp import SPINN
|
|
32
|
-
from jinns.utils._containers import *
|
|
33
|
-
from jinns.validation._validation import AbstractValidationModule
|
|
34
|
-
|
|
35
|
-
AnyLoss: TypeAlias = (
|
|
36
|
-
LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
AnyParams: TypeAlias = Params | ParamsDict
|
|
40
|
-
|
|
41
|
-
AnyDataGenerator: TypeAlias = (
|
|
42
|
-
DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
AnyPINN: TypeAlias = PINN | HyperPINN | SPINN
|
|
46
|
-
|
|
47
|
-
AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
|
|
48
|
-
rar_operands = NewType(
|
|
49
|
-
"rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
main_carry = NewType(
|
|
53
|
-
"main_carry",
|
|
54
|
-
tuple[
|
|
55
|
-
Int,
|
|
56
|
-
AnyLoss,
|
|
57
|
-
OptimizationContainer,
|
|
58
|
-
OptimizationExtraContainer,
|
|
59
|
-
DataGeneratorContainer,
|
|
60
|
-
AbstractValidationModule,
|
|
61
|
-
LossContainer,
|
|
62
|
-
StoredObjectContainer,
|
|
63
|
-
Float[Array, "n_iter"],
|
|
64
|
-
],
|
|
65
|
-
)
|
|
18
|
+
AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
|
|
19
|
+
AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
|
jinns/utils/_utils.py
CHANGED
|
@@ -2,20 +2,13 @@
|
|
|
2
2
|
Implements various utility functions
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from math import prod
|
|
6
5
|
import warnings
|
|
7
6
|
import jax
|
|
8
7
|
import jax.numpy as jnp
|
|
9
|
-
from jaxtyping import PyTree, Array
|
|
8
|
+
from jaxtyping import PyTree, Array, Bool
|
|
10
9
|
|
|
11
|
-
from jinns.data._DataGenerators import (
|
|
12
|
-
DataGeneratorODE,
|
|
13
|
-
CubicMeshPDEStatio,
|
|
14
|
-
CubicMeshPDENonStatio,
|
|
15
|
-
)
|
|
16
10
|
|
|
17
|
-
|
|
18
|
-
def _check_nan_in_pytree(pytree: PyTree) -> bool:
|
|
11
|
+
def _check_nan_in_pytree(pytree: PyTree) -> Bool[Array, " "]:
|
|
19
12
|
"""
|
|
20
13
|
Check if there is a NaN value anywhere is the pytree
|
|
21
14
|
|
|
@@ -55,7 +48,7 @@ def get_grid(in_array: Array) -> Array:
|
|
|
55
48
|
|
|
56
49
|
|
|
57
50
|
def _check_shape_and_type(
|
|
58
|
-
r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
51
|
+
r: Array | int | float, expected_shape: tuple, cause: str = "", binop: str = ""
|
|
59
52
|
) -> Array | float:
|
|
60
53
|
"""
|
|
61
54
|
Ensures float type and correct shapes for broadcasting when performing a
|
|
@@ -90,7 +83,7 @@ def _check_shape_and_type(
|
|
|
90
83
|
|
|
91
84
|
|
|
92
85
|
def _subtract_with_check(
|
|
93
|
-
a: Array | int, b: Array
|
|
86
|
+
a: Array | int | float, b: Array, cause: str = ""
|
|
94
87
|
) -> Array | float:
|
|
95
88
|
a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
|
|
96
89
|
return a - b
|
jinns/validation/__init__.py
CHANGED
jinns/validation/_validation.py
CHANGED
|
@@ -7,19 +7,23 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
|
|
9
9
|
import abc
|
|
10
|
-
from typing import TYPE_CHECKING
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
11
|
import equinox as eqx
|
|
12
12
|
import jax
|
|
13
13
|
import jax.numpy as jnp
|
|
14
|
-
from jaxtyping import Array
|
|
14
|
+
from jaxtyping import Array, Float
|
|
15
15
|
|
|
16
|
-
from jinns.data.
|
|
16
|
+
from jinns.data._utils import (
|
|
17
17
|
append_obs_batch,
|
|
18
18
|
append_param_batch,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
|
-
from jinns.
|
|
22
|
+
from jinns.data._DataGeneratorParameter import DataGeneratorParameter
|
|
23
|
+
from jinns.data._DataGeneratorObservations import DataGeneratorObservations
|
|
24
|
+
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
25
|
+
from jinns.parameters._params import Params
|
|
26
|
+
from jinns.loss._abstract_loss import AbstractLoss
|
|
23
27
|
|
|
24
28
|
# Using eqx Module for the DataClass + Pytree inheritance
|
|
25
29
|
# Abstract class and abstract/final pattern is used
|
|
@@ -40,8 +44,8 @@ class AbstractValidationModule(eqx.Module):
|
|
|
40
44
|
|
|
41
45
|
@abc.abstractmethod
|
|
42
46
|
def __call__(
|
|
43
|
-
self, params: Params
|
|
44
|
-
) -> tuple[
|
|
47
|
+
self, params: Params[Array]
|
|
48
|
+
) -> tuple[AbstractValidationModule, bool, Array, Params[Array]]:
|
|
45
49
|
raise NotImplementedError
|
|
46
50
|
|
|
47
51
|
|
|
@@ -52,24 +56,20 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
52
56
|
for more complicated validation strategy.
|
|
53
57
|
"""
|
|
54
58
|
|
|
55
|
-
loss:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
|
|
59
|
+
loss: AbstractLoss = eqx.field(kw_only=True)
|
|
60
|
+
validation_data: AbstractDataGenerator = eqx.field(kw_only=True)
|
|
61
|
+
validation_param_data: DataGeneratorParameter = eqx.field(
|
|
62
|
+
kw_only=True, default=None
|
|
63
|
+
)
|
|
64
|
+
validation_obs_data: DataGeneratorObservations | None = eqx.field(
|
|
62
65
|
kw_only=True, default=None
|
|
63
66
|
)
|
|
64
|
-
validation_obs_data: Union[
|
|
65
|
-
DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
|
|
66
|
-
] = eqx.field(kw_only=True, default=None)
|
|
67
67
|
call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
|
|
68
68
|
early_stopping: bool = eqx.field(
|
|
69
69
|
kw_only=True, default=True
|
|
70
70
|
) # globally control if early stopping happens
|
|
71
71
|
|
|
72
|
-
patience:
|
|
72
|
+
patience: int = eqx.field(kw_only=True, default=10)
|
|
73
73
|
best_val_loss: Array = eqx.field(
|
|
74
74
|
converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
|
|
75
75
|
)
|
|
@@ -79,10 +79,11 @@ class ValidationLoss(AbstractValidationModule):
|
|
|
79
79
|
)
|
|
80
80
|
|
|
81
81
|
def __call__(
|
|
82
|
-
self, params:
|
|
83
|
-
) -> tuple[
|
|
82
|
+
self, params: Params[Array]
|
|
83
|
+
) -> tuple[ValidationLoss, bool, Float[Array, " "], Params[Array]]:
|
|
84
84
|
# do in-place mutation
|
|
85
85
|
|
|
86
|
+
# pylint / pyright complains below when using the self attributes see: https://github.com/patrick-kidger/equinox/issues/1013
|
|
86
87
|
validation_data, val_batch = self.validation_data.get_batch()
|
|
87
88
|
if self.validation_param_data is not None:
|
|
88
89
|
validation_param_data, param_batch = self.validation_param_data.get_batch()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
|
|
|
25
25
|
Provides-Extra: notebook
|
|
26
26
|
Requires-Dist: jupyter; extra == "notebook"
|
|
27
27
|
Requires-Dist: seaborn; extra == "notebook"
|
|
28
|
+
Dynamic: license-file
|
|
28
29
|
|
|
29
30
|
jinns
|
|
30
31
|
=====
|
|
@@ -99,7 +100,7 @@ Here are the contributors guidelines:
|
|
|
99
100
|
pip install -e .
|
|
100
101
|
```
|
|
101
102
|
|
|
102
|
-
3. Install pre-commit and run it.
|
|
103
|
+
3. Install pre-commit and run it. Our pre-commit hooks consist in `ruff format` and `ruff check`. You can install `ruff` simply by `pip install ruff`. We highly recommend you to check the code type hints with `pyright` even though we currently have no rule concerning type checking in the pipeline.
|
|
103
104
|
|
|
104
105
|
```bash
|
|
105
106
|
pip install pre-commit
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
jinns/__init__.py,sha256=hyh3QKO2cQGK5cmvFYP0MrXb-tK_DM2T9CwLwO-sEX8,500
|
|
2
|
+
jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
|
|
3
|
+
jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
|
|
4
|
+
jinns/data/_CubicMeshPDENonStatio.py,sha256=c_8czJpxSoEvgZ8LDpL2sqtF9dcW4ELNO4juEFMOxog,16400
|
|
5
|
+
jinns/data/_CubicMeshPDEStatio.py,sha256=stZ0Kbb7_VwFmWUSPs0P6a6qRj2Tu67p7sxEfb1Ajks,17865
|
|
6
|
+
jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
|
|
7
|
+
jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
|
|
8
|
+
jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
|
|
9
|
+
jinns/data/__init__.py,sha256=4b4eVsoGHV89m2kGDiAOHsrGialZQ6j5ja575qWwQHs,677
|
|
10
|
+
jinns/data/_utils.py,sha256=XxaLIg_HIgcB7ACBIhTpHbCT1HXKcDaY1NABncAYX1c,5223
|
|
11
|
+
jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4YrE,410
|
|
12
|
+
jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
|
|
13
|
+
jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
|
|
14
|
+
jinns/loss/_DynamicLossAbstract.py,sha256=HIs6TtE9ouvT5H3cBR52UWSkALTgRWRG_kB3s890b2U,11253
|
|
15
|
+
jinns/loss/_LossODE.py,sha256=aCx5vD3CXmhws36gYre1iu_t29MefpTxW54gUxKej_Q,11856
|
|
16
|
+
jinns/loss/_LossPDE.py,sha256=mANXuWjm02bKwfoIn-RH8vxPY-RG3jVErcSrwwK3HzM,33259
|
|
17
|
+
jinns/loss/__init__.py,sha256=qnNRGjl6Tcga1koztMJnJ3eL8XNP0gRbNsXVgq4CkOI,1162
|
|
18
|
+
jinns/loss/_abstract_loss.py,sha256=hf6ohNoSAqzskFyivCiuE2SCqhsV4UWLv65L4V8H3ys,407
|
|
19
|
+
jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
|
|
20
|
+
jinns/loss/_loss_utils.py,sha256=R_jhBHkTwGu41gWnhYRswunxdzetPZ9-Gmkghzorock,11745
|
|
21
|
+
jinns/loss/_loss_weights.py,sha256=5BVZglM7Y3m_8muXcKT898fAC6_RbdLNQ7WWx3lOE9k,1077
|
|
22
|
+
jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
|
|
23
|
+
jinns/nn/__init__.py,sha256=gwE48oqB_FsSIE-hUvCLz0jPaqX350LBxzH6ueFWYk4,456
|
|
24
|
+
jinns/nn/_abstract_pinn.py,sha256=JUFjlV_nyheZw-max_tAUgFh6SspIbD5we_4bn70V6k,671
|
|
25
|
+
jinns/nn/_hyperpinn.py,sha256=hF7HRLMMVBPT9CTQC2DjpDRcQDJCrT9cAj8wfApT_WE,19412
|
|
26
|
+
jinns/nn/_mlp.py,sha256=Xmi-mG6uakN67R2S2UsBazdXIJVaGsD2B6TeJM1QjGY,8881
|
|
27
|
+
jinns/nn/_pinn.py,sha256=4pvgUPQdQaO3cPBuEU7W4UaLV7lodqcR3pVR1sC0ni4,8774
|
|
28
|
+
jinns/nn/_ppinn.py,sha256=LtjGQaLozdA4Kwutn8Pyerbu9yOc0t3_b701yfMb1ac,10392
|
|
29
|
+
jinns/nn/_save_load.py,sha256=UqVy2oBzvIeBy6XB9tb61x3-x8i4dNCXJHC5_-bko-I,7477
|
|
30
|
+
jinns/nn/_spinn.py,sha256=u5YG2FXcrg8p_uS2QFGmWoeFXYLxXnyV2e6BUHpo4xk,4774
|
|
31
|
+
jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
|
|
32
|
+
jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
|
|
33
|
+
jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
|
|
34
|
+
jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
|
|
35
|
+
jinns/parameters/_params.py,sha256=qn4IGMJhD9lDBqOWmGEMy4gXt5a6KHfirkYZwHO7Vwk,2633
|
|
36
|
+
jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
|
|
37
|
+
jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
|
|
38
|
+
jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
39
|
+
jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
|
|
40
|
+
jinns/solver/_solve.py,sha256=uPJsN4Pv_QEHYMlMdo29hlJXmWyCtf2aFZlj2M8Fl2U,24886
|
|
41
|
+
jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
|
|
42
|
+
jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
|
|
43
|
+
jinns/utils/_containers.py,sha256=XANkmGiXvFb7Qh8MtGuhcZQl4Fpw4woJcn17-y1-VHs,1690
|
|
44
|
+
jinns/utils/_types.py,sha256=PEPVEZ4XGT-7gCIasUHDYpIrMP_Ke1KTXGloXJPlK_k,746
|
|
45
|
+
jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
|
|
46
|
+
jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
|
|
47
|
+
jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
|
|
48
|
+
jinns-1.4.0.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
|
|
49
|
+
jinns-1.4.0.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
|
|
50
|
+
jinns-1.4.0.dist-info/METADATA,sha256=MkB5xNjrdFcJHmlwhk_RRwNNuCdkHLYpAxB-7TYhykg,5031
|
|
51
|
+
jinns-1.4.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
|
52
|
+
jinns-1.4.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
|
|
53
|
+
jinns-1.4.0.dist-info/RECORD,,
|