jinns 0.8.1__tar.gz → 0.8.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {jinns-0.8.1 → jinns-0.8.3}/.pre-commit-config.yaml +2 -2
  2. {jinns-0.8.1 → jinns-0.8.3}/PKG-INFO +1 -1
  3. {jinns-0.8.1 → jinns-0.8.3}/doc/source/index.rst +8 -0
  4. {jinns-0.8.1 → jinns-0.8.3}/jinns/data/_DataGenerators.py +62 -26
  5. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_hyperpinn.py +2 -2
  6. {jinns-0.8.1 → jinns-0.8.3}/jinns.egg-info/PKG-INFO +1 -1
  7. {jinns-0.8.1 → jinns-0.8.3}/jinns.egg-info/SOURCES.txt +1 -0
  8. jinns-0.8.3/tests/dataGenerator_tests/test_DataGeneratorParameter.py +80 -0
  9. {jinns-0.8.1 → jinns-0.8.3}/tests/save_load_tests/test_saving_loading_hyperpinn.py +23 -2
  10. {jinns-0.8.1 → jinns-0.8.3}/tests/save_load_tests/test_saving_loading_pinn.py +21 -2
  11. {jinns-0.8.1 → jinns-0.8.3}/tests/save_load_tests/test_saving_loading_spinn.py +20 -2
  12. {jinns-0.8.1 → jinns-0.8.3}/.gitignore +0 -0
  13. {jinns-0.8.1 → jinns-0.8.3}/.gitlab-ci.yml +0 -0
  14. {jinns-0.8.1 → jinns-0.8.3}/LICENSE +0 -0
  15. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
  16. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
  17. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
  18. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
  19. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
  20. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
  21. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
  22. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
  23. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
  24. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
  25. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
  26. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
  27. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
  28. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
  29. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
  30. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
  31. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
  32. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/burger_solution_grid.npy +0 -0
  33. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
  34. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
  35. {jinns-0.8.1 → jinns-0.8.3}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
  36. {jinns-0.8.1 → jinns-0.8.3}/README.md +0 -0
  37. {jinns-0.8.1 → jinns-0.8.3}/doc/Makefile +0 -0
  38. {jinns-0.8.1 → jinns-0.8.3}/doc/source/boundary_conditions.rst +0 -0
  39. {jinns-0.8.1 → jinns-0.8.3}/doc/source/conf.py +0 -0
  40. {jinns-0.8.1 → jinns-0.8.3}/doc/source/data.rst +0 -0
  41. {jinns-0.8.1 → jinns-0.8.3}/doc/source/dynamic_loss.rst +0 -0
  42. {jinns-0.8.1 → jinns-0.8.3}/doc/source/experimental.rst +0 -0
  43. {jinns-0.8.1 → jinns-0.8.3}/doc/source/fokker_planck.qmd +0 -0
  44. {jinns-0.8.1 → jinns-0.8.3}/doc/source/loss.rst +0 -0
  45. {jinns-0.8.1 → jinns-0.8.3}/doc/source/loss_ode.rst +0 -0
  46. {jinns-0.8.1 → jinns-0.8.3}/doc/source/loss_pde.rst +0 -0
  47. {jinns-0.8.1 → jinns-0.8.3}/doc/source/losses.rst +0 -0
  48. {jinns-0.8.1 → jinns-0.8.3}/doc/source/math_pinn.qmd +0 -0
  49. {jinns-0.8.1 → jinns-0.8.3}/doc/source/operators.rst +0 -0
  50. {jinns-0.8.1 → jinns-0.8.3}/doc/source/param_estim_pinn.qmd +0 -0
  51. {jinns-0.8.1 → jinns-0.8.3}/doc/source/rar.rst +0 -0
  52. {jinns-0.8.1 → jinns-0.8.3}/doc/source/seq2seq.rst +0 -0
  53. {jinns-0.8.1 → jinns-0.8.3}/doc/source/solve.rst +0 -0
  54. {jinns-0.8.1 → jinns-0.8.3}/doc/source/solver.rst +0 -0
  55. {jinns-0.8.1 → jinns-0.8.3}/doc/source/utils.rst +0 -0
  56. {jinns-0.8.1 → jinns-0.8.3}/jinns/__init__.py +0 -0
  57. {jinns-0.8.1 → jinns-0.8.3}/jinns/data/__init__.py +0 -0
  58. {jinns-0.8.1 → jinns-0.8.3}/jinns/data/_display.py +0 -0
  59. {jinns-0.8.1 → jinns-0.8.3}/jinns/experimental/__init__.py +0 -0
  60. {jinns-0.8.1 → jinns-0.8.3}/jinns/experimental/_diffrax_solver.py +0 -0
  61. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_DynamicLoss.py +0 -0
  62. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_DynamicLossAbstract.py +0 -0
  63. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_LossODE.py +0 -0
  64. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_LossPDE.py +0 -0
  65. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_Losses.py +0 -0
  66. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/__init__.py +0 -0
  67. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_boundary_conditions.py +0 -0
  68. {jinns-0.8.1 → jinns-0.8.3}/jinns/loss/_operators.py +0 -0
  69. {jinns-0.8.1 → jinns-0.8.3}/jinns/solver/__init__.py +0 -0
  70. {jinns-0.8.1 → jinns-0.8.3}/jinns/solver/_rar.py +0 -0
  71. {jinns-0.8.1 → jinns-0.8.3}/jinns/solver/_seq2seq.py +0 -0
  72. {jinns-0.8.1 → jinns-0.8.3}/jinns/solver/_solve.py +0 -0
  73. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/__init__.py +0 -0
  74. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_optim.py +0 -0
  75. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_pinn.py +0 -0
  76. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_save_load.py +0 -0
  77. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_spinn.py +0 -0
  78. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_utils.py +0 -0
  79. {jinns-0.8.1 → jinns-0.8.3}/jinns/utils/_utils_uspinn.py +0 -0
  80. {jinns-0.8.1 → jinns-0.8.3}/jinns.egg-info/dependency_links.txt +0 -0
  81. {jinns-0.8.1 → jinns-0.8.3}/jinns.egg-info/requires.txt +0 -0
  82. {jinns-0.8.1 → jinns-0.8.3}/jinns.egg-info/top_level.txt +0 -0
  83. {jinns-0.8.1 → jinns-0.8.3}/pyproject.toml +0 -0
  84. {jinns-0.8.1 → jinns-0.8.3}/setup.cfg +0 -0
  85. {jinns-0.8.1 → jinns-0.8.3}/tests/conftest.py +0 -0
  86. {jinns-0.8.1 → jinns-0.8.3}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
  87. {jinns-0.8.1 → jinns-0.8.3}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
  88. {jinns-0.8.1 → jinns-0.8.3}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
  89. {jinns-0.8.1 → jinns-0.8.3}/tests/runtests.sh +0 -0
  90. {jinns-0.8.1 → jinns-0.8.3}/tests/sharding_tests/test_Burger_x32_multiple_shardings.py +0 -0
  91. {jinns-0.8.1 → jinns-0.8.3}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py +0 -0
  92. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_Burger_x32.py +0 -0
  93. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_Burger_x64.py +0 -0
  94. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_Fisher_x32.py +0 -0
  95. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_Fisher_x64.py +0 -0
  96. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_GLV_x32.py +0 -0
  97. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_GLV_x64.py +0 -0
  98. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
  99. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
  100. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_OU2D_x32.py +0 -0
  101. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests/test_imperfect_sobolev_x32.py +0 -0
  102. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests_spinn/test_Burger_x32.py +0 -0
  103. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests_spinn/test_Fisher_x32.py +0 -0
  104. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
  105. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests_spinn/test_OU2D_x32.py +0 -0
  106. {jinns-0.8.1 → jinns-0.8.3}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +0 -0
