jinns 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. jinns/__init__.py +17 -7
  2. jinns/data/_AbstractDataGenerator.py +19 -0
  3. jinns/data/_Batchs.py +31 -12
  4. jinns/data/_CubicMeshPDENonStatio.py +431 -0
  5. jinns/data/_CubicMeshPDEStatio.py +464 -0
  6. jinns/data/_DataGeneratorODE.py +187 -0
  7. jinns/data/_DataGeneratorObservations.py +189 -0
  8. jinns/data/_DataGeneratorParameter.py +206 -0
  9. jinns/data/__init__.py +19 -9
  10. jinns/data/_utils.py +149 -0
  11. jinns/experimental/__init__.py +9 -0
  12. jinns/loss/_DynamicLoss.py +116 -189
  13. jinns/loss/_DynamicLossAbstract.py +45 -68
  14. jinns/loss/_LossODE.py +71 -336
  15. jinns/loss/_LossPDE.py +176 -513
  16. jinns/loss/__init__.py +28 -6
  17. jinns/loss/_abstract_loss.py +15 -0
  18. jinns/loss/_boundary_conditions.py +22 -21
  19. jinns/loss/_loss_utils.py +98 -173
  20. jinns/loss/_loss_weights.py +12 -44
  21. jinns/loss/_operators.py +84 -76
  22. jinns/nn/__init__.py +22 -0
  23. jinns/nn/_abstract_pinn.py +22 -0
  24. jinns/nn/_hyperpinn.py +434 -0
  25. jinns/nn/_mlp.py +217 -0
  26. jinns/nn/_pinn.py +204 -0
  27. jinns/nn/_ppinn.py +239 -0
  28. jinns/{utils → nn}/_save_load.py +39 -53
  29. jinns/nn/_spinn.py +123 -0
  30. jinns/nn/_spinn_mlp.py +202 -0
  31. jinns/nn/_utils.py +38 -0
  32. jinns/parameters/__init__.py +8 -1
  33. jinns/parameters/_derivative_keys.py +116 -177
  34. jinns/parameters/_params.py +18 -46
  35. jinns/plot/__init__.py +2 -0
  36. jinns/plot/_plot.py +38 -37
  37. jinns/solver/_rar.py +82 -65
  38. jinns/solver/_solve.py +111 -71
  39. jinns/solver/_utils.py +4 -6
  40. jinns/utils/__init__.py +2 -5
  41. jinns/utils/_containers.py +12 -9
  42. jinns/utils/_types.py +11 -57
  43. jinns/utils/_utils.py +4 -11
  44. jinns/validation/__init__.py +2 -0
  45. jinns/validation/_validation.py +20 -19
  46. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/METADATA +11 -10
  47. jinns-1.4.0.dist-info/RECORD +53 -0
  48. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
  49. jinns/data/_DataGenerators.py +0 -1634
  50. jinns/utils/_hyperpinn.py +0 -420
  51. jinns/utils/_pinn.py +0 -324
  52. jinns/utils/_ppinn.py +0 -227
  53. jinns/utils/_spinn.py +0 -249
  54. jinns-1.2.0.dist-info/RECORD +0 -41
  55. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
  56. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
  57. {jinns-1.2.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/solver/_solve.py CHANGED
@@ -8,55 +8,71 @@ from __future__ import (
8
8
  ) # https://docs.python.org/3/library/typing.html#constant
9
9
 
10
10
  import time
11
- from typing import TYPE_CHECKING, NamedTuple, Dict, Union
11
+ from typing import TYPE_CHECKING, Any, TypeAlias, Callable
12
12
  from functools import partial
13
13
  import optax
14
14
  import jax
15
15
  from jax import jit
16
16
  import jax.numpy as jnp
17
- from jaxtyping import Int, Bool, Float, Array
17
+ from jaxtyping import Float, Array
18
18
  from jinns.solver._rar import init_rar, trigger_rar
19
19
  from jinns.utils._utils import _check_nan_in_pytree
20
20
  from jinns.solver._utils import _check_batch_size
21
- from jinns.utils._containers import *
22
- from jinns.data._DataGenerators import (
23
- DataGeneratorODE,
24
- CubicMeshPDEStatio,
25
- CubicMeshPDENonStatio,
26
- append_obs_batch,
27
- append_param_batch,
21
+ from jinns.utils._containers import (
22
+ DataGeneratorContainer,
23
+ OptimizationContainer,
24
+ OptimizationExtraContainer,
25
+ LossContainer,
26
+ StoredObjectContainer,
28
27
  )
28
+ from jinns.data._utils import append_param_batch, append_obs_batch
29
29
 
30
30
  if TYPE_CHECKING:
31
- from jinns.utils._types import *
31
+ from jinns.parameters._params import Params
32
+ from jinns.utils._types import AnyLoss, AnyBatch
33
+ from jinns.validation._validation import AbstractValidationModule
34
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
35
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
36
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
37
+
38
+ main_carry: TypeAlias = tuple[
39
+ int,
40
+ AnyLoss,
41
+ OptimizationContainer,
42
+ OptimizationExtraContainer,
43
+ DataGeneratorContainer,
44
+ AbstractValidationModule | None,
45
+ LossContainer,
46
+ StoredObjectContainer,
47
+ Float[Array, " n_iter"] | None,
48
+ ]
32
49
 
33
50
 
34
51
  def solve(
35
- n_iter: Int,
36
- init_params: AnyParams,
37
- data: AnyDataGenerator,
52
+ n_iter: int,
53
+ init_params: Params[Array],
54
+ data: AbstractDataGenerator,
38
55
  loss: AnyLoss,
39
56
  optimizer: optax.GradientTransformation,
40
- print_loss_every: Int = 1000,
41
- opt_state: Union[NamedTuple, None] = None,
42
- tracked_params: Params | ParamsDict | None = None,
57
+ print_loss_every: int = 1000,
58
+ opt_state: optax.OptState | None = None,
59
+ tracked_params: Params[Any | None] | None = None,
43
60
  param_data: DataGeneratorParameter | None = None,
44
- obs_data: (
45
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None
46
- ) = None,
61
+ obs_data: DataGeneratorObservations | None = None,
47
62
  validation: AbstractValidationModule | None = None,
48
63
  obs_batch_sharding: jax.sharding.Sharding | None = None,
49
- verbose: Bool = True,
64
+ verbose: bool = True,
65
+ ahead_of_time: bool = True,
50
66
  ) -> tuple[
51
- Params | ParamsDict,
52
- Float[Array, "n_iter"],
53
- Dict[str, Float[Array, "n_iter"]],
54
- AnyDataGenerator,
67
+ Params[Array],
68
+ Float[Array, " n_iter"],
69
+ dict[str, Float[Array, " n_iter"]],
70
+ AbstractDataGenerator,
55
71
  AnyLoss,
56
- NamedTuple,
57
- AnyParams,
58
- Float[Array, "n_iter"],
59
- AnyParams,
72
+ optax.OptState,
73
+ Params[Array | None],
74
+ Float[Array, " n_iter"] | None,
75
+ Params[Array],
60
76
  ]:
61
77
  """
62
78
  Performs the optimization process via stochastic gradient descent
@@ -93,8 +109,7 @@ def solve(
93
109
  Default None. A DataGeneratorParameter object which can be used to
94
110
  sample equation parameters.
95
111
  obs_data
96
- Default None. A DataGeneratorObservations or
97
- DataGeneratorObservationsMultiPINNs
112
+ Default None. A DataGeneratorObservations
98
113
  object which can be used to sample minibatches of observations.
99
114
  validation
100
115
  Default None. Otherwise, a callable ``eqx.Module`` which implements a
@@ -118,6 +133,14 @@ def solve(
118
133
  verbose
119
134
  Default True. If False, no std output (loss or cause of
120
135
  exiting the optimization loop) will be produced.
136
+ ahead_of_time
137
+ Default True. Separate the compilation of the main training loop from
138
+ the execution to get both timings. You might need to avoid this
139
+ behaviour if you need to perform JAX transforms over chunks of code
140
+ containing `jinns.solve()` since AOT-compiled functions cannot be JAX
141
+ transformed (see https://jax.readthedocs.io/en/latest/aot.html#aot-compiled-functions-cannot-be-transformed).
142
+ When False, jinns does not provide any timing information (which would
143
+ be nonsense in a JIT transformed `solve()` function).
121
144
 
122
145
  Returns
123
146
  -------
@@ -162,11 +185,21 @@ def solve(
162
185
  _check_batch_size(obs_data, param_data, "n")
163
186
 
164
187
  if opt_state is None:
165
- opt_state = optimizer.init(init_params)
188
+ opt_state = optimizer.init(init_params) # type: ignore
189
+ # our Params are eqx.Module (dataclass + PyTree), PyTree is
190
+ # compatible with optax transform but not dataclass, this leads to a
191
+ # type hint error: we could prevent this by ensuring with the eqx.filter that
192
+ # we have only floating points optimizable params given to optax
193
+ # see https://docs.kidger.site/equinox/faq/#optax-throwing-a-typeerror
194
+ # opt_state = optimizer.init(
195
+ # eqx.filter(init_params, eqx.is_inexact_array)
196
+ # )
197
+ # but this seems like a hack and there is no better way
198
+ # https://github.com/google-deepmind/optax/issues/384
166
199
 
167
200
  # RAR sampling init (ouside scanned function to avoid dynamic slice error)
168
201
  # If RAR is not used the _rar_step_*() are juste None and data is unchanged
169
- data, _rar_step_true, _rar_step_false = init_rar(data)
202
+ data, _rar_step_true, _rar_step_false = init_rar(data) # type: ignore
170
203
 
171
204
  # Seq2seq
172
205
  curr_seq = 0
@@ -283,7 +316,7 @@ def solve(
283
316
  if verbose:
284
317
  _print_fn(i, train_loss_value, print_loss_every, prefix="[train] ")
285
318
 
286
- if validation is not None:
319
+ if validation is not None and validation_crit_values is not None:
287
320
  # there is a jax.lax.cond because we do not necesarily call the
288
321
  # validation step every iteration
289
322
  (
@@ -297,7 +330,7 @@ def solve(
297
330
  lambda operands: (
298
331
  operands[0],
299
332
  False,
300
- validation_crit_values[i - 1],
333
+ validation_crit_values[i - 1], # type: ignore don't know why it can still be None
301
334
  False,
302
335
  ),
303
336
  (
@@ -384,16 +417,21 @@ def solve(
384
417
  def train_fun(carry):
385
418
  return jax.lax.while_loop(break_fun, _one_iteration, carry)
386
419
 
387
- start = time.time()
388
- compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
389
- end = time.time()
390
- print("\nCompilation took\n", end - start, "\n")
420
+ if ahead_of_time:
421
+ start = time.time()
422
+ compiled_train_fun = jax.jit(train_fun).lower(carry).compile()
423
+ end = time.time()
424
+ if verbose:
425
+ print("\nCompilation took\n", end - start, "\n")
391
426
 
392
- start = time.time()
393
- carry = compiled_train_fun(carry)
394
- jax.block_until_ready(carry)
395
- end = time.time()
396
- print("\nTraining took\n", end - start, "\n")
427
+ start = time.time()
428
+ carry = compiled_train_fun(carry)
429
+ jax.block_until_ready(carry)
430
+ end = time.time()
431
+ if verbose:
432
+ print("\nTraining took\n", end - start, "\n")
433
+ else:
434
+ carry = train_fun(carry)
397
435
 
398
436
  (
399
437
  i,
@@ -417,7 +455,7 @@ def solve(
417
455
  # get ready to return the parameters at last iteration...
418
456
  # (by default arbitrary choice, this could be None)
419
457
  validation_parameters = optimization.last_non_nan_params
420
- if validation is not None:
458
+ if validation is not None and validation_crit_values is not None:
421
459
  jax.debug.print(
422
460
  "validation loss value = {validation_loss_val}",
423
461
  validation_loss_val=validation_crit_values[i - 1],
@@ -452,24 +490,28 @@ def _gradient_step(
452
490
  loss: AnyLoss,
453
491
  optimizer: optax.GradientTransformation,
454
492
  batch: AnyBatch,
455
- params: AnyParams,
456
- opt_state: NamedTuple,
457
- last_non_nan_params: AnyParams,
493
+ params: Params[Array],
494
+ opt_state: optax.OptState,
495
+ last_non_nan_params: Params[Array],
458
496
  ) -> tuple[
459
497
  AnyLoss,
460
498
  float,
461
- Dict[str, float],
462
- AnyParams,
463
- NamedTuple,
464
- AnyParams,
499
+ dict[str, float],
500
+ Params[Array],
501
+ optax.OptState,
502
+ Params[Array],
465
503
  ]:
466
504
  """
467
505
  optimizer cannot be jit-ted.
468
506
  """
469
507
  value_grad_loss = jax.value_and_grad(loss, has_aux=True)
470
508
  (loss_val, loss_terms), grads = value_grad_loss(params, batch)
471
- updates, opt_state = optimizer.update(grads, opt_state, params)
472
- params = optax.apply_updates(params, updates)
509
+ updates, opt_state = optimizer.update(
510
+ grads,
511
+ opt_state,
512
+ params, # type: ignore
513
+ ) # see optimizer.init for explaination
514
+ params = optax.apply_updates(params, updates) # type: ignore
473
515
 
474
516
  # check if any of the parameters is NaN
475
517
  last_non_nan_params = jax.lax.cond(
@@ -490,7 +532,7 @@ def _gradient_step(
490
532
 
491
533
 
492
534
  @partial(jit, static_argnames=["prefix"])
493
- def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
535
+ def _print_fn(i: int, loss_val: Float, print_loss_every: int, prefix: str = ""):
494
536
  # note that if the following is not jitted in the main lor loop, it is
495
537
  # super slow
496
538
  _ = jax.lax.cond(
@@ -507,17 +549,15 @@ def _print_fn(i: Int, loss_val: Float, print_loss_every: Int, prefix: str = ""):
507
549
 
508
550
  @jit
509
551
  def _store_loss_and_params(
510
- i: Int,
511
- params: AnyParams,
512
- stored_params: AnyParams,
513
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]],
514
- train_loss_values: Float[Array, "n_iter"],
552
+ i: int,
553
+ params: Params[Array],
554
+ stored_params: Params[Array],
555
+ stored_loss_terms: dict[str, Float[Array, " n_iter"]],
556
+ train_loss_values: Float[Array, " n_iter"],
515
557
  train_loss_val: float,
516
- loss_terms: Dict[str, float],
517
- tracked_params: AnyParams,
518
- ) -> tuple[
519
- Params | ParamsDict, Dict[str, Float[Array, "n_iter"]], Float[Array, "n_iter"]
520
- ]:
558
+ loss_terms: dict[str, float],
559
+ tracked_params: Params,
560
+ ) -> tuple[Params, dict[str, Float[Array, " n_iter"]], Float[Array, " n_iter"]]:
521
561
  stored_params = jax.tree_util.tree_map(
522
562
  lambda stored_value, param, tracked_param: (
523
563
  None
@@ -544,7 +584,7 @@ def _store_loss_and_params(
544
584
  return (stored_params, stored_loss_terms, train_loss_values)
545
585
 
546
586
 
547
- def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
587
+ def _get_break_fun(n_iter: int, verbose: bool) -> Callable[[main_carry], bool]:
548
588
  """
549
589
  Wrapper to get the break_fun with appropriate `n_iter`.
550
590
  The verbose argument is here to control printing (or not) when exiting
@@ -585,7 +625,7 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
585
625
  bool_nan_in_params = jax.lax.cond(
586
626
  _check_nan_in_pytree(optimization.params),
587
627
  lambda _: stop_while_loop(
588
- "NaN values in parameters " "(returning last non NaN values)"
628
+ "NaN values in parameters (returning last non NaN values)"
589
629
  ),
590
630
  continue_while_loop,
591
631
  None,
@@ -608,18 +648,18 @@ def _get_break_fun(n_iter: Int, verbose: Bool) -> Callable[[main_carry], Bool]:
608
648
 
609
649
 
610
650
  def _get_get_batch(
611
- obs_batch_sharding: jax.sharding.Sharding,
651
+ obs_batch_sharding: jax.sharding.Sharding | None,
612
652
  ) -> Callable[
613
653
  [
614
- AnyDataGenerator,
654
+ AbstractDataGenerator,
615
655
  DataGeneratorParameter | None,
616
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
656
+ DataGeneratorObservations | None,
617
657
  ],
618
658
  tuple[
619
659
  AnyBatch,
620
- AnyDataGenerator,
660
+ AbstractDataGenerator,
621
661
  DataGeneratorParameter | None,
622
- DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None,
662
+ DataGeneratorObservations | None,
623
663
  ],
624
664
  ]:
625
665
  """
jinns/solver/_utils.py CHANGED
@@ -1,9 +1,7 @@
1
- from jinns.data._DataGenerators import (
2
- DataGeneratorODE,
3
- CubicMeshPDEStatio,
4
- CubicMeshPDENonStatio,
5
- DataGeneratorParameter,
6
- )
1
+ from jinns.data._DataGeneratorODE import DataGeneratorODE
2
+ from jinns.data._CubicMeshPDEStatio import CubicMeshPDEStatio
3
+ from jinns.data._CubicMeshPDENonStatio import CubicMeshPDENonStatio
4
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
7
5
 
8
6
 
9
7
  def _check_batch_size(other_data, main_data, attr_name):
jinns/utils/__init__.py CHANGED
@@ -1,6 +1,3 @@
1
- from ._pinn import create_PINN, PINN
2
- from ._ppinn import create_PPINN, PPINN
3
- from ._spinn import create_SPINN, SPINN
4
- from ._hyperpinn import create_HYPERPINN, HYPERPINN
5
- from ._save_load import save_pinn, load_pinn
6
1
  from ._utils import get_grid
2
+
3
+ __all__ = ["get_grid"]
@@ -11,23 +11,26 @@ from jaxtyping import PyTree, Array, Float, Bool
11
11
  from optax import OptState
12
12
  import equinox as eqx
13
13
 
14
+ from jinns.parameters._params import Params
15
+
14
16
  if TYPE_CHECKING:
15
- from jinns.utils._types import *
17
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
18
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
19
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
20
+ from jinns.utils._types import AnyLoss
16
21
 
17
22
 
18
23
  class DataGeneratorContainer(eqx.Module):
19
- data: AnyDataGenerator
24
+ data: AbstractDataGenerator
20
25
  param_data: DataGeneratorParameter | None = None
21
- obs_data: DataGeneratorObservations | DataGeneratorObservationsMultiPINNs | None = (
22
- None
23
- )
26
+ obs_data: DataGeneratorObservations | None = None
24
27
 
25
28
 
26
29
  class ValidationContainer(eqx.Module):
27
30
  loss: AnyLoss | None
28
31
  data: DataGeneratorContainer
29
32
  hyperparams: PyTree = None
30
- loss_values: Float[Array, "n_iter"] | None = None
33
+ loss_values: Float[Array, " n_iter"] | None = None
31
34
 
32
35
 
33
36
  class OptimizationContainer(eqx.Module):
@@ -45,9 +48,9 @@ class OptimizationExtraContainer(eqx.Module):
45
48
 
46
49
 
47
50
  class LossContainer(eqx.Module):
48
- stored_loss_terms: Dict[str, Float[Array, "n_iter"]]
49
- train_loss_values: Float[Array, "n_iter"]
51
+ stored_loss_terms: Dict[str, Float[Array, " n_iter"]]
52
+ train_loss_values: Float[Array, " n_iter"]
50
53
 
51
54
 
52
55
  class StoredObjectContainer(eqx.Module):
53
- stored_params: list | None
56
+ stored_params: Params[Array | None]
jinns/utils/_types.py CHANGED
@@ -1,65 +1,19 @@
1
- # pragma: exclude file
2
1
  from __future__ import (
3
2
  annotations,
4
3
  ) # https://docs.python.org/3/library/typing.html#constant
5
4
 
6
- from typing import TypeAlias, TYPE_CHECKING, NewType
7
- from jaxtyping import Int
5
+ from typing import TypeAlias, TYPE_CHECKING, Callable
6
+ from jaxtyping import Float, Array
8
7
 
9
8
  if TYPE_CHECKING:
10
- from jinns.loss._LossPDE import (
11
- LossPDEStatio,
12
- LossPDENonStatio,
13
- SystemLossPDE,
14
- )
9
+ from jinns.data._Batchs import ODEBatch, PDEStatioBatch, PDENonStatioBatch
10
+ from jinns.loss._LossODE import LossODE
11
+ from jinns.loss._LossPDE import LossPDEStatio, LossPDENonStatio
15
12
 
16
- from jinns.loss._LossODE import LossODE, SystemLossODE
17
- from jinns.parameters._params import Params, ParamsDict
18
- from jinns.data._DataGenerators import (
19
- DataGeneratorODE,
20
- CubicMeshPDEStatio,
21
- CubicMeshPDENonStatio,
22
- DataGeneratorObservations,
23
- DataGeneratorParameter,
24
- DataGeneratorObservationsMultiPINNs,
25
- )
13
+ # Here we define types available for the whole package
14
+ BoundaryConditionFun: TypeAlias = Callable[
15
+ [Float[Array, " dim"] | Float[Array, " dim + 1"]], Float[Array, " dim_solution"]
16
+ ]
26
17
 
27
- from jinns.loss import DynamicLoss
28
- from jinns.data._Batchs import *
29
- from jinns.utils._pinn import PINN
30
- from jinns.utils._hyperpinn import HYPERPINN
31
- from jinns.utils._spinn import SPINN
32
- from jinns.utils._containers import *
33
- from jinns.validation._validation import AbstractValidationModule
34
-
35
- AnyLoss: TypeAlias = (
36
- LossPDEStatio | LossPDENonStatio | SystemLossPDE | LossODE | SystemLossODE
37
- )
38
-
39
- AnyParams: TypeAlias = Params | ParamsDict
40
-
41
- AnyDataGenerator: TypeAlias = (
42
- DataGeneratorODE | CubicMeshPDEStatio | CubicMeshPDENonStatio
43
- )
44
-
45
- AnyPINN: TypeAlias = PINN | HYPERPINN | SPINN
46
-
47
- AnyBatch: TypeAlias = ODEBatch | PDEStatioBatch | PDENonStatioBatch
48
- rar_operands = NewType(
49
- "rar_operands", tuple[AnyLoss, AnyParams, AnyDataGenerator, Int]
50
- )
51
-
52
- main_carry = NewType(
53
- "main_carry",
54
- tuple[
55
- Int,
56
- AnyLoss,
57
- OptimizationContainer,
58
- OptimizationExtraContainer,
59
- DataGeneratorContainer,
60
- AbstractValidationModule,
61
- LossContainer,
62
- StoredObjectContainer,
63
- Float[Array, "n_iter"],
64
- ],
65
- )
18
+ AnyBatch: TypeAlias = ODEBatch | PDENonStatioBatch | PDEStatioBatch
19
+ AnyLoss: TypeAlias = LossODE | LossPDEStatio | LossPDENonStatio
jinns/utils/_utils.py CHANGED
@@ -2,20 +2,13 @@
2
2
  Implements various utility functions
3
3
  """
4
4
 
5
- from math import prod
6
5
  import warnings
7
6
  import jax
8
7
  import jax.numpy as jnp
9
- from jaxtyping import PyTree, Array
8
+ from jaxtyping import PyTree, Array, Bool
10
9
 
11
- from jinns.data._DataGenerators import (
12
- DataGeneratorODE,
13
- CubicMeshPDEStatio,
14
- CubicMeshPDENonStatio,
15
- )
16
10
 
17
-
18
- def _check_nan_in_pytree(pytree: PyTree) -> bool:
11
+ def _check_nan_in_pytree(pytree: PyTree) -> Bool[Array, " "]:
19
12
  """
20
13
  Check if there is a NaN value anywhere is the pytree
21
14
 
@@ -55,7 +48,7 @@ def get_grid(in_array: Array) -> Array:
55
48
 
56
49
 
57
50
  def _check_shape_and_type(
58
- r: Array | int, expected_shape: tuple, cause: str = "", binop: str = ""
51
+ r: Array | int | float, expected_shape: tuple, cause: str = "", binop: str = ""
59
52
  ) -> Array | float:
60
53
  """
61
54
  Ensures float type and correct shapes for broadcasting when performing a
@@ -90,7 +83,7 @@ def _check_shape_and_type(
90
83
 
91
84
 
92
85
  def _subtract_with_check(
93
- a: Array | int, b: Array | int, cause: str = ""
86
+ a: Array | int | float, b: Array, cause: str = ""
94
87
  ) -> Array | float:
95
88
  a = _check_shape_and_type(a, b.shape, cause=cause, binop="-")
96
89
  return a - b
@@ -1 +1,3 @@
1
1
  from ._validation import AbstractValidationModule, ValidationLoss
2
+
3
+ __all__ = ["AbstractValidationModule", "ValidationLoss"]
@@ -7,19 +7,23 @@ from __future__ import (
7
7
  ) # https://docs.python.org/3/library/typing.html#constant
8
8
 
9
9
  import abc
10
- from typing import TYPE_CHECKING, Union
10
+ from typing import TYPE_CHECKING
11
11
  import equinox as eqx
12
12
  import jax
13
13
  import jax.numpy as jnp
14
- from jaxtyping import Array
14
+ from jaxtyping import Array, Float
15
15
 
16
- from jinns.data._DataGenerators import (
16
+ from jinns.data._utils import (
17
17
  append_obs_batch,
18
18
  append_param_batch,
19
19
  )
20
20
 
21
21
  if TYPE_CHECKING:
22
- from jinns.utils._types import *
22
+ from jinns.data._DataGeneratorParameter import DataGeneratorParameter
23
+ from jinns.data._DataGeneratorObservations import DataGeneratorObservations
24
+ from jinns.data._AbstractDataGenerator import AbstractDataGenerator
25
+ from jinns.parameters._params import Params
26
+ from jinns.loss._abstract_loss import AbstractLoss
23
27
 
24
28
  # Using eqx Module for the DataClass + Pytree inheritance
25
29
  # Abstract class and abstract/final pattern is used
@@ -40,8 +44,8 @@ class AbstractValidationModule(eqx.Module):
40
44
 
41
45
  @abc.abstractmethod
42
46
  def __call__(
43
- self, params: Params | ParamsDict
44
- ) -> tuple["AbstractValidationModule", bool, Array, bool]:
47
+ self, params: Params[Array]
48
+ ) -> tuple[AbstractValidationModule, bool, Array, Params[Array]]:
45
49
  raise NotImplementedError
46
50
 
47
51
 
@@ -52,24 +56,20 @@ class ValidationLoss(AbstractValidationModule):
52
56
  for more complicated validation strategy.
53
57
  """
54
58
 
55
- loss: AnyLoss = eqx.field(kw_only=True) # NOTE that
56
- # there used to be a deepcopy here which has been suppressed. 1) No need
57
- # because loss are now eqx.Module (immutable) so no risk of in-place
58
- # modification. 2) deepcopy is buggy with equinox, InitVar etc. (see issue
59
- # #857 on equinox github)
60
- validation_data: Union[AnyDataGenerator] = eqx.field(kw_only=True)
61
- validation_param_data: Union[DataGeneratorParameter, None] = eqx.field(
59
+ loss: AbstractLoss = eqx.field(kw_only=True)
60
+ validation_data: AbstractDataGenerator = eqx.field(kw_only=True)
61
+ validation_param_data: DataGeneratorParameter = eqx.field(
62
+ kw_only=True, default=None
63
+ )
64
+ validation_obs_data: DataGeneratorObservations | None = eqx.field(
62
65
  kw_only=True, default=None
63
66
  )
64
- validation_obs_data: Union[
65
- DataGeneratorObservations, DataGeneratorObservationsMultiPINNs, None
66
- ] = eqx.field(kw_only=True, default=None)
67
67
  call_every: int = eqx.field(kw_only=True, default=250) # concrete typing
68
68
  early_stopping: bool = eqx.field(
69
69
  kw_only=True, default=True
70
70
  ) # globally control if early stopping happens
71
71
 
72
- patience: Union[int] = eqx.field(kw_only=True, default=10)
72
+ patience: int = eqx.field(kw_only=True, default=10)
73
73
  best_val_loss: Array = eqx.field(
74
74
  converter=jnp.asarray, default_factory=lambda: jnp.array(jnp.inf), kw_only=True
75
75
  )
@@ -79,10 +79,11 @@ class ValidationLoss(AbstractValidationModule):
79
79
  )
80
80
 
81
81
  def __call__(
82
- self, params: AnyParams
83
- ) -> tuple["ValidationLoss", bool, float, AnyParams]:
82
+ self, params: Params[Array]
83
+ ) -> tuple[ValidationLoss, bool, Float[Array, " "], Params[Array]]:
84
84
  # do in-place mutation
85
85
 
86
+ # pylint / pyright complains below when using the self attributes see: https://github.com/patrick-kidger/equinox/issues/1013
86
87
  validation_data, val_batch = self.validation_data.get_batch()
87
88
  if self.validation_param_data is not None:
88
89
  validation_param_data, param_batch = self.validation_param_data.get_batch()
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: jinns
3
- Version: 1.2.0
3
+ Version: 1.4.0
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -25,6 +25,7 @@ Requires-Dist: matplotlib
25
25
  Provides-Extra: notebook
26
26
  Requires-Dist: jupyter; extra == "notebook"
27
27
  Requires-Dist: seaborn; extra == "notebook"
28
+ Dynamic: license-file
28
29
 
29
30
  jinns
30
31
  =====
@@ -99,7 +100,7 @@ Here are the contributors guidelines:
99
100
  pip install -e .
100
101
  ```
101
102
 
102
- 3. Install pre-commit and run it.
103
+ 3. Install pre-commit and run it. Our pre-commit hooks consist in `ruff format` and `ruff check`. You can install `ruff` simply by `pip install ruff`. We highly recommend you to check the code type hints with `pyright` even though we currently have no rule concerning type checking in the pipeline.
103
104
 
104
105
  ```bash
105
106
  pip install pre-commit
@@ -112,16 +113,16 @@ pre-commit install
112
113
 
113
114
  Don't hesitate to contribute and get your name on the list here !
114
115
 
115
- **List of contributors:** Hugo Gangloff, Nicolas Jouvin
116
+ **List of contributors:** Hugo Gangloff, Nicolas Jouvin, Lucia Clarotto, Inass Soukarieh
116
117
 
117
118
  # Cite us
118
119
 
119
- Please consider citing our work if you found it useful to yours, using the following lines
120
+ Please consider citing our work if you found it useful to yours, using this [ArXiV preprint](https://arxiv.org/abs/2412.14132)
120
121
  ```
121
- @software{jinns2024,
122
- title={\texttt{jinns}: Physics-Informed Neural Networks with JAX},
123
- author={Gangloff, Hugo and Jouvin, Nicolas},
124
- url={https://gitlab.com/mia_jinns},
125
- year={2024}
122
+ @article{gangloff_jouvin2024jinns,
123
+ title={jinns: a JAX Library for Physics-Informed Neural Networks},
124
+ author={Gangloff, Hugo and Jouvin, Nicolas},
125
+ journal={arXiv preprint arXiv:2412.14132},
126
+ year={2024}
126
127
  }
127
128
  ```