jinns 1.4.0__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +70 -47
  2. jinns-1.5.0/Notebooks/ODE/MS_model_Verhulst.ipynb +1951 -0
  3. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/ODE/linear_fo_equation.ipynb +28 -23
  4. jinns-1.5.0/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +899 -0
  5. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/1D_non_stationary_Burgers.ipynb +4 -6
  6. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +91 -44
  7. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +122 -62
  8. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +70 -76
  9. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +208 -140
  10. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +110 -104
  11. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +88 -47
  12. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2D_non_stationary_OU.ipynb +13 -9
  13. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +209 -219
  14. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +98 -75
  15. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/Tutorials/1D_non_stationary_Burgers_JointEstimation_Vanilla.ipynb +15 -15
  16. jinns-1.5.0/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +769 -0
  17. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/Tutorials/introducing_validation_loss.ipynb +2 -2
  18. {jinns-1.4.0 → jinns-1.5.0}/PKG-INFO +5 -2
  19. {jinns-1.4.0 → jinns-1.5.0}/README.md +4 -1
  20. jinns-1.5.0/docs/api/advanced/loss_weight_updates.md +7 -0
  21. jinns-1.5.0/docs/api/loss/loss_weights.md +13 -0
  22. {jinns-1.4.0 → jinns-1.5.0}/docs/changelog.md +6 -0
  23. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_DynamicLossAbstract.py +30 -2
  24. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_LossODE.py +88 -39
  25. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_LossPDE.py +143 -56
  26. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/__init__.py +4 -0
  27. jinns-1.5.0/jinns/loss/_abstract_loss.py +128 -0
  28. jinns-1.5.0/jinns/loss/_loss_components.py +43 -0
  29. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_loss_utils.py +11 -24
  30. jinns-1.5.0/jinns/loss/_loss_weight_updates.py +202 -0
  31. jinns-1.5.0/jinns/loss/_loss_weights.py +83 -0
  32. {jinns-1.4.0 → jinns-1.5.0}/jinns/solver/_solve.py +130 -41
  33. {jinns-1.4.0 → jinns-1.5.0}/jinns/utils/_containers.py +5 -2
  34. {jinns-1.4.0 → jinns-1.5.0}/jinns/utils/_types.py +12 -0
  35. {jinns-1.4.0 → jinns-1.5.0}/jinns.egg-info/PKG-INFO +5 -2
  36. {jinns-1.4.0 → jinns-1.5.0}/jinns.egg-info/SOURCES.txt +8 -0
  37. {jinns-1.4.0 → jinns-1.5.0}/mkdocs.yml +3 -1
  38. jinns-1.5.0/tests/adaptative_weight_tests/test_ReLoBRaLo_update.py +58 -0
  39. jinns-1.5.0/tests/adaptative_weight_tests/test_loss_weight_update.py +210 -0
  40. jinns-1.5.0/tests/adaptative_weight_tests/test_lr_annealing.py +69 -0
  41. jinns-1.5.0/tests/adaptative_weight_tests/test_soft_adapt.py +57 -0
  42. {jinns-1.4.0 → jinns-1.5.0}/tests/loss_tests/test_lossPDEstatio.py +1 -2
  43. {jinns-1.4.0 → jinns-1.5.0}/tests/parameters_tests/test_DerivativeKeysODE.py +1 -2
  44. {jinns-1.4.0 → jinns-1.5.0}/tests/sharding_tests/test_Burgers_x32_multiple_shardings.py +2 -2
  45. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_Burgers_x32.py +1 -1
  46. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_Burgers_x64.py +1 -1
  47. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_Fisher_x32.py +1 -1
  48. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_Fisher_x64.py +1 -1
  49. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_GLV_x32.py +1 -1
  50. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_GLV_x64.py +1 -1
  51. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_NSPipeFlow_x32.py +1 -1
  52. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_NSPipeFlow_x64.py +1 -1
  53. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_OU1D_statio_x32.py +2 -6
  54. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_OU2D_x32.py +1 -1
  55. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_nan_params_catch.py +1 -2
  56. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_parameter_tracker.py +1 -1
  57. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests/test_rar_algorithm.py +1 -2
  58. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_hyperpinn/test_NSPipeFlow_x32_hyperpinn.py +1 -1
  59. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_spinn/test_Burgers_x32_spinn.py +1 -1
  60. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_spinn/test_Fisher_x32_spinn.py +1 -1
  61. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +1 -1
  62. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_spinn/test_OU2D_x32_spinn.py +1 -1
  63. {jinns-1.4.0 → jinns-1.5.0}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64_spinn.py +1 -2
  64. {jinns-1.4.0 → jinns-1.5.0}/tests/validation_tests/test_vanilla_validation.py +1 -0
  65. jinns-1.4.0/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -878
  66. jinns-1.4.0/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -740
  67. jinns-1.4.0/docs/api/loss/loss_weights.md +0 -9
  68. jinns-1.4.0/jinns/loss/_abstract_loss.py +0 -15
  69. jinns-1.4.0/jinns/loss/_loss_weights.py +0 -27
  70. {jinns-1.4.0 → jinns-1.5.0}/.gitignore +0 -0
  71. {jinns-1.4.0 → jinns-1.5.0}/.gitlab-ci.yml +0 -0
  72. {jinns-1.4.0 → jinns-1.5.0}/.pre-commit-config.yaml +0 -0
  73. {jinns-1.4.0 → jinns-1.5.0}/AUTHORS +0 -0
  74. {jinns-1.4.0 → jinns-1.5.0}/LICENSE +0 -0
  75. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
  76. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
  77. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  78. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  79. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/Tutorials/burgers_solution_grid.npy +0 -0
  80. {jinns-1.4.0 → jinns-1.5.0}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
  81. {jinns-1.4.0 → jinns-1.5.0}/codemeta.json +0 -0
  82. {jinns-1.4.0 → jinns-1.5.0}/docs/README.md +0 -0
  83. {jinns-1.4.0 → jinns-1.5.0}/docs/_static/custom_css.css +0 -0
  84. {jinns-1.4.0 → jinns-1.5.0}/docs/_static/favicon.png +0 -0
  85. {jinns-1.4.0 → jinns-1.5.0}/docs/api/advanced/derivative_keys.md +0 -0
  86. {jinns-1.4.0 → jinns-1.5.0}/docs/api/advanced/differential_operators.md +0 -0
  87. {jinns-1.4.0 → jinns-1.5.0}/docs/api/datagenerators/datagenerators_core.md +0 -0
  88. {jinns-1.4.0 → jinns-1.5.0}/docs/api/datagenerators/datagenerators_other.md +0 -0
  89. {jinns-1.4.0 → jinns-1.5.0}/docs/api/loss/dynamic_loss.md +0 -0
  90. {jinns-1.4.0 → jinns-1.5.0}/docs/api/loss/loss_xde.md +0 -0
  91. {jinns-1.4.0 → jinns-1.5.0}/docs/api/pinn/hyperpinn.md +0 -0
  92. {jinns-1.4.0 → jinns-1.5.0}/docs/api/pinn/pinn.md +0 -0
  93. {jinns-1.4.0 → jinns-1.5.0}/docs/api/pinn/ppinn.md +0 -0
  94. {jinns-1.4.0 → jinns-1.5.0}/docs/api/pinn/save_load.md +0 -0
  95. {jinns-1.4.0 → jinns-1.5.0}/docs/api/pinn/spinn.md +0 -0
  96. {jinns-1.4.0 → jinns-1.5.0}/docs/api/plot.md +0 -0
  97. {jinns-1.4.0 → jinns-1.5.0}/docs/api/solver.md +0 -0
  98. {jinns-1.4.0 → jinns-1.5.0}/docs/doc_requirements.txt +0 -0
  99. {jinns-1.4.0 → jinns-1.5.0}/docs/index.md +0 -0
  100. {jinns-1.4.0 → jinns-1.5.0}/docs/javascripts/katex.js +0 -0
  101. {jinns-1.4.0 → jinns-1.5.0}/docs/maths/fokker_planck.md +0 -0
  102. {jinns-1.4.0 → jinns-1.5.0}/docs/maths/introduction_to_pinns.md +0 -0
  103. {jinns-1.4.0 → jinns-1.5.0}/img/jinns-diagram.png +0 -0
  104. {jinns-1.4.0 → jinns-1.5.0}/jinns/__init__.py +0 -0
  105. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_AbstractDataGenerator.py +0 -0
  106. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_Batchs.py +0 -0
  107. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_CubicMeshPDENonStatio.py +0 -0
  108. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_CubicMeshPDEStatio.py +0 -0
  109. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_DataGeneratorODE.py +0 -0
  110. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_DataGeneratorObservations.py +0 -0
  111. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_DataGeneratorParameter.py +0 -0
  112. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/__init__.py +0 -0
  113. {jinns-1.4.0 → jinns-1.5.0}/jinns/data/_utils.py +0 -0
  114. {jinns-1.4.0 → jinns-1.5.0}/jinns/experimental/__init__.py +0 -0
  115. {jinns-1.4.0 → jinns-1.5.0}/jinns/experimental/_diffrax_solver.py +0 -0
  116. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_DynamicLoss.py +0 -0
  117. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_boundary_conditions.py +0 -0
  118. {jinns-1.4.0 → jinns-1.5.0}/jinns/loss/_operators.py +0 -0
  119. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/__init__.py +0 -0
  120. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_abstract_pinn.py +0 -0
  121. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_hyperpinn.py +0 -0
  122. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_mlp.py +0 -0
  123. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_pinn.py +0 -0
  124. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_ppinn.py +0 -0
  125. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_save_load.py +0 -0
  126. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_spinn.py +0 -0
  127. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_spinn_mlp.py +0 -0
  128. {jinns-1.4.0 → jinns-1.5.0}/jinns/nn/_utils.py +0 -0
  129. {jinns-1.4.0 → jinns-1.5.0}/jinns/parameters/__init__.py +0 -0
  130. {jinns-1.4.0 → jinns-1.5.0}/jinns/parameters/_derivative_keys.py +0 -0
  131. {jinns-1.4.0 → jinns-1.5.0}/jinns/parameters/_params.py +0 -0
  132. {jinns-1.4.0 → jinns-1.5.0}/jinns/plot/__init__.py +0 -0
  133. {jinns-1.4.0 → jinns-1.5.0}/jinns/plot/_plot.py +0 -0
  134. {jinns-1.4.0 → jinns-1.5.0}/jinns/solver/__init__.py +0 -0
  135. {jinns-1.4.0 → jinns-1.5.0}/jinns/solver/_rar.py +0 -0
  136. {jinns-1.4.0 → jinns-1.5.0}/jinns/solver/_utils.py +0 -0
  137. {jinns-1.4.0 → jinns-1.5.0}/jinns/utils/__init__.py +0 -0
  138. {jinns-1.4.0 → jinns-1.5.0}/jinns/utils/_utils.py +0 -0
  139. {jinns-1.4.0 → jinns-1.5.0}/jinns/validation/__init__.py +0 -0
  140. {jinns-1.4.0 → jinns-1.5.0}/jinns/validation/_validation.py +0 -0
  141. {jinns-1.4.0 → jinns-1.5.0}/jinns.egg-info/dependency_links.txt +0 -0
  142. {jinns-1.4.0 → jinns-1.5.0}/jinns.egg-info/requires.txt +0 -0
  143. {jinns-1.4.0 → jinns-1.5.0}/jinns.egg-info/top_level.txt +0 -0
  144. {jinns-1.4.0 → jinns-1.5.0}/pyproject.toml +0 -0
  145. {jinns-1.4.0 → jinns-1.5.0}/setup.cfg +0 -0
  146. {jinns-1.4.0 → jinns-1.5.0}/tests/conftest.py +0 -0
  147. {jinns-1.4.0 → jinns-1.5.0}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  148. {jinns-1.4.0 → jinns-1.5.0}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  149. {jinns-1.4.0 → jinns-1.5.0}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  150. {jinns-1.4.0 → jinns-1.5.0}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
  151. {jinns-1.4.0 → jinns-1.5.0}/tests/loss_tests/test_lossODE.py +0 -0
  152. {jinns-1.4.0 → jinns-1.5.0}/tests/loss_tests/test_lossPDEnonstatio.py +0 -0
  153. {jinns-1.4.0 → jinns-1.5.0}/tests/loss_tests/test_norm_loss.py +0 -0
  154. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_hyperpinns.py +0 -0
  155. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_mlp.py +0 -0
  156. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_pinn.py +0 -0
  157. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_ppinn_mlp.py +0 -0
  158. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_smlp.py +0 -0
  159. {jinns-1.4.0 → jinns-1.5.0}/tests/nn_tests/test_spinn.py +0 -0
  160. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_divergence_fwd.py +0 -0
  161. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_divergence_rev.py +0 -0
  162. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_laplacian_fwd.py +0 -0
  163. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_laplacian_rev.py +0 -0
  164. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_vectorial_laplacian_fwd.py +0 -0
  165. {jinns-1.4.0 → jinns-1.5.0}/tests/operator_tests/test_vectorial_laplacian_rev.py +0 -0
  166. {jinns-1.4.0 → jinns-1.5.0}/tests/plot_tests/test_plot1D.py +0 -0
  167. {jinns-1.4.0 → jinns-1.5.0}/tests/plot_tests/test_plot2D.py +0 -0
  168. {jinns-1.4.0 → jinns-1.5.0}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
  169. {jinns-1.4.0 → jinns-1.5.0}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
  170. {jinns-1.4.0 → jinns-1.5.0}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
  171. {jinns-1.4.0 → jinns-1.5.0}/tests/utils_tests/test_solver_utils.py +0 -0
  172. {jinns-1.4.0 → jinns-1.5.0}/tests/utils_tests/test_subtract_with_check.py +0 -0
@@ -133,7 +133,7 @@
133
133
  " (eqx.nn.Linear, 64, 64),\n",
134
134
  " (jax.nn.tanh,),\n",
135
135
  " (eqx.nn.Linear, 64, 3),\n",
136
- " (jnp.exp,)\n",
136
+ " (jnp.exp,),\n",
137
137
  ")\n",
138
138
  "key, subkey = random.split(key)"
139
139
  ]
@@ -154,19 +154,14 @@
154
154
  "outputs": [],
155
155
  "source": [
156
156
  "nt = 5000\n",
157
- "method = 'uniform'\n",
157
+ "method = \"uniform\"\n",
158
158
  "tmin = 0\n",
159
159
  "tmax = 1\n",
160
160
  "\n",
161
161
  "Tmax = 30\n",
162
162
  "key, subkey = random.split(key)\n",
163
163
  "train_data = jinns.data.DataGeneratorODE(\n",
164
- " key=subkey,\n",
165
- " nt=nt,\n",
166
- " tmin=tmin,\n",
167
- " tmax=tmax,\n",
168
- " temporal_batch_size=512,\n",
169
- " method=method\n",
164
+ " key=subkey, nt=nt, tmin=tmin, tmax=tmax, temporal_batch_size=512, method=method\n",
170
165
  ")"
171
166
  ]