@@ -1,14 +1,14 @@
1
1
  default_stages: [commit]
2
2
  repos:
3
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.5.0
4
+ rev: v4.6.0
5
5
  hooks:
6
6
  - id: trailing-whitespace
7
7
  stages: [commit]
8
8
  - id: end-of-file-fixer
9
9
  stages: [commit]
10
10
  - repo: https://github.com/psf/black
11
- rev: 24.2.0
11
+ rev: 24.3.0
12
12
  hooks:
13
13
  - id: black
14
14
  stages: [commit]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.1
3
+ Version: 0.8.3
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -8,6 +8,14 @@ Welcome to jinn's documentation!
8
8
 
9
9
  Changelog:
10
10
 
11
+ * v0.8.3:
12
+
13
+ - Add the possibility to load user-provided tables of parameters in DataGeneratorParameter and not only to randomly sample them.
14
+
15
+ * v0.8.2:
16
+
17
+ - Fix a bug: it was not possible to jit a reloaded HyperPINN model.
18
+
11
19
  * v0.8.1:
12
20
 
13
21
  - New feature: `save_pinn` and `load_pinn` in `jinns.utils` for pre-trained
@@ -1101,8 +1101,9 @@ class DataGeneratorParameter:
1101
1101
  key,
1102
1102
  n,
