jinns 1.4.0__tar.gz → 1.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +70 -47
- jinns-1.5.1/Notebooks/ODE/MS_model_Verhulst.ipynb +1951 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/ODE/linear_fo_equation.ipynb +28 -23
- jinns-1.5.1/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +899 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/1D_non_stationary_Burgers.ipynb +4 -6
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +91 -44
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +122 -62
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +70 -76
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +208 -140
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +110 -104
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +88 -47
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2D_non_stationary_OU.ipynb +13 -9
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +209 -219
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +98 -75
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/Tutorials/1D_non_stationary_Burgers_JointEstimation_Vanilla.ipynb +15 -15
- jinns-1.5.1/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +769 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/Tutorials/introducing_validation_loss.ipynb +2 -2
- {jinns-1.4.0 → jinns-1.5.1}/PKG-INFO +5 -2
- {jinns-1.4.0 → jinns-1.5.1}/README.md +4 -1
- jinns-1.5.1/docs/api/advanced/loss_weight_updates.md +7 -0
- jinns-1.5.1/docs/api/loss/loss_weights.md +13 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/changelog.md +13 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/__init__.py +7 -7
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_CubicMeshPDENonStatio.py +156 -28
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_CubicMeshPDEStatio.py +132 -24
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_DynamicLossAbstract.py +30 -2
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_LossODE.py +177 -64
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_LossPDE.py +146 -68
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/__init__.py +4 -0
- jinns-1.5.1/jinns/loss/_abstract_loss.py +128 -0
- jinns-1.5.1/jinns/loss/_loss_components.py +43 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_loss_utils.py +34 -24
- jinns-1.5.1/jinns/loss/_loss_weight_updates.py +202 -0
- jinns-1.5.1/jinns/loss/_loss_weights.py +83 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/parameters/_params.py +8 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/solver/_solve.py +141 -46
- {jinns-1.4.0 → jinns-1.5.1}/jinns/utils/_containers.py +5 -2
- {jinns-1.4.0 → jinns-1.5.1}/jinns/utils/_types.py +12 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns.egg-info/PKG-INFO +5 -2
- {jinns-1.4.0 → jinns-1.5.1}/jinns.egg-info/SOURCES.txt +9 -0
- {jinns-1.4.0 → jinns-1.5.1}/mkdocs.yml +3 -1
- jinns-1.5.1/tests/adaptative_weight_tests/test_ReLoBRaLo_update.py +58 -0
- jinns-1.5.1/tests/adaptative_weight_tests/test_loss_weight_update.py +210 -0
- jinns-1.5.1/tests/adaptative_weight_tests/test_lr_annealing.py +69 -0
- jinns-1.5.1/tests/adaptative_weight_tests/test_soft_adapt.py +57 -0
- jinns-1.5.1/tests/dataGenerator_tests/test_sobol_method.py +94 -0
- jinns-1.5.1/tests/loss_tests/test_lossODE.py +120 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/loss_tests/test_lossPDEstatio.py +1 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/parameters_tests/test_DerivativeKeysODE.py +1 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/sharding_tests/test_Burgers_x32_multiple_shardings.py +2 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_Burgers_x32.py +2 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_Burgers_x64.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_Fisher_x32.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_Fisher_x64.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_GLV_x32.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_GLV_x64.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_NSPipeFlow_x32.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_NSPipeFlow_x64.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_OU1D_statio_x32.py +2 -6
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_OU2D_x32.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_nan_params_catch.py +1 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_parameter_tracker.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests/test_rar_algorithm.py +1 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_hyperpinn/test_NSPipeFlow_x32_hyperpinn.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_Burgers_x32_spinn.py +2 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_Fisher_x32_spinn.py +2 -2
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_OU2D_x32_spinn.py +1 -1
- {jinns-1.4.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64_spinn.py +2 -4
- {jinns-1.4.0 → jinns-1.5.1}/tests/validation_tests/test_vanilla_validation.py +1 -0
- jinns-1.4.0/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -878
- jinns-1.4.0/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -740
- jinns-1.4.0/docs/api/loss/loss_weights.md +0 -9
- jinns-1.4.0/jinns/loss/_abstract_loss.py +0 -15
- jinns-1.4.0/jinns/loss/_loss_weights.py +0 -27
- jinns-1.4.0/tests/loss_tests/test_lossODE.py +0 -41
- {jinns-1.4.0 → jinns-1.5.1}/.gitignore +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/.gitlab-ci.yml +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/.pre-commit-config.yaml +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/AUTHORS +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/LICENSE +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/Tutorials/burgers_solution_grid.npy +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/codemeta.json +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/README.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/_static/custom_css.css +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/_static/favicon.png +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/advanced/derivative_keys.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/advanced/differential_operators.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/datagenerators/datagenerators_core.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/datagenerators/datagenerators_other.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/loss/dynamic_loss.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/loss/loss_xde.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/pinn/hyperpinn.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/pinn/pinn.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/pinn/ppinn.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/pinn/save_load.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/pinn/spinn.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/plot.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/api/solver.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/doc_requirements.txt +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/index.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/javascripts/katex.js +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/maths/fokker_planck.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/docs/maths/introduction_to_pinns.md +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/img/jinns-diagram.png +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_AbstractDataGenerator.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_Batchs.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_DataGeneratorODE.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_DataGeneratorObservations.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_DataGeneratorParameter.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/data/_utils.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/experimental/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/experimental/_diffrax_solver.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_DynamicLoss.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_boundary_conditions.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/loss/_operators.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_abstract_pinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_hyperpinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_mlp.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_pinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_ppinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_save_load.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_spinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_spinn_mlp.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/nn/_utils.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/parameters/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/parameters/_derivative_keys.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/plot/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/plot/_plot.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/solver/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/solver/_rar.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/solver/_utils.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/utils/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/utils/_utils.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/validation/__init__.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns/validation/_validation.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns.egg-info/dependency_links.txt +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns.egg-info/requires.txt +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/jinns.egg-info/top_level.txt +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/pyproject.toml +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/setup.cfg +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/conftest.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/loss_tests/test_lossPDEnonstatio.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/loss_tests/test_norm_loss.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_hyperpinns.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_mlp.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_pinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_ppinn_mlp.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_smlp.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/nn_tests/test_spinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_divergence_fwd.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_divergence_rev.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_laplacian_fwd.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_laplacian_rev.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_vectorial_laplacian_fwd.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/operator_tests/test_vectorial_laplacian_rev.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/plot_tests/test_plot1D.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/plot_tests/test_plot2D.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/utils_tests/test_solver_utils.py +0 -0
- {jinns-1.4.0 → jinns-1.5.1}/tests/utils_tests/test_subtract_with_check.py +0 -0
|
@@ -133,7 +133,7 @@
|
|
|
133
133
|
" (eqx.nn.Linear, 64, 64),\n",
|
|
134
134
|
" (jax.nn.tanh,),\n",
|
|
135
135
|
" (eqx.nn.Linear, 64, 3),\n",
|
|
136
|
-
" (jnp.exp,)
|
|
136
|
+
" (jnp.exp,),\n",
|
|
137
137
|
")\n",
|
|
138
138
|
"key, subkey = random.split(key)"
|
|
139
139
|
]
|
|
@@ -154,19 +154,14 @@
|
|
|
154
154
|
"outputs": [],
|
|
155
155
|
"source": [
|
|
156
156
|
"nt = 5000\n",
|
|
157
|
-
"method =
|
|
157
|
+
"method = \"uniform\"\n",
|
|
158
158
|
"tmin = 0\n",
|
|
159
159
|
"tmax = 1\n",
|
|
160
160
|
"\n",
|
|
161
161
|
"Tmax = 30\n",
|
|
162
162
|
"key, subkey = random.split(key)\n",
|
|
163
163
|
"train_data = jinns.data.DataGeneratorODE(\n",
|
|
164
|
-
" key=subkey
|
|
165
|
-
" nt=nt,\n",
|
|
166
|
-
" tmin=tmin,\n",
|
|
167
|
-
" tmax=tmax,\n",
|
|
168
|
-
" temporal_batch_size=512,\n",
|
|
169
|
-
" method=method\n",
|
|
164
|
+
" key=subkey, nt=nt, tmin=tmin, tmax=tmax, temporal_batch_size=512, method=method\n",
|
|
170
165
|
")"
|
|
171
166
|
]
|
|
172
167
|
},
|
|
@@ -185,7 +180,9 @@
|
|
|
185
180
|
"metadata": {},
|
|
186
181
|
"outputs": [],
|
|
187
182
|
"source": [
|
|
188
|
-
"u, init_nn_params = jinns.nn.PINN_MLP.create(
|
|
183
|
+
"u, init_nn_params = jinns.nn.PINN_MLP.create(\n",
|
|
184
|
+
" key=subkey, eqx_list=eqx_list, eq_type=\"ODE\"\n",
|
|
185
|
+
")"
|
|
189
186
|
]
|
|
190
187
|
},
|
|
191
188
|
{
|
|
@@ -207,7 +204,8 @@
|
|
|
207
204
|
"source": [
|
|
208
205
|
"# initial conditions for each species\n",
|
|
209
206
|
"import numpy as onp\n",
|
|
210
|
-
"
|
|
207
|
+
"\n",
|
|
208
|
+
"N_0 = onp.array([10.0, 7.0, 4.0])\n",
|
|
211
209
|
"# growth rates for each species\n",
|
|
212
210
|
"growth_rates = jnp.array([0.1, 0.5, 0.8])\n",
|
|
213
211
|
"# carrying capacity for each species\n",
|
|
@@ -215,10 +213,11 @@
|
|
|
215
213
|
"# interactions\n",
|
|
216
214
|
"# NOTE that interaction with oneself is 0\n",
|
|
217
215
|
"# NOTE minus sign\n",
|
|
218
|
-
"interactions = (
|
|
219
|
-
"
|
|
220
|
-
"
|
|
221
|
-
"
|
|
216
|
+
"interactions = (\n",
|
|
217
|
+
" -jnp.array([0, 0.001, 0.001]),\n",
|
|
218
|
+
" -jnp.array([0.001, 0, 0.001]),\n",
|
|
219
|
+
" -jnp.array([0.001, 0.001, 0]),\n",
|
|
220
|
+
")"
|
|
222
221
|
]
|
|
223
222
|
},
|
|
224
223
|
{
|
|
@@ -286,9 +285,21 @@
|
|
|
286
285
|
"source": [
|
|
287
286
|
"vectorized_u_init = vmap(lambda t: u(t, init_params), (0), 0)\n",
|
|
288
287
|
"\n",
|
|
289
|
-
"plt.plot(
|
|
290
|
-
"
|
|
291
|
-
"
|
|
288
|
+
"plt.plot(\n",
|
|
289
|
+
" train_data.times.sort(axis=0) * Tmax,\n",
|
|
290
|
+
" vectorized_u_init(train_data.times.sort(axis=0))[:, 0],\n",
|
|
291
|
+
" label=\"N1\",\n",
|
|
292
|
+
")\n",
|
|
293
|
+
"plt.plot(\n",
|
|
294
|
+
" train_data.times.sort(axis=0) * Tmax,\n",
|
|
295
|
+
" vectorized_u_init(train_data.times.sort(axis=0))[:, 1],\n",
|
|
296
|
+
" label=\"N2\",\n",
|
|
297
|
+
")\n",
|
|
298
|
+
"plt.plot(\n",
|
|
299
|
+
" train_data.times.sort(axis=0) * Tmax,\n",
|
|
300
|
+
" vectorized_u_init(train_data.times.sort(axis=0))[:, 2],\n",
|
|
301
|
+
" label=\"N3\",\n",
|
|
302
|
+
")\n",
|
|
292
303
|
"\n",
|
|
293
304
|
"plt.legend()"
|
|
294
305
|
]
|
|
@@ -333,7 +344,7 @@
|
|
|
333
344
|
" loss_weights=loss_weights,\n",
|
|
334
345
|
" dynamic_loss=dynamic_loss,\n",
|
|
335
346
|
" initial_condition=(float(tmin), jnp.array([N_0[0], N_0[1], N_0[2]])),\n",
|
|
336
|
-
" params=init_params
|
|
347
|
+
" params=init_params,\n",
|
|
337
348
|
")"
|
|
338
349
|
]
|
|
339
350
|
},
|
|
@@ -357,10 +368,7 @@
|
|
|
357
368
|
"train_data, batch = train_data.get_batch()\n",
|
|
358
369
|
"\n",
|
|
359
370
|
"losses_and_grad = jax.value_and_grad(loss.evaluate, 0, has_aux=True)\n",
|
|
360
|
-
"losses, grads = losses_and_grad(\n",
|
|
361
|
-
" init_params,\n",
|
|
362
|
-
" batch\n",
|
|
363
|
-
")\n",
|
|
371
|
+
"losses, grads = losses_and_grad(init_params, batch)\n",
|
|
364
372
|
"l_tot, d = losses\n",
|
|
365
373
|
"print(f\"total loss: {l_tot}\")\n",
|
|
366
374
|
"print(f\"Individual losses: { {key: f'{val:.2f}' for key, val in d.items()} }\")"
|
|
@@ -406,13 +414,14 @@
|
|
|
406
414
|
" init_value=start_learning_rate,\n",
|
|
407
415
|
" transition_begin=5000,\n",
|
|
408
416
|
" transition_steps=100,\n",
|
|
409
|
-
" decay_rate=0.99
|
|
417
|
+
" decay_rate=0.99,\n",
|
|
418
|
+
")\n",
|
|
410
419
|
"\n",
|
|
411
420
|
"tx = optax.chain(\n",
|
|
412
421
|
" optax.scale_by_adam(), # Use the updates from adam.\n",
|
|
413
422
|
" optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler.\n",
|
|
414
423
|
" # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.\n",
|
|
415
|
-
" optax.scale(-1.0)
|
|
424
|
+
" optax.scale(-1.0),\n",
|
|
416
425
|
")"
|
|
417
426
|
]
|
|
418
427
|
},
|
|
@@ -487,12 +496,8 @@
|
|
|
487
496
|
}
|
|
488
497
|
],
|
|
489
498
|
"source": [
|
|
490
|
-
"params, total_loss_list, loss_by_term_dict, data, loss, _, _ , _, _ = jinns.solve(\n",
|
|
491
|
-
" init_params=params
|
|
492
|
-
" data=train_data,\n",
|
|
493
|
-
" optimizer=tx,\n",
|
|
494
|
-
" loss=loss,\n",
|
|
495
|
-
" n_iter=n_iter\n",
|
|
499
|
+
"params, total_loss_list, loss_by_term_dict, data, loss, _, _, _, _, _ = jinns.solve(\n",
|
|
500
|
+
" init_params=params, data=train_data, optimizer=tx, loss=loss, n_iter=n_iter\n",
|
|
496
501
|
")"
|
|
497
502
|
]
|
|
498
503
|
},
|
|
@@ -575,20 +580,23 @@
|
|
|
575
580
|
}
|
|
576
581
|
],
|
|
577
582
|
"source": [
|
|
578
|
-
"u_est = vmap(lambda t:u(t, params), (0), 0)\n",
|
|
583
|
+
"u_est = vmap(lambda t: u(t, params), (0), 0)\n",
|
|
579
584
|
"\n",
|
|
580
585
|
"key, subkey = random.split(key, 2)\n",
|
|
581
|
-
"val_data = jinns.data.DataGeneratorODE(
|
|
586
|
+
"val_data = jinns.data.DataGeneratorODE(\n",
|
|
587
|
+
" key=subkey, nt=nt, tmin=tmin, tmax=tmax, method=method\n",
|
|
588
|
+
")\n",
|
|
582
589
|
"\n",
|
|
583
590
|
"import pandas as pd\n",
|
|
591
|
+
"\n",
|
|
584
592
|
"ts = val_data.times.sort(axis=0).squeeze()\n",
|
|
585
593
|
"df = pd.DataFrame(\n",
|
|
586
594
|
" {\n",
|
|
587
|
-
" \"t\": ts * Tmax,
|
|
595
|
+
" \"t\": ts * Tmax, # rescale time for plotting\n",
|
|
588
596
|
" \"N1\": u_est(ts)[:, 0],\n",
|
|
589
597
|
" \"N2\": u_est(ts)[:, 1],\n",
|
|
590
598
|
" \"N3\": u_est(ts)[:, 2],\n",
|
|
591
|
-
" \"Method\": \"PINN\"
|
|
599
|
+
" \"Method\": \"PINN\",\n",
|
|
592
600
|
" },\n",
|
|
593
601
|
")\n",
|
|
594
602
|
"df.plot(x=\"t\")"
|
|
@@ -615,6 +623,7 @@
|
|
|
615
623
|
"\n",
|
|
616
624
|
"# NOTE the following line is not accurate as it skips one batch\n",
|
|
617
625
|
"\n",
|
|
626
|
+
"\n",
|
|
618
627
|
"def lotka_volterra_log(y_log, t, eq_params):\n",
|
|
619
628
|
" \"\"\"\n",
|
|
620
629
|
" Generalized Lotka-Volterra model for N bacterial species, with logarithmic transformation for stability.\n",
|
|
@@ -633,14 +642,19 @@
|
|
|
633
642
|
" dydt = np.zeros(N)\n",
|
|
634
643
|
"\n",
|
|
635
644
|
" for i in range(N):\n",
|
|
636
|
-
" dydt[i] = y[i] * (
|
|
645
|
+
" dydt[i] = y[i] * (\n",
|
|
646
|
+
" alpha[i]\n",
|
|
647
|
+
" - beta[i] * np.sum(y)\n",
|
|
648
|
+
" - np.sum([gamma[j][i] * y[j] for j in range(N)])\n",
|
|
649
|
+
" )\n",
|
|
637
650
|
"\n",
|
|
638
651
|
" dydt_log = dydt / y\n",
|
|
639
652
|
"\n",
|
|
640
653
|
" return dydt_log\n",
|
|
641
654
|
"\n",
|
|
655
|
+
"\n",
|
|
642
656
|
"# Define name bacteria\n",
|
|
643
|
-
"names = [
|
|
657
|
+
"names = [\"N1\", \"N2\", \"N3\"]\n",
|
|
644
658
|
"N = len(names)\n",
|
|
645
659
|
"\n",
|
|
646
660
|
"# Define model parameters\n",
|
|
@@ -671,14 +685,11 @@
|
|
|
671
685
|
"# comparative plots\n",
|
|
672
686
|
"df_scipy = pd.DataFrame(\n",
|
|
673
687
|
" {\n",
|
|
674
|
-
" \"t\": ts * Tmax,
|
|
675
|
-
" \"Method\": \"Scipy solver\"
|
|
676
|
-
" }
|
|
677
|
-
" {\n",
|
|
678
|
-
"
|
|
679
|
-
" },\n",
|
|
680
|
-
")\n",
|
|
681
|
-
"\n"
|
|
688
|
+
" \"t\": ts * Tmax, # rescale time for plotting\n",
|
|
689
|
+
" \"Method\": \"Scipy solver\",\n",
|
|
690
|
+
" }\n",
|
|
691
|
+
" | {f\"N{i + 1}\": y[:, i] for i in range(3)},\n",
|
|
692
|
+
")"
|
|
682
693
|
]
|
|
683
694
|
},
|
|
684
695
|
{
|
|
@@ -710,8 +721,20 @@
|
|
|
710
721
|
],
|
|
711
722
|
"source": [
|
|
712
723
|
"import seaborn as sns\n",
|
|
713
|
-
"
|
|
714
|
-
"
|
|
724
|
+
"\n",
|
|
725
|
+
"df_plot = pd.concat((df, df_scipy)).melt(\n",
|
|
726
|
+
" id_vars=[\"Method\", \"t\"], var_name=\"Population\", value_name=\"Solution\"\n",
|
|
727
|
+
")\n",
|
|
728
|
+
"sns.relplot(\n",
|
|
729
|
+
" df_plot,\n",
|
|
730
|
+
" kind=\"line\",\n",
|
|
731
|
+
" x=\"t\",\n",
|
|
732
|
+
" y=\"Solution\",\n",
|
|
733
|
+
" hue=\"Population\",\n",
|
|
734
|
+
" style=\"Method\",\n",
|
|
735
|
+
" height=4,\n",
|
|
736
|
+
" aspect=2,\n",
|
|
737
|
+
")"
|
|
715
738
|
]
|
|
716
739
|
},
|
|
717
740
|
{
|
|
@@ -749,7 +772,7 @@
|
|
|
749
772
|
"name": "python",
|
|
750
773
|
"nbconvert_exporter": "python",
|
|
751
774
|
"pygments_lexer": "ipython3",
|
|
752
|
-
"version": "3.11.
|
|
775
|
+
"version": "3.11.11"
|
|
753
776
|
},
|
|
754
777
|
"vscode": {
|
|
755
778
|
"interpreter": {
|