jinns 0.4.1__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {jinns-0.4.1 → jinns-0.4.2}/PKG-INFO +1 -1
  2. {jinns-0.4.1 → jinns-0.4.2}/doc/source/index.rst +4 -0
  3. {jinns-0.4.1 → jinns-0.4.2}/jinns/solver/_solve.py +2 -3
  4. {jinns-0.4.1 → jinns-0.4.2}/jinns.egg-info/PKG-INFO +1 -1
  5. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_Burger_x32.py +1 -2
  6. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_Burger_x64.py +1 -2
  7. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_Fisher_x32.py +1 -2
  8. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_Fisher_x64.py +1 -2
  9. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_GLV_x32.py +1 -2
  10. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_GLV_x64.py +1 -2
  11. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_NSPipeFlow_x32.py +1 -2
  12. {jinns-0.4.1 → jinns-0.4.2}/tests/solver_tests/test_NSPipeFlow_x64.py +1 -2
  13. {jinns-0.4.1 → jinns-0.4.2}/.gitignore +0 -0
  14. {jinns-0.4.1 → jinns-0.4.2}/.gitlab-ci.yml +0 -0
  15. {jinns-0.4.1 → jinns-0.4.2}/.pre-commit-config.yaml +0 -0
  16. {jinns-0.4.1 → jinns-0.4.2}/LICENSE +0 -0
  17. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
  18. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
  19. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
  20. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
  21. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/1D_non_stationary_Burger_JointEstimation_Vanilla.ipynb +0 -0
  22. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
  23. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/1D_non_stationary_OU.ipynb +0 -0
  24. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
  25. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel.ipynb +0 -0
  26. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
  27. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/2D_non_stationary_OU_RAR.ipynb +0 -0
  28. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  29. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  30. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
  31. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/burger_solution_grid.npy +0 -0
  32. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
  33. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/Tutorials/implementing_your_own_ODE_problem.ipynb +0 -0
  34. {jinns-0.4.1 → jinns-0.4.2}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
  35. {jinns-0.4.1 → jinns-0.4.2}/README.md +0 -0
  36. {jinns-0.4.1 → jinns-0.4.2}/doc/Makefile +0 -0
  37. {jinns-0.4.1 → jinns-0.4.2}/doc/source/PinnSolver.rst +0 -0
  38. {jinns-0.4.1 → jinns-0.4.2}/doc/source/boundary_conditions.rst +0 -0
  39. {jinns-0.4.1 → jinns-0.4.2}/doc/source/conf.py +0 -0
  40. {jinns-0.4.1 → jinns-0.4.2}/doc/source/data.rst +0 -0
  41. {jinns-0.4.1 → jinns-0.4.2}/doc/source/dynamic_loss.rst +0 -0
  42. {jinns-0.4.1 → jinns-0.4.2}/doc/source/fokker_planck.qmd +0 -0
  43. {jinns-0.4.1 → jinns-0.4.2}/doc/source/loss.rst +0 -0
  44. {jinns-0.4.1 → jinns-0.4.2}/doc/source/loss_ode.rst +0 -0
  45. {jinns-0.4.1 → jinns-0.4.2}/doc/source/loss_pde.rst +0 -0
  46. {jinns-0.4.1 → jinns-0.4.2}/doc/source/math_pinn.qmd +0 -0
  47. {jinns-0.4.1 → jinns-0.4.2}/doc/source/operators.rst +0 -0
  48. {jinns-0.4.1 → jinns-0.4.2}/doc/source/param_estim_pinn.qmd +0 -0
  49. {jinns-0.4.1 → jinns-0.4.2}/doc/source/rar.rst +0 -0
  50. {jinns-0.4.1 → jinns-0.4.2}/doc/source/seq2seq.rst +0 -0
  51. {jinns-0.4.1 → jinns-0.4.2}/doc/source/solver.rst +0 -0
  52. {jinns-0.4.1 → jinns-0.4.2}/doc/source/utils.rst +0 -0
  53. {jinns-0.4.1 → jinns-0.4.2}/jinns/__init__.py +0 -0
  54. {jinns-0.4.1 → jinns-0.4.2}/jinns/data/_DataGenerators.py +0 -0
  55. {jinns-0.4.1 → jinns-0.4.2}/jinns/data/__init__.py +0 -0
  56. {jinns-0.4.1 → jinns-0.4.2}/jinns/data/_display.py +0 -0
  57. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_DynamicLoss.py +0 -0
  58. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_DynamicLossAbstract.py +0 -0
  59. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_LossODE.py +0 -0
  60. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_LossPDE.py +0 -0
  61. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/__init__.py +0 -0
  62. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_boundary_conditions.py +0 -0
  63. {jinns-0.4.1 → jinns-0.4.2}/jinns/loss/_operators.py +0 -0
  64. {jinns-0.4.1 → jinns-0.4.2}/jinns/solver/__init__.py +0 -0
  65. {jinns-0.4.1 → jinns-0.4.2}/jinns/solver/_rar.py +0 -0
  66. {jinns-0.4.1 → jinns-0.4.2}/jinns/solver/_seq2seq.py +0 -0
  67. {jinns-0.4.1 → jinns-0.4.2}/jinns/utils/__init__.py +0 -0
  68. {jinns-0.4.1 → jinns-0.4.2}/jinns/utils/_utils.py +0 -0
  69. {jinns-0.4.1 → jinns-0.4.2}/jinns.egg-info/SOURCES.txt +0 -0
  70. {jinns-0.4.1 → jinns-0.4.2}/jinns.egg-info/dependency_links.txt +0 -0
  71. {jinns-0.4.1 → jinns-0.4.2}/jinns.egg-info/requires.txt +0 -0
  72. {jinns-0.4.1 → jinns-0.4.2}/jinns.egg-info/top_level.txt +0 -0
  73. {jinns-0.4.1 → jinns-0.4.2}/pyproject.toml +0 -0
  74. {jinns-0.4.1 → jinns-0.4.2}/setup.cfg +0 -0
  75. {jinns-0.4.1 → jinns-0.4.2}/tests/conftest.py +0 -0
  76. {jinns-0.4.1 → jinns-0.4.2}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  77. {jinns-0.4.1 → jinns-0.4.2}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  78. {jinns-0.4.1 → jinns-0.4.2}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  79. {jinns-0.4.1 → jinns-0.4.2}/tests/runtests.sh +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -8,6 +8,10 @@ Welcome to jinn's documentation!
8
8
 
9
9
  Changelog:
10
10
 
11
+ * v0.4.2:
12
+
13
+ - Critical bug correction concerning the manipulation of the optimizer's `opt_state` which caused weird failures in the optimization process
14
+
11
15
  * v0.4.1:
12
16
 
13
17
  - Generalize heterogeneity for the equation parameters. It can now be an arbitrary function provided by the user and thus depend on covariables. Update the corresponding notebook.
@@ -196,7 +196,7 @@ def solve(
196
196
  if carry["param_data"] is not None:
197
197
  batch = append_param_batch(batch, carry["param_data"].get_batch())
198
198
  carry["params"], carry["opt_state"] = optimizer.update(
199
- params=carry["params"], state=carry["state"], batch=batch
199
+ params=carry["params"], state=carry["opt_state"], batch=batch
200
200
  )
201
201
 
202
202
  # check if any of the parameters is NaN
@@ -260,7 +260,6 @@ def solve(
260
260
  {
261
261
  "params": init_params,
262
262
  "last_non_nan_params": init_params.copy(),
263
- "state": opt_state,
264
263
  "data": data,
265
264
  "curr_seq": curr_seq,
266
265
  "seq2seq": seq2seq,
@@ -281,7 +280,7 @@ def solve(
281
280
 
282
281
  params = res["params"]
283
282
  last_non_nan_params = res["last_non_nan_params"]
284
- opt_state = res["state"]
283
+ opt_state = res["opt_state"]
285
284
  data = res["data"]
286
285
  loss = res["loss"]
287
286
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_Burger_init():
13
13
  jax.config.update("jax_enable_x64", False)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
  eqx_list = [
@@ -111,4 +110,4 @@ def test_initial_loss_Burger(train_Burger_init):
111
110
 
112
111
  def test_10it_Burger(train_Burger_10it):
113
112
  total_loss_val = train_Burger_10it
114
- assert jnp.round(total_loss_val, 5) == jnp.round(2.0262, 5)
113
+ assert jnp.round(total_loss_val, 5) == jnp.round(2.04182, 5)
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_Burger_init():
13
13
  jax.config.update("jax_enable_x64", True)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
  eqx_list = [
@@ -110,4 +109,4 @@ def test_initial_loss_Burger(train_Burger_init):
110
109
 
111
110
  def test_10it_Burger(train_Burger_10it):
112
111
  total_loss_val = train_Burger_10it
113
- assert jnp.round(total_loss_val, 5) == jnp.round(2.77797, 5)
112
+ assert jnp.round(total_loss_val, 5) == jnp.round(2.80653, 5)
@@ -12,7 +12,6 @@ import jinns
12
12
  @pytest.fixture
13
13
  def train_Fisher_init():
14
14
  jax.config.update("jax_enable_x64", False)
15
- print(jax.config.FLAGS.jax_enable_x64)
16
15
  print(jax.devices())
17
16
  key = random.PRNGKey(2)
18
17
  eqx_list = [
@@ -142,4 +141,4 @@ def test_initial_loss_Fisher(train_Fisher_init):
142
141
 
143
142
  def test_10it_Fisher(train_Fisher_10it):
144
143
  total_loss_val = train_Fisher_10it
145
- assert jnp.round(total_loss_val, 5) == jnp.round(10.7401, 5)
144
+ assert jnp.round(total_loss_val, 5) == jnp.round(10.79058, 5)
@@ -12,7 +12,6 @@ import jinns
12
12
  @pytest.fixture
13
13
  def train_Fisher_init():
14
14
  jax.config.update("jax_enable_x64", True)
15
- print(jax.config.FLAGS.jax_enable_x64)
16
15
  print(jax.devices())
17
16
  key = random.PRNGKey(2)
18
17
  eqx_list = [
@@ -142,4 +141,4 @@ def test_initial_loss_Fisher(train_Fisher_init):
142
141
 
143
142
  def test_10it_Fisher(train_Fisher_10it):
144
143
  total_loss_val = train_Fisher_10it
145
- assert jnp.round(total_loss_val, 5) == jnp.round(10.87023, 5)
144
+ assert jnp.round(total_loss_val, 5) == jnp.round(10.88394, 5)
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_GLV_init():
13
13
  jax.config.update("jax_enable_x64", False)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
  key, subkey = random.split(key)
@@ -127,4 +126,4 @@ def test_initial_loss_GLV(train_GLV_init):
127
126
 
128
127
  def test_10it_GLV(train_GLV_10it):
129
128
  total_loss_val = train_GLV_10it
130
- assert jnp.round(total_loss_val, 5) == jnp.round(4317.0625, 5)
129
+ assert jnp.round(total_loss_val, 5) == jnp.round(4318.117, 5)
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_GLV_init():
13
13
  jax.config.update("jax_enable_x64", True)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
  key, subkey = random.split(key)
@@ -123,4 +122,4 @@ def test_initial_loss_GLV(train_GLV_init):
123
122
 
124
123
  def test_10it_GLV(train_GLV_10it):
125
124
  total_loss_val = train_GLV_10it
126
- assert jnp.round(total_loss_val, 5) == jnp.round(3819.72582, 5)
125
+ assert jnp.round(total_loss_val, 5) == jnp.round(3819.15887, 5)
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_NSPipeFlow_init():
13
13
  jax.config.update("jax_enable_x64", False)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
 
@@ -138,4 +137,4 @@ def test_initial_loss_NSPipeFlow(train_NSPipeFlow_init):
138
137
 
139
138
  def test_10it_NSPipeFlow(train_NSPipeFlow_10it):
140
139
  total_loss_val = train_NSPipeFlow_10it
141
- assert jnp.round(total_loss_val, 5) == jnp.round(0.00531, 5)
140
+ assert jnp.round(total_loss_val, 5) == jnp.round(0.00534, 5)
@@ -11,7 +11,6 @@ import jinns
11
11
  @pytest.fixture
12
12
  def train_NSPipeFlow_init():
13
13
  jax.config.update("jax_enable_x64", True)
14
- print(jax.config.FLAGS.jax_enable_x64)
15
14
  print(jax.devices())
16
15
  key = random.PRNGKey(2)
17
16
 
@@ -138,4 +137,4 @@ def test_initial_loss_NSPipeFlow(train_NSPipeFlow_init):
138
137
 
139
138
  def test_10it_NSPipeFlow(train_NSPipeFlow_10it):
140
139
  total_loss_val = train_NSPipeFlow_10it
141
- assert jnp.round(total_loss_val, 5) == jnp.round(0.00495, 5)
140
+ assert jnp.round(total_loss_val, 5) == jnp.round(0.00504, 5)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes