spherical-deepkriging 1.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. spherical_deepkriging-1.0.4/LICENSE +21 -0
  2. spherical_deepkriging-1.0.4/MANIFEST.in +5 -0
  3. spherical_deepkriging-1.0.4/PKG-INFO +115 -0
  4. spherical_deepkriging-1.0.4/README.md +40 -0
  5. spherical_deepkriging-1.0.4/pyproject.toml +145 -0
  6. spherical_deepkriging-1.0.4/setup.cfg +4 -0
  7. spherical_deepkriging-1.0.4/setup.py +112 -0
  8. spherical_deepkriging-1.0.4/spherical_deepkriging/__init__.py +0 -0
  9. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/__init__.py +0 -0
  10. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/__init__.py +0 -0
  11. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/mrts.py +39 -0
  12. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/utils.py +127 -0
  13. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/visualization.py +168 -0
  14. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/__init__.py +0 -0
  15. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/CMakeLists.txt +74 -0
  16. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/__init__.py +38 -0
  17. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/setup.py +28 -0
  18. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/spherical_basis.cpp +478 -0
  19. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/sphere_cpp.py +130 -0
  20. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/utils.py +17 -0
  21. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/__init__.py +0 -0
  22. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/visualization.py +95 -0
  23. spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/wenland.py +39 -0
  24. spherical_deepkriging-1.0.4/spherical_deepkriging/configs.py +54 -0
  25. spherical_deepkriging-1.0.4/spherical_deepkriging/logger.py +29 -0
  26. spherical_deepkriging-1.0.4/spherical_deepkriging/models/__init__.py +0 -0
  27. spherical_deepkriging-1.0.4/spherical_deepkriging/models/deep_kriging.py +139 -0
  28. spherical_deepkriging-1.0.4/spherical_deepkriging/models/universal_kriging.py +362 -0
  29. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/PKG-INFO +115 -0
  30. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/SOURCES.txt +45 -0
  31. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/dependency_links.txt +1 -0
  32. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/not-zip-safe +1 -0
  33. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/requires.txt +64 -0
  34. spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/top_level.txt +1 -0
  35. spherical_deepkriging-1.0.4/tests/test_basis_utils_and_wendland.py +99 -0
  36. spherical_deepkriging-1.0.4/tests/test_basis_visualization.py +125 -0
  37. spherical_deepkriging-1.0.4/tests/test_configs.py +30 -0
  38. spherical_deepkriging-1.0.4/tests/test_deep_kriging_shapes.py +43 -0
  39. spherical_deepkriging-1.0.4/tests/test_deep_kriging_train.py +111 -0
  40. spherical_deepkriging-1.0.4/tests/test_logger.py +32 -0
  41. spherical_deepkriging-1.0.4/tests/test_mrts_modules.py +68 -0
  42. spherical_deepkriging-1.0.4/tests/test_sphere_cpp_unit.py +94 -0
  43. spherical_deepkriging-1.0.4/tests/test_sphere_wrapper_and_cpp_init.py +89 -0
  44. spherical_deepkriging-1.0.4/tests/test_spherical_cpp.py +233 -0
  45. spherical_deepkriging-1.0.4/tests/test_spherical_cpp_setup.py +52 -0
  46. spherical_deepkriging-1.0.4/tests/test_universal_kriging_coords.py +14 -0
  47. spherical_deepkriging-1.0.4/tests/test_universal_kriging_unit.py +236 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 STLABTW
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,5 @@
1
+ prune spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/build
2
+ global-exclude *.so *.pyd *.dylib *.dll *.a *.lib
3
+
4
+ # Include C++ sources for the spherical basis extension in sdist builds.
5
+ recursive-include spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions *.cpp *.h *.hpp *.txt *.py CMakeLists.txt
@@ -0,0 +1,115 @@
1
+ Metadata-Version: 2.4
2
+ Name: spherical-deepkriging
3
+ Version: 1.0.4
4
+ Summary: Deep learning package for FRK
5
+ Author-email: Wen-Ting Wang <egpivo@gmail.com>, Hao-Yun Huang <hhuscout@gms.ndhu.edu.tw>, Ping-Hsun Chiang <andrew501228@gmail.com>, Wu-Wei Ying <wuweiying1011@gms.ndhu.edu.tw>
6
+ License-Expression: MIT
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch<3.0.0,>=2.0.0
19
+ Requires-Dist: pytorch-lightning<3.0.0,>=2.6.0
20
+ Requires-Dist: numpy<2.0.0,>=1.24.0
21
+ Requires-Dist: tensorflow<3.0.0,>=2.19.0
22
+ Requires-Dist: pandas>=2.0.0
23
+ Requires-Dist: tabulate>=0.9.0
24
+ Requires-Dist: scikit-learn>=1.3.0
25
+ Requires-Dist: scipy>=1.10.0
26
+ Requires-Dist: pybind11<3.0.0,>=2.11.0
27
+ Requires-Dist: tqdm>=4.65.0
28
+ Requires-Dist: tensorboard<2.20.0,>=2.13.0
29
+ Requires-Dist: protobuf<5.0.0,>=4.25.8
30
+ Requires-Dist: pykrige>=1.7.0
31
+ Requires-Dist: seaborn>=0.13.0
32
+ Requires-Dist: GPy>=1.10.0
33
+ Requires-Dist: optuna>=3.0.0
34
+ Requires-Dist: plotly>=5.0.0
35
+ Requires-Dist: ipykernel<7.0.0,>=6.13.0
36
+ Provides-Extra: vision
37
+ Requires-Dist: torchvision>=0.15.0; extra == "vision"
38
+ Provides-Extra: audio
39
+ Requires-Dist: torchaudio>=2.0.0; extra == "audio"
40
+ Provides-Extra: viz
41
+ Requires-Dist: matplotlib>=3.7.0; extra == "viz"
42
+ Requires-Dist: pillow<12.0.0,>=10.0.0; extra == "viz"
43
+ Requires-Dist: plotly>=5.0.0; extra == "viz"
44
+ Provides-Extra: data
45
+ Provides-Extra: mrts
46
+ Requires-Dist: jax<0.5.0,>=0.4.0; extra == "mrts"
47
+ Requires-Dist: jaxlib<0.5.0,>=0.4.0; extra == "mrts"
48
+ Provides-Extra: utils
49
+ Provides-Extra: jupyter
50
+ Requires-Dist: jupyter>=1.0.0; extra == "jupyter"
51
+ Requires-Dist: ipykernel<7.0.0,>=6.13.0; extra == "jupyter"
52
+ Requires-Dist: notebook>=7.0.0; extra == "jupyter"
53
+ Requires-Dist: matplotlib>=3.7.0; extra == "jupyter"
54
+ Provides-Extra: all
55
+ Requires-Dist: torchvision>=0.15.0; extra == "all"
56
+ Requires-Dist: torchaudio>=2.0.0; extra == "all"
57
+ Requires-Dist: matplotlib>=3.7.0; extra == "all"
58
+ Requires-Dist: scikit-learn>=1.3.0; extra == "all"
59
+ Requires-Dist: scipy>=1.10.0; extra == "all"
60
+ Requires-Dist: tqdm>=4.65.0; extra == "all"
61
+ Requires-Dist: tensorboard<2.20.0,>=2.13.0; extra == "all"
62
+ Requires-Dist: protobuf<5.0.0,>=4.25.8; extra == "all"
63
+ Requires-Dist: pillow<12.0.0,>=10.0.0; extra == "all"
64
+ Requires-Dist: pandas>=2.0.0; extra == "all"
65
+ Requires-Dist: plotly>=5.0.0; extra == "all"
66
+ Provides-Extra: dev
67
+ Requires-Dist: twine>=4.0.0; extra == "dev"
68
+ Requires-Dist: pytest>=7.4.0; extra == "dev"
69
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
70
+ Requires-Dist: black>=23.7.0; extra == "dev"
71
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
72
+ Requires-Dist: mypy>=1.5.0; extra == "dev"
73
+ Dynamic: license-file
74
+ Dynamic: requires-python
75
+
76
+ # Spherical DeepKriging
77
+
78
+ [![Tests](https://github.com/STLABTW/spherical-deepkriging/workflows/Test/badge.svg)](https://github.com/STLABTW/spherical-deepkriging/actions)
79
+ [![codecov](https://codecov.io/github/STLABTW/spherical-deepkriging/graph/badge.svg?token=OF0LKVDII6)](https://codecov.io/github/STLABTW/spherical-deepkriging)
80
+
81
+ Code for **DeepKriging on the Global Data** ([arXiv:2604.01689](https://arxiv.org/abs/2604.01689)): spherical spatial prediction with DeepKriging, MRTS-sphere / Wendland bases, and universal kriging. Implementation lives under `spherical_deepkriging/`.
82
+
83
+ ## Setup
84
+
85
+ Needs [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On Windows, use WSL.
86
+
87
+ ```bash
88
+ git clone https://github.com/STLABTW/spherical-deepkriging.git
89
+ cd spherical-deepkriging
90
+ make install-dev
91
+ ```
92
+
93
+ `make install-dev` creates the conda env and installs dependencies; `make build-spherical-cpp` builds the MRTS-sphere C++ extension.
94
+
95
+ ## Examples
96
+
97
+ - Smoke test: `examples/toy/toy_sphere_deepkriging.ipynb`
98
+ - Simulations: `examples/simulation/`
99
+ - Real data: `examples/real_data/`
100
+
101
+ See `examples/README.md` for run notes.
102
+
103
+ ## Citation
104
+
105
+ ```bibtex
106
+ @misc{huang2026deepkrigingglobaldata,
107
+ title={DeepKriging on the Global Data},
108
+ author={Hao-Yun Huang and Wen-Ting Wang and Ping-Hsun Chiang and Wei-Ying Wu},
109
+ year={2026},
110
+ eprint={2604.01689},
111
+ archivePrefix={arXiv},
112
+ primaryClass={stat.ME},
113
+ url={https://arxiv.org/abs/2604.01689},
114
+ }
115
+ ```
@@ -0,0 +1,40 @@
1
+ # Spherical DeepKriging
2
+
3
+ [![Tests](https://github.com/STLABTW/spherical-deepkriging/workflows/Test/badge.svg)](https://github.com/STLABTW/spherical-deepkriging/actions)
4
+ [![codecov](https://codecov.io/github/STLABTW/spherical-deepkriging/graph/badge.svg?token=OF0LKVDII6)](https://codecov.io/github/STLABTW/spherical-deepkriging)
5
+
6
+ Code for **DeepKriging on the Global Data** ([arXiv:2604.01689](https://arxiv.org/abs/2604.01689)): spherical spatial prediction with DeepKriging, MRTS-sphere / Wendland bases, and universal kriging. Implementation lives under `spherical_deepkriging/`.
7
+
8
+ ## Setup
9
+
10
+ Needs [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On Windows, use WSL.
11
+
12
+ ```bash
13
+ git clone https://github.com/STLABTW/spherical-deepkriging.git
14
+ cd spherical-deepkriging
15
+ make install-dev
16
+ ```
17
+
18
+ `make install-dev` creates the conda env and installs dependencies; `make build-spherical-cpp` builds the MRTS-sphere C++ extension.
19
+
20
+ ## Examples
21
+
22
+ - Smoke test: `examples/toy/toy_sphere_deepkriging.ipynb`
23
+ - Simulations: `examples/simulation/`
24
+ - Real data: `examples/real_data/`
25
+
26
+ See `examples/README.md` for run notes.
27
+
28
+ ## Citation
29
+
30
+ ```bibtex
31
+ @misc{huang2026deepkrigingglobaldata,
32
+ title={DeepKriging on the Global Data},
33
+ author={Hao-Yun Huang and Wen-Ting Wang and Ping-Hsun Chiang and Wei-Ying Wu},
34
+ year={2026},
35
+ eprint={2604.01689},
36
+ archivePrefix={arXiv},
37
+ primaryClass={stat.ME},
38
+ url={https://arxiv.org/abs/2604.01689},
39
+ }
40
+ ```
@@ -0,0 +1,145 @@
1
+ [build-system]
2
+ requires = [
3
+ "setuptools>=77.0.0",
4
+ "wheel",
5
+ "cmake>=3.10",
6
+ "pybind11>=2.11.0,<3.0.0",
7
+ "numpy>=1.24.0,<2.0.0",
8
+ "tomli>=2.0.0",
9
+ ]
10
+ build-backend = "setuptools.build_meta"
11
+
12
+ [project]
13
+ name = "spherical-deepkriging"
14
+ version = "1.0.4"
15
+ description = "Deep learning package for FRK"
16
+ readme = "README.md"
17
+ requires-python = ">=3.10"
18
+ license = "MIT"
19
+ license-files = ["LICENSE"]
20
+ authors = [
21
+ {name = "Wen-Ting Wang", email = "egpivo@gmail.com"},
22
+ {name = "Hao-Yun Huang", email = "hhuscout@gms.ndhu.edu.tw"},
23
+ {name = "Ping-Hsun Chiang", email = "andrew501228@gmail.com"},
24
+ {name = "Wu-Wei Ying", email = "wuweiying1011@gms.ndhu.edu.tw"},
25
+ ]
26
+ classifiers = [
27
+ "Development Status :: 3 - Alpha",
28
+ "Intended Audience :: Developers",
29
+ "Intended Audience :: Science/Research",
30
+ "Programming Language :: Python :: 3",
31
+ "Programming Language :: Python :: 3.10",
32
+ "Programming Language :: Python :: 3.11",
33
+ "Programming Language :: Python :: 3.12",
34
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
35
+ ]
36
+
37
+ dependencies = [
38
+ "torch>=2.0.0,<3.0.0", # Avoid conflicts with neural-spatial (requires <2.3.0), but allow newer versions for DeepFRK
39
+ "pytorch-lightning>=2.6.0,<3.0.0",
40
+ "numpy>=1.24.0,<2.0.0", # Avoid conflicts with packages requiring numpy<2.0.0 (hangman, neural-spatial, tensorflow)
41
+ "tensorflow>=2.19.0,<3.0.0",
42
+ "pandas>=2.0.0",
43
+ "tabulate>=0.9.0",
44
+ "scikit-learn>=1.3.0",
45
+ "scipy>=1.10.0",
46
+ "pybind11>=2.11.0,<3.0.0", # Avoid conflicts with neural-spatial (requires <3.0.0)
47
+ "tqdm>=4.65.0",
48
+ "tensorboard>=2.13.0,<2.20.0", # Use <2.20.0 to avoid conflicts, but keep >=2.13.0 for compatibility
49
+ "protobuf>=4.25.8,<5.0.0",
50
+ "pykrige>=1.7.0", # Universal Kriging for spatial interpolation
51
+ "seaborn>=0.13.0", # Statistical data visualization
52
+ "GPy>=1.10.0", # Gaussian Processes for data generation (matching Lin paper)
53
+ "optuna>=3.0.0", # Bayesian hyperparameter optimization
54
+ "plotly>=5.0.0", # Required for optuna.visualization
55
+ "ipykernel>=6.13.0,<7.0.0", # Jupyter kernel support
56
+ ]
57
+
58
+ [project.optional-dependencies]
59
+ vision = [
60
+ "torchvision>=0.15.0",
61
+ ]
62
+ audio = [
63
+ "torchaudio>=2.0.0",
64
+ ]
65
+ viz = [
66
+ "matplotlib>=3.7.0",
67
+ "pillow>=10.0.0,<12.0.0", # Avoid conflicts with streamlit (requires <12)
68
+ "plotly>=5.0.0", # Interactive visualization
69
+ ]
70
+ data = [
71
+ # pandas and scikit-learn are now in main dependencies
72
+ ]
73
+ mrts = [
74
+ "jax>=0.4.0,<0.5.0", # Avoid conflicts with neural-spatial (requires <0.5.0)
75
+ "jaxlib>=0.4.0,<0.5.0", # Avoid conflicts with neural-spatial (requires <0.5.0)
76
+ ]
77
+ utils = [
78
+ # tqdm and tensorboard moved to main dependencies
79
+ ]
80
+ jupyter = [
81
+ "jupyter>=1.0.0",
82
+ "ipykernel>=6.13.0,<7.0.0", # Jupyter kernel support
83
+ "notebook>=7.0.0",
84
+ "matplotlib>=3.7.0",
85
+ ]
86
+ all = [
87
+ "torchvision>=0.15.0",
88
+ "torchaudio>=2.0.0",
89
+ "matplotlib>=3.7.0",
90
+ "scikit-learn>=1.3.0",
91
+ "scipy>=1.10.0",
92
+ "tqdm>=4.65.0",
93
+ "tensorboard>=2.13.0,<2.20.0",
94
+ "protobuf>=4.25.8,<5.0.0",
95
+ "pillow>=10.0.0,<12.0.0",
96
+ "pandas>=2.0.0",
97
+ "plotly>=5.0.0",
98
+ ]
99
+ dev = [
100
+ "twine>=4.0.0",
101
+ "pytest>=7.4.0",
102
+ "pytest-cov>=4.1.0",
103
+ "black>=23.7.0",
104
+ "ruff>=0.1.0",
105
+ "mypy>=1.5.0",
106
+ ]
107
+
108
+ [tool.setuptools]
109
+
110
+ [tool.setuptools.packages.find]
111
+ where = ["."]
112
+ include = ["spherical_deepkriging*"]
113
+
114
+ [tool.black]
115
+ line-length = 88
116
+ target-version = ['py310']
117
+
118
+ [tool.ruff]
119
+ line-length = 88
120
+ target-version = "py310"
121
+
122
+ [tool.pytest.ini_options]
123
+ testpaths = ["tests"]
124
+ python_files = "test_*.py"
125
+ python_classes = "Test*"
126
+ python_functions = "test_*"
127
+
128
+ [tool.coverage.run]
129
+ source = ["spherical_deepkriging"]
130
+ omit = [
131
+ "*/tests/*",
132
+ "*/__pycache__/*",
133
+ "*/examples/*",
134
+ ]
135
+
136
+ [tool.coverage.report]
137
+ exclude_lines = [
138
+ "pragma: no cover",
139
+ "def __repr__",
140
+ "raise AssertionError",
141
+ "raise NotImplementedError",
142
+ "if __name__ == .__main__.:",
143
+ "if TYPE_CHECKING:",
144
+ "^\\.\\.\\.$", # Exclude ellipsis in overload signatures
145
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,112 @@
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ from setuptools import Extension, find_packages, setup
7
+ from setuptools.command.build_ext import build_ext
8
+
9
+ try:
10
+ import tomllib # Python 3.11+
11
+ except ModuleNotFoundError: # pragma: no cover
12
+ import tomli as tomllib # type: ignore[no-redef]
13
+
14
+
15
+ ROOT = Path(__file__).resolve().parent
16
+
17
+
18
+ class CMakeExtension(Extension):
19
+ def __init__(self, name: str, sourcedir: str):
20
+ super().__init__(name=name, sources=[])
21
+ self.sourcedir = str(Path(sourcedir).resolve())
22
+
23
+
24
+ class CMakeBuild(build_ext):
25
+ def build_extension(self, ext: Extension) -> None:
26
+ if not isinstance(ext, CMakeExtension):
27
+ return super().build_extension(ext)
28
+
29
+ extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()
30
+ build_temp = Path(self.build_temp) / ext.name.split(".")[-1]
31
+ build_temp.mkdir(parents=True, exist_ok=True)
32
+
33
+ cfg = "Release"
34
+ python_exe = sys.executable
35
+ jobs = max(1, (os.cpu_count() or 2) - 1)
36
+
37
+ cmake_args = [
38
+ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
39
+ f"-DCMAKE_BUILD_TYPE={cfg}",
40
+ f"-DPython3_EXECUTABLE={python_exe}",
41
+ ]
42
+
43
+ # Ensure CMake can locate pybind11Config.cmake in isolated build envs.
44
+ try:
45
+ import pybind11 # type: ignore
46
+
47
+ pybind11_dir = Path(pybind11.get_cmake_dir()).resolve()
48
+ cmake_args.append(f"-Dpybind11_DIR={pybind11_dir}")
49
+ except Exception:
50
+ # Let CMake try its default discovery path if pybind11 is unavailable.
51
+ pass
52
+
53
+ subprocess.check_call(
54
+ ["cmake", "-S", ext.sourcedir, "-B", str(build_temp), *cmake_args]
55
+ )
56
+
57
+ # cmake --build passes `--config` only on multi-config generators (Windows)
58
+ if os.name == "nt":
59
+ subprocess.check_call(
60
+ ["cmake", "--build", str(build_temp), "--config", cfg]
61
+ )
62
+ else:
63
+ subprocess.check_call(
64
+ [
65
+ "cmake",
66
+ "--build",
67
+ str(build_temp),
68
+ "--config",
69
+ cfg,
70
+ "--",
71
+ f"-j{jobs}",
72
+ ]
73
+ )
74
+
75
+
76
+ def _load_project_metadata() -> dict:
77
+ with open(ROOT / "pyproject.toml", "rb") as f:
78
+ data = tomllib.load(f)
79
+ return data.get("project", {}) # type: ignore[no-any-return]
80
+
81
+
82
+ project = _load_project_metadata()
83
+
84
+ name = project.get("name", "spherical-deepkriging")
85
+ version = project.get("version", "0.0.0")
86
+ dependencies = project.get("dependencies", [])
87
+ python_requires = project.get("requires-python", ">=3.10")
88
+
89
+ cpp_ext_spherical_dir = (
90
+ ROOT
91
+ / "spherical_deepkriging"
92
+ / "basis_functions"
93
+ / "mrts_sphere"
94
+ / "cpp_extensions"
95
+ )
96
+
97
+ setup(
98
+ name=name,
99
+ version=version,
100
+ python_requires=python_requires,
101
+ install_requires=dependencies,
102
+ packages=find_packages(include=["spherical_deepkriging*"]),
103
+ include_package_data=True,
104
+ ext_modules=[
105
+ CMakeExtension(
106
+ "spherical_deepkriging.basis_functions.mrts_sphere.cpp_extensions.spherical_basis",
107
+ sourcedir=str(cpp_ext_spherical_dir),
108
+ )
109
+ ],
110
+ cmdclass={"build_ext": CMakeBuild},
111
+ zip_safe=False,
112
+ )
@@ -0,0 +1,39 @@
1
+ from typing import Dict, Optional, Tuple, Union
2
+
3
+ import jax.numpy as jnp
4
+ from jax.scipy.linalg import solve
5
+
6
+ from spherical_deepkriging.basis_functions.mrts.utils import (
7
+ build_extended_matrix,
8
+ compute_h,
9
+ dist,
10
+ predict_rabf,
11
+ )
12
+
13
+
14
+ def mrts0(
15
+ knot: jnp.ndarray, k: int, x: Optional[jnp.ndarray] = None
16
+ ) -> Union[Tuple[jnp.ndarray, Dict[str, jnp.ndarray], int], jnp.ndarray]:
17
+ """Main function for MRTS with static pre-processing."""
18
+ Xu = jnp.asarray(knot)
19
+ ndims = Xu.shape[1]
20
+
21
+ if k < (ndims + 1):
22
+ raise ValueError(f"Invalid k: {k}. It must be >= {ndims + 1}")
23
+
24
+ slice_size = k - ndims - 1
25
+ AHA, UZ = build_extended_matrix(Xu, k=k, ndims=ndims, slice_size=slice_size)
26
+
27
+ BBBH = solve(
28
+ jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]).T
29
+ @ jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]),
30
+ jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]).T,
31
+ ) @ compute_h(dist(Xu, Xu), ndims)
32
+ obj_attrs = {
33
+ "S": AHA,
34
+ "UZ": UZ,
35
+ "Xu": Xu,
36
+ "BBBH": BBBH,
37
+ "ndims": ndims,
38
+ }
39
+ return predict_rabf(obj_attrs, x, k) if x is not None else (Xu, obj_attrs, k)
@@ -0,0 +1,127 @@
1
+ from functools import partial
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import jax.numpy as jnp
5
+ from jax import jit, lax
6
+
7
+
8
+ @jit
9
+ def dist(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
10
+ return jnp.sqrt(
11
+ jnp.sum((A[:, jnp.newaxis, :] - B[jnp.newaxis, :, :]) ** 2, axis=-1)
12
+ )
13
+
14
+
15
+ def compute_h(d: jnp.ndarray, ndims: int) -> jnp.ndarray:
16
+ """Computes the H matrix based on the dimension."""
17
+
18
+ def case_ndims_1(_: None) -> jnp.ndarray:
19
+ return (1 / 12) * d**3
20
+
21
+ def case_ndims_2(_: None) -> jnp.ndarray:
22
+ return (1 / 8) * d**2 * jnp.log(d + 1e-8)
23
+
24
+ def case_ndims_3(_: None) -> jnp.ndarray:
25
+ return (-1 / 8) * d
26
+
27
+ cases = [case_ndims_1, case_ndims_2, case_ndims_3]
28
+ return lax.switch(ndims - 1, cases, None)
29
+
30
+
31
+ @partial(jit, static_argnames=["ndims", "slice_size"])
32
+ def build_extended_matrix(
33
+ Xu: jnp.ndarray, k: int, ndims: int, slice_size: int
34
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
35
+ """Builds the S and UZ matrices."""
36
+ n = Xu.shape[0]
37
+ B = jnp.hstack([jnp.ones((n, 1)), Xu])
38
+ BBB = jnp.linalg.solve(B.T @ B, B.T)
39
+ A = jnp.eye(n) - B @ BBB
40
+
41
+ d = jnp.sqrt(jnp.sum((Xu[:, None, :] - Xu[None, :, :]) ** 2, axis=-1))
42
+ H = compute_h(d, ndims)
43
+ AH = H - (H @ B) @ BBB
44
+ AHA = AH - BBB.T @ (B.T @ AH)
45
+
46
+ gamma0, _, _ = jnp.linalg.svd(AHA, full_matrices=False)
47
+
48
+ gamma0 = lax.dynamic_slice(
49
+ gamma0, start_indices=(0, 0), slice_sizes=(gamma0.shape[0], slice_size)
50
+ )
51
+
52
+ trueBS = AHA @ gamma0
53
+ rho = jnp.sqrt(jnp.sum(trueBS**2, axis=0))
54
+ gammas = A @ gamma0 / rho * jnp.sqrt(n)
55
+
56
+ extension_dim = ndims + 1
57
+
58
+ UZ = jnp.vstack(
59
+ [
60
+ jnp.hstack([gammas, jnp.zeros((gammas.shape[0], extension_dim))]),
61
+ jnp.zeros((extension_dim, gammas.shape[1] + extension_dim)),
62
+ ]
63
+ )
64
+ pad_size = ndims + 1
65
+
66
+ def compute_valid_indices() -> jnp.ndarray:
67
+ max_valid_index = min(pad_size, Xu.shape[1])
68
+ return jnp.arange(max_valid_index)
69
+
70
+ valid_indices = lax.stop_gradient(compute_valid_indices())
71
+
72
+ updates = 1 / jnp.std(Xu[:, valid_indices], axis=0) / jnp.sqrt((n - 1) / n)
73
+ UZ = UZ.at[n + valid_indices, k - pad_size - 1 + valid_indices].set(updates)
74
+
75
+ return AHA, UZ
76
+
77
+
78
+ @partial(jit, static_argnames=["k"])
79
+ def predict_rabf(
80
+ obj: Dict[str, jnp.ndarray],
81
+ newx: Optional[jnp.ndarray] = None,
82
+ k: Optional[int] = None,
83
+ ) -> jnp.ndarray:
84
+ """Predicts new values based on the provided model object."""
85
+ if newx is None:
86
+ return obj
87
+
88
+ x0 = newx
89
+ d = dist(x0, obj["Xu"])
90
+ ndims = x0.shape[1] if x0.ndim > 1 else 1
91
+ H = compute_h(d, ndims)
92
+
93
+ kstar = k - ndims - 1
94
+
95
+ def true_branch(_: None) -> jnp.ndarray:
96
+ slice_UZ = lax.dynamic_slice(
97
+ obj["UZ"], start_indices=(0, 0), slice_sizes=(obj["Xu"].shape[0], kstar)
98
+ )
99
+ X1 = H @ slice_UZ
100
+
101
+ B = jnp.hstack(
102
+ [jnp.ones((x0.shape[0], 1)), x0 if x0.ndim > 1 else x0[:, jnp.newaxis]]
103
+ )
104
+ BBBH_UZ = lax.dynamic_slice(
105
+ obj["UZ"], start_indices=(0, 0), slice_sizes=(obj["Xu"].shape[0], kstar)
106
+ )
107
+ X1 -= B @ obj["BBBH"] @ BBBH_UZ
108
+ X1 /= jnp.sqrt(obj["Xu"].shape[0])
109
+ return X1
110
+
111
+ def false_branch(_: None) -> jnp.ndarray:
112
+ return jnp.zeros((x0.shape[0], kstar))
113
+
114
+ X1 = lax.cond(kstar > 0, true_branch, false_branch, operand=None)
115
+
116
+ X2 = jnp.hstack(
117
+ [jnp.ones((x0.shape[0], 1)), x0 if x0.ndim > 1 else x0[:, jnp.newaxis]]
118
+ )
119
+
120
+ def concat_branch(_: None) -> jnp.ndarray:
121
+ return jnp.hstack([X2, X1])
122
+
123
+ def no_concat_branch(_: None) -> jnp.ndarray:
124
+ padding = jnp.zeros((x0.shape[0], X1.shape[1]))
125
+ return jnp.hstack([X2, padding])
126
+
127
+ return lax.cond(kstar > 0, concat_branch, no_concat_branch, operand=None)