jinns 0.4.2__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. jinns-0.5.1/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +761 -0
  2. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +3 -3
  3. jinns-0.5.1/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +1258 -0
  4. jinns-0.5.1/Notebooks/PDE/1D_non_stationary_Burger.ipynb +768 -0
  5. jinns-0.5.1/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +858 -0
  6. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +59 -73
  7. jinns-0.5.1/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel.ipynb +845 -0
  8. jinns-0.5.1/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +1171 -0
  9. jinns-0.5.1/Notebooks/PDE/2D_non_stationary_OU.ipynb +1152 -0
  10. jinns-0.5.1/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +1836 -0
  11. jinns-0.5.1/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +1053 -0
  12. jinns-0.5.1/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +714 -0
  13. {jinns-0.4.2 → jinns-0.5.1}/PKG-INFO +13 -2
  14. {jinns-0.4.2 → jinns-0.5.1}/README.md +12 -1
  15. {jinns-0.4.2 → jinns-0.5.1}/doc/source/conf.py +2 -0
  16. {jinns-0.4.2 → jinns-0.5.1}/doc/source/index.rst +14 -0
  17. {jinns-0.4.2 → jinns-0.5.1}/jinns/data/_display.py +78 -21
  18. jinns-0.5.1/jinns/loss/_DynamicLoss.py +974 -0
  19. {jinns-0.4.2 → jinns-0.5.1}/jinns/loss/_DynamicLossAbstract.py +17 -10
  20. {jinns-0.4.2 → jinns-0.5.1}/jinns/loss/_LossODE.py +8 -14
  21. {jinns-0.4.2 → jinns-0.5.1}/jinns/loss/_LossPDE.py +351 -202
  22. {jinns-0.4.2 → jinns-0.5.1}/jinns/loss/__init__.py +0 -6
  23. jinns-0.5.1/jinns/loss/_boundary_conditions.py +468 -0
  24. jinns-0.5.1/jinns/loss/_operators.py +315 -0
  25. {jinns-0.4.2 → jinns-0.5.1}/jinns/solver/_solve.py +10 -5
  26. {jinns-0.4.2 → jinns-0.5.1}/jinns/utils/__init__.py +2 -2
  27. jinns-0.5.1/jinns/utils/_pinn.py +298 -0
  28. jinns-0.5.1/jinns/utils/_spinn.py +238 -0
  29. jinns-0.5.1/jinns/utils/_utils.py +157 -0
  30. {jinns-0.4.2 → jinns-0.5.1}/jinns.egg-info/PKG-INFO +13 -2
  31. {jinns-0.4.2 → jinns-0.5.1}/jinns.egg-info/SOURCES.txt +11 -3
  32. jinns-0.5.1/tests/runtests.sh +11 -0
  33. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_Burger_x32.py +2 -3
  34. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_Burger_x64.py +2 -3
  35. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_Fisher_x32.py +2 -3
  36. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_Fisher_x64.py +2 -3
  37. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_GLV_x32.py +4 -5
  38. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_GLV_x64.py +4 -5
  39. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_NSPipeFlow_x32.py +38 -33
  40. {jinns-0.4.2 → jinns-0.5.1}/tests/solver_tests/test_NSPipeFlow_x64.py +38 -32
  41. jinns-0.5.1/tests/solver_tests/test_OU2D_x32.py +152 -0
  42. jinns-0.5.1/tests/solver_tests/test_imperfect_sobolev_x32.py +158 -0
  43. jinns-0.5.1/tests/solver_tests_spinn/test_Burger_x32.py +114 -0
  44. jinns-0.5.1/tests/solver_tests_spinn/test_Fisher_x32.py +140 -0
  45. jinns-0.5.1/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +165 -0
  46. jinns-0.5.1/tests/solver_tests_spinn/test_OU2D_x32.py +144 -0
  47. jinns-0.5.1/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +173 -0
  48. jinns-0.4.2/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -707
  49. jinns-0.4.2/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -1143
  50. jinns-0.4.2/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -692
  51. jinns-0.4.2/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -693
  52. jinns-0.4.2/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel.ipynb +0 -800
  53. jinns-0.4.2/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -795
  54. jinns-0.4.2/Notebooks/PDE/2D_non_stationary_OU_RAR.ipynb +0 -754
  55. jinns-0.4.2/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -1085
  56. jinns-0.4.2/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -1060
  57. jinns-0.4.2/Notebooks/Tutorials/implementing_your_own_ODE_problem.ipynb +0 -59
  58. jinns-0.4.2/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -716
  59. jinns-0.4.2/jinns/loss/_DynamicLoss.py +0 -1472
  60. jinns-0.4.2/jinns/loss/_boundary_conditions.py +0 -302
  61. jinns-0.4.2/jinns/loss/_operators.py +0 -162
  62. jinns-0.4.2/jinns/utils/_utils.py +0 -542
  63. jinns-0.4.2/tests/runtests.sh +0 -5
  64. {jinns-0.4.2 → jinns-0.5.1}/.gitignore +0 -0
  65. {jinns-0.4.2 → jinns-0.5.1}/.gitlab-ci.yml +0 -0
  66. {jinns-0.4.2 → jinns-0.5.1}/.pre-commit-config.yaml +0 -0
  67. {jinns-0.4.2 → jinns-0.5.1}/LICENSE +0 -0
  68. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/1D_non_stationary_Burger_JointEstimation_Vanilla.ipynb +0 -0
  69. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/1D_non_stationary_OU.ipynb +0 -0
  70. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  71. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  72. {jinns-0.4.2 → jinns-0.5.1}/Notebooks/PDE/burger_solution_grid.npy +0 -0
  73. {jinns-0.4.2 → jinns-0.5.1}/doc/Makefile +0 -0
  74. {jinns-0.4.2 → jinns-0.5.1}/doc/source/PinnSolver.rst +0 -0
  75. {jinns-0.4.2 → jinns-0.5.1}/doc/source/boundary_conditions.rst +0 -0
  76. {jinns-0.4.2 → jinns-0.5.1}/doc/source/data.rst +0 -0
  77. {jinns-0.4.2 → jinns-0.5.1}/doc/source/dynamic_loss.rst +0 -0
  78. {jinns-0.4.2 → jinns-0.5.1}/doc/source/fokker_planck.qmd +0 -0
  79. {jinns-0.4.2 → jinns-0.5.1}/doc/source/loss.rst +0 -0
  80. {jinns-0.4.2 → jinns-0.5.1}/doc/source/loss_ode.rst +0 -0
  81. {jinns-0.4.2 → jinns-0.5.1}/doc/source/loss_pde.rst +0 -0
  82. {jinns-0.4.2 → jinns-0.5.1}/doc/source/math_pinn.qmd +0 -0
  83. {jinns-0.4.2 → jinns-0.5.1}/doc/source/operators.rst +0 -0
  84. {jinns-0.4.2 → jinns-0.5.1}/doc/source/param_estim_pinn.qmd +0 -0
  85. {jinns-0.4.2 → jinns-0.5.1}/doc/source/rar.rst +0 -0
  86. {jinns-0.4.2 → jinns-0.5.1}/doc/source/seq2seq.rst +0 -0
  87. {jinns-0.4.2 → jinns-0.5.1}/doc/source/solver.rst +0 -0
  88. {jinns-0.4.2 → jinns-0.5.1}/doc/source/utils.rst +0 -0
  89. {jinns-0.4.2 → jinns-0.5.1}/jinns/__init__.py +0 -0
  90. {jinns-0.4.2 → jinns-0.5.1}/jinns/data/_DataGenerators.py +0 -0
  91. {jinns-0.4.2 → jinns-0.5.1}/jinns/data/__init__.py +0 -0
  92. {jinns-0.4.2 → jinns-0.5.1}/jinns/solver/__init__.py +0 -0
  93. {jinns-0.4.2 → jinns-0.5.1}/jinns/solver/_rar.py +0 -0
  94. {jinns-0.4.2 → jinns-0.5.1}/jinns/solver/_seq2seq.py +0 -0
  95. {jinns-0.4.2 → jinns-0.5.1}/jinns.egg-info/dependency_links.txt +0 -0
  96. {jinns-0.4.2 → jinns-0.5.1}/jinns.egg-info/requires.txt +0 -0
  97. {jinns-0.4.2 → jinns-0.5.1}/jinns.egg-info/top_level.txt +0 -0
  98. {jinns-0.4.2 → jinns-0.5.1}/pyproject.toml +0 -0
  99. {jinns-0.4.2 → jinns-0.5.1}/setup.cfg +0 -0
  100. {jinns-0.4.2 → jinns-0.5.1}/tests/conftest.py +0 -0
  101. {jinns-0.4.2 → jinns-0.5.1}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  102. {jinns-0.4.2 → jinns-0.5.1}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  103. {jinns-0.4.2 → jinns-0.5.1}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
@@ -0,0 +1,761 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "40925fec",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Generalized Lotka Volterra"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "b825efcb",
14
+ "metadata": {},
15
+ "source": [
16
+ "We consider a Generalized Lotka Volterra system with $3$ populations\n",
17
+ "$$\n",
18
+ "\\frac{\\partial}{\\partial t}u_i(t) = r_iu_i(t) - \\sum_{j\\neq i}\\alpha_{ij}u_j(t)\n",
19
+ "-\\alpha_{i,i}u_i(t) + c_iu_i(t) + \\sum_{j \\neq i} c_ju_j(t), i\\in\\{1, 2, 3\\}\n",
20
+ "$$"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "id": "f337b94d",
26
+ "metadata": {},
27
+ "source": [
28
+ "More information on this ODE system can be found at [https://stefanoallesina.github.io/Sao_Paulo_School/intro.html#basic-formulation](https://stefanoallesina.github.io/Sao_Paulo_School/intro.html#basic-formulation)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 1,
34
+ "id": "8bf8bebc-b311-4eb4-ad63-11447f62b280",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "%load_ext autoreload\n",
39
+ "%autoreload 2\n",
40
+ "%matplotlib inline"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "ddee93b7",
46
+ "metadata": {},
47
+ "source": [
48
+ "Float64 and GPU settings"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 2,
54
+ "id": "5cdc87e2",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "#import os; os.environ[\"JAX_ENABLE_X64\"] = \"TRUE\" # comment/uncomment to disable/enable float64 for JAX\n",
59
+ "#import os; os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\" # If uncommented then GPU is disable"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "id": "e42b1b48",
65
+ "metadata": {},
66
+ "source": [
67
+ "Import our package"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 3,
73
+ "id": "fbdd16f7",
74
+ "metadata": {
75
+ "scrolled": true
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "import jinns"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "id": "09955058",
85
+ "metadata": {},
86
+ "source": [
87
+ "Import other dependencies"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 4,
93
+ "id": "3abe5254-7556-424e-a57e-d364d67244a1",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "import jax\n",
98
+ "from jax import random, vmap\n",
99
+ "import jax.numpy as jnp\n",
100
+ "import equinox as eqx\n",
101
+ "\n",
102
+ "import matplotlib.pyplot as plt\n",
103
+ "\n",
104
+ "key = random.PRNGKey(2)\n",
105
+ "key, subkey = random.split(key)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "id": "2bfbd766",
111
+ "metadata": {},
112
+ "source": [
113
+ "Create the neural network architecture for the PINN with `equinox`. Note that we will use the same architecture for the 3 populations."
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 5,
119
+ "id": "9396d007-04f1-4893-a3c8-c58c36845ee0",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "eqx_list = [\n",
124
+ " [eqx.nn.Linear, 1, 20],\n",
125
+ " [jax.nn.tanh],\n",
126
+ " [eqx.nn.Linear, 20, 20],\n",
127
+ " [jax.nn.tanh],\n",
128
+ " [eqx.nn.Linear, 20, 20],\n",
129
+ " [jax.nn.tanh],\n",
130
+ " [eqx.nn.Linear, 20, 1],\n",
131
+ " [jnp.exp]\n",
132
+ "]\n",
133
+ "key, subkey = random.split(key)\n",
134
+ "u = jinns.utils.create_PINN(subkey, eqx_list, \"ODE\")"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 6,
140
+ "id": "1e47cbca-3af2-4ab2-a379-4b763c383843",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "init_nn_params = u.init_params()"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "0a5d567b",
150
+ "metadata": {},
151
+ "source": [
152
+ "Create a DataGenerator object"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 7,
158
+ "id": "15088440",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "n = 320\n",
163
+ "batch_size = 32\n",
164
+ "method = 'uniform'\n",
165
+ "tmin = 0\n",
166
+ "tmax = 1\n",
167
+ "\n",
168
+ "Tmax = 30\n",
169
+ "key, subkey = random.split(key)\n",
170
+ "train_data = jinns.data.DataGeneratorODE(\n",
171
+ " subkey,\n",
172
+ " n,\n",
173
+ " tmin,\n",
174
+ " tmax,\n",
175
+ " batch_size, \n",
176
+ " method=method\n",
177
+ ")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "id": "4f1ac783",
183
+ "metadata": {},
184
+ "source": [
185
+ "Initialize 3 set of neural network parameters for the 3 populations"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 8,
191
+ "id": "4fec8c54",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "init_nn_params_list = []\n",
196
+ "for _ in range(3):\n",
197
+ " key, subkey = random.split(key)\n",
198
+ " u = jinns.utils.create_PINN(subkey, eqx_list, \"ODE\", 0)\n",
199
+ " init_nn_params = u.init_params()\n",
200
+ " init_nn_params_list.append(init_nn_params)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "id": "e595cbcd",
206
+ "metadata": {},
207
+ "source": [
208
+ "Visualize the output of the neural networks before the parameter learning step"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 9,
214
+ "id": "d18c73d2-ff23-4019-a7a1-40cc023dbf53",
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "data": {
219
+ "text/plain": [
220
+ "<matplotlib.legend.Legend at 0x7f16fc34c790>"
221
+ ]
222
+ },
223
+ "execution_count": 9,
224
+ "metadata": {},
225
+ "output_type": "execute_result"
226
+ },
227
+ {
228
+ "data": {
229
+ "image/png": "",
230
+ "text/plain": [
231
+ "<Figure size 640x480 with 1 Axes>"
232
+ ]
233
+ },
234
+ "metadata": {},
235
+ "output_type": "display_data"
236
+ }
237
+ ],
238
+ "source": [
239
+ "vectorized_u_init = vmap(lambda t: u(t, init_nn_params_list[0]), (0), 0)\n",
240
+ "vectorized_v_init = vmap(lambda t: u(t, init_nn_params_list[1]), (0), 0)\n",
241
+ "vectorized_w_init = vmap(lambda t: u(t, init_nn_params_list[2]), (0), 0)\n",
242
+ "\n",
243
+ "\n",
244
+ "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_u_init(train_data.times.sort(axis=0)), label=\"N1\")\n",
245
+ "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_v_init(train_data.times.sort(axis=0)), label=\"N2\")\n",
246
+ "plt.plot(train_data.times.sort(axis=0) * Tmax, vectorized_w_init(train_data.times.sort(axis=0)), label=\"N3\")\n",
247
+ "\n",
248
+ "plt.legend()"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "f0701671",
254
+ "metadata": {},
255
+ "source": [
256
+ "## Model parameters"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 10,
262
+ "id": "8c609f60",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "# initial conditions for each species\n",
267
+ "N_0 = jnp.array([10., 7., 4.])\n",
268
+ "# growth rates for each species\n",
269
+ "growth_rates = jnp.array([0.1, 0.5, 0.8])\n",
270
+ "# carrying capacity for each species\n",
271
+ "carrying_capacities = jnp.array([0.04, 0.02, 0.02])\n",
272
+ "# interactions\n",
273
+ "# NOTE that for the interaction between the species **with itself** is always at position 0\n",
274
+ "# NOTE minus sign \n",
275
+ "interactions = -jnp.array([[0, 0.001, 0.001], [0, 0.001, 0.001], [0, 0.001, 0.001]])"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "markdown",
280
+ "id": "bbe7f24b",
281
+ "metadata": {},
282
+ "source": [
283
+ "## Loss construction"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "id": "33e3c866",
289
+ "metadata": {},
290
+ "source": [
291
+ "A set of parameters as required by the losses' `evaluate` functions is a dictionary with the neural network parameters `nn_params` and the equation parameters `eq_params`. Here we construct this dictionary.\n",
292
+ "\n",
293
+ "__Note__ that `nn_params` and `eq_params` must always be top level keys but can be nested dictionaries.\n",
294
+ "\n",
295
+ "__Note__ that the keys of the sub dictionaries `nn_params` and `eq_params` (here `str(i)`) can differ !"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 11,
301
+ "id": "11d93e85",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "# initiate parameters dictionary\n",
306
+ "init_params = {}\n",
307
+ "\n",
308
+ "init_params[\"nn_params\"] = {\n",
309
+ " str(i): init_nn_params_list[i]\n",
310
+ " for i in range(3)\n",
311
+ "}\n",
312
+ "\n",
313
+ "init_params[\"eq_params\"] = {\n",
314
+ " str(i):{\n",
315
+ " \"carrying_capacity\": carrying_capacities[i],\n",
316
+ " \"growth_rate\": growth_rates[i],\n",
317
+ " \"interactions\": interactions[i, :]\n",
318
+ " }\n",
319
+ " for i in range(3)\n",
320
+ "}"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "id": "aa14a602-1b0e-4582-876e-99d0322c57a0",
326
+ "metadata": {},
327
+ "source": [
328
+ "We construct a SystemLossODE with GeneralizedLotkaVolterra losses for each population. Here `key_main` refer to the key in `params[\"nn_params\"]` which are the parameters for the main PINN of the equation (the PINN which represents the solution differentiated with respect to the `t`). `key_others` refer to the keys in `params[\"nn_params\"]` which are the parameters for the PINNs which interact with `key_main`."
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 12,
334
+ "id": "b6a65062",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "N1_dynamic_loss = jinns.loss.GeneralizedLotkaVolterra(key_main=\"0\", keys_other=[\"1\", \"2\"], Tmax=Tmax)\n",
339
+ "N2_dynamic_loss = jinns.loss.GeneralizedLotkaVolterra(key_main=\"1\", keys_other=[\"0\", \"2\"], Tmax=Tmax)\n",
340
+ "N3_dynamic_loss = jinns.loss.GeneralizedLotkaVolterra(key_main=\"2\", keys_other=[\"0\", \"1\"], Tmax=Tmax)"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 13,
346
+ "id": "22b2647f",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "loss_weights = {\"dyn_loss\":1, \"initial_condition\":1 * Tmax}\n",
351
+ "\n",
352
+ "loss = jinns.loss.SystemLossODE(\n",
353
+ " u_dict={\"0\":u, \"1\":u, \"2\":u},\n",
354
+ " loss_weights=loss_weights,\n",
355
+ " dynamic_loss_dict={\"0\": N1_dynamic_loss, \"1\":N2_dynamic_loss, \"2\":N3_dynamic_loss},\n",
356
+ " initial_condition_dict={\"0\":(float(tmin), N_0[0]), \"1\":(float(tmin), N_0[1]), \"2\":(float(tmin), N_0[2])}\n",
357
+ ")"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 14,
363
+ "id": "5f5418d5-629f-4745-ad0f-3778020cc635",
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "name": "stdout",
368
+ "output_type": "stream",
369
+ "text": [
370
+ "total loss: 4579.1962890625\n",
371
+ "Individual losses: {'dyn_loss': '676.71', 'initial_condition': '3902.49', 'observations': '0.00'}\n"
372
+ ]
373
+ }
374
+ ],
375
+ "source": [
376
+ "# Testing the loss function\n",
377
+ "losses_and_grad = jax.value_and_grad(loss.evaluate, 0, has_aux=True)\n",
378
+ "losses, grads = losses_and_grad(\n",
379
+ " init_params,\n",
380
+ " train_data.get_batch()\n",
381
+ ")\n",
382
+ "l_tot, d = losses\n",
383
+ "print(f\"total loss: {l_tot}\")\n",
384
+ "print(f\"Individual losses: { {key: f'{val:.2f}' for key, val in d.items()} }\")"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "markdown",
389
+ "id": "64835b79-0bce-4f06-bd57-5ee051796663",
390
+ "metadata": {},
391
+ "source": [
392
+ "## Learning the neural network parameters\n",
393
+ "The learning process here consider known equation parameters `eq_params`. We thus only update `nn_params`"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 15,
399
+ "id": "4e2c75a4-e3de-4d10-9424-4ee4ae206da3",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "params = init_params"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": 16,
409
+ "id": "8d0106ad-d1e4-4fa8-958d-c8ebd4572d76",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "# Optimizer\n",
414
+ "import optax\n",
415
+ "tx = optax.adam(learning_rate=1e-3)"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 17,
421
+ "id": "055a7e63-4d0e-4246-b792-2007a0deeaab",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "n_iter = int(50000)"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": 18,
431
+ "id": "9284d4ed",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "key, subkey = random.split(key)"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": 19,
441
+ "id": "df0ab21d-bfc1-4e81-8708-df8b30d0173b",
442
+ "metadata": {},
443
+ "outputs": [
444
+ {
445
+ "name": "stdout",
446
+ "output_type": "stream",
447
+ "text": [
448
+ "Iteration 0: loss value = 4555.5380859375\n"
449
+ ]
450
+ },
451
+ {
452
+ "data": {
453
+ "application/vnd.jupyter.widget-view+json": {
454
+ "model_id": "0885ff2c80a84f4693344948349fcbe0",
455
+ "version_major": 2,
456
+ "version_minor": 0
457
+ },
458
+ "text/plain": [
459
+ " 0%| | 0/50000 [00:00<?, ?it/s]"
460
+ ]
461
+ },
462
+ "metadata": {},
463
+ "output_type": "display_data"
464
+ },
465
+ {
466
+ "name": "stdout",
467
+ "output_type": "stream",
468
+ "text": [
469
+ "Iteration 1000: loss value = 318.8741149902344\n",
470
+ "Iteration 2000: loss value = 224.96556091308594\n",
471
+ "Iteration 3000: loss value = 150.76829528808594\n",
472
+ "Iteration 4000: loss value = 91.44490051269531\n",
473
+ "Iteration 5000: loss value = 56.90504455566406\n",
474
+ "Iteration 6000: loss value = 35.36555099487305\n",
475
+ "Iteration 7000: loss value = 23.368562698364258\n",
476
+ "Iteration 8000: loss value = 15.524669647216797\n",
477
+ "Iteration 9000: loss value = 10.52753734588623\n",
478
+ "Iteration 10000: loss value = 7.434139251708984\n",
479
+ "Iteration 11000: loss value = 5.210718154907227\n",
480
+ "Iteration 12000: loss value = 3.7904136180877686\n",
481
+ "Iteration 13000: loss value = 2.4596922397613525\n",
482
+ "Iteration 14000: loss value = 1.9706025123596191\n",
483
+ "Iteration 15000: loss value = 1.482211709022522\n",
484
+ "Iteration 16000: loss value = 0.9678654670715332\n",
485
+ "Iteration 17000: loss value = 1.0758025646209717\n",
486
+ "Iteration 18000: loss value = 0.4116891026496887\n",
487
+ "Iteration 19000: loss value = 0.32329803705215454\n",
488
+ "Iteration 20000: loss value = 0.5885492563247681\n",
489
+ "Iteration 21000: loss value = 0.31351974606513977\n",
490
+ "Iteration 22000: loss value = 0.178497776389122\n",
491
+ "Iteration 23000: loss value = 0.10003294795751572\n",
492
+ "Iteration 24000: loss value = 0.21409007906913757\n",
493
+ "Iteration 25000: loss value = 0.10440698266029358\n",
494
+ "Iteration 26000: loss value = 0.14562144875526428\n",
495
+ "Iteration 27000: loss value = 0.09297636151313782\n",
496
+ "Iteration 28000: loss value = 0.04384230822324753\n",
497
+ "Iteration 29000: loss value = 0.06380896270275116\n",
498
+ "Iteration 30000: loss value = 0.047628261148929596\n",
499
+ "Iteration 31000: loss value = 0.08742117881774902\n",
500
+ "Iteration 32000: loss value = 0.2274346649646759\n",
501
+ "Iteration 33000: loss value = 0.039586469531059265\n",
502
+ "Iteration 34000: loss value = 0.041193120181560516\n",
503
+ "Iteration 35000: loss value = 0.0384555421769619\n",
504
+ "Iteration 36000: loss value = 0.02351181022822857\n",
505
+ "Iteration 37000: loss value = 0.03427871689200401\n",
506
+ "Iteration 38000: loss value = 0.04472379386425018\n",
507
+ "Iteration 39000: loss value = 0.018797673285007477\n",
508
+ "Iteration 40000: loss value = 0.04032493382692337\n",
509
+ "Iteration 41000: loss value = 0.01733144372701645\n",
510
+ "Iteration 42000: loss value = 0.10980413109064102\n",
511
+ "Iteration 43000: loss value = 0.19737961888313293\n",
512
+ "Iteration 44000: loss value = 0.024736665189266205\n",
513
+ "Iteration 45000: loss value = 0.0491604208946228\n",
514
+ "Iteration 46000: loss value = 0.021411903202533722\n",
515
+ "Iteration 47000: loss value = 0.02852526120841503\n",
516
+ "Iteration 48000: loss value = 0.6865979433059692\n",
517
+ "Iteration 49000: loss value = 0.02246144972741604\n",
518
+ "Iteration 50000: loss value = 0.09908341616392136\n"
519
+ ]
520
+ }
521
+ ],
522
+ "source": [
523
+ "params, total_loss_list, loss_by_term_dict, data, loss, _, _ = jinns.solve(\n",
524
+ " init_params=params,\n",
525
+ " data=train_data,\n",
526
+ " optimizer=tx,\n",
527
+ " loss=loss,\n",
528
+ " n_iter=n_iter\n",
529
+ ")"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": 20,
535
+ "id": "b65dfc33",
536
+ "metadata": {},
537
+ "outputs": [
538
+ {
539
+ "data": {
540
+ "text/plain": [
541
+ "Array(0.09908342, dtype=float32)"
542
+ ]
543
+ },
544
+ "execution_count": 20,
545
+ "metadata": {},
546
+ "output_type": "execute_result"
547
+ }
548
+ ],
549
+ "source": [
550
+ "total_loss_list[-1]"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "markdown",
555
+ "id": "1d0a1757",
556
+ "metadata": {},
557
+ "source": [
558
+ "## Results"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "markdown",
563
+ "id": "64c794ff",
564
+ "metadata": {},
565
+ "source": [
566
+ "Plot the loss values"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 21,
572
+ "id": "2cd778b4-d9d9-4f69-ad02-2a3f7eacf59d",
573
+ "metadata": {
574
+ "scrolled": true
575
+ },
576
+ "outputs": [
577
+ {
578
+ "data": {
579
+ "image/png": "",
580
+ "text/plain": [
581
+ "<Figure size 640x480 with 1 Axes>"
582
+ ]
583
+ },
584
+ "metadata": {},
585
+ "output_type": "display_data"
586
+ }
587
+ ],
588
+ "source": [
589
+ "for loss_name, loss_values in loss_by_term_dict.items():\n",
590
+ " plt.plot(jnp.log10(loss_values), label=loss_name)\n",
591
+ "plt.plot(jnp.log10(total_loss_list), label=\"total loss\")\n",
592
+ "plt.legend()\n",
593
+ "plt.show();"
594
+ ]
595
+ },
596
+ {
597
+ "cell_type": "markdown",
598
+ "id": "a6247171",
599
+ "metadata": {},
600
+ "source": [
601
+ "Plot the ODE solutions learned by the PINN"
602
+ ]
603
+ },
604
+ {
605
+ "cell_type": "code",
606
+ "execution_count": 22,
607
+ "id": "6d473743-c9a8-4406-b18c-256496cfde59",
608
+ "metadata": {},
609
+ "outputs": [
610
+ {
611
+ "data": {
612
+ "text/plain": [
613
+ "<matplotlib.legend.Legend at 0x7f16280d1050>"
614
+ ]
615
+ },
616
+ "execution_count": 22,
617
+ "metadata": {},
618
+ "output_type": "execute_result"
619
+ },
620
+ {
621
+ "data": {
622
+ "image/png": "",
623
+ "text/plain": [
624
+ "<Figure size 640x480 with 1 Axes>"
625
+ ]
626
+ },
627
+ "metadata": {},
628
+ "output_type": "display_data"
629
+ }
630
+ ],
631
+ "source": [
632
+ "u_est_fp = vmap(lambda t:u(t, params[\"nn_params\"][\"0\"]), (0), 0)\n",
633
+ "v_est_fp = vmap(lambda t:u(t, params[\"nn_params\"][\"1\"]), (0), 0)\n",
634
+ "w_est_fp = vmap(lambda t:u(t, params[\"nn_params\"][\"2\"]), (0), 0)\n",
635
+ "\n",
636
+ "\n",
637
+ "key, subkey = random.split(key, 2)\n",
638
+ "val_data = jinns.data.DataGeneratorODE(subkey, n, tmin, tmax, batch_size, method)\n",
639
+ "\n",
640
+ "plt.plot(val_data.times.sort(axis=0) * Tmax, u_est_fp(val_data.times.sort(axis=0)), label=\"N1\")\n",
641
+ "plt.plot(val_data.times.sort(axis=0) * Tmax, v_est_fp(val_data.times.sort(axis=0)), label=\"N2\")\n",
642
+ "plt.plot(val_data.times.sort(axis=0) * Tmax, w_est_fp(val_data.times.sort(axis=0)), label=\"N3\")\n",
643
+ "\n",
644
+ "plt.legend()"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "markdown",
649
+ "id": "aed49c41",
650
+ "metadata": {},
651
+ "source": [
652
+ "## Compare with the scipy solver\n",
653
+ "Code from Lorenzo Sala"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": 23,
659
+ "id": "484380a5",
660
+ "metadata": {},
661
+ "outputs": [
662
+ {
663
+ "data": {
664
+ "image/png": "",
665
+ "text/plain": [
666
+ "<Figure size 640x480 with 1 Axes>"
667
+ ]
668
+ },
669
+ "metadata": {},
670
+ "output_type": "display_data"
671
+ }
672
+ ],
673
+ "source": [
674
+ "import numpy as np\n",
675
+ "from scipy.integrate import odeint\n",
676
+ "\n",
677
+ "def lotka_volterra_log(y_log, t, params):\n",
678
+ " \"\"\"\n",
679
+ " Generalized Lotka-Volterra model for N bacterial species, with logarithmic transformation for stability.\n",
680
+ " \n",
681
+ " Parameters:\n",
682
+ " y_log (array): Array of log-transformed bacterial populations.\n",
683
+ " t (float): Time.\n",
684
+ " params (tuple): Tuple of model parameters.\n",
685
+ " \n",
686
+ " Returns:\n",
687
+ " dydt (array): Array of derivative of log-transformed bacterial populations with respect to time.\n",
688
+ " \"\"\"\n",
689
+ " alpha, beta, gamma, _ = params\n",
690
+ " N = len(y_log)\n",
691
+ " y = np.exp(y_log)\n",
692
+ " dydt = np.zeros(N)\n",
693
+ " \n",
694
+ " for i in range(N):\n",
695
+ " dydt[i] = y[i] * (alpha[i] - beta[i] * np.sum(y) - np.sum([gamma[j][i] * y[j] for j in range(N)]))\n",
696
+ " \n",
697
+ " dydt_log = dydt / y\n",
698
+ " \n",
699
+ " return dydt_log\n",
700
+ "\n",
701
+ "# Define name bacteria\n",
702
+ "names = ['N1', 'N2', 'N3']\n",
703
+ "N = len(names)\n",
704
+ "\n",
705
+ "# Define model parameters\n",
706
+ "death_rates = None\n",
707
+ "params = (growth_rates, carrying_capacities, interactions, death_rates)\n",
708
+ "\n",
709
+ "# Define initial bacterial populations\n",
710
+ "y0 = [10, 7, 4] #[0.26, 0.37, 0.57] #\n",
711
+ "\n",
712
+ "# Define time points\n",
713
+ "Tmax = 30\n",
714
+ "t = np.linspace(0, Tmax, 1000)\n",
715
+ "\n",
716
+ "############################\n",
717
+ "\n",
718
+ "y0_log = np.log(y0)\n",
719
+ "y_log = odeint(lotka_volterra_log, y0_log, t, args=(params,))\n",
720
+ "y = np.exp(y_log)\n",
721
+ "\n",
722
+ "for i in range(N): \n",
723
+ " plt.plot(t, y[:,i], label=names[i])"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": null,
729
+ "id": "e962c046",
730
+ "metadata": {},
731
+ "outputs": [],
732
+ "source": []
733
+ }
734
+ ],
735
+ "metadata": {
736
+ "kernelspec": {
737
+ "display_name": "Python 3 (ipykernel)",
738
+ "language": "python",
739
+ "name": "python3"
740
+ },
741
+ "language_info": {
742
+ "codemirror_mode": {
743
+ "name": "ipython",
744
+ "version": 3
745
+ },
746
+ "file_extension": ".py",
747
+ "mimetype": "text/x-python",
748
+ "name": "python",
749
+ "nbconvert_exporter": "python",
750
+ "pygments_lexer": "ipython3",
751
+ "version": "3.11.2"
752
+ },
753
+ "vscode": {
754
+ "interpreter": {
755
+ "hash": "991718e94fb5d91fa62c7598521d2199c208ff1ff700f1ac060f334be0bee194"
756
+ }
757
+ }
758
+ },
759
+ "nbformat": 4,
760
+ "nbformat_minor": 5
761
+ }