jinns 0.8.2__tar.gz → 0.8.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jinns-0.8.2 → jinns-0.8.4}/.pre-commit-config.yaml +2 -2
- {jinns-0.8.2 → jinns-0.8.4}/PKG-INFO +1 -1
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/index.rst +9 -1
- {jinns-0.8.2 → jinns-0.8.4}/jinns/data/_DataGenerators.py +62 -26
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_DynamicLossAbstract.py +1 -1
- {jinns-0.8.2 → jinns-0.8.4}/jinns.egg-info/PKG-INFO +1 -1
- {jinns-0.8.2 → jinns-0.8.4}/jinns.egg-info/SOURCES.txt +1 -0
- jinns-0.8.4/tests/dataGenerator_tests/test_DataGeneratorParameter.py +80 -0
- {jinns-0.8.2 → jinns-0.8.4}/.gitignore +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/.gitlab-ci.yml +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/LICENSE +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/ODE/1D_Generalized_Lotka_Volterra_seq2seq.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/1D_non_stationary_Burger.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/burger_solution_grid.npy +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/imperfect_modeling_sobolev_reg.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/README.md +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/Makefile +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/boundary_conditions.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/conf.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/data.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/dynamic_loss.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/experimental.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/fokker_planck.qmd +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/loss.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/loss_ode.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/loss_pde.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/losses.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/math_pinn.qmd +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/operators.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/param_estim_pinn.qmd +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/rar.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/seq2seq.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/solve.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/solver.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/doc/source/utils.rst +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/data/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/data/_display.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/experimental/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/experimental/_diffrax_solver.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_DynamicLoss.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_LossODE.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_LossPDE.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_Losses.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_boundary_conditions.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/loss/_operators.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/solver/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/solver/_rar.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/solver/_seq2seq.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/solver/_solve.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/__init__.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_hyperpinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_optim.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_pinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_save_load.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_spinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_utils.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns/utils/_utils_uspinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns.egg-info/dependency_links.txt +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns.egg-info/requires.txt +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/jinns.egg-info/top_level.txt +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/pyproject.toml +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/setup.cfg +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/conftest.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/runtests.sh +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/sharding_tests/test_Burger_x32_multiple_shardings.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_Burger_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_Burger_x64.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_Fisher_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_Fisher_x64.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_GLV_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_GLV_x64.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_OU2D_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests/test_imperfect_sobolev_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests_spinn/test_Burger_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests_spinn/test_Fisher_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests_spinn/test_OU2D_x32.py +0 -0
- {jinns-0.8.2 → jinns-0.8.4}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64.py +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
default_stages: [commit]
|
|
2
2
|
repos:
|
|
3
3
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
4
|
-
rev: v4.
|
|
4
|
+
rev: v4.6.0
|
|
5
5
|
hooks:
|
|
6
6
|
- id: trailing-whitespace
|
|
7
7
|
stages: [commit]
|
|
8
8
|
- id: end-of-file-fixer
|
|
9
9
|
stages: [commit]
|
|
10
10
|
- repo: https://github.com/psf/black
|
|
11
|
-
rev: 24.
|
|
11
|
+
rev: 24.3.0
|
|
12
12
|
hooks:
|
|
13
13
|
- id: black
|
|
14
14
|
stages: [commit]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.4
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -8,9 +8,17 @@ Welcome to jinn's documentation!
|
|
|
8
8
|
|
|
9
9
|
Changelog:
|
|
10
10
|
|
|
11
|
+
* v0.8.4:
|
|
12
|
+
|
|
13
|
+
- Fix a bug: wrong argument in the wrapper function for heterogeneous parameter evaluation of a PDEStatio
|
|
14
|
+
|
|
15
|
+
* v0.8.3:
|
|
16
|
+
|
|
17
|
+
- Add the possibility to load user-provided tables of parameters in DataGeneratorParameter and not only to randomly sample them.
|
|
18
|
+
|
|
11
19
|
* v0.8.2:
|
|
12
20
|
|
|
13
|
-
- Fix a bug: it was not possible to jit a reloaded HyperPINN model
|
|
21
|
+
- Fix a bug: it was not possible to jit a reloaded HyperPINN model.
|
|
14
22
|
|
|
15
23
|
* v0.8.1:
|
|
16
24
|
|
|
@@ -1101,8 +1101,9 @@ class DataGeneratorParameter:
|
|
|
1101
1101
|
key,
|
|
1102
1102
|
n,
|
|
1103
1103
|
param_batch_size,
|
|
1104
|
-
param_ranges,
|
|
1104
|
+
param_ranges=None,
|
|
1105
1105
|
method="grid",
|
|
1106
|
+
user_data=None,
|
|
1106
1107
|
data_exists=False,
|
|
1107
1108
|
):
|
|
1108
1109
|
r"""
|
|
@@ -1140,25 +1141,47 @@ class DataGeneratorParameter:
|
|
|
1140
1141
|
Must be left to `False` when created by the user. Avoids the
|
|
1141
1142
|
regeneration of :math:`\Omega`, :math:`\partial\Omega` and
|
|
1142
1143
|
time points at each pytree flattening and unflattening.
|
|
1144
|
+
user_data
|
|
1145
|
+
A dictionary containing user-provided data for parameters.
|
|
1146
|
+
As for `param_ranges`, the key corresponds to the parameter name,
|
|
1147
|
+
the keys must match the keys in `params["eq_params"]` and only
|
|
1148
|
+
unidimensional arrays are supported. Therefore, the jnp arrays
|
|
1149
|
+
found at `user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1150
|
+
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1151
|
+
priority goes for the content in `user_data`.
|
|
1152
|
+
Defaults to None.
|
|
1143
1153
|
"""
|
|
1144
1154
|
self.data_exists = data_exists
|
|
1145
1155
|
self.method = method
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1156
|
+
|
|
1157
|
+
if n < param_batch_size:
|
|
1158
|
+
raise ValueError(
|
|
1159
|
+
f"Number of data points ({n}) is smaller than the"
|
|
1160
|
+
f"number of batch points ({param_batch_size})."
|
|
1149
1161
|
)
|
|
1162
|
+
|
|
1163
|
+
if user_data is None:
|
|
1164
|
+
user_data = {}
|
|
1165
|
+
if param_ranges is None:
|
|
1166
|
+
param_ranges = {}
|
|
1167
|
+
if not isinstance(key, dict):
|
|
1168
|
+
all_keys = set().union(param_ranges, user_data)
|
|
1169
|
+
self._keys = dict(zip(all_keys, jax.random.split(key, len(all_keys))))
|
|
1150
1170
|
else:
|
|
1151
1171
|
self._keys = key
|
|
1152
1172
|
self.n = n
|
|
1153
1173
|
self.param_batch_size = param_batch_size
|
|
1154
1174
|
self.param_ranges = param_ranges
|
|
1175
|
+
self.user_data = user_data
|
|
1155
1176
|
|
|
1156
1177
|
if not self.data_exists:
|
|
1157
1178
|
self.generate_data()
|
|
1158
1179
|
# The previous call to self.generate_data() has created
|
|
1159
|
-
# the dict self.param_n_samples
|
|
1180
|
+
# the dict self.param_n_samples and then we will only use this one
|
|
1181
|
+
# because it has merged the scattered data between `user_data` and
|
|
1182
|
+
# `param_ranges`
|
|
1160
1183
|
self.curr_param_idx = {}
|
|
1161
|
-
for k in self.
|
|
1184
|
+
for k in self.param_n_samples.keys():
|
|
1162
1185
|
self.curr_param_idx[k] = 0
|
|
1163
1186
|
(
|
|
1164
1187
|
self._keys[k],
|
|
@@ -1167,22 +1190,40 @@ class DataGeneratorParameter:
|
|
|
1167
1190
|
) = _reset_batch_idx_and_permute(self._get_param_operands(k))
|
|
1168
1191
|
|
|
1169
1192
|
def generate_data(self):
|
|
1170
|
-
|
|
1193
|
+
"""
|
|
1194
|
+
Generate parameter samples, either through generation
|
|
1195
|
+
or using user-provided data.
|
|
1196
|
+
"""
|
|
1171
1197
|
self.param_n_samples = {}
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1198
|
+
|
|
1199
|
+
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1200
|
+
for k in all_keys:
|
|
1201
|
+
if self.user_data and k in self.user_data:
|
|
1202
|
+
if self.user_data[k].shape == (self.n, 1):
|
|
1203
|
+
self.param_n_samples[k] = self.user_data[k]
|
|
1204
|
+
if self.user_data[k].shape == (self.n,):
|
|
1205
|
+
self.param_n_samples[k] = self.user_data[k][:, None]
|
|
1206
|
+
else:
|
|
1207
|
+
raise ValueError(
|
|
1208
|
+
"Wrong shape for user provided parameters"
|
|
1209
|
+
f" in user_data dictionary at key='{k}'"
|
|
1210
|
+
)
|
|
1184
1211
|
else:
|
|
1185
|
-
|
|
1212
|
+
if self.method == "grid":
|
|
1213
|
+
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1214
|
+
self.partial = (xmax - xmin) / self.n
|
|
1215
|
+
# shape (n, 1)
|
|
1216
|
+
self.param_n_samples[k] = jnp.arange(xmin, xmax, self.partial)[
|
|
1217
|
+
:, None
|
|
1218
|
+
]
|
|
1219
|
+
elif self.method == "uniform":
|
|
1220
|
+
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1221
|
+
self._keys[k], subkey = random.split(self._keys[k], 2)
|
|
1222
|
+
self.param_n_samples[k] = random.uniform(
|
|
1223
|
+
subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
1224
|
+
)
|
|
1225
|
+
else:
|
|
1226
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1186
1227
|
|
|
1187
1228
|
def _get_param_operands(self, k):
|
|
1188
1229
|
return (
|
|
@@ -1247,12 +1288,7 @@ class DataGeneratorParameter:
|
|
|
1247
1288
|
)
|
|
1248
1289
|
aux_data = {
|
|
1249
1290
|
k: vars(self)[k]
|
|
1250
|
-
for k in [
|
|
1251
|
-
"n",
|
|
1252
|
-
"param_batch_size",
|
|
1253
|
-
"method",
|
|
1254
|
-
"param_ranges",
|
|
1255
|
-
]
|
|
1291
|
+
for k in ["n", "param_batch_size", "method", "param_ranges", "user_data"]
|
|
1256
1292
|
}
|
|
1257
1293
|
return (children, aux_data)
|
|
1258
1294
|
|
|
@@ -154,7 +154,7 @@ class PDEStatio(DynamicLoss):
|
|
|
154
154
|
_params = {
|
|
155
155
|
"nn_params": params["nn_params"],
|
|
156
156
|
"eq_params": self.eval_heterogeneous_parameters(
|
|
157
|
-
|
|
157
|
+
x, u, params, self.eq_params_heterogeneity
|
|
158
158
|
),
|
|
159
159
|
}
|
|
160
160
|
new_args = args[:-1] + (_params,)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.4
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -82,6 +82,7 @@ tests/runtests.sh
|
|
|
82
82
|
tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py
|
|
83
83
|
tests/dataGenerator_tests/test_CubicMeshPDEStatio.py
|
|
84
84
|
tests/dataGenerator_tests/test_DataGeneratorODE.py
|
|
85
|
+
tests/dataGenerator_tests/test_DataGeneratorParameter.py
|
|
85
86
|
tests/save_load_tests/test_saving_loading_hyperpinn.py
|
|
86
87
|
tests/save_load_tests/test_saving_loading_pinn.py
|
|
87
88
|
tests/save_load_tests/test_saving_loading_spinn.py
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import jinns
|
|
5
|
+
|
|
6
|
+
n = 64
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def create_DataGeneratorParameter():
|
|
11
|
+
key = jax.random.PRNGKey(2)
|
|
12
|
+
key, subkey = jax.random.split(key)
|
|
13
|
+
|
|
14
|
+
param_batch_size = 64
|
|
15
|
+
method = "uniform"
|
|
16
|
+
param_ranges = {"theta": (10.0, 11.0)}
|
|
17
|
+
user_data = {"nu": jnp.arange(n)}
|
|
18
|
+
|
|
19
|
+
return jinns.data.DataGeneratorParameter(
|
|
20
|
+
subkey,
|
|
21
|
+
n,
|
|
22
|
+
param_batch_size,
|
|
23
|
+
param_ranges=param_ranges,
|
|
24
|
+
method=method,
|
|
25
|
+
user_data=user_data,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def create_DataGeneratorParameter_only_user_data():
|
|
31
|
+
key = jax.random.PRNGKey(2)
|
|
32
|
+
key, subkey = jax.random.split(key)
|
|
33
|
+
|
|
34
|
+
param_batch_size = 64
|
|
35
|
+
method = "uniform"
|
|
36
|
+
user_data = {"nu": jnp.arange(n)}
|
|
37
|
+
|
|
38
|
+
return jinns.data.DataGeneratorParameter(
|
|
39
|
+
subkey,
|
|
40
|
+
n,
|
|
41
|
+
param_batch_size,
|
|
42
|
+
method=method,
|
|
43
|
+
user_data=user_data,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_get_batch(create_DataGeneratorParameter):
|
|
48
|
+
data_generator_parameters = create_DataGeneratorParameter
|
|
49
|
+
param_batch = data_generator_parameters.get_batch()
|
|
50
|
+
assert jnp.allclose(jnp.sort(jnp.unique(param_batch["nu"])), jnp.arange(n)) and (
|
|
51
|
+
jnp.all(param_batch["theta"] >= 10.0) and jnp.all(param_batch["theta"] <= 11.0)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_get_batch_only_user_data(create_DataGeneratorParameter_only_user_data):
|
|
56
|
+
data_generator_parameters = create_DataGeneratorParameter_only_user_data
|
|
57
|
+
param_batch = data_generator_parameters.get_batch()
|
|
58
|
+
assert jnp.allclose(jnp.sort(jnp.unique(param_batch["nu"])), jnp.arange(n))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_raise_error_with_wrong_shape_for_user_data():
|
|
62
|
+
key = jax.random.PRNGKey(2)
|
|
63
|
+
key, subkey = jax.random.split(key)
|
|
64
|
+
|
|
65
|
+
param_batch_size = 64
|
|
66
|
+
method = "uniform"
|
|
67
|
+
param_ranges = {"theta": (10.0, 11.0)}
|
|
68
|
+
# user_data is not (n,) or (n,1)
|
|
69
|
+
user_data = {"nu": jnp.ones((n, 1, 1))}
|
|
70
|
+
|
|
71
|
+
with pytest.raises(ValueError) as e_info:
|
|
72
|
+
# __init__ calls self.generate_data() that we are testing for
|
|
73
|
+
data_generator_parameters = jinns.data.DataGeneratorParameter(
|
|
74
|
+
subkey,
|
|
75
|
+
n,
|
|
76
|
+
param_batch_size,
|
|
77
|
+
param_ranges=param_ranges,
|
|
78
|
+
method=method,
|
|
79
|
+
user_data=user_data,
|
|
80
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jinns-0.8.2 → jinns-0.8.4}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{jinns-0.8.2 → jinns-0.8.4}/tests/sharding_tests/test_imperfect_sobolev_x32_multiple_shardings.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|