jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +114 -187
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +146 -520
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +20 -19
  19. jinns/loss/_loss_utils.py +78 -159
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -74
  22. jinns/nn/__init__.py +15 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +94 -57
  25. jinns/nn/_mlp.py +50 -25
  26. jinns/nn/_pinn.py +33 -19
  27. jinns/nn/_ppinn.py +70 -34
  28. jinns/nn/_save_load.py +21 -51
  29. jinns/nn/_spinn.py +33 -16
  30. jinns/nn/_spinn_mlp.py +28 -22
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +35 -34
  37. jinns/solver/_rar.py +80 -63
  38. jinns/solver/_solve.py +89 -63
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -0
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns-1.3.0.dist-info/RECORD +0 -44
  51. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  52. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  53. {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py CHANGED
@@ -8,56 +8,71 @@ from __future__ import (
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
10
  import time
11
- from typing import TYPE_CHECKING, NamedTuple, Dict, Union
11
+ from typing import TYPE_CHECKING, Any, TypeAlias, Callable
12
12
  from functools import partial
13
13
  import optax
14
14
  import jax
15
15
  from jax import jit
16
16
  import jax.numpy as jnp
17
- from jaxtyping import Int, Bool, Float, Array
17
+ from jaxtyping import Float, Array
18
18
  from jinns.solver._rar import init_rar, trigger_rar
19
19
  from jinns.utils._utils import _check_nan_in_pytree
20
20
  from jinns.solver._utils import _check_batch_size
21
- from jinns.utils._containers import *
22
- from jinns.data._DataGenerators import (
23
- DataGeneratorODE,
24
- CubicMeshPDEStatio,
25
- CubicMeshPDENonStatio,
26
- append_obs_batch,
27
- append_param_batch,
21
+ from jinns.utils._containers import (
22
+ DataGeneratorContainer,
23
+ OptimizationContainer,
24
+ OptimizationExtraContainer,
25
+ LossContainer,
26
+ StoredObjectContainer,
28
27
  )
28
+ from jinns.data._utils import append_param_batch, append_obs_batch
29
29
 
30
30
  if TYPE_CHECKING:
31
- from jinns.utils._types import *
31
+ from jinns.parameters._params import Params
32
+ from jinns.utils._types import AnyLoss, AnyBatch
33
+ from jinns.validation._validation import AbstractValidationModule
34
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
35
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
36
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
37
+
38
+ main_carry: TypeAlias = tuple[
39
+ int,
40
+ AnyLoss,
41
+ OptimizationContainer,
42
+ OptimizationExtraContainer,
43
+ DataGeneratorContainer,
44
+ AbstractValidationModule | None,
45
+ LossContainer,
46
+ StoredObjectContainer,
47
+ Float[Array, " n_iter"] | None,
48
+ ]
32
49
 
33
50
 
34
51
  def solve(
35
- n_iter: Int,
36
- init_params: AnyParams,
37
- data: AnyDataGenerator,
52
+ n_iter: int,
53
+ init_params: Params[Array],
54
+ data: AbstractDataGenerator,
38
55
  loss: AnyLoss,
39
56
  optimizer: optax.GradientTransformation,
40
- print_loss_every: Int = 1000,
41
- opt_state: Union[NamedTuple, None] = None,
42
- tracked_params: Params | ParamsDict | None = None,
57
+ print_loss_every: int = 1000,
58
+ opt_state: optax.OptState | None = None,
59
+ tracked_params: Params[Any | None] | None = None,
43
60
  param_data: DataGeneratorParameter | None = None,
44
- obs_data: (
45
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
46
- ) = None,
61
+ obs_data: DataGeneratorObservations | None = None,
47
62
  validation: AbstractValidationModule | None = None,
48
63
  obs_batch_sharding: jax.sharding.Sharding | None = None,
49
- verbose: Bool = True,
50
- ahead_of_time: Bool = True,
64
+ verbose: bool = True,
65
+ ahead_of_time: bool = True,
51
66
  ) -> tuple[
52
- Params | ParamsDict,
53
- Float[Array, "n_iter"],
54
- Dict[str, Float[Array, "n_iter"]],
55
- AnyDataGenerator,
67
+ Params[Array],
68
+ Float[Array, " n_iter"],
69
+ dict[str, Float[Array, " n_iter"]],
70
+ AbstractDataGenerator,
56
71
  AnyLoss,
57
- NamedTuple,
58
- AnyParams,
59
- Float[Array, "n_iter"],
60
- AnyParams,
72
+ optax.OptState,
73
+ Params[Array | None],
74
+ Float[Array, " n_iter"] | None,
75
+ Params[Array],
61
76
  ]:
62
77
  """
63
78
  Performs the optimization process via stochastic gradient descent
@@ -94,8 +109,7 @@ def solve(
94
109
  Default None. A DataGeneratorParameter object which can be used to
95
110
  sample equation parameters.
96
111
  obs_data
97
- Default None. A DataGeneratorObservations or
98
- DataGeneratorObservationsMultiPINNs
112
+ Default None. A DataGeneratorObservations
99
113
  object which can be used to sample minibatches of observations.
100
114
  validation
101
115
  Default None. Otherwise, a callable ``eqx.Module`` which implements a
@@ -171,11 +185,21 @@ def solve(
171
185
  _check_batch_size(obs_data, param_data, "n")
172
186
 
173
187
  if opt_state is None:
174
- opt_state = optimizer.init(init_params)
188
+ opt_state = optimizer.init(init_params) # type: ignore
189
+ # our Params are eqx.Module (dataclass + PyTree), PyTree is
190
+ # compatible with optax transform but not dataclass, this leads to a
191
+ # type hint error: we could prevent this by ensuring with the eqx.filter that
192
+ # we have only floating points optimizable params given to optax
193
+ # see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
194
+ # opt_state = optimizer.init(
195
+ # eqx.filter(init_params, eqx.is_inexact_array)
196
+ # )
197
+ # but this seems like a hack and there is no better way
198
+ # https://github.com/google-deepmind/optax/issues/384
175
199
 
176
200
  # RAR sampling init (ouside scanned function to avoid dynamic slice error)
177
201
  # If RAR is not used the _rar_step_*() are juste None and data is unchanged
178
- data, _rar_step_true, _rar_step_false = init_rar(data)
202
+ data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
179
203
 
180
204
  # Seq2seq
181
205
  curr_seq = 0
@@ -292,7 +316,7 @@ def solve(
292
316
  if verbose:
293
317
  _print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
294
318
 
295
- if validation is not None:
319
+ if validation is not None and validation_crit_values is not None:
296
320
  # there is a jax.lax.cond because we do not necesarily call the
297
321
  # validation step every iteration
298
322
  (
@@ -306,7 +330,7 @@ def solve(
306
330
  lambda operands: (
307
331
  operands[0],
308
332
  False,
309
- validation_crit_values[i - 1],
333
+ validation_crit_values[i - 1], # type: ignore don't know why it can still be None
310
334
  False,
311
335
  ),
312
336
  (
@@ -431,7 +455,7 @@ def solve(
431
455
  # get ready to return the parameters at last iteration...
432
456
  # (by default arbitrary choice, this could be None)
433
457
  validation_parameters = optimization.last_non_nan_params
434
- if validation is not None:
458
+ if validation is not None and validation_crit_values is not None:
435
459
  jax.debug.print(
436
460
  "validation loss value = {validation_loss_val}",
437
461
  validation_loss_val=validation_crit_values[i - 1],
@@ -466,24 +490,28 @@ def _gradient_step(
466
490
  loss: AnyLoss,
467
491
  optimizer: optax.GradientTransformation,
468
492
  batch: AnyBatch,
469
- params: AnyParams,
470
- opt_state: NamedTuple,
471
- last_non_nan_params: AnyParams,
493
+ params: Params[Array],
494
+ opt_state: optax.OptState,
495
+ last_non_nan_params: Params[Array],
472
496
  ) -> tuple[
473
497
  AnyLoss,
474
498
  float,
475
- Dict[str, float],
476
- AnyParams,
477
- NamedTuple,
478
- AnyParams,
499
+ dict[str, float],
500
+ Params[Array],
501
+ optax.OptState,
502
+ Params[Array],
479
503
  ]:
480
504
  """
481
505
  optimizer cannot be jit-ted.
482
506
  """
483
507
  value_grad_loss = jax.value_and_grad(loss, has_aux=True)
484
508
  (loss_val, loss_terms), grads = value_grad_loss(params, batch)
485
- updates, opt_state = optimizer.update(grads, opt_state, params)
486
- params = optax.apply_updates(params, updates)
509
+ updates, opt_state = optimizer.update(
510
+ grads,
511
+ opt_state,
512
+ params, # type: ignore
513
+ ) # see optimizer.init for explaination
514
+ params = optax.apply_updates(params, updates) # type: ignore
487
515
 
488
516
  # check if any of the parameters is NaN
489
517
  last_non_nan_params = jax.lax.cond(
@@ -504,7 +532,7 @@ def _gradient_step(
504
532
 
505
533
 
506
534
  @partial(jit, static_argnames=["prefix"])
507
- def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
535
+ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
508
536
  # note that if the following is not jitted in the main lor loop, it is
509
537
  # super slow
510
538
  _ = jax.lax.cond(
@@ -521,17 +549,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
521
549
 
522
550
  @jit
523
551
  def _store_loss_and_params(
524
- i: Int,
525
- params: AnyParams,
526
- stored_params: AnyParams,
527
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
528
- train_loss_values: Float[Array, "n_iter"],
552
+ i: int,
553
+ params: Params[Array],
554
+ stored_params: Params[Array],
555
+ stored_loss_terms: dict[str, Float[Array, " n_iter"]],
556
+ train_loss_values: Float[Array, " n_iter"],
529
557
  train_loss_val: float,
530
- loss_terms: Dict[str, float],
531
- tracked_params: AnyParams,
532
- ) -> tuple[
533
- Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
534
- ]:
558
+ loss_terms: dict[str, float],
559
+ tracked_params: Params,
560
+ ) -> tuple[Params, dict[str, Float[Array, " n_iter"]], Float[Array, " n_iter"]]:
535
561
  stored_params = jax.tree_util.tree_map(
536
562
  lambda stored_value, param, tracked_param: (
537
563
  None
@@ -558,7 +584,7 @@ def _store_loss_and_params(
558
584
  return (stored_params, stored_loss_terms, train_loss_values)
559
585
 
560
586
 
561
- def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
587
+ def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
562
588
  """
563
589
  Wrapper to get the break_fun with appropriate `n_iter`.
564
590
  The verbose argument is here to control printing (or not) when exiting
@@ -599,7 +625,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
599
625
  bool_nan_in_params = jax.lax.cond(
600
626
  _check_nan_in_pytree(optimization.params),
601
627
  lambda _: stop_while_loop(
602
- "NaN values in parameters " "(returning last non NaN values)"
628
+ "NaN values in parameters (returning last non NaN values)"
603
629
  ),
604
630
  continue_while_loop,
605
631
  None,
@@ -622,18 +648,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
622
648
 
623
649
 
624
650
  def _get_get_batch(
625
- obs_batch_sharding: jax.sharding.Sharding,
651
+ obs_batch_sharding: jax.sharding.Sharding | None,
626
652
  ) -> Callable[
627
653
  [
628
- AnyDataGenerator,
654
+ AbstractDataGenerator,
629
655
  DataGeneratorParameter | None,
630
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
656
+ DataGeneratorObservations | None,
631
657
  ],
632
658
  tuple[
633
659
  AnyBatch,
634
- AnyDataGenerator,
660
+ AbstractDataGenerator,
635
661
  DataGeneratorParameter | None,
636
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
662
+ DataGeneratorObservations | None,
637
663
  ],
638
664
  ]:
639
665
  """
jinns/solver/_utils.py CHANGED
@@ -1,9 +1,7 @@
1
- from jinns.data._DataGenerators import (
2
- DataGeneratorODE,
3
- CubicMeshPDEStatio,
4
- CubicMeshPDENonStatio,
5
- DataGeneratorParameter,
6
- )
1
+ from jinns.data._DataGeneratorODE import DataGeneratorODE
2
+ from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
3
+ from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
7
5
 
8
6
 
9
7
  def _check_batch_size(other_data, main_data, attr_name):
jinns/utils/__init__.py CHANGED
@@ -1 +1,3 @@
1
1
  from ._utils import get_grid
2
+
3
+ __all__ = ["get_grid"]
@@ -11,23 +11,26 @@ from jaxtyping import PyTree, Array, Float, Bool
11
11
  from optax import OptState
12
12
  import equinox as eqx
13
13
 
14
+ from jinns.parameters._params import Params
15
+
14
16
  if TYPE_CHECKING:
15
- from jinns.utils._types import *
17
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
18
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
19
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
20
+ from jinns.utils._types import AnyLoss
16
21
 
17
22
 
18
23
  class DataGeneratorContainer(eqx.Module):
19
- data: AnyDataGenerator
24
+ data: AbstractDataGenerator
20
25
  param_data: DataGeneratorParameter | None = None
21
- obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
22
- None
23
- )
26
+ obs_data: DataGeneratorObservations | None = None
24
27
 
25
28
 
26
29
  class ValidationContainer(eqx.Module):
27
30
  loss: AnyLoss | None
28
31
  data: DataGeneratorContainer
29
32
  hyperparams: PyTree = None
30
- loss_values: Float[Array, "n_iter"] | None = None
33
+ loss_values: Float[Array, " n_iter"] | None = None
31
34
 
32
35
 
33
36
  class OptimizationContainer(eqx.Module):
@@ -45,9 +48,9 @@ class OptimizationExtraContainer(eqx.Module):
45
48
 
46
49
 
47
50
  class LossContainer(eqx.Module):
48
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
49
- train_loss_values: Float[Array, "n_iter"]
51
+ stored_loss_terms: Dict[str, Float[Array, " n_iter"]]
52
+ train_loss_values: Float[Array, " n_iter"]
50
53
 
51
54
 
52
55
  class StoredObjectContainer(eqx.Module):
53
- stored_params: list | None
56
+ stored_params: Params[Array | None]
jinns/utils/_types.py CHANGED
@@ -1,65 +1,19 @@
1
- # pragma: exclude file
2
1
  from __future__ import (
3
2
  annotations,
4
3
  ) # https://docs.python.org/3/library/typing.html#constant
5
4
 
6
- from typing import TypeAlias, TYPE_CHECKING, NewType
7
- from jaxtyping import Int
5
+ from typing import TypeAlias, TYPE_CHECKING, Callable
6
+ from jaxtyping import Float, Array
8
7
 
9
8
  if TYPE_CHECKING:
10
- from jinns.loss._LossPDE import (
11
- LossPDEStatio,
12
- LossPDENonStatio,
13
- SystemLossPDE,
14
- )
9
+ from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
10
+ from jinns.loss._LossODE import LossODE
11
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
15
12
 
16
- from jinns.loss._LossODE import LossODE, SystemLossODE
17
- from jinns.parameters._params import Params, ParamsDict
18
- from jinns.data._DataGenerators import (
19
- DataGeneratorODE,
20
- CubicMeshPDEStatio,
21
- CubicMeshPDENonStatio,
22
- DataGeneratorObservations,
23
- DataGeneratorParameter,
24
- DataGeneratorObservationsMultiPINNs,
25
- )
13
+ # Here we define types available for the whole package
14
+ BoundaryConditionFun: TypeAlias = Callable[
15
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
16
+ ]
26
17
 
27
- from jinns.loss import DynamicLoss
28
- from jinns.data._Batchs import *
29
- from jinns.nn._pinn import PINN
30
- from jinns.nn._hyperpinn import HyperPINN
31
- from jinns.nn._spinn_mlp import SPINN
32
- from jinns.utils._containers import *
33
- from jinns.validation._validation import AbstractValidationModule
34
-
35
- AnyLoss: TypeAlias = (
36
- LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
37
- )
38
-
39
- AnyParams: TypeAlias = Params | ParamsDict
40
-
41
- AnyDataGenerator: TypeAlias = (
42
- DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
43
- )
44
-
45
- AnyPINN: TypeAlias = PINN | HyperPINN | SPINN
46
-
47
- AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
48
- rar_operands = NewType(
49
- "rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
50
- )
51
-
52
- main_carry = NewType(
53
- "main_carry",
54
- tuple[
55
- Int,
56
- AnyLoss,
57
- OptimizationContainer,
58
- OptimizationExtraContainer,
59
- DataGeneratorContainer,
60
- AbstractValidationModule,
61
- LossContainer,
62
- StoredObjectContainer,
63
- Float[Array, "n_iter"],
64
- ],
65
- )
18
+ AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
19
+ AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
jinns/utils/_utils.py CHANGED
@@ -2,20 +2,13 @@
2
2
  Implements various utility functions
3
3
  """
4
4
 
5
- from math import prod
6
5
  import warnings
7
6
  import jax
8
7
  import jax.numpy as jnp
9
- from jaxtyping import PyTree, Array
8
+ from jaxtyping import PyTree, Array, Bool
10
9
 
11
- from jinns.data._DataGenerators import (
12
- DataGeneratorODE,
13
- CubicMeshPDEStatio,
14
- CubicMeshPDENonStatio,
15
- )
16
10
 
17
-
18
- def _check_nan_in_pytree(pytree: PyTree) -> bool:
11
+ def _check_nan_in_pytree(pytree: PyTree) -> Bool[Array, " "]:
19
12
  """
20
13
  Check if there is a NaN value anywhere is the pytree
21
14
 
@@ -55,7 +48,7 @@ def get_grid(in_array: Array) -> Array:
55
48
 
56
49
 
57
50
  def _check_shape_and_type(
58
- r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
51
+ r: Array | int | float, expected_shape: tuple, cause: str = "", binop: str = ""
59
52
  ) -> Array | float:
60
53
  """
61
54
  Ensures float type and correct shapes for broadcasting when performing a
@@ -90,7 +83,7 @@ def _check_shape_and_type(
90
83
 
91
84
 
92
85
  def _subtract_with_check(
93
- a: Array | int, b: Array | int, cause: str = ""
86
+ a: Array | int | float, b: Array, cause: str = ""
94
87
  ) -> Array | float:
95
88
  a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
96
89
  return a - b
@@ -1 +1,3 @@
1
1
  from ._validation import AbstractValidationModule, ValidationLoss
2
+
3
+ __all__ = ["AbstractValidationModule", "ValidationLoss"]
@@ -7,19 +7,23 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  import abc
10
- from typing import TYPE_CHECKING, Union
10
+ from typing import TYPE_CHECKING
11
11
  import equinox as eqx
12
12
  import jax
13
13
  import jax.numpy as jnp
14
- from jaxtyping import Array
14
+ from jaxtyping import Array, Float
15
15
 
16
- from jinns.data._DataGenerators import (
16
+ from jinns.data._utils import (
17
17
  append_obs_batch,
18
18
  append_param_batch,
19
19
  )
20
20
 
21
21
  if TYPE_CHECKING:
22
- from jinns.utils._types import *
22
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
23
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
24
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
25
+ from jinns.parameters._params import Params
26
+ from jinns.loss._abstract_loss import AbstractLoss
23
27
 
24
28
  # Using eqx Module for the DataClass + Pytree inheritance
25
29
  # Abstract class and abstract/final pattern is used
@@ -40,8 +44,8 @@ class AbstractValidationModule(eqx.Module):
40
44
 
41
45
  @abc.abstractmethod
42
46
  def __call__(
43
- self, params: Params | ParamsDict
44
- ) -> tuple["AbstractValidationModule", bool, Array, bool]:
47
+ self, params: Params[Array]
48
+ ) -> tuple[AbstractValidationModule, bool, Array, Params[Array]]:
45
49
  raise NotImplementedError
46
50
 
47
51
 
@@ -52,24 +56,20 @@ class ValidationLoss(AbstractValidationModule):
52
56
  for more complicated validation strategy.
53
57
  """
54
58
 
55
- loss: AnyLoss = eqx.field(kw_only=True) # NOTE that
56
- # there used to be a deepcopy here which has been suppressed. 1) No need
57
- # because loss are now eqx.Module (immutable) so no risk of in-place
58
- # modification. 2) deepcopy is buggy with equinox, InitVar etc. (see issue
59
- # #857 on equinox github)
60
- validation_data: Union[AnyDataGenerator] = eqx.field(kw_only=True)
61
- validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
59
+ loss: AbstractLoss = eqx.field(kw_only=True)
60
+ validation_data: AbstractDataGenerator = eqx.field(kw_only=True)
61
+ validation_param_data: DataGeneratorParameter = eqx.field(
62
+ kw_only=True, default=None
63
+ )
64
+ validation_obs_data: DataGeneratorObservations | None = eqx.field(
62
65
  kw_only=True, default=None
63
66
  )
64
- validation_obs_data: Union[
65
- DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
66
- ] = eqx.field(kw_only=True, default=None)
67
67
  call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
68
68
  early_stopping: bool = eqx.field(
69
69
  kw_only=True, default=True
70
70
  ) # globally control if early stopping happens
71
71
 
72
- patience: Union[int] = eqx.field(kw_only=True, default=10)
72
+ patience: int = eqx.field(kw_only=True, default=10)
73
73
  best_val_loss: Array = eqx.field(
74
74
  converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
75
75
  )
@@ -79,10 +79,11 @@ class ValidationLoss(AbstractValidationModule):
79
79
  )
80
80
 
81
81
  def __call__(
82
- self, params: AnyParams
83
- ) -> tuple["ValidationLoss", bool, float, AnyParams]:
82
+ self, params: Params[Array]
83
+ ) -> tuple[ValidationLoss, bool, Float[Array, " "], Params[Array]]:
84
84
  # do in-place mutation
85
85
 
86
+ # pylint / pyright complains below when using the self attributes see: https://github.com/patrick-kidger/equinox/issues/1013
86
87
  validation_data, val_batch = self.validation_data.get_batch()
87
88
  if self.validation_param_data is not None:
88
89
  validation_param_data, param_batch = self.validation_param_data.get_batch()
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.3.0
3
+ Version: 1.4.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
25
25
  Provides-Extra: notebook
26
26
  Requires-Dist: jupyter; extra == "notebook"
27
27
  Requires-Dist: seaborn; extra == "notebook"
28
+ Dynamic: license-file
28
29
 
29
30
  jinns
30
31
  =====
@@ -99,7 +100,7 @@ Here are the contributors guidelines:
99
100
  pip install -e .
100
101
  ```
101
102
 
102
- 3. Install pre-commit and run it.
103
+ 3. Install pre-commit and run it. Our pre-commit hooks consist in `ruff format` and `ruff check`. You can install `ruff` simply by `pip install ruff`. We highly recommend you to check the code type hints with `pyright` even though we currently have no rule concerning type checking in the pipeline.
103
104
 
104
105
  ```bash
105
106
  pip install pre-commit
@@ -0,0 +1,53 @@
1
+ jinns/__init__.py,sha256=hyh3QKO2cQGK5cmvFYP0MrXb-tK_DM2T9CwLwO-sEX8,500
2
+ jinns/data/_AbstractDataGenerator.py,sha256=O61TBOyeOFKwf1xqKzFD4KwCWRDnm2XgyJ-kKY9fmB4,557
3
+ jinns/data/_Batchs.py,sha256=-DlD6Qag3zs5QbKtKAOvOzV7JOpNOqAm_P8cwo1dIZg,1574
4
+ jinns/data/_CubicMeshPDENonStatio.py,sha256=c_8czJpxSoEvgZ8LDpL2sqtF9dcW4ELNO4juEFMOxog,16400
5
+ jinns/data/_CubicMeshPDEStatio.py,sha256=stZ0Kbb7_VwFmWUSPs0P6a6qRj2Tu67p7sxEfb1Ajks,17865
6
+ jinns/data/_DataGeneratorODE.py,sha256=5RzUbQFEsooAZsocDw4wRgA_w5lJmDMuY4M6u79K-1c,7260
7
+ jinns/data/_DataGeneratorObservations.py,sha256=jknepLsJatSJHFq5lLMD-fFHkPGj5q286LEjE-vH24k,7738
8
+ jinns/data/_DataGeneratorParameter.py,sha256=IedX3jcOj7ZDW_18IAcRR75KVzQzo85z9SICIKDBJl4,8539
9
+ jinns/data/__init__.py,sha256=4b4eVsoGHV89m2kGDiAOHsrGialZQ6j5ja575qWwQHs,677
10
+ jinns/data/_utils.py,sha256=XxaLIg_HIgcB7ACBIhTpHbCT1HXKcDaY1NABncAYX1c,5223
11
+ jinns/experimental/__init__.py,sha256=DT9e57zbjfzPeRnXemGUqnGd--MhV77FspChT0z4YrE,410
12
+ jinns/experimental/_diffrax_solver.py,sha256=upMr3kTTNrxEiSUO_oLvCXcjS9lPxSjvbB81h3qlhaU,6813
13
+ jinns/loss/_DynamicLoss.py,sha256=4mb7OCP-cGZ_mG2MQ-AniddDcuBT78p4bQI7rZpwte4,22722
14
+ jinns/loss/_DynamicLossAbstract.py,sha256=HIs6TtE9ouvT5H3cBR52UWSkALTgRWRG_kB3s890b2U,11253
15
+ jinns/loss/_LossODE.py,sha256=aCx5vD3CXmhws36gYre1iu_t29MefpTxW54gUxKej_Q,11856
16
+ jinns/loss/_LossPDE.py,sha256=mANXuWjm02bKwfoIn-RH8vxPY-RG3jVErcSrwwK3HzM,33259
17
+ jinns/loss/__init__.py,sha256=qnNRGjl6Tcga1koztMJnJ3eL8XNP0gRbNsXVgq4CkOI,1162
18
+ jinns/loss/_abstract_loss.py,sha256=hf6ohNoSAqzskFyivCiuE2SCqhsV4UWLv65L4V8H3ys,407
19
+ jinns/loss/_boundary_conditions.py,sha256=9HGw1cGLfmEilP4V4B2T0zl0YP1kNtrtXVLQNiBmWgc,12464
20
+ jinns/loss/_loss_utils.py,sha256=R_jhBHkTwGu41gWnhYRswunxdzetPZ9-Gmkghzorock,11745
21
+ jinns/loss/_loss_weights.py,sha256=5BVZglM7Y3m_8muXcKT898fAC6_RbdLNQ7WWx3lOE9k,1077
22
+ jinns/loss/_operators.py,sha256=Ds5yRH7hu-jaGRp7PYbt821BgYuEvgWHufWhYgdMjw0,22909
23
+ jinns/nn/__init__.py,sha256=gwE48oqB_FsSIE-hUvCLz0jPaqX350LBxzH6ueFWYk4,456
24
+ jinns/nn/_abstract_pinn.py,sha256=JUFjlV_nyheZw-max_tAUgFh6SspIbD5we_4bn70V6k,671
25
+ jinns/nn/_hyperpinn.py,sha256=hF7HRLMMVBPT9CTQC2DjpDRcQDJCrT9cAj8wfApT_WE,19412
26
+ jinns/nn/_mlp.py,sha256=Xmi-mG6uakN67R2S2UsBazdXIJVaGsD2B6TeJM1QjGY,8881
27
+ jinns/nn/_pinn.py,sha256=4pvgUPQdQaO3cPBuEU7W4UaLV7lodqcR3pVR1sC0ni4,8774
28
+ jinns/nn/_ppinn.py,sha256=LtjGQaLozdA4Kwutn8Pyerbu9yOc0t3_b701yfMb1ac,10392
29
+ jinns/nn/_save_load.py,sha256=UqVy2oBzvIeBy6XB9tb61x3-x8i4dNCXJHC5_-bko-I,7477
30
+ jinns/nn/_spinn.py,sha256=u5YG2FXcrg8p_uS2QFGmWoeFXYLxXnyV2e6BUHpo4xk,4774
31
+ jinns/nn/_spinn_mlp.py,sha256=uCL454sF0Tfj7KT-fdXPnvKJYRQOuq60N0r2b2VAB8Q,7606
32
+ jinns/nn/_utils.py,sha256=9UXz73iHKHVQYPBPIEitrHYJzJ14dspRwPfLA8avx0c,1120
33
+ jinns/parameters/__init__.py,sha256=O0n7y6R1LRmFzzugCxMFCMS2pgsuWSh-XHjfFViN_eg,265
34
+ jinns/parameters/_derivative_keys.py,sha256=YlLDX49PfYhr2Tj--t3praiD8JOUTZU6PTmjbNZsbMc,19173
35
+ jinns/parameters/_params.py,sha256=qn4IGMJhD9lDBqOWmGEMy4gXt5a6KHfirkYZwHO7Vwk,2633
36
+ jinns/plot/__init__.py,sha256=KPHX0Um4FbciZO1yD8kjZbkaT8tT964Y6SE2xCQ4eDU,135
37
+ jinns/plot/_plot.py,sha256=-A5auNeElaz2_8UzVQJQE4143ZFg0zgMjStU7kwttEY,11565
38
+ jinns/solver/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
+ jinns/solver/_rar.py,sha256=vSVTnCGCusI1vTZCvIkP2_G8we44G_42yZHx2sOK9DE,10291
40
+ jinns/solver/_solve.py,sha256=uPJsN4Pv_QEHYMlMdo29hlJXmWyCtf2aFZlj2M8Fl2U,24886
41
+ jinns/solver/_utils.py,sha256=sM2UbVzYyjw24l4QSIR3IlynJTPGD_S08r8v0lXMxA8,5876
42
+ jinns/utils/__init__.py,sha256=OEYWLCw8pKE7xoQREbd6SHvCjuw2QZHuVA6YwDcsBE8,53
43
+ jinns/utils/_containers.py,sha256=XANkmGiXvFb7Qh8MtGuhcZQl4Fpw4woJcn17-y1-VHs,1690
44
+ jinns/utils/_types.py,sha256=PEPVEZ4XGT-7gCIasUHDYpIrMP_Ke1KTXGloXJPlK_k,746
45
+ jinns/utils/_utils.py,sha256=M7NXX9ok-BkH5o_xo74PB1_Cc8XiDipSl51rq82dTH4,2821
46
+ jinns/validation/__init__.py,sha256=FTyUO-v1b8Tv-FDSQsntrH7zl9E0ENexqKMT_dFRkYo,124
47
+ jinns/validation/_validation.py,sha256=8p6sMKiBAvA6JNm65hjkMj0997LJ0BkyCREEh0AnPVE,4803
48
+ jinns-1.4.0.dist-info/licenses/AUTHORS,sha256=7NwCj9nU-HNG1asvy4qhQ2w7oZHrn-Lk5_wK_Ve7a3M,80
49
+ jinns-1.4.0.dist-info/licenses/LICENSE,sha256=BIAkGtXB59Q_BG8f6_OqtQ1BHPv60ggE9mpXJYz2dRM,11337
50
+ jinns-1.4.0.dist-info/METADATA,sha256=MkB5xNjrdFcJHmlwhk_RRwNNuCdkHLYpAxB-7TYhykg,5031
51
+ jinns-1.4.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
52
+ jinns-1.4.0.dist-info/top_level.txt,sha256=RXbkr2hzy8WBE8aiRyrJYFqn3JeMJIhMdybLjjLTB9c,6
53
+ jinns-1.4.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5