jinns 1.2.0__tar.gz → 1.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jinns-1.2.0 → jinns-1.3.0}/.gitlab-ci.yml +10 -6
- {jinns-1.2.0 → jinns-1.3.0}/.pre-commit-config.yaml +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +11 -11
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/linear_fo_equation.ipynb +11 -11
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +4 -5
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/1D_non_stationary_Burgers.ipynb +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +2 -2
- jinns-1.3.0/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +1402 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +4 -4
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +1 -1
- jinns-1.3.0/Notebooks/PDE/2D_non_stationary_OU.ipynb +1059 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/1D_non_stationary_Burgers_JointEstimation_Vanilla.ipynb +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/introducing_validation_loss.ipynb +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/load_save_model.ipynb +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/PKG-INFO +9 -9
- {jinns-1.2.0 → jinns-1.3.0}/README.md +7 -7
- jinns-1.3.0/docs/api/pinn/hyperpinn.md +3 -0
- jinns-1.3.0/docs/api/pinn/pinn.md +7 -0
- jinns-1.3.0/docs/api/pinn/ppinn.md +3 -0
- jinns-1.3.0/docs/api/pinn/save_load.md +5 -0
- jinns-1.3.0/docs/api/pinn/spinn.md +7 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/changelog.md +8 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/index.md +12 -4
- {jinns-1.2.0 → jinns-1.3.0}/jinns/data/_DataGenerators.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_DynamicLoss.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_LossODE.py +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_LossPDE.py +75 -38
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_boundary_conditions.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_loss_utils.py +21 -15
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_operators.py +0 -2
- jinns-1.3.0/jinns/nn/__init__.py +7 -0
- jinns-1.3.0/jinns/nn/_hyperpinn.py +397 -0
- jinns-1.3.0/jinns/nn/_mlp.py +192 -0
- jinns-1.3.0/jinns/nn/_pinn.py +190 -0
- jinns-1.3.0/jinns/nn/_ppinn.py +203 -0
- {jinns-1.2.0/jinns/utils → jinns-1.3.0/jinns/nn}/_save_load.py +39 -23
- jinns-1.3.0/jinns/nn/_spinn.py +106 -0
- jinns-1.3.0/jinns/nn/_spinn_mlp.py +196 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/plot/_plot.py +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_rar.py +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_solve.py +23 -9
- jinns-1.3.0/jinns/utils/__init__.py +1 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_types.py +4 -4
- {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/PKG-INFO +9 -9
- {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/SOURCES.txt +18 -8
- {jinns-1.2.0 → jinns-1.3.0}/mkdocs.yml +2 -10
- jinns-1.3.0/tests/loss_tests/test_lossPDEstatio.py +138 -0
- jinns-1.3.0/tests/loss_tests/test_norm_loss.py +92 -0
- jinns-1.3.0/tests/nn_tests/test_hyperpinns.py +109 -0
- jinns-1.3.0/tests/nn_tests/test_mlp.py +104 -0
- jinns-1.3.0/tests/nn_tests/test_pinn.py +84 -0
- jinns-1.3.0/tests/nn_tests/test_ppinn_mlp.py +107 -0
- jinns-1.3.0/tests/nn_tests/test_smlp.py +72 -0
- {jinns-1.2.0/tests/utils_tests → jinns-1.3.0/tests/nn_tests}/test_spinn.py +16 -10
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_divergence_fwd.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_divergence_rev.py +4 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_laplacian_fwd.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_laplacian_rev.py +4 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_vectorial_laplacian_fwd.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/operator_tests/test_vectorial_laplacian_rev.py +4 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/parameters_tests/test_DerivativeKeysODE.py +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/plot_tests/test_plot1D.py +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/plot_tests/test_plot2D.py +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_hyperpinn.py +10 -20
- {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_pinn.py +14 -12
- {jinns-1.2.0 → jinns-1.3.0}/tests/save_load_tests/test_saving_loading_spinn.py +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/sharding_tests/test_Burgers_x32_multiple_shardings.py +6 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Burgers_x32.py +5 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Burgers_x64.py +5 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Fisher_x32.py +5 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_Fisher_x64.py +5 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_GLV_x32.py +8 -4
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_GLV_x64.py +8 -4
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_NSPipeFlow_x32.py +10 -4
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_NSPipeFlow_x64.py +10 -4
- jinns-1.3.0/tests/solver_tests/test_OU1D_statio_x32.py +134 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_OU2D_x32.py +9 -8
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_nan_params_catch.py +3 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_parameter_tracker.py +3 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests/test_rar_algorithm.py +18 -15
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_hyperpinn/test_NSPipeFlow_x32_hyperpinn.py +23 -12
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_Burgers_x32_spinn.py +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_Fisher_x32_spinn.py +3 -3
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_OU2D_x32_spinn.py +8 -6
- {jinns-1.2.0 → jinns-1.3.0}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64_spinn.py +2 -2
- {jinns-1.2.0 → jinns-1.3.0}/tests/utils_tests/test_solver_utils.py +1 -1
- {jinns-1.2.0 → jinns-1.3.0}/tests/validation_tests/test_vanilla_validation.py +1 -1
- jinns-1.2.0/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -1377
- jinns-1.2.0/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -1083
- jinns-1.2.0/docs/api/pinn/hyperpinn.md +0 -5
- jinns-1.2.0/docs/api/pinn/pinn.md +0 -5
- jinns-1.2.0/docs/api/pinn/save_load.md +0 -5
- jinns-1.2.0/docs/api/pinn/spinn.md +0 -5
- jinns-1.2.0/jinns/utils/__init__.py +0 -6
- jinns-1.2.0/jinns/utils/_hyperpinn.py +0 -420
- jinns-1.2.0/jinns/utils/_pinn.py +0 -324
- jinns-1.2.0/jinns/utils/_ppinn.py +0 -227
- jinns-1.2.0/jinns/utils/_spinn.py +0 -249
- jinns-1.2.0/tests/utils_tests/test_hyperpinns.py +0 -132
- jinns-1.2.0/tests/utils_tests/test_pinn.py +0 -125
- {jinns-1.2.0 → jinns-1.3.0}/.gitignore +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/AUTHORS +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/LICENSE +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/Notebooks/Tutorials/burger_solution_grid.npy +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/codemeta.json +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/README.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/_static/custom_css.css +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/_static/favicon.png +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/advanced/derivative_keys.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/advanced/differential_operators.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/datagenerators/datagenerators_core.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/datagenerators/datagenerators_other.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/dynamic_loss.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/loss_xde.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/loss/systems_of_xde.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/plot.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/api/solver.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/doc_requirements.txt +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/javascripts/katex.js +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/maths/fokker_planck.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/docs/maths/introduction_to_pinns.md +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/img/jinns-diagram.png +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/data/_Batchs.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/data/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/experimental/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/experimental/_diffrax_solver.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_DynamicLossAbstract.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/loss/_loss_weights.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/_derivative_keys.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/parameters/_params.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/plot/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/solver/_utils.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_containers.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/utils/_utils.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/validation/__init__.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns/validation/_validation.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/dependency_links.txt +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/requires.txt +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/jinns.egg-info/top_level.txt +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/pyproject.toml +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/setup.cfg +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/conftest.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
- {jinns-1.2.0 → jinns-1.3.0}/tests/utils_tests/test_subtract_with_check.py +0 -0
|
@@ -17,10 +17,11 @@ black:
|
|
|
17
17
|
run_tests:
|
|
18
18
|
stage: tests
|
|
19
19
|
before_script:
|
|
20
|
-
-
|
|
21
|
-
-
|
|
20
|
+
- virtualenv venv
|
|
21
|
+
- source venv/bin/activate
|
|
22
|
+
- pip install pytest coverage pytest-cov
|
|
23
|
+
- pip install -e .
|
|
22
24
|
script:
|
|
23
|
-
- pip install --break-system-packages -e .
|
|
24
25
|
- pytest --cov=jinns --ignore=tests/solver_tests/test_NSPipeFlow_x64.py
|
|
25
26
|
coverage: '/TOTAL.*\s+(\d+%)$/'
|
|
26
27
|
|
|
@@ -28,10 +29,13 @@ build_doc:
|
|
|
28
29
|
stage: build
|
|
29
30
|
needs: [] # don't need to wait for other jobs
|
|
30
31
|
before_script:
|
|
31
|
-
|
|
32
|
+
- virtualenv venv
|
|
33
|
+
- source venv/bin/activate
|
|
34
|
+
- pip install -e .
|
|
35
|
+
- pip install -r docs/doc_requirements.txt
|
|
32
36
|
script:
|
|
33
|
-
- pip
|
|
34
|
-
-
|
|
37
|
+
- pip list
|
|
38
|
+
- python -m mkdocstrings_handlers.python.debug
|
|
35
39
|
- mkdocs build
|
|
36
40
|
- mkdocs build # twice, see https://github.com/patrick-kidger/pytkdocs_tweaks
|
|
37
41
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
default_stages: [commit]
|
|
2
2
|
repos:
|
|
3
3
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
4
|
-
rev:
|
|
4
|
+
rev: v5.0.0
|
|
5
5
|
hooks:
|
|
6
6
|
- id: trailing-whitespace
|
|
7
7
|
stages: [commit]
|
|
8
8
|
- id: end-of-file-fixer
|
|
9
9
|
stages: [commit]
|
|
10
10
|
- repo: https://github.com/psf/black
|
|
11
|
-
rev:
|
|
11
|
+
rev: 25.1.0
|
|
12
12
|
hooks:
|
|
13
13
|
- id: black
|
|
14
14
|
stages: [commit]
|
|
@@ -109,16 +109,16 @@
|
|
|
109
109
|
"metadata": {},
|
|
110
110
|
"outputs": [],
|
|
111
111
|
"source": [
|
|
112
|
-
"eqx_list =
|
|
113
|
-
"
|
|
114
|
-
"
|
|
115
|
-
"
|
|
116
|
-
"
|
|
117
|
-
"
|
|
118
|
-
"
|
|
119
|
-
"
|
|
120
|
-
"
|
|
121
|
-
"
|
|
112
|
+
"eqx_list = (\n",
|
|
113
|
+
" (eqx.nn.Linear, 1, 20),\n",
|
|
114
|
+
" (jax.nn.tanh,),\n",
|
|
115
|
+
" (eqx.nn.Linear, 20, 20),\n",
|
|
116
|
+
" (jax.nn.tanh,),\n",
|
|
117
|
+
" (eqx.nn.Linear, 20, 20),\n",
|
|
118
|
+
" (jax.nn.tanh,),\n",
|
|
119
|
+
" (eqx.nn.Linear, 20, 1),\n",
|
|
120
|
+
" (jnp.exp,)\n",
|
|
121
|
+
")\n",
|
|
122
122
|
"key, subkey = random.split(key)"
|
|
123
123
|
]
|
|
124
124
|
},
|
|
@@ -172,7 +172,7 @@
|
|
|
172
172
|
"init_nn_params_list = []\n",
|
|
173
173
|
"for _ in range(3):\n",
|
|
174
174
|
" key, subkey = random.split(key)\n",
|
|
175
|
-
" u, init_nn_params = jinns.
|
|
175
|
+
" u, init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"ODE\")\n",
|
|
176
176
|
" init_nn_params_list.append(init_nn_params)"
|
|
177
177
|
]
|
|
178
178
|
},
|
|
@@ -103,18 +103,18 @@
|
|
|
103
103
|
"metadata": {},
|
|
104
104
|
"outputs": [],
|
|
105
105
|
"source": [
|
|
106
|
-
"eqx_list =
|
|
107
|
-
"
|
|
108
|
-
"
|
|
109
|
-
"
|
|
110
|
-
"
|
|
111
|
-
"
|
|
112
|
-
"
|
|
113
|
-
"
|
|
114
|
-
" #
|
|
115
|
-
"
|
|
106
|
+
"eqx_list = (\n",
|
|
107
|
+
" (eqx.nn.Linear, 1, 20),\n",
|
|
108
|
+
" (jax.nn.tanh,),\n",
|
|
109
|
+
" (eqx.nn.Linear, 20, 20),\n",
|
|
110
|
+
" (jax.nn.tanh,),\n",
|
|
111
|
+
" (eqx.nn.Linear, 20, 20),\n",
|
|
112
|
+
" (jax.nn.tanh,),\n",
|
|
113
|
+
" (eqx.nn.Linear, 20, 1),\n",
|
|
114
|
+
" # (jnp.exp,)\n",
|
|
115
|
+
")\n",
|
|
116
116
|
"key, subkey = random.split(key)\n",
|
|
117
|
-
"u, init_nn_params = jinns.
|
|
117
|
+
"u, init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"ODE\")"
|
|
118
118
|
]
|
|
119
119
|
},
|
|
120
120
|
{
|
|
@@ -253,11 +253,10 @@
|
|
|
253
253
|
"\n",
|
|
254
254
|
"for i, k in enumerate(nn_keys):\n",
|
|
255
255
|
" key, subkey = random.split(key)\n",
|
|
256
|
-
" u_k, init_nn_params_k = jinns.
|
|
257
|
-
" subkey,\n",
|
|
258
|
-
" eqx_list,\n",
|
|
259
|
-
"
|
|
260
|
-
" 0,\n",
|
|
256
|
+
" u_k, init_nn_params_k = jinns.nn.PINN_MLP.create(\n",
|
|
257
|
+
" key=subkey,\n",
|
|
258
|
+
" eqx_list=eqx_list,\n",
|
|
259
|
+
" eq_type=\"ODE\",\n",
|
|
261
260
|
" input_transform=feature_transform,\n",
|
|
262
261
|
" output_transform=partial(output_transform, id_component=i),\n",
|
|
263
262
|
" )\n",
|
|
@@ -116,7 +116,7 @@
|
|
|
116
116
|
" (eqx.nn.Linear, 32, 1)\n",
|
|
117
117
|
")\n",
|
|
118
118
|
"key, subkey = random.split(key)\n",
|
|
119
|
-
"u_pinn, init_nn_params_pinn = jinns.
|
|
119
|
+
"u_pinn, init_nn_params_pinn = jinns.nn.PINN_MLP(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")"
|
|
120
120
|
]
|
|
121
121
|
},
|
|
122
122
|
{
|
|
@@ -146,7 +146,7 @@
|
|
|
146
146
|
" (eqx.nn.Linear, 128, r)\n",
|
|
147
147
|
")\n",
|
|
148
148
|
"key, subkey = random.split(key)\n",
|
|
149
|
-
"u_spinn, init_nn_params_spinn = jinns.
|
|
149
|
+
"u_spinn, init_nn_params_spinn = jinns.nn.SPINN_MLP.create(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
|
|
150
150
|
]
|
|
151
151
|
},
|
|
152
152
|
{
|
|
@@ -153,7 +153,7 @@
|
|
|
153
153
|
" (jnp.exp,)\n",
|
|
154
154
|
")\n",
|
|
155
155
|
"key, subkey = random.split(key)\n",
|
|
156
|
-
"u_pinn, init_nn_params_pinn = jinns.
|
|
156
|
+
"u_pinn, init_nn_params_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")"
|
|
157
157
|
]
|
|
158
158
|
},
|
|
159
159
|
{
|
|
@@ -184,7 +184,7 @@
|
|
|
184
184
|
"\n",
|
|
185
185
|
")\n",
|
|
186
186
|
"key, subkey = random.split(key)\n",
|
|
187
|
-
"u_spinn, init_nn_params_spinn = jinns.
|
|
187
|
+
"u_spinn, init_nn_params_spinn = jinns.nn.SPINN_MLP.create(subkey, d, r, eqx_list, \"nonstatio_PDE\")"
|
|
188
188
|
]
|
|
189
189
|
},
|
|
190
190
|
{
|
|
@@ -201,7 +201,7 @@
|
|
|
201
201
|
")\n",
|
|
202
202
|
"\n",
|
|
203
203
|
"key, subkey = random.split(key)\n",
|
|
204
|
-
"u, init_sol_nn_params= jinns.
|
|
204
|
+
"u, init_sol_nn_params= jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\", slice_solution=jnp.s_[:1])"
|
|
205
205
|
]
|
|
206
206
|
},
|
|
207
207
|
{
|
|
@@ -230,7 +230,7 @@
|
|
|
230
230
|
"# )\n",
|
|
231
231
|
"\n",
|
|
232
232
|
"# key, subkey = random.split(key)\n",
|
|
233
|
-
"# sol_pinn = jinns.
|
|
233
|
+
"# sol_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")\n",
|
|
234
234
|
"# init_sol_nn_params = sol_pinn.init_params()\n",
|
|
235
235
|
"# eqx_list = (\n",
|
|
236
236
|
"# (eqx.nn.Linear, 3, 50), # 3 = t + x (2D)\n",
|
|
@@ -243,7 +243,7 @@
|
|
|
243
243
|
"# )\n",
|
|
244
244
|
"\n",
|
|
245
245
|
"# key, subkey = random.split(key)\n",
|
|
246
|
-
"# a_pinn = jinns.
|
|
246
|
+
"# a_pinn = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"nonstatio_PDE\")\n",
|
|
247
247
|
"# init_a_nn_params = a_pinn.init_params()\n",
|
|
248
248
|
"\n",
|
|
249
249
|
"# from jinns.utils._pinn import PINN\n",
|
|
@@ -206,7 +206,7 @@
|
|
|
206
206
|
"]\n",
|
|
207
207
|
"key, subkey = random.split(key)\n",
|
|
208
208
|
"u_output_transform = lambda pinn_in, pinn_out, params: pinn_out * (R**2 - pinn_in[1] ** 2)\n",
|
|
209
|
-
"u, u_init_nn_params = jinns.
|
|
209
|
+
"u, u_init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"statio_PDE\", output_transform=u_output_transform)\n",
|
|
210
210
|
"\n",
|
|
211
211
|
"eqx_list = [\n",
|
|
212
212
|
" [eqx.nn.Linear, 2, 50],\n",
|
|
@@ -223,7 +223,7 @@
|
|
|
223
223
|
" + (xmax - pinn_in[0]) / (xmax - xmin) * p_in\n",
|
|
224
224
|
" + (xmin - pinn_in[0]) * (xmax - pinn_in[0]) * pinn_out\n",
|
|
225
225
|
" )\n",
|
|
226
|
-
"p, p_init_nn_params = jinns.
|
|
226
|
+
"p, p_init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"statio_PDE\", output_transform=p_output_transform)"
|
|
227
227
|
]
|
|
228
228
|
},
|
|
229
229
|
{
|