172
167
  },
@@ -185,7 +180,9 @@
185
180
  "metadata": {},
186
181
  "outputs": [],
187
182
  "source": [
188
- "u, init_nn_params = jinns.nn.PINN_MLP.create(key=subkey, eqx_list=eqx_list, eq_type=\"ODE\")"
183
+ "u, init_nn_params = jinns.nn.PINN_MLP.create(\n",
184
+ " key=subkey, eqx_list=eqx_list, eq_type=\"ODE\"\n",
185
+ ")"
189
186
  ]
190
187
  },
191
188
  {
@@ -207,7 +204,8 @@
207
204
  "source": [
208
205
  "# initial conditions for each species\n",
209
206
  "import numpy as onp\n",
210
- "N_0 = onp.array([10., 7., 4.])\n",
207
+ "\n",
208
+ "N_0 = onp.array([10.0, 7.0, 4.0])\n",
211
209
  "# growth rates for each species\n",
212
210
  "growth_rates = jnp.array([0.1, 0.5, 0.8])\n",
213
211
  "# carrying capacity for each species\n",
@@ -215,10 +213,11 @@
215
213
  "# interactions\n",
216
214
  "# NOTE that interaction with oneself is 0\n",
217
215
  "# NOTE minus sign\n",
218
- "interactions = (-jnp.array([0, 0.001, 0.001]),\n",
219
- " -jnp.array([0.001, 0, 0.001]), \n",
220
- " -jnp.array([0.001, 0.001, 0])\n",
221
- " )"
216
+ "interactions = (\n",
217
+ " -jnp.array([0, 0.001, 0.001]),\n",
218
+ " -jnp.array([0.001, 0, 0.001]),\n",
219
+ " -jnp.array([0.001, 0.001, 0]),\n",
220
+ ")"
222
221
  ]
223
222
  },
224
223
  {
@@ -286,9 +285,21 @@
286
285
  "source": [
287
286
  "vectorized_u_init = vmap(lambda t: u(t, init_params), (0), 0)\n",
288
287
  "\n",
289
- "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_u_init(train_data.times.sort(axis=0))[:, 0], label=\"N1\")\n",
290
- "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_u_init(train_data.times.sort(axis=0))[:, 1], label=\"N2\")\n",
291
- "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_u_init(train_data.times.sort(axis=0))[:, 2], label=\"N3\")\n",
288
+ "plt.plot(\n",
289
+ " train_data.times.sort(axis=0) * Tmax,\n",
290
+ " vectorized_u_init(train_data.times.sort(axis=0))[:, 0],\n",
291
+ " label=\"N1\",\n",
292
+ ")\n",
293
+ "plt.plot(\n",
294
+ " train_data.times.sort(axis=0) * Tmax,\n",
295
+ " vectorized_u_init(train_data.times.sort(axis=0))[:, 1],\n",
296
+ " label=\"N2\",\n",
297
+ ")\n",
298
+ "plt.plot(\n",
299
+ " train_data.times.sort(axis=0) * Tmax,\n",
300
+ " vectorized_u_init(train_data.times.sort(axis=0))[:, 2],\n",
301
+ " label=\"N3\",\n",
302
+ ")\n",
292
303
  "\n",
293
304
  "plt.legend()"
294
305
  ]
@@ -333,7 +344,7 @@
333
344
  " loss_weights=loss_weights,\n",
334
345
  " dynamic_loss=dynamic_loss,\n",
335
346
  " initial_condition=(float(tmin), jnp.array([N_0[0], N_0[1], N_0[2]])),\n",
336
- " params=init_params\n",
347
+ " params=init_params,\n",
337
348
  ")"
338
349
  ]
339
350
  },
@@ -357,10 +368,7 @@
357
368
  "train_data, batch = train_data.get_batch()\n",
358
369
  "\n",
359
370
  "losses_and_grad = jax.value_and_grad(loss.evaluate, 0, has_aux=True)\n",
360
- "losses, grads = losses_and_grad(\n",
361
- " init_params,\n",
362
- " batch\n",
363
- ")\n",
371
+ "losses, grads = losses_and_grad(init_params, batch)\n",
364
372
  "l_tot, d = losses\n",
365
373
  "print(f\"total loss: {l_tot}\")\n",
366
374
  "print(f\"Individual losses: { {key: f'{val:.2f}' for key, val in d.items()} }\")"
@@ -406,13 +414,14 @@
406
414
  " init_value=start_learning_rate,\n",
407
415
  " transition_begin=5000,\n",
408
416
  " transition_steps=100,\n",
409
- " decay_rate=0.99)\n",
417
+ " decay_rate=0.99,\n",
418
+ ")\n",
410
419
  "\n",
411
420
  "tx = optax.chain(\n",
412
421
  " optax.scale_by_adam(), # Use the updates from adam.\n",
413
422
  " optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler.\n",
414
423
  " # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.\n",
415
- " optax.scale(-1.0)\n",
424
+ " optax.scale(-1.0),\n",
416
425
  ")"
417
426
  ]
418
427
  },
@@ -487,12 +496,8 @@
487
496
  }
488
497
  ],
489
498
  "source": [
490
- "params, total_loss_list, loss_by_term_dict, data, loss, _, _ , _, _ = jinns.solve(\n",
491
- " init_params=params,\n",
492
- " data=train_data,\n",
493
- " optimizer=tx,\n",
494
- " loss=loss,\n",
495
- " n_iter=n_iter\n",
499
+ "params, total_loss_list, loss_by_term_dict, data, loss, _, _, _, _, _ = jinns.solve(\n",
500
+ " init_params=params, data=train_data, optimizer=tx, loss=loss, n_iter=n_iter\n",
496
501
  ")"
497
502
  ]
498
503
  },
@@ -575,20 +580,23 @@
575
580
  }
576
581
  ],
577
582
  "source": [
578
- "u_est = vmap(lambda t:u(t, params), (0), 0)\n",
583
+ "u_est = vmap(lambda t: u(t, params), (0), 0)\n",
579
584
  "\n",
580
585
  "key, subkey = random.split(key, 2)\n",
581
- "val_data = jinns.data.DataGeneratorODE(key=subkey, nt=nt, tmin=tmin, tmax=tmax, method=method)\n",
586
+ "val_data = jinns.data.DataGeneratorODE(\n",
587
+ " key=subkey, nt=nt, tmin=tmin, tmax=tmax, method=method\n",
588
+ ")\n",
582
589
  "\n",
583
590
  "import pandas as pd\n",
591
+ "\n",
584
592
  "ts = val_data.times.sort(axis=0).squeeze()\n",
585
593
  "df = pd.DataFrame(\n",
586
594
  " {\n",
587
- " \"t\": ts * Tmax, # rescale time for plotting\n",
595
+ " \"t\": ts * Tmax, # rescale time for plotting\n",
588
596
  " \"N1\": u_est(ts)[:, 0],\n",
589
597
  " \"N2\": u_est(ts)[:, 1],\n",
590
598
  " \"N3\": u_est(ts)[:, 2],\n",
591
- " \"Method\": \"PINN\"\n",
599
+ " \"Method\": \"PINN\",\n",
592
600
  " },\n",
593
601
  ")\n",
594
602
  "df.plot(x=\"t\")"
@@ -615,6 +623,7 @@
615
623
  "\n",
616
624
  "# NOTE the following line is not accurate as it skips one batch\n",
617
625
  "\n",
626
+ "\n",
618
627
  "def lotka_volterra_log(y_log, t, eq_params):\n",
619
628
  " \"\"\"\n",
620
629
  " Generalized Lotka-Volterra model for N bacterial species, with logarithmic transformation for stability.\n",
@@ -633,14 +642,19 @@
633
642
  " dydt = np.zeros(N)\n",
634
643
  "\n",
635
644
  " for i in range(N):\n",
636
- " dydt[i] = y[i] * (alpha[i] - beta[i] * np.sum(y) - np.sum([gamma[j][i] * y[j] for j in range(N)]))\n",
645
+ " dydt[i] = y[i] * (\n",
646
+ " alpha[i]\n",
647
+ " - beta[i] * np.sum(y)\n",
648
+ " - np.sum([gamma[j][i] * y[j] for j in range(N)])\n",
649
+ " )\n",
637
650
  "\n",
638
651
  " dydt_log = dydt / y\n",
639
652
  "\n",
640
653
  " return dydt_log\n",
641
654
  "\n",
655
+ "\n",
642
656
  "# Define name bacteria\n",
643
- "names = ['N1', 'N2', 'N3']\n",
657
+ "names = [\"N1\", \"N2\", \"N3\"]\n",
644
658
  "N = len(names)\n",
645
659
  "\n",
646
660
  "# Define model parameters\n",
@@ -671,14 +685,11 @@
671
685
  "# comparative plots\n",
672
686
  "df_scipy = pd.DataFrame(\n",
673
687
  " {\n",
674
- " \"t\": ts * Tmax, # rescale time for plotting\n",
675
- " \"Method\": \"Scipy solver\"\n",
676
- " } |\n",
677
- " {\n",
678
- " f\"N{i+1}\": y[:,i] for i in range(3)\n",
679
- " },\n",
680
- ")\n",
681
- "\n"
688
+ " \"t\": ts * Tmax, # rescale time for plotting\n",
689
+ " \"Method\": \"Scipy solver\",\n",
690
+ " }\n",
691
+ " | {f\"N{i + 1}\": y[:, i] for i in range(3)},\n",
692
+ ")"
682
693
  ]
683
694
  },
684
695
  {
@@ -710,8 +721,20 @@
710
721
  ],
711
722
  "source": [
712
723
  "import seaborn as sns\n",
713
- "df_plot = pd.concat((df, df_scipy)).melt(id_vars=['Method', \"t\"], var_name=\"Population\", value_name=\"Solution\")\n",
714
- "sns.relplot(df_plot, kind='line', x='t', y='Solution', hue='Population', style='Method', height=4, aspect=2)"
724
+ "\n",
725
+ "df_plot = pd.concat((df, df_scipy)).melt(\n",
726
+ " id_vars=[\"Method\", \"t\"], var_name=\"Population\", value_name=\"Solution\"\n",
727
+ ")\n",
728
+ "sns.relplot(\n",
729
+ " df_plot,\n",
730
+ " kind=\"line\",\n",
731
+ " x=\"t\",\n",
732
+ " y=\"Solution\",\n",
733
+ " hue=\"Population\",\n",
734
+ " style=\"Method\",\n",
735
+ " height=4,\n",
736
+ " aspect=2,\n",
737
+ ")"
715
738
  ]
716
739
  },
717
740
  {
@@ -749,7 +772,7 @@
749
772
  "name": "python",
750
773
  "nbconvert_exporter": "python",
751
774
  "pygments_lexer": "ipython3",
752
- "version": "3.11.2"
775
+ "version": "3.11.11"
753
776
  },
754
777
  "vscode": {
755
778
  "interpreter": {