jinns 1.2.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. {jinns-1.2.0 → jinns-1.3.0}/.gitlab-ci.yml +10 -6
  2. {jinns-1.2.0 → jinns-1.3.0}/.pre-commit-config.yaml +2 -2
  3. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +11 -11
  4. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/linear_fo_equation.ipynb +11 -11
  5. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +4 -5
  6. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/1D_non_stationary_Burgers.ipynb +2 -2
  7. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +2 -2
  8. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +3 -3
  9. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +2 -2
  10. jinns-1.3.0/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +1402 -0
  11. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +4 -4
  12. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +1 -1
  13. jinns-1.3.0/Notebooks/PDE/2D_non_stationary_OU.ipynb +1059 -0
  14. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +2 -2
  15. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +1 -1
  16. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/1D_non_stationary_Burgers_JointEstimation_Vanilla.ipynb +1 -1
  17. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +3 -3
  18. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/introducing_validation_loss.ipynb +2 -2
  19. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/load_save_model.ipynb +2 -2
  20. {jinns-1.2.0 → jinns-1.3.0}/PKG-INFO +9 -9
  21. {jinns-1.2.0 → jinns-1.3.0}/README.md +7 -7
  22. jinns-1.3.0/docs/api/pinn/hyperpinn.md +3 -0
  23. jinns-1.3.0/docs/api/pinn/pinn.md +7 -0
  24. jinns-1.3.0/docs/api/pinn/ppinn.md +3 -0
  25. jinns-1.3.0/docs/api/pinn/save_load.md +5 -0
  26. jinns-1.3.0/docs/api/pinn/spinn.md +7 -0
  27. {jinns-1.2.0 → jinns-1.3.0}/docs/changelog.md +8 -0
  28. {jinns-1.2.0 → jinns-1.3.0}/docs/index.md +12 -4
  29. {jinns-1.2.0 → jinns-1.3.0}/jinns/data/_DataGenerators.py +2 -2
  30. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_DynamicLoss.py +2 -2
  31. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_LossODE.py +1 -1
  32. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_LossPDE.py +75 -38
  33. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_boundary_conditions.py +2 -2
  34. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_loss_utils.py +21 -15
  35. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_operators.py +0 -2
  36. jinns-1.3.0/jinns/nn/__init__.py +7 -0
  37. jinns-1.3.0/jinns/nn/_hyperpinn.py +397 -0
  38. jinns-1.3.0/jinns/nn/_mlp.py +192 -0
  39. jinns-1.3.0/jinns/nn/_pinn.py +190 -0
  40. jinns-1.3.0/jinns/nn/_ppinn.py +203 -0
  41. {jinns-1.2.0/jinns/utils → jinns-1.3.0/jinns/nn}/_save_load.py +39 -23
  42. jinns-1.3.0/jinns/nn/_spinn.py +106 -0
  43. jinns-1.3.0/jinns/nn/_spinn_mlp.py +196 -0
  44. {jinns-1.2.0 → jinns-1.3.0}/jinns/plot/_plot.py +3 -3
  45. {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_rar.py +3 -3
  46. {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_solve.py +23 -9
  47. jinns-1.3.0/jinns/utils/__init__.py +1 -0
  48. {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_types.py +4 -4
  49. {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/PKG-INFO +9 -9
  50. {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/SOURCES.txt +18 -8
  51. {jinns-1.2.0 → jinns-1.3.0}/mkdocs.yml +2 -10
  52. jinns-1.3.0/tests/loss_tests/test_lossPDEstatio.py +138 -0
  53. jinns-1.3.0/tests/loss_tests/test_norm_loss.py +92 -0
  54. jinns-1.3.0/tests/nn_tests/test_hyperpinns.py +109 -0
  55. jinns-1.3.0/tests/nn_tests/test_mlp.py +104 -0
  56. jinns-1.3.0/tests/nn_tests/test_pinn.py +84 -0
  57. jinns-1.3.0/tests/nn_tests/test_ppinn_mlp.py +107 -0
  58. jinns-1.3.0/tests/nn_tests/test_smlp.py +72 -0
  59. {jinns-1.2.0/tests/utils_tests → jinns-1.3.0/tests/nn_tests}/test_spinn.py +16 -10
  60. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_divergence_fwd.py +2 -2
  61. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_divergence_rev.py +4 -2
  62. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_laplacian_fwd.py +2 -2
  63. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_laplacian_rev.py +4 -2
  64. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_vectorial_laplacian_fwd.py +2 -2
  65. {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_vectorial_laplacian_rev.py +4 -2
  66. {jinns-1.2.0 → jinns-1.3.0}/tests/parameters_tests/test_DerivativeKeysODE.py +1 -1
  67. {jinns-1.2.0 → jinns-1.3.0}/tests/plot_tests/test_plot1D.py +1 -1
  68. {jinns-1.2.0 → jinns-1.3.0}/tests/plot_tests/test_plot2D.py +1 -1
  69. {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_hyperpinn.py +10 -20
  70. {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_pinn.py +14 -12
  71. {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_spinn.py +3 -3
  72. {jinns-1.2.0 → jinns-1.3.0}/tests/sharding_tests/test_Burgers_x32_multiple_shardings.py +6 -2
  73. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Burgers_x32.py +5 -3
  74. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Burgers_x64.py +5 -3
  75. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Fisher_x32.py +5 -3
  76. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Fisher_x64.py +5 -3
  77. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_GLV_x32.py +8 -4
  78. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_GLV_x64.py +8 -4
  79. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_NSPipeFlow_x32.py +10 -4
  80. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_NSPipeFlow_x64.py +10 -4
  81. jinns-1.3.0/tests/solver_tests/test_OU1D_statio_x32.py +134 -0
  82. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_OU2D_x32.py +9 -8
  83. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_nan_params_catch.py +3 -1
  84. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_parameter_tracker.py +3 -1
  85. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_rar_algorithm.py +18 -15
  86. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_hyperpinn/test_NSPipeFlow_x32_hyperpinn.py +23 -12
  87. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_Burgers_x32_spinn.py +3 -3
  88. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_Fisher_x32_spinn.py +3 -3
  89. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +2 -2
  90. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_OU2D_x32_spinn.py +8 -6
  91. {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64_spinn.py +2 -2
  92. {jinns-1.2.0 → jinns-1.3.0}/tests/utils_tests/test_solver_utils.py +1 -1
  93. {jinns-1.2.0 → jinns-1.3.0}/tests/validation_tests/test_vanilla_validation.py +1 -1
  94. jinns-1.2.0/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -1377
  95. jinns-1.2.0/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -1083
  96. jinns-1.2.0/docs/api/pinn/hyperpinn.md +0 -5
  97. jinns-1.2.0/docs/api/pinn/pinn.md +0 -5
  98. jinns-1.2.0/docs/api/pinn/save_load.md +0 -5
  99. jinns-1.2.0/docs/api/pinn/spinn.md +0 -5
  100. jinns-1.2.0/jinns/utils/__init__.py +0 -6
  101. jinns-1.2.0/jinns/utils/_hyperpinn.py +0 -420
  102. jinns-1.2.0/jinns/utils/_pinn.py +0 -324
  103. jinns-1.2.0/jinns/utils/_ppinn.py +0 -227
  104. jinns-1.2.0/jinns/utils/_spinn.py +0 -249
  105. jinns-1.2.0/tests/utils_tests/test_hyperpinns.py +0 -132
  106. jinns-1.2.0/tests/utils_tests/test_pinn.py +0 -125
  107. {jinns-1.2.0 → jinns-1.3.0}/.gitignore +0 -0
  108. {jinns-1.2.0 → jinns-1.3.0}/AUTHORS +0 -0
  109. {jinns-1.2.0 → jinns-1.3.0}/LICENSE +0 -0
  110. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
  111. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
  112. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  113. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  114. {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/burger_solution_grid.npy +0 -0
  115. {jinns-1.2.0 → jinns-1.3.0}/codemeta.json +0 -0
  116. {jinns-1.2.0 → jinns-1.3.0}/docs/README.md +0 -0
  117. {jinns-1.2.0 → jinns-1.3.0}/docs/_static/custom_css.css +0 -0
  118. {jinns-1.2.0 → jinns-1.3.0}/docs/_static/favicon.png +0 -0
  119. {jinns-1.2.0 → jinns-1.3.0}/docs/api/advanced/derivative_keys.md +0 -0
  120. {jinns-1.2.0 → jinns-1.3.0}/docs/api/advanced/differential_operators.md +0 -0
  121. {jinns-1.2.0 → jinns-1.3.0}/docs/api/datagenerators/datagenerators_core.md +0 -0
  122. {jinns-1.2.0 → jinns-1.3.0}/docs/api/datagenerators/datagenerators_other.md +0 -0
  123. {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/dynamic_loss.md +0 -0
  124. {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/loss_xde.md +0 -0
  125. {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/systems_of_xde.md +0 -0
  126. {jinns-1.2.0 → jinns-1.3.0}/docs/api/plot.md +0 -0
  127. {jinns-1.2.0 → jinns-1.3.0}/docs/api/solver.md +0 -0
  128. {jinns-1.2.0 → jinns-1.3.0}/docs/doc_requirements.txt +0 -0
  129. {jinns-1.2.0 → jinns-1.3.0}/docs/javascripts/katex.js +0 -0
  130. {jinns-1.2.0 → jinns-1.3.0}/docs/maths/fokker_planck.md +0 -0
  131. {jinns-1.2.0 → jinns-1.3.0}/docs/maths/introduction_to_pinns.md +0 -0
  132. {jinns-1.2.0 → jinns-1.3.0}/img/jinns-diagram.png +0 -0
  133. {jinns-1.2.0 → jinns-1.3.0}/jinns/__init__.py +0 -0
  134. {jinns-1.2.0 → jinns-1.3.0}/jinns/data/_Batchs.py +0 -0
  135. {jinns-1.2.0 → jinns-1.3.0}/jinns/data/__init__.py +0 -0
  136. {jinns-1.2.0 → jinns-1.3.0}/jinns/experimental/__init__.py +0 -0
  137. {jinns-1.2.0 → jinns-1.3.0}/jinns/experimental/_diffrax_solver.py +0 -0
  138. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_DynamicLossAbstract.py +0 -0
  139. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/__init__.py +0 -0
  140. {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_loss_weights.py +0 -0
  141. {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/__init__.py +0 -0
  142. {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/_derivative_keys.py +0 -0
  143. {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/_params.py +0 -0
  144. {jinns-1.2.0 → jinns-1.3.0}/jinns/plot/__init__.py +0 -0
  145. {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/__init__.py +0 -0
  146. {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_utils.py +0 -0
  147. {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_containers.py +0 -0
  148. {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_utils.py +0 -0
  149. {jinns-1.2.0 → jinns-1.3.0}/jinns/validation/__init__.py +0 -0
  150. {jinns-1.2.0 → jinns-1.3.0}/jinns/validation/_validation.py +0 -0
  151. {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/dependency_links.txt +0 -0
  152. {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/requires.txt +0 -0
  153. {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/top_level.txt +0 -0
  154. {jinns-1.2.0 → jinns-1.3.0}/pyproject.toml +0 -0
  155. {jinns-1.2.0 → jinns-1.3.0}/setup.cfg +0 -0
  156. {jinns-1.2.0 → jinns-1.3.0}/tests/conftest.py +0 -0
  157. {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  158. {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  159. {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  160. {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
  161. {jinns-1.2.0 → jinns-1.3.0}/tests/utils_tests/test_subtract_with_check.py +0 -0
@@ -17,10 +17,11 @@ black:
17
17
  run_tests:
18
18
  stage: tests
19
19
  before_script:
20
- - pip install --break-system-packages pytest coverage pytest-cov
21
- - pwd
20
+ - virtualenv venv
21
+ - source venv/bin/activate
22
+ - pip install pytest coverage pytest-cov
23
+ - pip install -e .
22
24
  script:
23
- - pip install --break-system-packages -e .
24
25
  - pytest --cov=jinns --ignore=tests/solver_tests/test_NSPipeFlow_x64.py
25
26
  coverage: '/TOTAL.*\s+(\d+%)$/'
26
27
 
@@ -28,10 +29,13 @@ build_doc:
28
29
  stage: build
29
30
  needs: [] # don't need to wait for other jobs
30
31
  before_script:
31
-
32
+ - virtualenv venv
33
+ - source venv/bin/activate
34
+ - pip install -e .
35
+ - pip install -r docs/doc_requirements.txt
32
36
  script:
33
- - pip install --break-system-packages -e .
34
- - pip install --break-system-packages -r docs/doc_requirements.txt
37
+ - pip list
38
+ - python -m mkdocstrings_handlers.python.debug
35
39
  - mkdocs build
36
40
  - mkdocs build # twice, see https://github.com/patrick-kidger/pytkdocs_tweaks
37
41
 
@@ -1,14 +1,14 @@
1
1
  default_stages: [commit]
2
2
  repos:
3
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.6.0
4
+ rev: v5.0.0
5
5
  hooks:
6
6
  - id: trailing-whitespace
7
7
  stages: [commit]
8
8
  - id: end-of-file-fixer
9
9
  stages: [commit]
10
10
  - repo: https://github.com/psf/black
11
- rev: 24.3.0
11
+ rev: 25.1.0
12
12
  hooks:
13
13
  - id: black
14
14
  stages: [commit]
@@ -109,16 +109,16 @@
109
109
  "metadata": {},
110
110
  "outputs": [],
111
111
  "source": [
112
- "eqx_list = [\n",
113
- " [eqx.nn.Linear, 1, 20],\n",
114
- " [jax.nn.tanh],\n",
115
- " [eqx.nn.Linear, 20, 20],\n",
116
- " [jax.nn.tanh],\n",
117
- " [eqx.nn.Linear, 20, 20],\n",
118
- " [jax.nn.tanh],\n",
119
- " [eqx.nn.Linear, 20, 1],\n",
120
- " [jnp.exp]\n",
121
- "]\n",
112
+ "eqx_list = (\n",
113
+ " (eqx.nn.Linear, 1, 20),\n",
114
+ " (jax.nn.tanh,),\n",
115
+ " (eqx.nn.Linear, 20, 20),\n",
116
+ " (jax.nn.tanh,),\n",
117
+ " (eqx.nn.Linear, 20, 20),\n",
118
+ " (jax.nn.tanh,),\n",
119
+ " (eqx.nn.Linear, 20, 1),\n",
120
+ " (jnp.exp,)\n",
121
+ ")\n",
122
122
  "key, subkey = random.split(key)"
123
123
  ]
124
124
  },
@@ -172,7 +172,7 @@
172
172
  "init_nn_params_list = []\n",
173
173
  "for _ in range(3):\n",
174
174
  " key, subkey = random.split(key)\n",
175
- " u, init_nn_params = jinns.utils.create_PINN(subkey, eqx_list, \"ODE\", 0)\n",
175
+ " u, init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"ODE\")\n",
176
176
  " init_nn_params_list.append(init_nn_params)"
177
177
  ]
178
178
  },
@@ -103,18 +103,18 @@
103
103
  "metadata": {},
104
104
  "outputs": [],
105
105
  "source": [
106
- "eqx_list = [\n",
107
- " [eqx.nn.Linear, 1, 20],\n",
108
- " [jax.nn.tanh],\n",
109
- " [eqx.nn.Linear, 20, 20],\n",
110
- " [jax.nn.tanh],\n",
111
- " [eqx.nn.Linear, 20, 20],\n",
112
- " [jax.nn.tanh],\n",
113
- " [eqx.nn.Linear, 20, 1],\n",
114
- " # [jnp.exp]\n",
115
- "]\n",
106
+ "eqx_list = (\n",
107
+ " (eqx.nn.Linear, 1, 20),\n",
108
+ " (jax.nn.tanh,),\n",
109
+ " (eqx.nn.Linear, 20, 20),\n",
110
+ " (jax.nn.tanh,),\n",
111
+ " (eqx.nn.Linear, 20, 20),\n",
112
+ " (jax.nn.tanh,),\n",
113
+ " (eqx.nn.Linear, 20, 1),\n",
114
+ " # (jnp.exp,)\n",
115
+ ")\n",
116
116
  "key, subkey = random.split(key)\n",
117
- "u, init_nn_params = jinns.utils.create_PINN(subkey, eqx_list, \"ODE\")"
117
+ "u, init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"ODE\")"
118
118
  ]
119
119
  },
120
120
  {
@@ -253,11 +253,10 @@
253
253
  "\n",
254
254
  "for i, k in enumerate(nn_keys):\n",
255
255
  " key, subkey = random.split(key)\n",
256
- " u_k, init_nn_params_k = jinns.utils.create_PINN(\n",
257
- " subkey,\n",
258
- " eqx_list,\n",
259
- " \"ODE\",\n",
260
- " 0,\n",
256
+ " u_k, init_nn_params_k = jinns.nn.PINN_MLP.create(\n",
257
+ " key=subkey,\n",
258
+ " eqx_list=eqx_list,\n",
259
+ " eq_type=\"ODE\",\n",
261
260
  " input_transform=feature_transform,\n",
262
261
  " output_transform=partial(output_transform, id_component=i),\n",
263
262
  " )\n",
@@ -116,7 +116,7 @@
116
116
  " (eqx.nn.Linear, 32, 1)\n",
117
117
  ")\n",
118
118
  "key, subkey = random.split(key)\n",
119
- "u_pinn, init_nn_params_pinn = jinns.utils.create_PINN(subkey, eqx_list, \"nonstatio_PDE\", 1)"
119
+ "u_pinn, init_nn_params_pinn = jinns.nn.PINN_MLP(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")"
120
120
  ]
121
121
  },
122
122
  {
@@ -146,7 +146,7 @@
146
146
  " (eqx.nn.Linear, 128, r)\n",
147
147
  ")\n",
148
148
  "key, subkey = random.split(key)\n",
149
- "u_spinn, init_nn_params_spinn = jinns.utils.create_SPINN(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
149
+ "u_spinn, init_nn_params_spinn = jinns.nn.SPINN_MLP.create(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
150
150
  ]
151
151
  },
152
152
  {
@@ -153,7 +153,7 @@
153
153
  " (jnp.exp,)\n",
154
154
  ")\n",
155
155
  "key, subkey = random.split(key)\n",
156
- "u_pinn, init_nn_params_pinn = jinns.utils.create_PINN(subkey, eqx_list, \"nonstatio_PDE\", 1)"
156
+ "u_pinn, init_nn_params_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")"
157
157
  ]
158
158
  },
159
159
  {
@@ -184,7 +184,7 @@
184
184
  "\n",
185
185
  ")\n",
186
186
  "key, subkey = random.split(key)\n",
187
- "u_spinn, init_nn_params_spinn = jinns.utils.create_SPINN(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
187
+ "u_spinn, init_nn_params_spinn = jinns.nn.SPINN_MLP.create(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
188
188
  ]
189
189
  },
190
190
  {
@@ -201,7 +201,7 @@
201
201
  ")\n",
202
202
  "\n",
203
203
  "key, subkey = random.split(key)\n",
204
- "u, init_sol_nn_params= jinns.utils.create_PINN(subkey, eqx_list, \"nonstatio_PDE\", 2, slice_solution=jnp.s_[:1])"
204
+ "u, init_sol_nn_params= jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\", slice_solution=jnp.s_[:1])"
205
205
  ]
206
206
  },
207
207
  {
@@ -230,7 +230,7 @@
230
230
  "# )\n",
231
231
  "\n",
232
232
  "# key, subkey = random.split(key)\n",
233
- "# sol_pinn = jinns.utils.create_PINN(subkey, eqx_list, \"nonstatio_PDE\", 2)\n",
233
+ "# sol_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")\n",
234
234
  "# init_sol_nn_params = sol_pinn.init_params()\n",
235
235
  "# eqx_list = (\n",
236
236
  "# (eqx.nn.Linear, 3, 50), # 3 = t + x (2D)\n",
@@ -243,7 +243,7 @@
243
243
  "# )\n",
244
244
  "\n",
245
245
  "# key, subkey = random.split(key)\n",
246
- "# a_pinn = jinns.utils.create_PINN(subkey, eqx_list, \"nonstatio_PDE\", 2)\n",
246
+ "# a_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")\n",
247
247
  "# init_a_nn_params = a_pinn.init_params()\n",
248
248
  "\n",
249
249
  "# from jinns.utils._pinn import PINN\n",
@@ -206,7 +206,7 @@
206
206
  "]\n",
207
207
  "key, subkey = random.split(key)\n",
208
208
  "u_output_transform = lambda pinn_in, pinn_out, params: pinn_out * (R**2 - pinn_in[1] ** 2)\n",
209
- "u, u_init_nn_params = jinns.utils.create_PINN(subkey, eqx_list, \"statio_PDE\", 2, output_transform=u_output_transform)\n",
209
+ "u, u_init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"statio_PDE\", output_transform=u_output_transform)\n",
210
210
  "\n",
211
211
  "eqx_list = [\n",
212
212
  " [eqx.nn.Linear, 2, 50],\n",
@@ -223,7 +223,7 @@
223
223
  " + (xmax - pinn_in[0]) / (xmax - xmin) * p_in\n",
224
224
  " + (xmin - pinn_in[0]) * (xmax - pinn_in[0]) * pinn_out\n",
225
225
  " )\n",
226
- "p, p_init_nn_params = jinns.utils.create_PINN(subkey, eqx_list, \"statio_PDE\", 2, output_transform=p_output_transform)"
226
+ "p, p_init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"statio_PDE\", output_transform=p_output_transform)"
227
227
  ]
228
228
  },
229
229
  {