1103
1103
  param_batch_size,
1104
- param_ranges,
1104
+ param_ranges=None,
1105
1105
  method="grid",
1106
+ user_data=None,
1106
1107
  data_exists=False,
1107
1108
  ):
1108
1109
  r"""
@@ -1140,25 +1141,47 @@ class DataGeneratorParameter:
1140
1141
  Must be left to `False` when created by the user. Avoids the
1141
1142
  regeneration of :math:`\Omega`, :math:`\partial\Omega` and
1142
1143
  time points at each pytree flattening and unflattening.
1144
+ user_data
1145
+ A dictionary containing user-provided data for parameters.
1146
+ As for `param_ranges`, the key corresponds to the parameter name,
1147
+ the keys must match the keys in `params["eq_params"]` and only
1148
+ unidimensional arrays are supported. Therefore, the jnp arrays
1149
+ found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
1150
+ Note that if the same key appears in `param_ranges` and `user_data`
1151
+ priority goes for the content in `user_data`.
1152
+ Defaults to None.
1143
1153
  """
1144
1154
  self.data_exists = data_exists
1145
1155
  self.method = method
1146
- if not isinstance(key, dict):
1147
- self._keys = dict(
1148
- zip(param_ranges.keys(), jax.random.split(key, len(param_ranges)))
1156
+
1157
+ if n < param_batch_size:
1158
+ raise ValueError(
1159
+ f"Number of data points ({n}) is smaller than the"
1160
+ f"number of batch points ({param_batch_size})."
1149
1161
  )
1162
+
1163
+ if user_data is None:
1164
+ user_data = {}
1165
+ if param_ranges is None:
1166
+ param_ranges = {}
1167
+ if not isinstance(key, dict):
1168
+ all_keys = set().union(param_ranges, user_data)
1169
+ self._keys = dict(zip(all_keys, jax.random.split(key, len(all_keys))))
1150
1170
  else:
1151
1171
  self._keys = key
1152
1172
  self.n = n
1153
1173
  self.param_batch_size = param_batch_size
1154
1174
  self.param_ranges = param_ranges
1175
+ self.user_data = user_data
1155
1176
 
1156
1177
  if not self.data_exists:
1157
1178
  self.generate_data()
1158
1179
  # The previous call to self.generate_data() has created
1159
- # the dict self.param_n_samples
1180
+ # the dict self.param_n_samples and then we will only use this one
1181
+ # because it has merged the scattered data between `user_data` and
1182
+ # `param_ranges`
1160
1183
  self.curr_param_idx = {}
1161
- for k in self.param_ranges.keys():
1184
+ for k in self.param_n_samples.keys():
1162
1185
  self.curr_param_idx[k] = 0
1163
1186
  (
1164
1187
  self._keys[k],
@@ -1167,22 +1190,40 @@ class DataGeneratorParameter:
1167
1190
  ) = _reset_batch_idx_and_permute(self._get_param_operands(k))
1168
1191
 
1169
1192
  def generate_data(self):
1170
- # Generate param n samples
1193
+ """
1194
+ Generate parameter samples, either through generation
1195
+ or using user-provided data.
1196
+ """
1171
1197
  self.param_n_samples = {}
1172
- for k, e in self.param_ranges.items():
1173
- if self.method == "grid":
1174
- xmin, xmax = e[0], e[1]
1175
- self.partial = (xmax - xmin) / self.n
1176
- # shape (n, 1)
1177
- self.param_n_samples[k] = jnp.arange(xmin, xmax, self.partial)[:, None]
1178
- elif self.method == "uniform":
1179
- xmin, xmax = e[0], e[1]
1180
- self._keys[k], subkey = random.split(self._keys[k], 2)
1181
- self.param_n_samples[k] = random.uniform(
1182
- subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
1183
- )
1198
+
1199
+ all_keys = set().union(self.param_ranges, self.user_data)
1200
+ for k in all_keys:
1201
+ if self.user_data and k in self.user_data:
1202
+ if self.user_data[k].shape == (self.n, 1):
1203
+ self.param_n_samples[k] = self.user_data[k]
1204
+ if self.user_data[k].shape == (self.n,):
1205
+ self.param_n_samples[k] = self.user_data[k][:, None]
1206
+ else:
1207
+ raise ValueError(
1208
+ "Wrong shape for user provided parameters"
1209
+ f" in user_data dictionary at key='{k}'"
1210
+ )
1184
1211
  else:
1185
- raise ValueError("Method " + self.method + " is not implemented.")
1212
+ if self.method == "grid":
1213
+ xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1214
+ self.partial = (xmax - xmin) / self.n
1215
+ # shape (n, 1)
1216
+ self.param_n_samples[k] = jnp.arange(xmin, xmax, self.partial)[
1217
+ :, None
1218
+ ]
1219
+ elif self.method == "uniform":
1220
+ xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
1221
+ self._keys[k], subkey = random.split(self._keys[k], 2)
1222
+ self.param_n_samples[k] = random.uniform(
1223
+ subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
1224
+ )
1225
+ else:
1226
+ raise ValueError("Method " + self.method + " is not implemented.")
1186
1227
 
1187
1228
  def _get_param_operands(self, k):
1188
1229
  return (
@@ -1247,12 +1288,7 @@ class DataGeneratorParameter:
1247
1288
  )
1248
1289
  aux_data = {
1249
1290
  k: vars(self)[k]
1250
- for k in [
1251
- "n",
1252
- "param_batch_size",
1253
- "method",
1254
- "param_ranges",
1255
- ]
1291
+ for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
1256
1292
  }
1257
1293
  return (children, aux_data)
1258
1294
 
@@ -39,8 +39,8 @@ class HYPERPINN(PINN):
39
39
  static_hyper: eqx.Module
40
40
  hyperparams: list = eqx.field(static=True)
41
41
  hypernet_input_size: int
42
- pinn_params_sum: ArrayLike
43
- pinn_params_cumsum: ArrayLike
42
+ pinn_params_sum: ArrayLike = eqx.field(static=True)
43
+ pinn_params_cumsum: ArrayLike = eqx.field(static=True)
44
44
 
45
45
  def __init__(
46
46
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jinns
3
- Version: 0.8.1
3
+ Version: 0.8.3
4
4
  Summary: Physics Informed Neural Network with JAX
5
5
  Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
6
6
  Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
@@ -82,6 +82,7 @@ tests/runtests.sh
82
82
  tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py
83
83
  tests/dataGenerator_tests/test_CubicMeshPDEStatio.py
84
84
  tests/dataGenerator_tests/test_DataGeneratorODE.py
85
+ tests/dataGenerator_tests/test_DataGeneratorParameter.py
85
86
  tests/save_load_tests/test_saving_loading_hyperpinn.py
86
87
  tests/save_load_tests/test_saving_loading_pinn.py
87
88
  tests/save_load_tests/test_saving_loading_spinn.py
@@ -0,0 +1,80 @@
1
+ import pytest
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import jinns
5
+
6
+ n = 64
7
+
8
+
9
+ @pytest.fixture
10
+ def create_DataGeneratorParameter():
11
+ key = jax.random.PRNGKey(2)
12
+ key, subkey = jax.random.split(key)
13
+
14
+ param_batch_size = 64
15
+ method = "uniform"
16
+ param_ranges = {"theta": (10.0, 11.0)}
17
+ user_data = {"nu": jnp.arange(n)}
18
+
19
+ return jinns.data.DataGeneratorParameter(
20
+ subkey,
21
+ n,
22
+ param_batch_size,
23
+ param_ranges=param_ranges,
24
+ method=method,
25
+ user_data=user_data,
26
+ )
27
+
28
+
29
+ @pytest.fixture
30
+ def create_DataGeneratorParameter_only_user_data():
31
+ key = jax.random.PRNGKey(2)
32
+ key, subkey = jax.random.split(key)
33
+
34
+ param_batch_size = 64
35
+ method = "uniform"
36
+ user_data = {"nu": jnp.arange(n)}
37
+
38
+ return jinns.data.DataGeneratorParameter(
39
+ subkey,
40
+ n,
41
+ param_batch_size,
42
+ method=method,
43
+ user_data=user_data,
44
+ )
45
+
46
+
47
+ def test_get_batch(create_DataGeneratorParameter):
48
+ data_generator_parameters = create_DataGeneratorParameter
49
+ param_batch = data_generator_parameters.get_batch()
50
+ assert jnp.allclose(jnp.sort(jnp.unique(param_batch["nu"])), jnp.arange(n)) and (
51
+ jnp.all(param_batch["theta"] >= 10.0) and jnp.all(param_batch["theta"] <= 11.0)
52
+ )
53
+
54
+
55
+ def test_get_batch_only_user_data(create_DataGeneratorParameter_only_user_data):
56
+ data_generator_parameters = create_DataGeneratorParameter_only_user_data
57
+ param_batch = data_generator_parameters.get_batch()
58
+ assert jnp.allclose(jnp.sort(jnp.unique(param_batch["nu"])), jnp.arange(n))
59
+
60
+
61
+ def test_raise_error_with_wrong_shape_for_user_data():
62
+ key = jax.random.PRNGKey(2)
63
+ key, subkey = jax.random.split(key)
64
+
65
+ param_batch_size = 64
66
+ method = "uniform"
67
+ param_ranges = {"theta": (10.0, 11.0)}
68
+ # user_data is not (n,) or (n,1)
69
+ user_data = {"nu": jnp.ones((n, 1, 1))}
70
+
71
+ with pytest.raises(ValueError) as e_info:
72
+ # __init__ calls self.generate_data() that we are testing for
73
+ data_generator_parameters = jinns.data.DataGeneratorParameter(
74
+ subkey,
75
+ n,
76
+ param_batch_size,
77
+ param_ranges=param_ranges,
78
+ method=method,
79
+ user_data=user_data,
80
+ )
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
 
@@ -73,7 +73,7 @@ def save_reload():
73
73
  }
74
74
 
75
75
  # Save
76
- filename = "./test"
76
+ filename = str(tmpdir.join("test"))
77
77
  kwargs_creation = {
78
78
  "key": subkey,
79
79
  "eqx_list": eqx_list,
@@ -107,3 +107,24 @@ def test_equality_save_reload(save_reload):
107
107
  v_u_reloaded(test_points[:, 0:1], test_points[:, 1:3], params_reloaded),
108
108
  atol=1e-3,
109
109
  )
110
+
111
+
112
+ def test_jitting_reloaded_hyperpinn(save_reload):
113
+ """
114
+ This test ensures that the reloaded hyperpinn is jit-able.
115
+ Some conversion of onp.array nodes can arise when reloading.
116
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
117
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
118
+ This tests is here for testimony.
119
+ """
120
+
121
+ key, _, _, params_reloaded, u_reloaded = save_reload
122
+
123
+ v_u_reloaded = jax.vmap(
124
+ u_reloaded, (0, 0, {"nn_params": None, "eq_params": {"D": 0, "r": 0}})
125
+ )
126
+ v_u_reloaded_jitted = jax.jit(v_u_reloaded)
127
+
128
+ key, subkey = jax.random.split(key, 2)
129
+ test_points = jax.random.normal(subkey, shape=(10, 5))
130
+ v_u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:3], params_reloaded)
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
  eqx_list = [
@@ -28,7 +28,7 @@ def save_reload():
28
28
  params = {"nn_params": params, "eq_params": {}}
29
29
 
30
30
  # Save
31
- filename = "./test"
31
+ filename = str(tmpdir.join("test"))
32
32
  kwargs_creation = {
33
33
  "key": subkey,
34
34
  "eqx_list": eqx_list,
@@ -57,3 +57,22 @@ def test_equality_save_reload(save_reload):
57
57
  v_u_reloaded(test_points[:, 0:1], test_points[:, 1:], params_reloaded),
58
58
  atol=1e-3,
59
59
  )
60
+
61
+
62
+ def test_jitting_reloaded_pinn(save_reload):
63
+ """
64
+ This test ensures that the reloaded pinn is jit-able.
65
+ Some conversion of onp.array nodes can arise when reloading.
66
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
67
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
68
+ This tests is here for testimony.
69
+ """
70
+
71
+ key, _, _, params_reloaded, u_reloaded = save_reload
72
+
73
+ key, subkey = jax.random.split(key, 2)
74
+ test_points = jax.random.normal(subkey, shape=(10, 2))
75
+ v_u_reloaded = jax.vmap(u_reloaded, (0, 0, None))
76
+ v_u_reloaded_jitted = jax.jit(v_u_reloaded)
77
+
78
+ v_u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:3], params_reloaded)
@@ -9,7 +9,7 @@ from jinns.utils import save_pinn, load_pinn
9
9
 
