jinns 1.5.0__tar.gz → 1.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jinns-1.5.0 → jinns-1.5.1}/PKG-INFO +1 -1
- {jinns-1.5.0 → jinns-1.5.1}/docs/changelog.md +7 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/__init__.py +7 -7
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_CubicMeshPDENonStatio.py +156 -28
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_CubicMeshPDEStatio.py +132 -24
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_LossODE.py +95 -31
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_LossPDE.py +6 -15
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_loss_utils.py +23 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/parameters/_params.py +8 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/solver/_solve.py +11 -5
- {jinns-1.5.0 → jinns-1.5.1}/jinns.egg-info/PKG-INFO +1 -1
- {jinns-1.5.0 → jinns-1.5.1}/jinns.egg-info/SOURCES.txt +1 -0
- jinns-1.5.1/tests/dataGenerator_tests/test_sobol_method.py +94 -0
- jinns-1.5.1/tests/loss_tests/test_lossODE.py +120 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_Burgers_x32.py +1 -1
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_Burgers_x32_spinn.py +1 -1
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_Fisher_x32_spinn.py +1 -1
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_ReactionDiffusion_nonhomo_x64_spinn.py +1 -2
- jinns-1.5.0/tests/loss_tests/test_lossODE.py +0 -41
- {jinns-1.5.0 → jinns-1.5.1}/.gitignore +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/.gitlab-ci.yml +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/.pre-commit-config.yaml +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/AUTHORS +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/LICENSE +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/1D_Generalized_Lotka_Volterra.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/MS_model_Verhulst.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/linear_fo_equation.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/sbinn_data/glucose.dat +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/sbinn_data/meal.dat +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/ODE/systems_biology_informed_neural_network.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/1D_non_stationary_Burgers.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/1D_non_stationary_Fisher_KPP_Bounded_Domain.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_Heat_inverse_problem.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_Metamodel_hyperpinn.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_Navier_Stokes_PipeFlow_SoftConstraints.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_Poisson_inverse_problem.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2D_non_stationary_OU.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/2d_nonstatio_ou_standardsampling.png +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/OU_1D_nonstatio_solution_grid.npy +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/Reaction_Diffusion_2D_heterogenous_model.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/PDE/Reaction_Diffusion_2D_homogeneous_metamodel_hyperpinn_diffrax.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/Tutorials/1D_non_stationary_Burgers_JointEstimation_Vanilla.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/Tutorials/burgers_solution_grid.npy +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/Tutorials/implementing_your_own_PDE_problem.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/Tutorials/introducing_validation_loss.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/Notebooks/Tutorials/load_save_model.ipynb +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/README.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/codemeta.json +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/README.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/_static/custom_css.css +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/_static/favicon.png +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/advanced/derivative_keys.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/advanced/differential_operators.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/advanced/loss_weight_updates.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/datagenerators/datagenerators_core.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/datagenerators/datagenerators_other.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/loss/dynamic_loss.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/loss/loss_weights.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/loss/loss_xde.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/pinn/hyperpinn.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/pinn/pinn.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/pinn/ppinn.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/pinn/save_load.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/pinn/spinn.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/plot.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/api/solver.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/doc_requirements.txt +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/index.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/javascripts/katex.js +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/maths/fokker_planck.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/docs/maths/introduction_to_pinns.md +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/img/jinns-diagram.png +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_AbstractDataGenerator.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_Batchs.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_DataGeneratorODE.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_DataGeneratorObservations.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_DataGeneratorParameter.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/data/_utils.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/experimental/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/experimental/_diffrax_solver.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_DynamicLoss.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_DynamicLossAbstract.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_abstract_loss.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_boundary_conditions.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_loss_components.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_loss_weight_updates.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_loss_weights.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/loss/_operators.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_abstract_pinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_hyperpinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_mlp.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_pinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_ppinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_save_load.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_spinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_spinn_mlp.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/nn/_utils.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/parameters/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/parameters/_derivative_keys.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/plot/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/plot/_plot.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/solver/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/solver/_rar.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/solver/_utils.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/utils/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/utils/_containers.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/utils/_types.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/utils/_utils.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/validation/__init__.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns/validation/_validation.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns.egg-info/dependency_links.txt +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns.egg-info/requires.txt +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/jinns.egg-info/top_level.txt +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/mkdocs.yml +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/pyproject.toml +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/setup.cfg +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/adaptative_weight_tests/test_ReLoBRaLo_update.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/adaptative_weight_tests/test_loss_weight_update.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/adaptative_weight_tests/test_lr_annealing.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/adaptative_weight_tests/test_soft_adapt.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/conftest.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_CubicMeshPDENonStatio.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_CubicMeshPDEStatio.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_DataGeneratorODE.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/dataGenerator_tests/test_DataGeneratorParameter.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/loss_tests/test_lossPDEnonstatio.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/loss_tests/test_lossPDEstatio.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/loss_tests/test_norm_loss.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_hyperpinns.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_mlp.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_pinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_ppinn_mlp.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_smlp.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/nn_tests/test_spinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_divergence_fwd.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_divergence_rev.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_laplacian_fwd.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_laplacian_rev.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_vectorial_laplacian_fwd.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/operator_tests/test_vectorial_laplacian_rev.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/parameters_tests/test_DerivativeKeysODE.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/plot_tests/test_plot1D.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/plot_tests/test_plot2D.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_hyperpinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_pinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/save_load_tests/test_saving_loading_spinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/sharding_tests/test_Burgers_x32_multiple_shardings.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_Burgers_x64.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_Fisher_x32.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_Fisher_x64.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_GLV_x32.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_GLV_x64.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_NSPipeFlow_x32.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_NSPipeFlow_x64.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_OU1D_statio_x32.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_OU2D_x32.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_nan_params_catch.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_parameter_tracker.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests/test_rar_algorithm.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_hyperpinn/test_NSPipeFlow_x32_hyperpinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_NSPipeFlow_x32_spinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/solver_tests_spinn/test_OU2D_x32_spinn.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/utils_tests/test_solver_utils.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/utils_tests/test_subtract_with_check.py +0 -0
- {jinns-1.5.0 → jinns-1.5.1}/tests/validation_tests/test_vanilla_validation.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jinns
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.1
|
|
4
4
|
Summary: Physics Informed Neural Network with JAX
|
|
5
5
|
Author-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
6
6
|
Maintainer-email: Hugo Gangloff <hugo.gangloff@inrae.fr>, Nicolas Jouvin <nicolas.jouvin@inrae.fr>
|
|
@@ -1,5 +1,12 @@
|
|
|
1
1
|
# Changelog
|
|
2
2
|
|
|
3
|
+
* Unreleased (currently on `main`)
|
|
4
|
+
|
|
5
|
+
- More (initial) conditions for LossODE. Add the possibility to handle for example a condition at t0 and tmax [!77](https://gitlab.com/mia_jinns/jinns/-/merge_requests/77)
|
|
6
|
+
- Fix [inconsistent timings](https://gitlab.com/mia_jinns/jinns/-/issues?show=eyJpaWQiOiIxOCIsImZ1bGxfcGF0aCI6Im1pYV9qaW5ucy9qaW5ucyIsImlkIjoxNzAwMjcyNzh9) in `jinns.solve()`: the total elapsed time could previously be far off the sum of compilation + training time. This was due to a waste call to (non-JIT) `loss.evaluate` in the initialization. We now don't waste compute time to init the training loop and print the additional "initialization time" in order to have consistent timing with respect to user's elapsed time.
|
|
7
|
+
- Fix silence `equinox` warning about `field(init=False)` [!83](https://gitlab.com/mia_jinns/jinns/-/merge_requests/83).
|
|
8
|
+
- Add Quasi-random samplers [!84](https://gitlab.com/mia_jinns/jinns/-/merge_requests/84).
|
|
9
|
+
|
|
3
10
|
* v1.5.0
|
|
4
11
|
|
|
5
12
|
- Adaptative loss weights following a user defined update scheme, see the updated intro tutorial. Breaking changes since `jinns.solve()` returns signature has changed. Moreover, we do not support vectorial loss weights any more. Users are expected to ponderate the loss of their vectorial dynamic loss directly inside the dynamic loss definition (and obviously, this new loss weight update feature will not apply).
|
|
@@ -1,10 +1,3 @@
|
|
|
1
|
-
# import jinns.data
|
|
2
|
-
# import jinns.loss
|
|
3
|
-
# import jinns.solver
|
|
4
|
-
# import jinns.utils
|
|
5
|
-
# import jinns.experimental
|
|
6
|
-
# import jinns.parameters
|
|
7
|
-
# import jinns.plot
|
|
8
1
|
from jinns import data as data
|
|
9
2
|
from jinns import loss as loss
|
|
10
3
|
from jinns import solver as solver
|
|
@@ -16,3 +9,10 @@ from jinns import nn as nn
|
|
|
16
9
|
from jinns.solver._solve import solve
|
|
17
10
|
|
|
18
11
|
__all__ = ["nn", "solve"]
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
warnings.filterwarnings(
|
|
16
|
+
action="ignore",
|
|
17
|
+
message=r"Using `field\(init=False\)`",
|
|
18
|
+
)
|
|
@@ -7,8 +7,10 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
|
+
import numpy as np
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
13
|
+
from scipy.stats import qmc
|
|
12
14
|
from jaxtyping import Key, Array, Float
|
|
13
15
|
from jinns.data._Batchs import PDENonStatioBatch
|
|
14
16
|
from jinns.data._utils import (
|
|
@@ -65,11 +67,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
65
67
|
The minimum value of the time domain to consider
|
|
66
68
|
tmax : float
|
|
67
69
|
The maximum value of the time domain to consider
|
|
68
|
-
method :
|
|
70
|
+
method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
|
|
69
71
|
Either `grid` or `uniform`, default is `uniform`.
|
|
70
72
|
The method that generates the `nt` time points. `grid` means
|
|
71
73
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
72
|
-
sampled points over the domain
|
|
74
|
+
sampled points over the domain.
|
|
75
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
76
|
+
be JIT compatible.
|
|
73
77
|
rar_parameters : Dict[str, int], default=None
|
|
74
78
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
75
79
|
Otherwise a dictionary with keys
|
|
@@ -150,9 +154,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
150
154
|
elif self.method == "uniform":
|
|
151
155
|
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
152
156
|
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
157
|
+
elif self.method in ["sobol", "halton"]:
|
|
158
|
+
self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
|
|
153
159
|
else:
|
|
154
160
|
raise ValueError(
|
|
155
|
-
f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
|
|
161
|
+
f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
|
|
156
162
|
)
|
|
157
163
|
|
|
158
164
|
if self.domain_batch_size is None:
|
|
@@ -182,21 +188,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
182
188
|
"number of points per facets (nb//2*self.dim)"
|
|
183
189
|
" cannot be lower than border batch size"
|
|
184
190
|
)
|
|
185
|
-
self.
|
|
186
|
-
self.key,
|
|
187
|
-
|
|
188
|
-
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
189
|
-
boundary_times = jnp.repeat(
|
|
190
|
-
boundary_times, self.omega_border.shape[-1], axis=2
|
|
191
|
-
)
|
|
192
|
-
if self.dim == 1:
|
|
193
|
-
self.border = make_cartesian_product(
|
|
194
|
-
boundary_times, self.omega_border[None, None]
|
|
191
|
+
if self.method in ["grid", "uniform"]:
|
|
192
|
+
self.key, boundary_times = self.generate_time_data(
|
|
193
|
+
self.key, self.nb // (2 * self.dim)
|
|
195
194
|
)
|
|
195
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
196
|
+
boundary_times = jnp.repeat(
|
|
197
|
+
boundary_times, self.omega_border.shape[-1], axis=2
|
|
198
|
+
)
|
|
199
|
+
if self.dim == 1:
|
|
200
|
+
self.border = make_cartesian_product(
|
|
201
|
+
boundary_times, self.omega_border[None, None]
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
self.border = jnp.concatenate(
|
|
205
|
+
[boundary_times, self.omega_border], axis=1
|
|
206
|
+
)
|
|
196
207
|
else:
|
|
197
|
-
self.border =
|
|
198
|
-
|
|
208
|
+
self.key, self.border = self.qmc_in_time_omega_border_domain(
|
|
209
|
+
self.key,
|
|
210
|
+
self.nb, # type: ignore (see inside the fun)
|
|
199
211
|
)
|
|
212
|
+
|
|
200
213
|
if self.border_batch_size is None:
|
|
201
214
|
self.curr_border_idx = 0
|
|
202
215
|
else:
|
|
@@ -209,14 +222,30 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
209
222
|
self.curr_border_idx = 0
|
|
210
223
|
|
|
211
224
|
if self.ni is not None:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
225
|
+
if self.method == "grid":
|
|
226
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
227
|
+
if self.ni != perfect_sq:
|
|
228
|
+
warnings.warn(
|
|
229
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
230
|
+
f" perfect square dataset size (self.ni = {self.ni})."
|
|
231
|
+
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
232
|
+
)
|
|
233
|
+
self.ni = perfect_sq
|
|
234
|
+
if self.method in ["sobol", "halton"]:
|
|
235
|
+
log2_n = jnp.log2(self.ni)
|
|
236
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
237
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
238
|
+
closest_two_power = (
|
|
239
|
+
lower_pow
|
|
240
|
+
if (self.ni - lower_pow) < (higher_pow - self.ni)
|
|
241
|
+
else higher_pow
|
|
218
242
|
)
|
|
219
|
-
|
|
243
|
+
if self.n != closest_two_power:
|
|
244
|
+
warnings.warn(
|
|
245
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
246
|
+
f"Modfiying self.n from {self.ni} to {closest_two_power}.",
|
|
247
|
+
)
|
|
248
|
+
self.ni = int(closest_two_power)
|
|
220
249
|
self.key, self.initial = self.generate_omega_data(
|
|
221
250
|
self.key, data_size=self.ni
|
|
222
251
|
)
|
|
@@ -245,16 +274,115 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
245
274
|
if self.method == "grid":
|
|
246
275
|
partial_times = (self.tmax - self.tmin) / nt
|
|
247
276
|
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
248
|
-
|
|
277
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
249
278
|
return key, self.sample_in_time_domain(subkey, nt)
|
|
250
279
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
251
280
|
|
|
252
281
|
def sample_in_time_domain(self, key: Key, nt: int) -> Float[Array, " nt 1"]:
|
|
253
|
-
return jax.random.uniform(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
282
|
+
return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
|
|
283
|
+
|
|
284
|
+
def qmc_in_time_omega_domain(
|
|
285
|
+
self, key: Key, sample_size: int
|
|
286
|
+
) -> tuple[Key, Float[Array, "n 1+dim"]]:
|
|
287
|
+
"""
|
|
288
|
+
Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
|
|
289
|
+
We generate time and omega samples jointly
|
|
290
|
+
"""
|
|
291
|
+
key, subkey = jax.random.split(key, 2)
|
|
292
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
293
|
+
sampler = qmc_generator(
|
|
294
|
+
d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
295
|
+
)
|
|
296
|
+
samples = sampler.random(n=sample_size)
|
|
297
|
+
samples[:, 1:] = qmc.scale(
|
|
298
|
+
samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
|
|
299
|
+
) # We scale omega domain to be in (min_pts, max_pts)
|
|
300
|
+
return key, jnp.array(samples)
|
|
301
|
+
|
|
302
|
+
def qmc_in_time_omega_border_domain(
|
|
303
|
+
self, key: Key, sample_size: int | None = None
|
|
304
|
+
) -> tuple[Key, Float[Array, "n 1+dim"]] | None:
|
|
305
|
+
"""
|
|
306
|
+
For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
|
|
307
|
+
|
|
308
|
+
We need to do some type ignore in this function because we have lost
|
|
309
|
+
the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
|
|
310
|
+
"""
|
|
311
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
312
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
313
|
+
if sample_size is None:
|
|
314
|
+
return None
|
|
315
|
+
if self.dim == 1:
|
|
316
|
+
key, subkey = jax.random.split(key, 2)
|
|
317
|
+
qmc_seq = qmc_generator(
|
|
318
|
+
d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
319
|
+
)
|
|
320
|
+
boundary_times = jnp.array(
|
|
321
|
+
qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
|
|
322
|
+
)
|
|
323
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
324
|
+
boundary_times = jnp.repeat(
|
|
325
|
+
boundary_times,
|
|
326
|
+
self.omega_border.shape[-1], # type: ignore
|
|
327
|
+
axis=2,
|
|
328
|
+
)
|
|
329
|
+
return key, make_cartesian_product(
|
|
330
|
+
boundary_times,
|
|
331
|
+
self.omega_border[None, None], # type: ignore
|
|
332
|
+
)
|
|
333
|
+
if self.dim == 2:
|
|
334
|
+
# currently hard-coded the 4 edges for d==2
|
|
335
|
+
# TODO : find a general & efficient way to sample from the border
|
|
336
|
+
# (facets) of the hypercube in general dim.
|
|
337
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
338
|
+
facet_n = sample_size // (2 * self.dim)
|
|
339
|
+
|
|
340
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
341
|
+
qmc_seq = qmc_generator(
|
|
342
|
+
d=2,
|
|
343
|
+
scramble=True,
|
|
344
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
345
|
+
)
|
|
346
|
+
u = qmc_seq.random(n=facet_n)
|
|
347
|
+
u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
|
|
348
|
+
return jnp.array(u)
|
|
349
|
+
|
|
350
|
+
xmin_sample = generate_qmc_sample(
|
|
351
|
+
subkeys[0], self.min_pts[1], self.max_pts[1]
|
|
352
|
+
) # [t,x,y]
|
|
353
|
+
xmin = jnp.hstack(
|
|
354
|
+
[
|
|
355
|
+
xmin_sample[:, 0:1],
|
|
356
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
357
|
+
xmin_sample[:, 1:2],
|
|
358
|
+
]
|
|
359
|
+
)
|
|
360
|
+
xmax_sample = generate_qmc_sample(
|
|
361
|
+
subkeys[1], self.min_pts[1], self.max_pts[1]
|
|
362
|
+
)
|
|
363
|
+
xmax = jnp.hstack(
|
|
364
|
+
[
|
|
365
|
+
xmax_sample[:, 0:1],
|
|
366
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
367
|
+
xmax_sample[:, 1:2],
|
|
368
|
+
]
|
|
369
|
+
)
|
|
370
|
+
ymin = jnp.hstack(
|
|
371
|
+
[
|
|
372
|
+
generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
|
|
373
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
374
|
+
]
|
|
375
|
+
)
|
|
376
|
+
ymax = jnp.hstack(
|
|
377
|
+
[
|
|
378
|
+
generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
|
|
379
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
380
|
+
]
|
|
381
|
+
)
|
|
382
|
+
return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
383
|
+
raise NotImplementedError(
|
|
384
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
385
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
258
386
|
)
|
|
259
387
|
|
|
260
388
|
def _get_domain_operands(
|
|
@@ -8,8 +8,11 @@ from __future__ import (
|
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
import jax
|
|
11
|
+
import numpy as np
|
|
11
12
|
import jax.numpy as jnp
|
|
13
|
+
from scipy.stats import qmc
|
|
12
14
|
from jaxtyping import Key, Array, Float
|
|
15
|
+
from typing import Literal
|
|
13
16
|
from jinns.data._Batchs import PDEStatioBatch
|
|
14
17
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
18
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
50
53
|
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
51
54
|
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
52
55
|
x_{n,max})$
|
|
53
|
-
method :
|
|
54
|
-
Either
|
|
56
|
+
method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
|
|
57
|
+
Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
|
|
55
58
|
The method that generates the `nt` time points. `grid` means
|
|
56
59
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
57
|
-
sampled points over the domain
|
|
60
|
+
sampled points over the domain.
|
|
61
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
62
|
+
be JIT compatible.
|
|
58
63
|
rar_parameters : dict[str, int], default=None
|
|
59
64
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
60
65
|
Otherwise a dictionary with keys
|
|
@@ -94,7 +99,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
94
99
|
# shape in jax.lax.dynamic_slice
|
|
95
100
|
min_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
96
101
|
max_pts: tuple[float, ...] = eqx.field(kw_only=True)
|
|
97
|
-
method:
|
|
102
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
|
|
98
103
|
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
99
104
|
)
|
|
100
105
|
rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
|
|
@@ -132,6 +137,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
132
137
|
)
|
|
133
138
|
self.n = perfect_sq
|
|
134
139
|
|
|
140
|
+
if self.method in ["sobol", "halton"]:
|
|
141
|
+
log2_n = jnp.log2(self.n)
|
|
142
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
143
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
144
|
+
closest_two_power = (
|
|
145
|
+
lower_pow
|
|
146
|
+
if (self.n - lower_pow) < (higher_pow - self.n)
|
|
147
|
+
else higher_pow
|
|
148
|
+
)
|
|
149
|
+
if self.n != closest_two_power:
|
|
150
|
+
warnings.warn(
|
|
151
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
152
|
+
f"Modfiying self.n from {self.n} to {closest_two_power}.",
|
|
153
|
+
)
|
|
154
|
+
self.n = int(closest_two_power)
|
|
155
|
+
|
|
135
156
|
if self.omega_batch_size is None:
|
|
136
157
|
self.curr_omega_idx = 0
|
|
137
158
|
else:
|
|
@@ -176,24 +197,48 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
176
197
|
def sample_in_omega_domain(
|
|
177
198
|
self, keys: Key, sample_size: int
|
|
178
199
|
) -> Float[Array, " n dim"]:
|
|
200
|
+
if self.method == "uniform":
|
|
201
|
+
if self.dim == 1:
|
|
202
|
+
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
203
|
+
return jax.random.uniform(
|
|
204
|
+
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return jnp.concatenate(
|
|
208
|
+
[
|
|
209
|
+
jax.random.uniform(
|
|
210
|
+
keys[i],
|
|
211
|
+
(sample_size, 1),
|
|
212
|
+
minval=self.min_pts[i],
|
|
213
|
+
maxval=self.max_pts[i],
|
|
214
|
+
)
|
|
215
|
+
for i in range(self.dim)
|
|
216
|
+
],
|
|
217
|
+
axis=-1,
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
return self._qmc_in_omega_domain(keys, sample_size)
|
|
221
|
+
|
|
222
|
+
def _qmc_in_omega_domain(
|
|
223
|
+
self, subkey: Key, sample_size: int
|
|
224
|
+
) -> Float[Array, "n dim"]:
|
|
225
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
179
226
|
if self.dim == 1:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
227
|
+
qmc_seq = qmc_generator(
|
|
228
|
+
d=self.dim,
|
|
229
|
+
scramble=True,
|
|
230
|
+
rng=np.random.default_rng(np.uint32(subkey)),
|
|
183
231
|
)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
minval=self.min_pts[i],
|
|
191
|
-
maxval=self.max_pts[i],
|
|
192
|
-
)
|
|
193
|
-
for i in range(self.dim)
|
|
194
|
-
],
|
|
195
|
-
axis=-1,
|
|
232
|
+
u = qmc_seq.random(n=sample_size)
|
|
233
|
+
return jnp.array(
|
|
234
|
+
qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
|
|
235
|
+
)
|
|
236
|
+
sampler = qmc.Sobol(
|
|
237
|
+
d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
196
238
|
)
|
|
239
|
+
samples = sampler.random(n=sample_size)
|
|
240
|
+
samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
|
|
241
|
+
return jnp.array(samples)
|
|
197
242
|
|
|
198
243
|
def sample_in_omega_border_domain(
|
|
199
244
|
self, keys: Key, sample_size: int | None = None
|
|
@@ -260,6 +305,62 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
260
305
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
261
306
|
)
|
|
262
307
|
|
|
308
|
+
def qmc_in_omega_border_domain(
|
|
309
|
+
self, keys: Key, sample_size: int | None = None
|
|
310
|
+
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
311
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
312
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
313
|
+
if sample_size is None:
|
|
314
|
+
return None
|
|
315
|
+
if self.dim == 1:
|
|
316
|
+
xmin = self.min_pts[0]
|
|
317
|
+
xmax = self.max_pts[0]
|
|
318
|
+
return jnp.array([xmin, xmax]).astype(float)
|
|
319
|
+
if self.dim == 2:
|
|
320
|
+
# currently hard-coded the 4 edges for d==2
|
|
321
|
+
# TODO : find a general & efficient way to sample from the border
|
|
322
|
+
# (facets) of the hypercube in general dim.
|
|
323
|
+
facet_n = sample_size // (2 * self.dim)
|
|
324
|
+
|
|
325
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
326
|
+
qmc_seq = qmc_generator(
|
|
327
|
+
d=1,
|
|
328
|
+
scramble=True,
|
|
329
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
330
|
+
)
|
|
331
|
+
u = qmc_seq.random(n=facet_n)
|
|
332
|
+
return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
|
|
333
|
+
|
|
334
|
+
xmin = jnp.hstack(
|
|
335
|
+
[
|
|
336
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
337
|
+
generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
xmax = jnp.hstack(
|
|
341
|
+
[
|
|
342
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
343
|
+
generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
|
|
344
|
+
]
|
|
345
|
+
)
|
|
346
|
+
ymin = jnp.hstack(
|
|
347
|
+
[
|
|
348
|
+
generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
|
|
349
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
350
|
+
]
|
|
351
|
+
)
|
|
352
|
+
ymax = jnp.hstack(
|
|
353
|
+
[
|
|
354
|
+
generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
|
|
355
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
356
|
+
]
|
|
357
|
+
)
|
|
358
|
+
return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
359
|
+
raise NotImplementedError(
|
|
360
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
361
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
362
|
+
)
|
|
363
|
+
|
|
263
364
|
def generate_omega_data(
|
|
264
365
|
self, key: Key, data_size: int | None = None
|
|
265
366
|
) -> tuple[
|
|
@@ -290,8 +391,8 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
290
391
|
)
|
|
291
392
|
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
292
393
|
omega = jnp.concatenate(xyz_, axis=-1)
|
|
293
|
-
elif self.method
|
|
294
|
-
if self.dim == 1:
|
|
394
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
395
|
+
if self.dim == 1 or self.method in ["sobol", "halton"]:
|
|
295
396
|
key, subkeys = jax.random.split(key, 2)
|
|
296
397
|
else:
|
|
297
398
|
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
@@ -317,10 +418,17 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
317
418
|
key, *subkeys = jax.random.split(key, 5)
|
|
318
419
|
else:
|
|
319
420
|
subkeys = None
|
|
320
|
-
omega_border = self.sample_in_omega_border_domain(
|
|
321
|
-
subkeys, sample_size=data_size
|
|
322
|
-
)
|
|
323
421
|
|
|
422
|
+
if self.method in ["grid", "uniform"]:
|
|
423
|
+
omega_border = self.sample_in_omega_border_domain(
|
|
424
|
+
subkeys, sample_size=data_size
|
|
425
|
+
)
|
|
426
|
+
elif self.method in ["sobol", "halton"]:
|
|
427
|
+
omega_border = self.qmc_in_omega_border_domain(
|
|
428
|
+
subkeys, sample_size=data_size
|
|
429
|
+
)
|
|
430
|
+
else:
|
|
431
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
324
432
|
return key, omega_border
|
|
325
433
|
|
|
326
434
|
def _get_omega_operands(
|