jinns 0.8.5__tar.gz → 0.8.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {jinns-0.8.5 → jinns-0.8.6}/PKG-INFO +1 -1
  2. {jinns-0.8.5 → jinns-0.8.6}/doc/source/index.rst +5 -1
  3. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_hyperpinn.py +1 -13
  4. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_pinn.py +1 -13
  5. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_spinn.py +1 -7
  6. {jinns-0.8.5 → jinns-0.8.6}/jinns.egg-info/PKG-INFO +1 -1
  7. {jinns-0.8.5 → jinns-0.8.6}/jinns.egg-info/SOURCES.txt +4 -1
  8. {jinns-0.8.5 → jinns-0.8.6}/tests/runtests.sh +4 -0
  9. jinns-0.8.6/tests/utils_tests/test_hyperpinns.py +138 -0
  10. jinns-0.8.6/tests/utils_tests/test_pinn.py +131 -0
  11. jinns-0.8.6/tests/utils_tests/test_spinn.py +110 -0
  12. {jinns-0.8.5 → jinns-0.8.6}/.gitignore +0 -0
  13. {jinns-0.8.5 → jinns-0.8.6}/.gitlab-ci.yml +0 -0
  14. {jinns-0.8.5 → jinns-0.8.6}/.pre-commit-config.yaml +0 -0
  15. {jinns-0.8.5 → jinns-0.8.6}/LICENSE +0 -0
  16. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
  17. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
  18. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
  19. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
  20. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
  21. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
  22. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
  23. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
  24. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
  25. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
  26. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
  27. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
  28. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
  29. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  30. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  31. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
  32. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
  33. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/burger_solution_grid.npy +0 -0
  34. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
  35. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
  36. {jinns-0.8.5 → jinns-0.8.6}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
  37. {jinns-0.8.5 → jinns-0.8.6}/README.md +0 -0
  38. {jinns-0.8.5 → jinns-0.8.6}/doc/Makefile +0 -0
  39. {jinns-0.8.5 → jinns-0.8.6}/doc/source/boundary_conditions.rst +0 -0
  40. {jinns-0.8.5 → jinns-0.8.6}/doc/source/conf.py +0 -0
  41. {jinns-0.8.5 → jinns-0.8.6}/doc/source/data.rst +0 -0
  42. {jinns-0.8.5 → jinns-0.8.6}/doc/source/dynamic_loss.rst +0 -0
  43. {jinns-0.8.5 → jinns-0.8.6}/doc/source/experimental.rst +0 -0
  44. {jinns-0.8.5 → jinns-0.8.6}/doc/source/fokker_planck.qmd +0 -0
  45. {jinns-0.8.5 → jinns-0.8.6}/doc/source/loss.rst +0 -0
  46. {jinns-0.8.5 → jinns-0.8.6}/doc/source/loss_ode.rst +0 -0
  47. {jinns-0.8.5 → jinns-0.8.6}/doc/source/loss_pde.rst +0 -0
  48. {jinns-0.8.5 → jinns-0.8.6}/doc/source/losses.rst +0 -0
  49. {jinns-0.8.5 → jinns-0.8.6}/doc/source/math_pinn.qmd +0 -0
  50. {jinns-0.8.5 → jinns-0.8.6}/doc/source/operators.rst +0 -0
  51. {jinns-0.8.5 → jinns-0.8.6}/doc/source/param_estim_pinn.qmd +0 -0
  52. {jinns-0.8.5 → jinns-0.8.6}/doc/source/rar.rst +0 -0
  53. {jinns-0.8.5 → jinns-0.8.6}/doc/source/seq2seq.rst +0 -0
  54. {jinns-0.8.5 → jinns-0.8.6}/doc/source/solve.rst +0 -0
  55. {jinns-0.8.5 → jinns-0.8.6}/doc/source/solver.rst +0 -0
  56. {jinns-0.8.5 → jinns-0.8.6}/doc/source/utils.rst +0 -0
  57. {jinns-0.8.5 → jinns-0.8.6}/jinns/__init__.py +0 -0
  58. {jinns-0.8.5 → jinns-0.8.6}/jinns/data/_DataGenerators.py +0 -0
  59. {jinns-0.8.5 → jinns-0.8.6}/jinns/data/__init__.py +0 -0
  60. {jinns-0.8.5 → jinns-0.8.6}/jinns/data/_display.py +0 -0
  61. {jinns-0.8.5 → jinns-0.8.6}/jinns/experimental/__init__.py +0 -0
  62. {jinns-0.8.5 → jinns-0.8.6}/jinns/experimental/_diffrax_solver.py +0 -0
  63. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_DynamicLoss.py +0 -0
  64. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_DynamicLossAbstract.py +0 -0
  65. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_LossODE.py +0 -0
  66. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_LossPDE.py +0 -0
  67. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_Losses.py +0 -0
  68. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/__init__.py +0 -0
  69. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_boundary_conditions.py +0 -0
  70. {jinns-0.8.5 → jinns-0.8.6}/jinns/loss/_operators.py +0 -0
  71. {jinns-0.8.5 → jinns-0.8.6}/jinns/solver/__init__.py +0 -0
  72. {jinns-0.8.5 → jinns-0.8.6}/jinns/solver/_rar.py +0 -0
  73. {jinns-0.8.5 → jinns-0.8.6}/jinns/solver/_seq2seq.py +0 -0
  74. {jinns-0.8.5 → jinns-0.8.6}/jinns/solver/_solve.py +0 -0
  75. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/__init__.py +0 -0
  76. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_optim.py +0 -0
  77. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_save_load.py +0 -0
  78. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_utils.py +0 -0
  79. {jinns-0.8.5 → jinns-0.8.6}/jinns/utils/_utils_uspinn.py +0 -0
  80. {jinns-0.8.5 → jinns-0.8.6}/jinns.egg-info/dependency_links.txt +0 -0
  81. {jinns-0.8.5 → jinns-0.8.6}/jinns.egg-info/requires.txt +0 -0
  82. {jinns-0.8.5 → jinns-0.8.6}/jinns.egg-info/top_level.txt +0 -0
  83. {jinns-0.8.5 → jinns-0.8.6}/pyproject.toml +0 -0
  84. {jinns-0.8.5 → jinns-0.8.6}/setup.cfg +0 -0
  85. {jinns-0.8.5 → jinns-0.8.6}/tests/conftest.py +0 -0
  86. {jinns-0.8.5 → jinns-0.8.6}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  87. {jinns-0.8.5 → jinns-0.8.6}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  88. {jinns-0.8.5 → jinns-0.8.6}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  89. {jinns-0.8.5 → jinns-0.8.6}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
  90. {jinns-0.8.5 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
  91. {jinns-0.8.5 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
  92. {jinns-0.8.5 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
  93. {jinns-0.8.5 → jinns-0.8.6}/tests/sharding_tests/test_Burger_x32_multiple_shardings.py +0 -0
  94. {jinns-0.8.5 → jinns-0.8.6}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py +0 -0
  95. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_Burger_x32.py +0 -0
  96. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_Burger_x64.py +0 -0
  97. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_Fisher_x32.py +0 -0
  98. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_Fisher_x64.py +0 -0
  99. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_GLV_x32.py +0 -0
  100. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_GLV_x64.py +0 -0
  101. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
  102. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
  103. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_OU2D_x32.py +0 -0
  104. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests/test_imperfect_sobolev_x32.py +0 -0
  105. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests_spinn/test_Burger_x32.py +0 -0
  106. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests_spinn/test_Fisher_x32.py +0 -0
  107. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
  108. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests_spinn/test_OU2D_x32.py +0 -0
  109. {jinns-0.8.5 → jinns-0.8.6}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.5
3
+ Version: 0.8.6
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -8,9 +8,13 @@ Welcome to jinn's documentation!
8
8
 
9
9
  Changelog:
10
10
 
11
+ * v0.8.6:
12
+
13
+ - Merge `[!37] <https://gitlab.com/mia_jinns/jinns/-/merge_requests/37>`_
14
+
11
15
  * v0.8.5:
12
16
 
13
- - Merge [!36](https://gitlab.com/mia_jinns/jinns/-/merge_requests/36)
17
+ - Merge `[!36] <https://gitlab.com/mia_jinns/jinns/-/merge_requests/36>`_
14
18
 
15
19
  * v0.8.4:
16
20
 
@@ -216,13 +216,7 @@ def create_HYPERPINN(
216
216
 
217
217
  Returns
218
218
  -------
219
- init_fn
220
- A function which (re-)initializes the PINN parameters with the provided
221
- jax random key
222
- apply_fn
223
- A function to apply the neural network on given inputs for given
224
- parameters. A typical call will be of the form `u(t, params)` for
225
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
219
+ `u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
226
220
 
227
221
  Raises
228
222
  ------
@@ -289,7 +283,6 @@ def create_HYPERPINN(
289
283
 
290
284
  if shared_pinn_outputs is not None:
291
285
  hyperpinns = []
292
- static = None
293
286
  for output_slice in shared_pinn_outputs:
294
287
  hyperpinn = HYPERPINN(
295
288
  mlp,
@@ -302,11 +295,6 @@ def create_HYPERPINN(
302
295
  hypernet_input_size,
303
296
  output_slice,
304
297
  )
305
- # all the pinns are in fact the same so we share the same static
306
- if static is None:
307
- static = hyperpinn.static
308
- else:
309
- hyperpinn.static = static
310
298
  hyperpinns.append(hyperpinn)
311
299
  return hyperpinns
312
300
  hyperpinn = HYPERPINN(
@@ -200,13 +200,7 @@ def create_PINN(
200
200
 
201
201
  Returns
202
202
  -------
203
- init_fn
204
- A function which (re-)initializes the PINN parameters with the provided
205
- jax random key
206
- apply_fn
207
- A function to apply the neural network on given inputs for given
208
- parameters. A typical call will be of the form `u(t, params)` for
209
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
203
+ `u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
210
204
 
211
205
  Raises
212
206
  ------
@@ -253,7 +247,6 @@ def create_PINN(
253
247
 
254
248
  if shared_pinn_outputs is not None:
255
249
  pinns = []
256
- static = None
257
250
  for output_slice in shared_pinn_outputs:
258
251
  pinn = PINN(
259
252
  mlp,
@@ -263,11 +256,6 @@ def create_PINN(
263
256
  output_transform,
264
257
  output_slice,
265
258
  )
266
- # all the pinns are in fact the same so we share the same static
267
- if static is None:
268
- static = pinn.static
269
- else:
270
- pinn.static = static
271
259
  pinns.append(pinn)
272
260
  return pinns
273
261
  pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
@@ -194,13 +194,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
194
194
 
195
195
  Returns
196
196
  -------
197
- init_fn
198
- A function which (re-)initializes the SPINN parameters with the provided
199
- jax random key
200
- apply_fn
201
- A function to apply the neural network on given inputs for given
202
- parameters. A typical call will be of the form `u(t, params)` for
203
- ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
197
+ `u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
204
198
 
205
199
  Raises
206
200
  ------
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.5
3
+ Version: 0.8.6
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -102,4 +102,7 @@ tests/solver_tests_spinn/test_Burger_x32.py
102
102
  tests/solver_tests_spinn/test_Fisher_x32.py
103
103
  tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py
104
104
  tests/solver_tests_spinn/test_OU2D_x32.py
105
- tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py
105
+ tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py
106
+ tests/utils_tests/test_hyperpinns.py
107
+ tests/utils_tests/test_pinn.py
108
+ tests/utils_tests/test_spinn.py
@@ -16,4 +16,8 @@ if [ $? -ne 0 ]; then
16
16
  exit $?
17
17
  fi
18
18
  pytest solver_tests_spinn/*
19
+ if [ $? -ne 0 ]; then
20
+ exit $?
21
+ fi
22
+ pytest utils_tests/*
19
23
  exit $?
@@ -0,0 +1,138 @@
1
+ """
2
+ Test script for custom PINN eqx.Module
3
+ """
4
+
5
+ import pytest
6
+ import jax
7
+ import jax.random as random
8
+ import jax.numpy as jnp
9
+ import equinox as eqx
10
+
11
+ import jinns
12
+ from jinns.utils import create_PINN
13
+ import jinns.utils
14
+
15
+ key = random.PRNGKey(2)
16
+ key, subkey = random.split(key)
17
+
18
+ d = 5
19
+ n_param = 42
20
+ hyperparams = [f"param {i}" for i in range(n_param)]
21
+
22
+ EQX_LIST = [
23
+ [jax.nn.swish],
24
+ [eqx.nn.Linear, 16, 16],
25
+ [jax.nn.swish],
26
+ [eqx.nn.Linear, 16, 16],
27
+ ]
28
+
29
+ eqx_list_hyper = [
30
+ [eqx.nn.Linear, n_param, 32], # input is of size 42
31
+ [jax.nn.tanh],
32
+ [eqx.nn.Linear, 32, 16],
33
+ [jax.nn.tanh],
34
+ [
35
+ eqx.nn.Linear,
36
+ 16,
37
+ 1000,
38
+ ], # 1000 is a random guess, it will automatically be filled with the correct value
39
+ ]
40
+
41
+
42
+ @pytest.fixture
43
+ def create_pinn_ode():
44
+ eqx_list = [[eqx.nn.Linear, 1, 16]] + EQX_LIST
45
+ u_ode = jinns.utils.create_HYPERPINN(
46
+ subkey, eqx_list, "ODE", hyperparams, n_param, 0, eqx_list_hyper
47
+ )
48
+ return u_ode
49
+
50
+
51
+ @pytest.fixture
52
+ def create_pinn_statio():
53
+ eqx_list = [[eqx.nn.Linear, d, 16]] + EQX_LIST
54
+ u_statio = jinns.utils.create_HYPERPINN(
55
+ subkey, eqx_list, "statio_PDE", hyperparams, n_param, d, eqx_list_hyper
56
+ )
57
+ return u_statio
58
+
59
+
60
+ @pytest.fixture
61
+ def create_pinn_nonstatio():
62
+ eqx_list = [[eqx.nn.Linear, d + 1, 16]] + EQX_LIST
63
+ u_nonstatio = jinns.utils.create_HYPERPINN(
64
+ subkey, eqx_list, "nonstatio_PDE", hyperparams, n_param, d + 1, eqx_list_hyper
65
+ )
66
+
67
+ return u_nonstatio
68
+
69
+
70
+ @pytest.fixture
71
+ def create_pinn_nonstatio_shared_output():
72
+
73
+ # specific argument since we want to have u1 and u2 as separate nns
74
+ shared_pinn_output = (
75
+ jnp.s_[:2],
76
+ jnp.s_[2],
77
+ )
78
+
79
+ eqx_list = [[eqx.nn.Linear, d, 16]] + EQX_LIST
80
+ u1, u2 = jinns.utils.create_HYPERPINN(
81
+ subkey,
82
+ eqx_list,
83
+ "nonstatio_PDE",
84
+ hyperparams,
85
+ n_param,
86
+ d,
87
+ eqx_list_hyper,
88
+ shared_pinn_outputs=shared_pinn_output,
89
+ )
90
+ return u1, u2, shared_pinn_output
91
+
92
+
93
+ def test_ode_pinn_struct(create_pinn_ode):
94
+
95
+ u_ode = create_pinn_ode
96
+ assert isinstance(u_ode, jinns.utils._hyperpinn.HYPERPINN)
97
+ assert u_ode.eq_type == "ODE"
98
+ assert u_ode.output_slice is None
99
+ assert isinstance(u_ode.slice_solution, slice)
100
+ _ = u_ode.init_params()
101
+
102
+
103
+ def test_statio_pinn_struct(create_pinn_statio):
104
+
105
+ u_statio = create_pinn_statio
106
+ assert u_statio.eq_type == "statio_PDE"
107
+ assert isinstance(u_statio, jinns.utils._hyperpinn.HYPERPINN)
108
+ assert u_statio.output_slice is None
109
+ assert isinstance(u_statio.slice_solution, slice)
110
+ _ = u_statio.init_params()
111
+
112
+
113
+ def test_nonstatio_pinn_struct(create_pinn_nonstatio):
114
+
115
+ u_nonstatio = create_pinn_nonstatio
116
+ assert u_nonstatio.eq_type == "nonstatio_PDE"
117
+ assert isinstance(u_nonstatio, jinns.utils._hyperpinn.HYPERPINN)
118
+ assert u_nonstatio.output_slice is None
119
+ assert isinstance(u_nonstatio.slice_solution, slice)
120
+ _ = u_nonstatio.init_params()
121
+
122
+
123
+ def test_nonstatio_pinn_shared_output(create_pinn_nonstatio_shared_output):
124
+
125
+ u1, u2, shared_pinn_ouput = create_pinn_nonstatio_shared_output
126
+ assert u1.eq_type == "nonstatio_PDE"
127
+ assert u1.output_slice == shared_pinn_ouput[0]
128
+ assert isinstance(u1.slice_solution, slice)
129
+
130
+ assert u2.eq_type == "nonstatio_PDE"
131
+ assert u2.output_slice == shared_pinn_ouput[1]
132
+ assert isinstance(u2.slice_solution, slice)
133
+
134
+ param1 = u1.init_params()
135
+ param2 = u2.init_params()
136
+
137
+ # the init parameters for the 2 nns should be the same PyTree
138
+ assert eqx.tree_equal(param1, param2)
@@ -0,0 +1,131 @@
1
+ """
2
+ Test script for custom PINN eqx.Module
3
+ """
4
+
5
+ import pytest
6
+ import jax
7
+ import jax.random as random
8
+ import jax.numpy as jnp
9
+ import equinox as eqx
10
+
11
+ import jinns
12
+ from jinns.utils import create_PINN
13
+
14
+
15
+ d = 5
16
+
17
+
18
+ @pytest.fixture
19
+ def create_pinn_ode():
20
+ key = random.PRNGKey(2)
21
+ eqx_list = [
22
+ [eqx.nn.Linear, 1, 128],
23
+ [jax.nn.tanh],
24
+ [eqx.nn.Linear, 128, 1],
25
+ ]
26
+ key, subkey = random.split(key)
27
+ u_statio = create_PINN(subkey, eqx_list, "ODE")
28
+
29
+ return u_statio
30
+
31
+
32
+ @pytest.fixture
33
+ def create_pinn_statio():
34
+ key = random.PRNGKey(2)
35
+ eqx_list = [
36
+ [eqx.nn.Linear, d, 128],
37
+ [jax.nn.tanh],
38
+ [eqx.nn.Linear, 128, 1],
39
+ ]
40
+ key, subkey = random.split(key)
41
+ u_statio = create_PINN(subkey, eqx_list, "statio_PDE", d)
42
+
43
+ return u_statio
44
+
45
+
46
+ @pytest.fixture
47
+ def create_pinn_nonstatio():
48
+ key = random.PRNGKey(2)
49
+ eqx_list = [
50
+ [eqx.nn.Linear, d, 128],
51
+ [jax.nn.tanh],
52
+ [eqx.nn.Linear, 128, 1],
53
+ ]
54
+ key, subkey = random.split(key)
55
+ u_nonstatio = create_PINN(subkey, eqx_list, "nonstatio_PDE", d)
56
+
57
+ return u_nonstatio
58
+
59
+
60
+ @pytest.fixture
61
+ def create_pinn_nonstatio_shared_output():
62
+ key = random.PRNGKey(2)
63
+ eqx_list = [
64
+ [eqx.nn.Linear, d, 128],
65
+ [jax.nn.tanh],
66
+ [eqx.nn.Linear, 128, 1],
67
+ ]
68
+ key, subkey = random.split(key)
69
+ # specific argument since we want to have u1 and u2 as separate nns
70
+ shared_pinn_output = (
71
+ jnp.s_[:2],
72
+ jnp.s_[2],
73
+ )
74
+
75
+ u1, u2 = jinns.utils.create_PINN(
76
+ subkey,
77
+ eqx_list,
78
+ "nonstatio_PDE",
79
+ d,
80
+ shared_pinn_outputs=shared_pinn_output,
81
+ )
82
+ return u1, u2, shared_pinn_output
83
+
84
+
85
+ def test_ode_pinn_struct(create_pinn_ode):
86
+
87
+ u_ode = create_pinn_ode
88
+ assert u_ode.eq_type == "ODE"
89
+ assert isinstance(u_ode, jinns.utils._pinn.PINN)
90
+ assert u_ode.output_slice is None
91
+ assert isinstance(u_ode.slice_solution, slice)
92
+ _ = u_ode.init_params()
93
+
94
+
95
+ def test_statio_pinn_struct(create_pinn_statio):
96
+
97
+ u_statio = create_pinn_statio
98
+ assert u_statio.eq_type == "statio_PDE"
99
+ assert isinstance(u_statio, jinns.utils._pinn.PINN)
100
+
101
+ assert u_statio.output_slice is None
102
+ assert isinstance(u_statio.slice_solution, slice)
103
+ _ = u_statio.init_params()
104
+
105
+
106
+ def test_nonstatio_pinn_struct(create_pinn_nonstatio):
107
+
108
+ u_nonstatio = create_pinn_nonstatio
109
+ assert u_nonstatio.eq_type == "nonstatio_PDE"
110
+ assert isinstance(u_nonstatio, jinns.utils._pinn.PINN)
111
+ assert u_nonstatio.output_slice is None
112
+ assert isinstance(u_nonstatio.slice_solution, slice)
113
+ _ = u_nonstatio.init_params()
114
+
115
+
116
+ def test_nonstatio_pinn_shared_output(create_pinn_nonstatio_shared_output):
117
+
118
+ u1, u2, shared_pinn_ouput = create_pinn_nonstatio_shared_output
119
+ assert u1.eq_type == "nonstatio_PDE"
120
+ assert u1.output_slice == shared_pinn_ouput[0]
121
+ assert isinstance(u1.slice_solution, slice)
122
+
123
+ assert u2.eq_type == "nonstatio_PDE"
124
+ assert u2.output_slice == shared_pinn_ouput[1]
125
+ assert isinstance(u2.slice_solution, slice)
126
+
127
+ param1 = u1.init_params()
128
+ param2 = u2.init_params()
129
+
130
+ # the init parameters for the 2 nns should be the same PyTree
131
+ assert eqx.tree_equal(param1, param2)
@@ -0,0 +1,110 @@
1
+ """
2
+ Test script for custom PINN eqx.Module
3
+ """
4
+
5
+ import pytest
6
+ import jax
7
+ import jax.random as random
8
+ import jax.numpy as jnp
9
+ import equinox as eqx
10
+
11
+ import jinns
12
+ from jinns.utils import create_SPINN
13
+ import jinns.utils
14
+
15
+
16
+ d = 5
17
+ r = 100 # embedding dim
18
+ m = 1 # output dim
19
+ eqx_list = [
20
+ [eqx.nn.Linear, 1, 128],
21
+ [jax.nn.tanh],
22
+ [eqx.nn.Linear, 128, 128],
23
+ [jax.nn.tanh],
24
+ [eqx.nn.Linear, 128, 128],
25
+ [jax.nn.tanh],
26
+ [eqx.nn.Linear, 128, r * m],
27
+ ]
28
+
29
+
30
+ def _assert_attr_equal(u):
31
+ assert u.m == m
32
+ assert u.r == r
33
+
34
+
35
+ @pytest.fixture
36
+ def create_SPINN_ode():
37
+ key = random.PRNGKey(2)
38
+ key, subkey = random.split(key)
39
+ u_statio = create_SPINN(subkey, 1, r, eqx_list, "ODE", m)
40
+
41
+ return u_statio
42
+
43
+
44
+ @pytest.fixture
45
+ def create_SPINN_statio():
46
+ key = random.PRNGKey(2)
47
+ key, subkey = random.split(key)
48
+ u_statio = create_SPINN(subkey, d, r, eqx_list, "statio_PDE", m)
49
+
50
+ return u_statio
51
+
52
+
53
+ @pytest.fixture
54
+ def create_SPINN_nonstatio():
55
+ key = random.PRNGKey(2)
56
+ key, subkey = random.split(key)
57
+ u_nonstatio = create_SPINN(subkey, d, r, eqx_list, "nonstatio_PDE", m)
58
+
59
+ return u_nonstatio
60
+
61
+
62
+ def test_ode_pinn_struct(create_SPINN_ode):
63
+
64
+ u_ode = create_SPINN_ode
65
+ assert u_ode.eq_type == "ODE"
66
+ assert isinstance(u_ode, jinns.utils._spinn.SPINN)
67
+ assert u_ode.d == 1
68
+ _assert_attr_equal(u_ode)
69
+ _ = u_ode.init_params()
70
+
71
+
72
+ def test_statio_pinn_struct(create_SPINN_statio):
73
+
74
+ u_statio = create_SPINN_statio
75
+ assert u_statio.eq_type == "statio_PDE"
76
+ assert isinstance(u_statio, jinns.utils._spinn.SPINN)
77
+
78
+ assert u_statio.d == d
79
+ _assert_attr_equal(u_statio)
80
+ _ = u_statio.init_params()
81
+
82
+
83
+ def test_nonstatio_pinn_struct(create_SPINN_nonstatio):
84
+
85
+ u_nonstatio = create_SPINN_nonstatio
86
+ assert u_nonstatio.eq_type == "nonstatio_PDE"
87
+ assert isinstance(u_nonstatio, jinns.utils._spinn.SPINN)
88
+ assert u_nonstatio.d == d # in non-statio SPINN user should include `t` in `d`
89
+ _assert_attr_equal(u_nonstatio)
90
+ _ = u_nonstatio.init_params()
91
+
92
+
93
+ def test_raising_error_init_SPINN():
94
+
95
+ # output_dim != r*m
96
+ with pytest.raises(ValueError) as e:
97
+ wrong_eqx_list = [
98
+ [eqx.nn.Linear, 1, 128],
99
+ [jax.nn.tanh],
100
+ [eqx.nn.Linear, 128, r * m + 1], # output_dim != r*m
101
+ ]
102
+ _ = create_SPINN(random.PRNGKey(1), d, r, wrong_eqx_list, "nonstatio_PDE", m)
103
+
104
+ # d > 24
105
+ with pytest.raises(ValueError) as e:
106
+ _ = create_SPINN(random.PRNGKey(1), 24, r, wrong_eqx_list, "nonstatio_PDE", m)
107
+
108
+ # d > 24
109
+ with pytest.raises(ValueError) as e:
110
+ _ = create_SPINN(random.PRNGKey(1), 24, r, wrong_eqx_list, "nonstatio_PDE", m)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes