jinns 0.8.1__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {jinns-0.8.1 → jinns-0.8.2}/PKG-INFO +1 -1
  2. {jinns-0.8.1 → jinns-0.8.2}/doc/source/index.rst +4 -0
  3. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_hyperpinn.py +2 -2
  4. {jinns-0.8.1 → jinns-0.8.2}/jinns.egg-info/PKG-INFO +1 -1
  5. {jinns-0.8.1 → jinns-0.8.2}/tests/save_load_tests/test_saving_loading_hyperpinn.py +23 -2
  6. {jinns-0.8.1 → jinns-0.8.2}/tests/save_load_tests/test_saving_loading_pinn.py +21 -2
  7. {jinns-0.8.1 → jinns-0.8.2}/tests/save_load_tests/test_saving_loading_spinn.py +20 -2
  8. {jinns-0.8.1 → jinns-0.8.2}/.gitignore +0 -0
  9. {jinns-0.8.1 → jinns-0.8.2}/.gitlab-ci.yml +0 -0
  10. {jinns-0.8.1 → jinns-0.8.2}/.pre-commit-config.yaml +0 -0
  11. {jinns-0.8.1 → jinns-0.8.2}/LICENSE +0 -0
  12. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
  13. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
  14. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
  15. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
  16. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
  17. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
  18. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
  19. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
  20. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
  21. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
  22. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
  23. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
  24. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
  25. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  26. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  27. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
  28. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
  29. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/burger_solution_grid.npy +0 -0
  30. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
  31. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
  32. {jinns-0.8.1 → jinns-0.8.2}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
  33. {jinns-0.8.1 → jinns-0.8.2}/README.md +0 -0
  34. {jinns-0.8.1 → jinns-0.8.2}/doc/Makefile +0 -0
  35. {jinns-0.8.1 → jinns-0.8.2}/doc/source/boundary_conditions.rst +0 -0
  36. {jinns-0.8.1 → jinns-0.8.2}/doc/source/conf.py +0 -0
  37. {jinns-0.8.1 → jinns-0.8.2}/doc/source/data.rst +0 -0
  38. {jinns-0.8.1 → jinns-0.8.2}/doc/source/dynamic_loss.rst +0 -0
  39. {jinns-0.8.1 → jinns-0.8.2}/doc/source/experimental.rst +0 -0
  40. {jinns-0.8.1 → jinns-0.8.2}/doc/source/fokker_planck.qmd +0 -0
  41. {jinns-0.8.1 → jinns-0.8.2}/doc/source/loss.rst +0 -0
  42. {jinns-0.8.1 → jinns-0.8.2}/doc/source/loss_ode.rst +0 -0
  43. {jinns-0.8.1 → jinns-0.8.2}/doc/source/loss_pde.rst +0 -0
  44. {jinns-0.8.1 → jinns-0.8.2}/doc/source/losses.rst +0 -0
  45. {jinns-0.8.1 → jinns-0.8.2}/doc/source/math_pinn.qmd +0 -0
  46. {jinns-0.8.1 → jinns-0.8.2}/doc/source/operators.rst +0 -0
  47. {jinns-0.8.1 → jinns-0.8.2}/doc/source/param_estim_pinn.qmd +0 -0
  48. {jinns-0.8.1 → jinns-0.8.2}/doc/source/rar.rst +0 -0
  49. {jinns-0.8.1 → jinns-0.8.2}/doc/source/seq2seq.rst +0 -0
  50. {jinns-0.8.1 → jinns-0.8.2}/doc/source/solve.rst +0 -0
  51. {jinns-0.8.1 → jinns-0.8.2}/doc/source/solver.rst +0 -0
  52. {jinns-0.8.1 → jinns-0.8.2}/doc/source/utils.rst +0 -0
  53. {jinns-0.8.1 → jinns-0.8.2}/jinns/__init__.py +0 -0
  54. {jinns-0.8.1 → jinns-0.8.2}/jinns/data/_DataGenerators.py +0 -0
  55. {jinns-0.8.1 → jinns-0.8.2}/jinns/data/__init__.py +0 -0
  56. {jinns-0.8.1 → jinns-0.8.2}/jinns/data/_display.py +0 -0
  57. {jinns-0.8.1 → jinns-0.8.2}/jinns/experimental/__init__.py +0 -0
  58. {jinns-0.8.1 → jinns-0.8.2}/jinns/experimental/_diffrax_solver.py +0 -0
  59. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_DynamicLoss.py +0 -0
  60. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_DynamicLossAbstract.py +0 -0
  61. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_LossODE.py +0 -0
  62. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_LossPDE.py +0 -0
  63. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_Losses.py +0 -0
  64. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/__init__.py +0 -0
  65. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_boundary_conditions.py +0 -0
  66. {jinns-0.8.1 → jinns-0.8.2}/jinns/loss/_operators.py +0 -0
  67. {jinns-0.8.1 → jinns-0.8.2}/jinns/solver/__init__.py +0 -0
  68. {jinns-0.8.1 → jinns-0.8.2}/jinns/solver/_rar.py +0 -0
  69. {jinns-0.8.1 → jinns-0.8.2}/jinns/solver/_seq2seq.py +0 -0
  70. {jinns-0.8.1 → jinns-0.8.2}/jinns/solver/_solve.py +0 -0
  71. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/__init__.py +0 -0
  72. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_optim.py +0 -0
  73. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_pinn.py +0 -0
  74. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_save_load.py +0 -0
  75. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_spinn.py +0 -0
  76. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_utils.py +0 -0
  77. {jinns-0.8.1 → jinns-0.8.2}/jinns/utils/_utils_uspinn.py +0 -0
  78. {jinns-0.8.1 → jinns-0.8.2}/jinns.egg-info/SOURCES.txt +0 -0
  79. {jinns-0.8.1 → jinns-0.8.2}/jinns.egg-info/dependency_links.txt +0 -0
  80. {jinns-0.8.1 → jinns-0.8.2}/jinns.egg-info/requires.txt +0 -0
  81. {jinns-0.8.1 → jinns-0.8.2}/jinns.egg-info/top_level.txt +0 -0
  82. {jinns-0.8.1 → jinns-0.8.2}/pyproject.toml +0 -0
  83. {jinns-0.8.1 → jinns-0.8.2}/setup.cfg +0 -0
  84. {jinns-0.8.1 → jinns-0.8.2}/tests/conftest.py +0 -0
  85. {jinns-0.8.1 → jinns-0.8.2}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  86. {jinns-0.8.1 → jinns-0.8.2}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  87. {jinns-0.8.1 → jinns-0.8.2}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  88. {jinns-0.8.1 → jinns-0.8.2}/tests/runtests.sh +0 -0
  89. {jinns-0.8.1 → jinns-0.8.2}/tests/sharding_tests/test_Burger_x32_multiple_shardings.py +0 -0
  90. {jinns-0.8.1 → jinns-0.8.2}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py +0 -0
  91. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_Burger_x32.py +0 -0
  92. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_Burger_x64.py +0 -0
  93. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_Fisher_x32.py +0 -0
  94. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_Fisher_x64.py +0 -0
  95. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_GLV_x32.py +0 -0
  96. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_GLV_x64.py +0 -0
  97. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
  98. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
  99. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_OU2D_x32.py +0 -0
  100. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests/test_imperfect_sobolev_x32.py +0 -0
  101. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests_spinn/test_Burger_x32.py +0 -0
  102. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests_spinn/test_Fisher_x32.py +0 -0
  103. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
  104. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests_spinn/test_OU2D_x32.py +0 -0
  105. {jinns-0.8.1 → jinns-0.8.2}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.1