10
10
 
11
11
  @pytest.fixture
12
- def save_reload():
12
+ def save_reload(tmpdir):
13
13
  jax.config.update("jax_enable_x64", False)
14
14
  key = random.PRNGKey(2)
15
15
  d = 2
@@ -30,7 +30,7 @@ def save_reload():
30
30
  params = {"nn_params": params, "eq_params": {}}
31
31
 
32
32
  # Save
33
- filename = "./test"
33
+ filename = str(tmpdir.join("test"))
34
34
  kwargs_creation = {
35
35
  "key": subkey,
36
36
  "d": d,
@@ -58,3 +58,21 @@ def test_equality_save_reload(save_reload):
58
58
  u_reloaded(test_points[:, 0:1], test_points[:, 1:], params_reloaded),
59
59
  atol=1e-3,
60
60
  )
61
+
62
+
63
+ def test_jitting_reloaded_spinn(save_reload):
64
+ """
65
+ This test ensures that the reloaded spinn is jit-able.
66
+ Some conversion of onp.array nodes can arise when reloading.
67
+ See this MR : https://gitlab.com/mia_jinns/jinns/-/merge_requests/32
68
+ jinns v0.8.2 uses eqx.field(static=True) to solve the problem.
69
+ This tests is here for testimony.
70
+ """
71
+
72
+ key, _, _, params_reloaded, u_reloaded = save_reload
73
+
74
+ key, subkey = jax.random.split(key, 2)
75
+ test_points = jax.random.normal(subkey, shape=(10, 2))
76
+
77
+ u_reloaded_jitted = jax.jit(u_reloaded.__call__)
78
+ u_reloaded_jitted(test_points[:, 0:1], test_points[:, 1:], params_reloaded)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes