jinns 0.8.4__tar.gz → 0.8.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jinns-0.8.4 → jinns-0.8.6}/PKG-INFO +1 -1
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/index.rst +8 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/solver/_solve.py +3 -1
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_hyperpinn.py +1 -13
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_pinn.py +1 -13
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_spinn.py +1 -7
- {jinns-0.8.4 → jinns-0.8.6}/jinns.egg-info/PKG-INFO +1 -1
- {jinns-0.8.4 → jinns-0.8.6}/jinns.egg-info/SOURCES.txt +4 -1
- {jinns-0.8.4 → jinns-0.8.6}/tests/runtests.sh +4 -0
- jinns-0.8.6/tests/utils_tests/test_hyperpinns.py +138 -0
- jinns-0.8.6/tests/utils_tests/test_pinn.py +131 -0
- jinns-0.8.6/tests/utils_tests/test_spinn.py +110 -0
- {jinns-0.8.4 → jinns-0.8.6}/.gitignore +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/.gitlab-ci.yml +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/.pre-commit-config.yaml +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/LICENSE +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/burger_solution_grid.npy +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/README.md +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/Makefile +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/boundary_conditions.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/conf.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/data.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/dynamic_loss.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/experimental.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/fokker_planck.qmd +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/loss.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/loss_ode.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/loss_pde.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/losses.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/math_pinn.qmd +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/operators.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/param_estim_pinn.qmd +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/rar.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/seq2seq.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/solve.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/solver.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/doc/source/utils.rst +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/data/_DataGenerators.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/data/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/data/_display.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/experimental/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/experimental/_diffrax_solver.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_DynamicLoss.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_DynamicLossAbstract.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_LossODE.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_LossPDE.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_Losses.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_boundary_conditions.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/loss/_operators.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/solver/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/solver/_rar.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/solver/_seq2seq.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/__init__.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_optim.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_save_load.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_utils.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns/utils/_utils_uspinn.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns.egg-info/dependency_links.txt +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns.egg-info/requires.txt +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/jinns.egg-info/top_level.txt +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/pyproject.toml +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/setup.cfg +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/conftest.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/sharding_tests/test_Burger_x32_multiple_shardings.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_Burger_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_Burger_x64.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_Fisher_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_Fisher_x64.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_GLV_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_GLV_x64.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_OU2D_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests/test_imperfect_sobolev_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests_spinn/test_Burger_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests_spinn/test_Fisher_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests_spinn/test_OU2D_x32.py +0 -0
- {jinns-0.8.4 → jinns-0.8.6}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -8,6 +8,14 @@ Welcome to jinn's documentation!
|
|
|
8
8
|
|
|
9
9
|
Changelog:
|
|
10
10
|
|
|
11
|
+
* v0.8.6:
|
|
12
|
+
|
|
13
|
+
- Merge `[!37] <https://gitlab.com/mia_jinns/jinns/-/merge_requests/37>`_
|
|
14
|
+
|
|
15
|
+
* v0.8.5:
|
|
16
|
+
|
|
17
|
+
- Merge `[!36] <https://gitlab.com/mia_jinns/jinns/-/merge_requests/36>`_
|
|
18
|
+
|
|
11
19
|
* v0.8.4:
|
|
12
20
|
|
|
13
21
|
- Fix a bug: wrong argument in the wrapper function for heterogeneous parameter evaluation of a PDEStatio
|
|
@@ -218,6 +218,7 @@ def solve(
|
|
|
218
218
|
batch, data, param_data, obs_data = get_batch(data, param_data, obs_data)
|
|
219
219
|
|
|
220
220
|
(
|
|
221
|
+
loss,
|
|
221
222
|
loss_val,
|
|
222
223
|
loss_terms,
|
|
223
224
|
params,
|
|
@@ -311,7 +312,7 @@ def solve(
|
|
|
311
312
|
)
|
|
312
313
|
|
|
313
314
|
|
|
314
|
-
@partial(jit, static_argnames=["
|
|
315
|
+
@partial(jit, static_argnames=["optimizer"])
|
|
315
316
|
def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params):
|
|
316
317
|
"""
|
|
317
318
|
loss and optimizer cannot be jit-ted.
|
|
@@ -330,6 +331,7 @@ def gradient_step(loss, optimizer, batch, params, opt_state, last_non_nan_params
|
|
|
330
331
|
)
|
|
331
332
|
|
|
332
333
|
return (
|
|
334
|
+
loss,
|
|
333
335
|
loss_val,
|
|
334
336
|
loss_terms,
|
|
335
337
|
params,
|
|
@@ -216,13 +216,7 @@ def create_HYPERPINN(
|
|
|
216
216
|
|
|
217
217
|
Returns
|
|
218
218
|
-------
|
|
219
|
-
|
|
220
|
-
A function which (re-)initializes the PINN parameters with the provided
|
|
221
|
-
jax random key
|
|
222
|
-
apply_fn
|
|
223
|
-
A function to apply the neural network on given inputs for given
|
|
224
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
225
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
219
|
+
`u`, a :class:`.HyperPINN` object which inherits from `eqx.Module` (hence callable).
|
|
226
220
|
|
|
227
221
|
Raises
|
|
228
222
|
------
|
|
@@ -289,7 +283,6 @@ def create_HYPERPINN(
|
|
|
289
283
|
|
|
290
284
|
if shared_pinn_outputs is not None:
|
|
291
285
|
hyperpinns = []
|
|
292
|
-
static = None
|
|
293
286
|
for output_slice in shared_pinn_outputs:
|
|
294
287
|
hyperpinn = HYPERPINN(
|
|
295
288
|
mlp,
|
|
@@ -302,11 +295,6 @@ def create_HYPERPINN(
|
|
|
302
295
|
hypernet_input_size,
|
|
303
296
|
output_slice,
|
|
304
297
|
)
|
|
305
|
-
# all the pinns are in fact the same so we share the same static
|
|
306
|
-
if static is None:
|
|
307
|
-
static = hyperpinn.static
|
|
308
|
-
else:
|
|
309
|
-
hyperpinn.static = static
|
|
310
298
|
hyperpinns.append(hyperpinn)
|
|
311
299
|
return hyperpinns
|
|
312
300
|
hyperpinn = HYPERPINN(
|
|
@@ -200,13 +200,7 @@ def create_PINN(
|
|
|
200
200
|
|
|
201
201
|
Returns
|
|
202
202
|
-------
|
|
203
|
-
|
|
204
|
-
A function which (re-)initializes the PINN parameters with the provided
|
|
205
|
-
jax random key
|
|
206
|
-
apply_fn
|
|
207
|
-
A function to apply the neural network on given inputs for given
|
|
208
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
209
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
203
|
+
`u`, a :class:`.PINN` object which inherits from `eqx.Module` (hence callable). This comes with a bound method :func:`u.init_params() <PINN.init_params>`. When `shared_pinn_ouput` is not None, a list of :class:`.PINN` with the same structure is returned, only differing by there final slicing of the network output.
|
|
210
204
|
|
|
211
205
|
Raises
|
|
212
206
|
------
|
|
@@ -253,7 +247,6 @@ def create_PINN(
|
|
|
253
247
|
|
|
254
248
|
if shared_pinn_outputs is not None:
|
|
255
249
|
pinns = []
|
|
256
|
-
static = None
|
|
257
250
|
for output_slice in shared_pinn_outputs:
|
|
258
251
|
pinn = PINN(
|
|
259
252
|
mlp,
|
|
@@ -263,11 +256,6 @@ def create_PINN(
|
|
|
263
256
|
output_transform,
|
|
264
257
|
output_slice,
|
|
265
258
|
)
|
|
266
|
-
# all the pinns are in fact the same so we share the same static
|
|
267
|
-
if static is None:
|
|
268
|
-
static = pinn.static
|
|
269
|
-
else:
|
|
270
|
-
pinn.static = static
|
|
271
259
|
pinns.append(pinn)
|
|
272
260
|
return pinns
|
|
273
261
|
pinn = PINN(mlp, slice_solution, eq_type, input_transform, output_transform, None)
|
|
@@ -194,13 +194,7 @@ def create_SPINN(key, d, r, eqx_list, eq_type, m=1):
|
|
|
194
194
|
|
|
195
195
|
Returns
|
|
196
196
|
-------
|
|
197
|
-
|
|
198
|
-
A function which (re-)initializes the SPINN parameters with the provided
|
|
199
|
-
jax random key
|
|
200
|
-
apply_fn
|
|
201
|
-
A function to apply the neural network on given inputs for given
|
|
202
|
-
parameters. A typical call will be of the form `u(t, params)` for
|
|
203
|
-
ODE or `u(t, x, params)` for nD PDEs (`x` being multidimensional)
|
|
197
|
+
`u`, a :class:`.SPINN` object which inherits from `eqx.Module` (hence callable).
|
|
204
198
|
|
|
205
199
|
Raises
|
|
206
200
|
------
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -102,4 +102,7 @@ tests/solver_tests_spinn/test_Burger_x32.py
|
|
|
102
102
|
tests/solver_tests_spinn/test_Fisher_x32.py
|
|
103
103
|
tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py
|
|
104
104
|
tests/solver_tests_spinn/test_OU2D_x32.py
|
|
105
|
-
tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py
|
|
105
|
+
tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py
|
|
106
|
+
tests/utils_tests/test_hyperpinns.py
|
|
107
|
+
tests/utils_tests/test_pinn.py
|
|
108
|
+
tests/utils_tests/test_spinn.py
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for custom PINN eqx.Module
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import jax
|
|
7
|
+
import jax.random as random
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
import jinns
|
|
12
|
+
from jinns.utils import create_PINN
|
|
13
|
+
import jinns.utils
|
|
14
|
+
|
|
15
|
+
key = random.PRNGKey(2)
|
|
16
|
+
key, subkey = random.split(key)
|
|
17
|
+
|
|
18
|
+
d = 5
|
|
19
|
+
n_param = 42
|
|
20
|
+
hyperparams = [f"param {i}" for i in range(n_param)]
|
|
21
|
+
|
|
22
|
+
EQX_LIST = [
|
|
23
|
+
[jax.nn.swish],
|
|
24
|
+
[eqx.nn.Linear, 16, 16],
|
|
25
|
+
[jax.nn.swish],
|
|
26
|
+
[eqx.nn.Linear, 16, 16],
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
eqx_list_hyper = [
|
|
30
|
+
[eqx.nn.Linear, n_param, 32], # input is of size 42
|
|
31
|
+
[jax.nn.tanh],
|
|
32
|
+
[eqx.nn.Linear, 32, 16],
|
|
33
|
+
[jax.nn.tanh],
|
|
34
|
+
[
|
|
35
|
+
eqx.nn.Linear,
|
|
36
|
+
16,
|
|
37
|
+
1000,
|
|
38
|
+
], # 1000 is a random guess, it will automatically be filled with the correct value
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture
|
|
43
|
+
def create_pinn_ode():
|
|
44
|
+
eqx_list = [[eqx.nn.Linear, 1, 16]] + EQX_LIST
|
|
45
|
+
u_ode = jinns.utils.create_HYPERPINN(
|
|
46
|
+
subkey, eqx_list, "ODE", hyperparams, n_param, 0, eqx_list_hyper
|
|
47
|
+
)
|
|
48
|
+
return u_ode
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.fixture
|
|
52
|
+
def create_pinn_statio():
|
|
53
|
+
eqx_list = [[eqx.nn.Linear, d, 16]] + EQX_LIST
|
|
54
|
+
u_statio = jinns.utils.create_HYPERPINN(
|
|
55
|
+
subkey, eqx_list, "statio_PDE", hyperparams, n_param, d, eqx_list_hyper
|
|
56
|
+
)
|
|
57
|
+
return u_statio
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def create_pinn_nonstatio():
|
|
62
|
+
eqx_list = [[eqx.nn.Linear, d + 1, 16]] + EQX_LIST
|
|
63
|
+
u_nonstatio = jinns.utils.create_HYPERPINN(
|
|
64
|
+
subkey, eqx_list, "nonstatio_PDE", hyperparams, n_param, d + 1, eqx_list_hyper
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return u_nonstatio
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.fixture
|
|
71
|
+
def create_pinn_nonstatio_shared_output():
|
|
72
|
+
|
|
73
|
+
# specific argument since we want to have u1 and u2 as separate nns
|
|
74
|
+
shared_pinn_output = (
|
|
75
|
+
jnp.s_[:2],
|
|
76
|
+
jnp.s_[2],
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
eqx_list = [[eqx.nn.Linear, d, 16]] + EQX_LIST
|
|
80
|
+
u1, u2 = jinns.utils.create_HYPERPINN(
|
|
81
|
+
subkey,
|
|
82
|
+
eqx_list,
|
|
83
|
+
"nonstatio_PDE",
|
|
84
|
+
hyperparams,
|
|
85
|
+
n_param,
|
|
86
|
+
d,
|
|
87
|
+
eqx_list_hyper,
|
|
88
|
+
shared_pinn_outputs=shared_pinn_output,
|
|
89
|
+
)
|
|
90
|
+
return u1, u2, shared_pinn_output
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_ode_pinn_struct(create_pinn_ode):
|
|
94
|
+
|
|
95
|
+
u_ode = create_pinn_ode
|
|
96
|
+
assert isinstance(u_ode, jinns.utils._hyperpinn.HYPERPINN)
|
|
97
|
+
assert u_ode.eq_type == "ODE"
|
|
98
|
+
assert u_ode.output_slice is None
|
|
99
|
+
assert isinstance(u_ode.slice_solution, slice)
|
|
100
|
+
_ = u_ode.init_params()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_statio_pinn_struct(create_pinn_statio):
|
|
104
|
+
|
|
105
|
+
u_statio = create_pinn_statio
|
|
106
|
+
assert u_statio.eq_type == "statio_PDE"
|
|
107
|
+
assert isinstance(u_statio, jinns.utils._hyperpinn.HYPERPINN)
|
|
108
|
+
assert u_statio.output_slice is None
|
|
109
|
+
assert isinstance(u_statio.slice_solution, slice)
|
|
110
|
+
_ = u_statio.init_params()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_nonstatio_pinn_struct(create_pinn_nonstatio):
|
|
114
|
+
|
|
115
|
+
u_nonstatio = create_pinn_nonstatio
|
|
116
|
+
assert u_nonstatio.eq_type == "nonstatio_PDE"
|
|
117
|
+
assert isinstance(u_nonstatio, jinns.utils._hyperpinn.HYPERPINN)
|
|
118
|
+
assert u_nonstatio.output_slice is None
|
|
119
|
+
assert isinstance(u_nonstatio.slice_solution, slice)
|
|
120
|
+
_ = u_nonstatio.init_params()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_nonstatio_pinn_shared_output(create_pinn_nonstatio_shared_output):
|
|
124
|
+
|
|
125
|
+
u1, u2, shared_pinn_ouput = create_pinn_nonstatio_shared_output
|
|
126
|
+
assert u1.eq_type == "nonstatio_PDE"
|
|
127
|
+
assert u1.output_slice == shared_pinn_ouput[0]
|
|
128
|
+
assert isinstance(u1.slice_solution, slice)
|
|
129
|
+
|
|
130
|
+
assert u2.eq_type == "nonstatio_PDE"
|
|
131
|
+
assert u2.output_slice == shared_pinn_ouput[1]
|
|
132
|
+
assert isinstance(u2.slice_solution, slice)
|
|
133
|
+
|
|
134
|
+
param1 = u1.init_params()
|
|
135
|
+
param2 = u2.init_params()
|
|
136
|
+
|
|
137
|
+
# the init parameters for the 2 nns should be the same PyTree
|
|
138
|
+
assert eqx.tree_equal(param1, param2)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for custom PINN eqx.Module
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import jax
|
|
7
|
+
import jax.random as random
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
import jinns
|
|
12
|
+
from jinns.utils import create_PINN
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
d = 5
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def create_pinn_ode():
|
|
20
|
+
key = random.PRNGKey(2)
|
|
21
|
+
eqx_list = [
|
|
22
|
+
[eqx.nn.Linear, 1, 128],
|
|
23
|
+
[jax.nn.tanh],
|
|
24
|
+
[eqx.nn.Linear, 128, 1],
|
|
25
|
+
]
|
|
26
|
+
key, subkey = random.split(key)
|
|
27
|
+
u_statio = create_PINN(subkey, eqx_list, "ODE")
|
|
28
|
+
|
|
29
|
+
return u_statio
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.fixture
|
|
33
|
+
def create_pinn_statio():
|
|
34
|
+
key = random.PRNGKey(2)
|
|
35
|
+
eqx_list = [
|
|
36
|
+
[eqx.nn.Linear, d, 128],
|
|
37
|
+
[jax.nn.tanh],
|
|
38
|
+
[eqx.nn.Linear, 128, 1],
|
|
39
|
+
]
|
|
40
|
+
key, subkey = random.split(key)
|
|
41
|
+
u_statio = create_PINN(subkey, eqx_list, "statio_PDE", d)
|
|
42
|
+
|
|
43
|
+
return u_statio
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def create_pinn_nonstatio():
|
|
48
|
+
key = random.PRNGKey(2)
|
|
49
|
+
eqx_list = [
|
|
50
|
+
[eqx.nn.Linear, d, 128],
|
|
51
|
+
[jax.nn.tanh],
|
|
52
|
+
[eqx.nn.Linear, 128, 1],
|
|
53
|
+
]
|
|
54
|
+
key, subkey = random.split(key)
|
|
55
|
+
u_nonstatio = create_PINN(subkey, eqx_list, "nonstatio_PDE", d)
|
|
56
|
+
|
|
57
|
+
return u_nonstatio
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def create_pinn_nonstatio_shared_output():
|
|
62
|
+
key = random.PRNGKey(2)
|
|
63
|
+
eqx_list = [
|
|
64
|
+
[eqx.nn.Linear, d, 128],
|
|
65
|
+
[jax.nn.tanh],
|
|
66
|
+
[eqx.nn.Linear, 128, 1],
|
|
67
|
+
]
|
|
68
|
+
key, subkey = random.split(key)
|
|
69
|
+
# specific argument since we want to have u1 and u2 as separate nns
|
|
70
|
+
shared_pinn_output = (
|
|
71
|
+
jnp.s_[:2],
|
|
72
|
+
jnp.s_[2],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
u1, u2 = jinns.utils.create_PINN(
|
|
76
|
+
subkey,
|
|
77
|
+
eqx_list,
|
|
78
|
+
"nonstatio_PDE",
|
|
79
|
+
d,
|
|
80
|
+
shared_pinn_outputs=shared_pinn_output,
|
|
81
|
+
)
|
|
82
|
+
return u1, u2, shared_pinn_output
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_ode_pinn_struct(create_pinn_ode):
|
|
86
|
+
|
|
87
|
+
u_ode = create_pinn_ode
|
|
88
|
+
assert u_ode.eq_type == "ODE"
|
|
89
|
+
assert isinstance(u_ode, jinns.utils._pinn.PINN)
|
|
90
|
+
assert u_ode.output_slice is None
|
|
91
|
+
assert isinstance(u_ode.slice_solution, slice)
|
|
92
|
+
_ = u_ode.init_params()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_statio_pinn_struct(create_pinn_statio):
|
|
96
|
+
|
|
97
|
+
u_statio = create_pinn_statio
|
|
98
|
+
assert u_statio.eq_type == "statio_PDE"
|
|
99
|
+
assert isinstance(u_statio, jinns.utils._pinn.PINN)
|
|
100
|
+
|
|
101
|
+
assert u_statio.output_slice is None
|
|
102
|
+
assert isinstance(u_statio.slice_solution, slice)
|
|
103
|
+
_ = u_statio.init_params()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_nonstatio_pinn_struct(create_pinn_nonstatio):
|
|
107
|
+
|
|
108
|
+
u_nonstatio = create_pinn_nonstatio
|
|
109
|
+
assert u_nonstatio.eq_type == "nonstatio_PDE"
|
|
110
|
+
assert isinstance(u_nonstatio, jinns.utils._pinn.PINN)
|
|
111
|
+
assert u_nonstatio.output_slice is None
|
|
112
|
+
assert isinstance(u_nonstatio.slice_solution, slice)
|
|
113
|
+
_ = u_nonstatio.init_params()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def test_nonstatio_pinn_shared_output(create_pinn_nonstatio_shared_output):
|
|
117
|
+
|
|
118
|
+
u1, u2, shared_pinn_ouput = create_pinn_nonstatio_shared_output
|
|
119
|
+
assert u1.eq_type == "nonstatio_PDE"
|
|
120
|
+
assert u1.output_slice == shared_pinn_ouput[0]
|
|
121
|
+
assert isinstance(u1.slice_solution, slice)
|
|
122
|
+
|
|
123
|
+
assert u2.eq_type == "nonstatio_PDE"
|
|
124
|
+
assert u2.output_slice == shared_pinn_ouput[1]
|
|
125
|
+
assert isinstance(u2.slice_solution, slice)
|
|
126
|
+
|
|
127
|
+
param1 = u1.init_params()
|
|
128
|
+
param2 = u2.init_params()
|
|
129
|
+
|
|
130
|
+
# the init parameters for the 2 nns should be the same PyTree
|
|
131
|
+
assert eqx.tree_equal(param1, param2)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test script for custom PINN eqx.Module
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import jax
|
|
7
|
+
import jax.random as random
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import equinox as eqx
|
|
10
|
+
|
|
11
|
+
import jinns
|
|
12
|
+
from jinns.utils import create_SPINN
|
|
13
|
+
import jinns.utils
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
d = 5
|
|
17
|
+
r = 100 # embedding dim
|
|
18
|
+
m = 1 # output dim
|
|
19
|
+
eqx_list = [
|
|
20
|
+
[eqx.nn.Linear, 1, 128],
|
|
21
|
+
[jax.nn.tanh],
|
|
22
|
+
[eqx.nn.Linear, 128, 128],
|
|
23
|
+
[jax.nn.tanh],
|
|
24
|
+
[eqx.nn.Linear, 128, 128],
|
|
25
|
+
[jax.nn.tanh],
|
|
26
|
+
[eqx.nn.Linear, 128, r * m],
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _assert_attr_equal(u):
|
|
31
|
+
assert u.m == m
|
|
32
|
+
assert u.r == r
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.fixture
|
|
36
|
+
def create_SPINN_ode():
|
|
37
|
+
key = random.PRNGKey(2)
|
|
38
|
+
key, subkey = random.split(key)
|
|
39
|
+
u_statio = create_SPINN(subkey, 1, r, eqx_list, "ODE", m)
|
|
40
|
+
|
|
41
|
+
return u_statio
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def create_SPINN_statio():
|
|
46
|
+
key = random.PRNGKey(2)
|
|
47
|
+
key, subkey = random.split(key)
|
|
48
|
+
u_statio = create_SPINN(subkey, d, r, eqx_list, "statio_PDE", m)
|
|
49
|
+
|
|
50
|
+
return u_statio
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.fixture
|
|
54
|
+
def create_SPINN_nonstatio():
|
|
55
|
+
key = random.PRNGKey(2)
|
|
56
|
+
key, subkey = random.split(key)
|
|
57
|
+
u_nonstatio = create_SPINN(subkey, d, r, eqx_list, "nonstatio_PDE", m)
|
|
58
|
+
|
|
59
|
+
return u_nonstatio
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_ode_pinn_struct(create_SPINN_ode):
|
|
63
|
+
|
|
64
|
+
u_ode = create_SPINN_ode
|
|
65
|
+
assert u_ode.eq_type == "ODE"
|
|
66
|
+
assert isinstance(u_ode, jinns.utils._spinn.SPINN)
|
|
67
|
+
assert u_ode.d == 1
|
|
68
|
+
_assert_attr_equal(u_ode)
|
|
69
|
+
_ = u_ode.init_params()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_statio_pinn_struct(create_SPINN_statio):
|
|
73
|
+
|
|
74
|
+
u_statio = create_SPINN_statio
|
|
75
|
+
assert u_statio.eq_type == "statio_PDE"
|
|
76
|
+
assert isinstance(u_statio, jinns.utils._spinn.SPINN)
|
|
77
|
+
|
|
78
|
+
assert u_statio.d == d
|
|
79
|
+
_assert_attr_equal(u_statio)
|
|
80
|
+
_ = u_statio.init_params()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_nonstatio_pinn_struct(create_SPINN_nonstatio):
|
|
84
|
+
|
|
85
|
+
u_nonstatio = create_SPINN_nonstatio
|
|
86
|
+
assert u_nonstatio.eq_type == "nonstatio_PDE"
|
|
87
|
+
assert isinstance(u_nonstatio, jinns.utils._spinn.SPINN)
|
|
88
|
+
assert u_nonstatio.d == d # in non-statio SPINN user should include `t` in `d`
|
|
89
|
+
_assert_attr_equal(u_nonstatio)
|
|
90
|
+
_ = u_nonstatio.init_params()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_raising_error_init_SPINN():
|
|
94
|
+
|
|
95
|
+
# output_dim != r*m
|
|
96
|
+
with pytest.raises(ValueError) as e:
|
|
97
|
+
wrong_eqx_list = [
|
|
98
|
+
[eqx.nn.Linear, 1, 128],
|
|
99
|
+
[jax.nn.tanh],
|
|
100
|
+
[eqx.nn.Linear, 128, r * m + 1], # output_dim != r*m
|
|
101
|
+
]
|
|
102
|
+
_ = create_SPINN(random.PRNGKey(1), d, r, wrong_eqx_list, "nonstatio_PDE", m)
|
|
103
|
+
|
|
104
|
+
# d > 24
|
|
105
|
+
with pytest.raises(ValueError) as e:
|
|
106
|
+
_ = create_SPINN(random.PRNGKey(1), 24, r, wrong_eqx_list, "nonstatio_PDE", m)
|
|
107
|
+
|
|
108
|
+
# d > 24
|
|
109
|
+
with pytest.raises(ValueError) as e:
|
|
110
|
+
_ = create_SPINN(random.PRNGKey(1), 24, r, wrong_eqx_list, "nonstatio_PDE", m)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jinns-0.8.4 → jinns-0.8.6}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jinns-0.8.4 → jinns-0.8.6}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|