3
+ Version: 0.8.2
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -8,6 +8,10 @@ Welcome to jinn's documentation!
8
8
 
9
9
  Changelog:
10
10
 
11
+ * v0.8.2:
12
+
13
+ - Fix a bug: it was not possible to jit a reloaded HyperPINN model
14
+
11
15
  * v0.8.1:
12
16
 
13
17
  - New feature: `save_pinn` and `load_pinn` in `jinns.utils` for pre-trained
@@ -39,8 +39,8 @@ class HYPERPINN(PINN):
39
39
  static_hyper: eqx.Module
40
40
  hyperparams: list = eqx.field(static=True)
41
41
  hypernet_input_size: int
42
- pinn_params_sum: ArrayLike
43
- pinn_params_cumsum: ArrayLike
42
+ pinn_params_sum: ArrayLike = eqx.field(static=True)
43
+ pinn_params_cumsum: ArrayLike = eqx.field(static=True)
44
44
 
45
45
  def __init__(
46
46
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.1
3
+ Version: 0.8.2
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
 
@@ -73,7 +73,7 @@ def save_reload():
73
73
  }
74
74
 
75
75
  # Save
76
- filename = "./test"
76
+ filename = str(tmpdir.join("test"))
77
77
  kwargs_creation = {
78
78
  "key": subkey,
79
79
  "eqx_list": eqx_list,
@@ -107,3 +107,24 @@ def test_equality_save_reload(save_reload):
107
107
  v_u_reloaded(test_points[:, 0:1], test_points[:, 1:3], params_reloaded),
108
108
  atol=1e-3,
109
109
  )
110
+
111
+
112
+ def test_jitting_reloaded_hyperpinn(save_reload):
113
+ """
114
+ This test ensures that the reloaded hyperpinn is jit-able.
115
+ Some conversion of onp.array nodes can arise when reloading.
116
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
117
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
118
+ This tests is here for testimony.
119
+ """
120
+
121
+ key, _, _, params_reloaded, u_reloaded = save_reload
122
+
123
+ v_u_reloaded = jax.vmap(
124
+ u_reloaded, (0, 0, {"nn_params": None, "eq_params": {"D": 0, "r": 0}})
125
+ )
126
+ v_u_reloaded_jitted = jax.jit(v_u_reloaded)
127
+
128
+ key, subkey = jax.random.split(key, 2)
129
+ test_points = jax.random.normal(subkey, shape=(10, 5))
130
+ v_u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:3], params_reloaded)
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
  eqx_list = [
@@ -28,7 +28,7 @@ def save_reload():
28
28
  params = {"nn_params": params, "eq_params": {}}
29
29
 
30
30
  # Save
31
- filename = "./test"
31
+ filename = str(tmpdir.join("test"))
32
32
  kwargs_creation = {
33
33
  "key": subkey,
34
34
  "eqx_list": eqx_list,
@@ -57,3 +57,22 @@ def test_equality_save_reload(save_reload):
57
57
  v_u_reloaded(test_points[:, 0:1], test_points[:, 1:], params_reloaded),
58
58
  atol=1e-3,
59
59
  )
60
+
61
+
62
+ def test_jitting_reloaded_pinn(save_reload):
63
+ """
64
+ This test ensures that the reloaded pinn is jit-able.
65
+ Some conversion of onp.array nodes can arise when reloading.
66
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
67
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
68
+ This tests is here for testimony.
69
+ """
70
+
71
+ key, _, _, params_reloaded, u_reloaded = save_reload
72
+
73
+ key, subkey = jax.random.split(key, 2)
74
+ test_points = jax.random.normal(subkey, shape=(10, 2))
75
+ v_u_reloaded = jax.vmap(u_reloaded, (0, 0, None))
76
+ v_u_reloaded_jitted = jax.jit(v_u_reloaded)
77
+
78
+ v_u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:3], params_reloaded)
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
  d = 2
@@ -30,7 +30,7 @@ def save_reload():
30
30
  params = {"nn_params": params, "eq_params": {}}
31
31
 
32
32
  # Save
33
- filename = "./test"
33
+ filename = str(tmpdir.join("test"))
34
34
  kwargs_creation = {
35
35
  "key": subkey,
36
36
  "d": d,
@@ -58,3 +58,21 @@ def test_equality_save_reload(save_reload):
58
58
  u_reloaded(test_points[:, 0:1], test_points[:, 1:], params_reloaded),
59
59
  atol=1e-3,
60
60
  )
61
+
62
+
63
+ def test_jitting_reloaded_spinn(save_reload):
64
+ """
65
+ This test ensures that the reloaded spinn is jit-able.
66
+ Some conversion of onp.array nodes can arise when reloading.
67
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
68
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
69
+ This tests is here for testimony.
70
+ """
71
+
72
+ key, _, _, params_reloaded, u_reloaded = save_reload
73
+
74
+ key, subkey = jax.random.split(key, 2)
75
+ test_points = jax.random.normal(subkey, shape=(10, 2))
76
+
77
+ u_reloaded_jitted = jax.jit(u_reloaded.__call__)
78
+ u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:], params_reloaded)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes