spherical-deepkriging 1.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spherical_deepkriging-1.0.4/LICENSE +21 -0
- spherical_deepkriging-1.0.4/MANIFEST.in +5 -0
- spherical_deepkriging-1.0.4/PKG-INFO +115 -0
- spherical_deepkriging-1.0.4/README.md +40 -0
- spherical_deepkriging-1.0.4/pyproject.toml +145 -0
- spherical_deepkriging-1.0.4/setup.cfg +4 -0
- spherical_deepkriging-1.0.4/setup.py +112 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/mrts.py +39 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/utils.py +127 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts/visualization.py +168 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/CMakeLists.txt +74 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/__init__.py +38 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/setup.py +28 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/spherical_basis.cpp +478 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/mrts_sphere/sphere_cpp.py +130 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/utils.py +17 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/visualization.py +95 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/basis_functions/wendland/wenland.py +39 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/configs.py +54 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/logger.py +29 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/models/__init__.py +0 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/models/deep_kriging.py +139 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging/models/universal_kriging.py +362 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/PKG-INFO +115 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/SOURCES.txt +45 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/dependency_links.txt +1 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/not-zip-safe +1 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/requires.txt +64 -0
- spherical_deepkriging-1.0.4/spherical_deepkriging.egg-info/top_level.txt +1 -0
- spherical_deepkriging-1.0.4/tests/test_basis_utils_and_wendland.py +99 -0
- spherical_deepkriging-1.0.4/tests/test_basis_visualization.py +125 -0
- spherical_deepkriging-1.0.4/tests/test_configs.py +30 -0
- spherical_deepkriging-1.0.4/tests/test_deep_kriging_shapes.py +43 -0
- spherical_deepkriging-1.0.4/tests/test_deep_kriging_train.py +111 -0
- spherical_deepkriging-1.0.4/tests/test_logger.py +32 -0
- spherical_deepkriging-1.0.4/tests/test_mrts_modules.py +68 -0
- spherical_deepkriging-1.0.4/tests/test_sphere_cpp_unit.py +94 -0
- spherical_deepkriging-1.0.4/tests/test_sphere_wrapper_and_cpp_init.py +89 -0
- spherical_deepkriging-1.0.4/tests/test_spherical_cpp.py +233 -0
- spherical_deepkriging-1.0.4/tests/test_spherical_cpp_setup.py +52 -0
- spherical_deepkriging-1.0.4/tests/test_universal_kriging_coords.py +14 -0
- spherical_deepkriging-1.0.4/tests/test_universal_kriging_unit.py +236 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 STLABTW
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
prune spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions/build
|
|
2
|
+
global-exclude *.so *.pyd *.dylib *.dll *.a *.lib
|
|
3
|
+
|
|
4
|
+
# Include C++ sources for the spherical basis extension in sdist builds.
|
|
5
|
+
recursive-include spherical_deepkriging/basis_functions/mrts_sphere/cpp_extensions *.cpp *.h *.hpp *.txt *.py CMakeLists.txt
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: spherical-deepkriging
|
|
3
|
+
Version: 1.0.4
|
|
4
|
+
Summary: Deep learning package for FRK
|
|
5
|
+
Author-email: Wen-Ting Wang <egpivo@gmail.com>, Hao-Yun Huang <hhuscout@gms.ndhu.edu.tw>, Ping-Hsun Chiang <andrew501228@gmail.com>, Wu-Wei Ying <wuweiying1011@gms.ndhu.edu.tw>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch<3.0.0,>=2.0.0
|
|
19
|
+
Requires-Dist: pytorch-lightning<3.0.0,>=2.6.0
|
|
20
|
+
Requires-Dist: numpy<2.0.0,>=1.24.0
|
|
21
|
+
Requires-Dist: tensorflow<3.0.0,>=2.19.0
|
|
22
|
+
Requires-Dist: pandas>=2.0.0
|
|
23
|
+
Requires-Dist: tabulate>=0.9.0
|
|
24
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
25
|
+
Requires-Dist: scipy>=1.10.0
|
|
26
|
+
Requires-Dist: pybind11<3.0.0,>=2.11.0
|
|
27
|
+
Requires-Dist: tqdm>=4.65.0
|
|
28
|
+
Requires-Dist: tensorboard<2.20.0,>=2.13.0
|
|
29
|
+
Requires-Dist: protobuf<5.0.0,>=4.25.8
|
|
30
|
+
Requires-Dist: pykrige>=1.7.0
|
|
31
|
+
Requires-Dist: seaborn>=0.13.0
|
|
32
|
+
Requires-Dist: GPy>=1.10.0
|
|
33
|
+
Requires-Dist: optuna>=3.0.0
|
|
34
|
+
Requires-Dist: plotly>=5.0.0
|
|
35
|
+
Requires-Dist: ipykernel<7.0.0,>=6.13.0
|
|
36
|
+
Provides-Extra: vision
|
|
37
|
+
Requires-Dist: torchvision>=0.15.0; extra == "vision"
|
|
38
|
+
Provides-Extra: audio
|
|
39
|
+
Requires-Dist: torchaudio>=2.0.0; extra == "audio"
|
|
40
|
+
Provides-Extra: viz
|
|
41
|
+
Requires-Dist: matplotlib>=3.7.0; extra == "viz"
|
|
42
|
+
Requires-Dist: pillow<12.0.0,>=10.0.0; extra == "viz"
|
|
43
|
+
Requires-Dist: plotly>=5.0.0; extra == "viz"
|
|
44
|
+
Provides-Extra: data
|
|
45
|
+
Provides-Extra: mrts
|
|
46
|
+
Requires-Dist: jax<0.5.0,>=0.4.0; extra == "mrts"
|
|
47
|
+
Requires-Dist: jaxlib<0.5.0,>=0.4.0; extra == "mrts"
|
|
48
|
+
Provides-Extra: utils
|
|
49
|
+
Provides-Extra: jupyter
|
|
50
|
+
Requires-Dist: jupyter>=1.0.0; extra == "jupyter"
|
|
51
|
+
Requires-Dist: ipykernel<7.0.0,>=6.13.0; extra == "jupyter"
|
|
52
|
+
Requires-Dist: notebook>=7.0.0; extra == "jupyter"
|
|
53
|
+
Requires-Dist: matplotlib>=3.7.0; extra == "jupyter"
|
|
54
|
+
Provides-Extra: all
|
|
55
|
+
Requires-Dist: torchvision>=0.15.0; extra == "all"
|
|
56
|
+
Requires-Dist: torchaudio>=2.0.0; extra == "all"
|
|
57
|
+
Requires-Dist: matplotlib>=3.7.0; extra == "all"
|
|
58
|
+
Requires-Dist: scikit-learn>=1.3.0; extra == "all"
|
|
59
|
+
Requires-Dist: scipy>=1.10.0; extra == "all"
|
|
60
|
+
Requires-Dist: tqdm>=4.65.0; extra == "all"
|
|
61
|
+
Requires-Dist: tensorboard<2.20.0,>=2.13.0; extra == "all"
|
|
62
|
+
Requires-Dist: protobuf<5.0.0,>=4.25.8; extra == "all"
|
|
63
|
+
Requires-Dist: pillow<12.0.0,>=10.0.0; extra == "all"
|
|
64
|
+
Requires-Dist: pandas>=2.0.0; extra == "all"
|
|
65
|
+
Requires-Dist: plotly>=5.0.0; extra == "all"
|
|
66
|
+
Provides-Extra: dev
|
|
67
|
+
Requires-Dist: twine>=4.0.0; extra == "dev"
|
|
68
|
+
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
|
69
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
70
|
+
Requires-Dist: black>=23.7.0; extra == "dev"
|
|
71
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
72
|
+
Requires-Dist: mypy>=1.5.0; extra == "dev"
|
|
73
|
+
Dynamic: license-file
|
|
74
|
+
Dynamic: requires-python
|
|
75
|
+
|
|
76
|
+
# Spherical DeepKriging
|
|
77
|
+
|
|
78
|
+
[](https://github.com/STLABTW/spherical-deepkriging/actions)
|
|
79
|
+
[](https://codecov.io/github/STLABTW/spherical-deepkriging)
|
|
80
|
+
|
|
81
|
+
Code for **DeepKriging on the Global Data** ([arXiv:2604.01689](https://arxiv.org/abs/2604.01689)): spherical spatial prediction with DeepKriging, MRTS-sphere / Wendland bases, and universal kriging. Implementation lives under `spherical_deepkriging/`.
|
|
82
|
+
|
|
83
|
+
## Setup
|
|
84
|
+
|
|
85
|
+
Needs [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On Windows, use WSL.
|
|
86
|
+
|
|
87
|
+
```bash
|
|
88
|
+
git clone https://github.com/STLABTW/spherical-deepkriging.git
|
|
89
|
+
cd spherical-deepkriging
|
|
90
|
+
make install-dev
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
`make install-dev` creates the conda env and installs dependencies; `make build-spherical-cpp` builds the MRTS-sphere C++ extension.
|
|
94
|
+
|
|
95
|
+
## Examples
|
|
96
|
+
|
|
97
|
+
- Smoke test: `examples/toy/toy_sphere_deepkriging.ipynb`
|
|
98
|
+
- Simulations: `examples/simulation/`
|
|
99
|
+
- Real data: `examples/real_data/`
|
|
100
|
+
|
|
101
|
+
See `examples/README.md` for run notes.
|
|
102
|
+
|
|
103
|
+
## Citation
|
|
104
|
+
|
|
105
|
+
```bibtex
|
|
106
|
+
@misc{huang2026deepkrigingglobaldata,
|
|
107
|
+
title={DeepKriging on the Global Data},
|
|
108
|
+
author={Hao-Yun Huang and Wen-Ting Wang and Ping-Hsun Chiang and Wei-Ying Wu},
|
|
109
|
+
year={2026},
|
|
110
|
+
eprint={2604.01689},
|
|
111
|
+
archivePrefix={arXiv},
|
|
112
|
+
primaryClass={stat.ME},
|
|
113
|
+
url={https://arxiv.org/abs/2604.01689},
|
|
114
|
+
}
|
|
115
|
+
```
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Spherical DeepKriging
|
|
2
|
+
|
|
3
|
+
[](https://github.com/STLABTW/spherical-deepkriging/actions)
|
|
4
|
+
[](https://codecov.io/github/STLABTW/spherical-deepkriging)
|
|
5
|
+
|
|
6
|
+
Code for **DeepKriging on the Global Data** ([arXiv:2604.01689](https://arxiv.org/abs/2604.01689)): spherical spatial prediction with DeepKriging, MRTS-sphere / Wendland bases, and universal kriging. Implementation lives under `spherical_deepkriging/`.
|
|
7
|
+
|
|
8
|
+
## Setup
|
|
9
|
+
|
|
10
|
+
Needs [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On Windows, use WSL.
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
git clone https://github.com/STLABTW/spherical-deepkriging.git
|
|
14
|
+
cd spherical-deepkriging
|
|
15
|
+
make install-dev
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
`make install-dev` creates the conda env and installs dependencies; `make build-spherical-cpp` builds the MRTS-sphere C++ extension.
|
|
19
|
+
|
|
20
|
+
## Examples
|
|
21
|
+
|
|
22
|
+
- Smoke test: `examples/toy/toy_sphere_deepkriging.ipynb`
|
|
23
|
+
- Simulations: `examples/simulation/`
|
|
24
|
+
- Real data: `examples/real_data/`
|
|
25
|
+
|
|
26
|
+
See `examples/README.md` for run notes.
|
|
27
|
+
|
|
28
|
+
## Citation
|
|
29
|
+
|
|
30
|
+
```bibtex
|
|
31
|
+
@misc{huang2026deepkrigingglobaldata,
|
|
32
|
+
title={DeepKriging on the Global Data},
|
|
33
|
+
author={Hao-Yun Huang and Wen-Ting Wang and Ping-Hsun Chiang and Wei-Ying Wu},
|
|
34
|
+
year={2026},
|
|
35
|
+
eprint={2604.01689},
|
|
36
|
+
archivePrefix={arXiv},
|
|
37
|
+
primaryClass={stat.ME},
|
|
38
|
+
url={https://arxiv.org/abs/2604.01689},
|
|
39
|
+
}
|
|
40
|
+
```
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"setuptools>=77.0.0",
|
|
4
|
+
"wheel",
|
|
5
|
+
"cmake>=3.10",
|
|
6
|
+
"pybind11>=2.11.0,<3.0.0",
|
|
7
|
+
"numpy>=1.24.0,<2.0.0",
|
|
8
|
+
"tomli>=2.0.0",
|
|
9
|
+
]
|
|
10
|
+
build-backend = "setuptools.build_meta"
|
|
11
|
+
|
|
12
|
+
[project]
|
|
13
|
+
name = "spherical-deepkriging"
|
|
14
|
+
version = "1.0.4"
|
|
15
|
+
description = "Deep learning package for FRK"
|
|
16
|
+
readme = "README.md"
|
|
17
|
+
requires-python = ">=3.10"
|
|
18
|
+
license = "MIT"
|
|
19
|
+
license-files = ["LICENSE"]
|
|
20
|
+
authors = [
|
|
21
|
+
{name = "Wen-Ting Wang", email = "egpivo@gmail.com"},
|
|
22
|
+
{name = "Hao-Yun Huang", email = "hhuscout@gms.ndhu.edu.tw"},
|
|
23
|
+
{name = "Ping-Hsun Chiang", email = "andrew501228@gmail.com"},
|
|
24
|
+
{name = "Wu-Wei Ying", email = "wuweiying1011@gms.ndhu.edu.tw"},
|
|
25
|
+
]
|
|
26
|
+
classifiers = [
|
|
27
|
+
"Development Status :: 3 - Alpha",
|
|
28
|
+
"Intended Audience :: Developers",
|
|
29
|
+
"Intended Audience :: Science/Research",
|
|
30
|
+
"Programming Language :: Python :: 3",
|
|
31
|
+
"Programming Language :: Python :: 3.10",
|
|
32
|
+
"Programming Language :: Python :: 3.11",
|
|
33
|
+
"Programming Language :: Python :: 3.12",
|
|
34
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
dependencies = [
|
|
38
|
+
"torch>=2.0.0,<3.0.0", # Avoid conflicts with neural-spatial (requires <2.3.0), but allow newer versions for DeepFRK
|
|
39
|
+
"pytorch-lightning>=2.6.0,<3.0.0",
|
|
40
|
+
"numpy>=1.24.0,<2.0.0", # Avoid conflicts with packages requiring numpy<2.0.0 (hangman, neural-spatial, tensorflow)
|
|
41
|
+
"tensorflow>=2.19.0,<3.0.0",
|
|
42
|
+
"pandas>=2.0.0",
|
|
43
|
+
"tabulate>=0.9.0",
|
|
44
|
+
"scikit-learn>=1.3.0",
|
|
45
|
+
"scipy>=1.10.0",
|
|
46
|
+
"pybind11>=2.11.0,<3.0.0", # Avoid conflicts with neural-spatial (requires <3.0.0)
|
|
47
|
+
"tqdm>=4.65.0",
|
|
48
|
+
"tensorboard>=2.13.0,<2.20.0", # Use <2.20.0 to avoid conflicts, but keep >=2.13.0 for compatibility
|
|
49
|
+
"protobuf>=4.25.8,<5.0.0",
|
|
50
|
+
"pykrige>=1.7.0", # Universal Kriging for spatial interpolation
|
|
51
|
+
"seaborn>=0.13.0", # Statistical data visualization
|
|
52
|
+
"GPy>=1.10.0", # Gaussian Processes for data generation (matching Lin paper)
|
|
53
|
+
"optuna>=3.0.0", # Bayesian hyperparameter optimization
|
|
54
|
+
"plotly>=5.0.0", # Required for optuna.visualization
|
|
55
|
+
"ipykernel>=6.13.0,<7.0.0", # Jupyter kernel support
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[project.optional-dependencies]
|
|
59
|
+
vision = [
|
|
60
|
+
"torchvision>=0.15.0",
|
|
61
|
+
]
|
|
62
|
+
audio = [
|
|
63
|
+
"torchaudio>=2.0.0",
|
|
64
|
+
]
|
|
65
|
+
viz = [
|
|
66
|
+
"matplotlib>=3.7.0",
|
|
67
|
+
"pillow>=10.0.0,<12.0.0", # Avoid conflicts with streamlit (requires <12)
|
|
68
|
+
"plotly>=5.0.0", # Interactive visualization
|
|
69
|
+
]
|
|
70
|
+
data = [
|
|
71
|
+
# pandas and scikit-learn are now in main dependencies
|
|
72
|
+
]
|
|
73
|
+
mrts = [
|
|
74
|
+
"jax>=0.4.0,<0.5.0", # Avoid conflicts with neural-spatial (requires <0.5.0)
|
|
75
|
+
"jaxlib>=0.4.0,<0.5.0", # Avoid conflicts with neural-spatial (requires <0.5.0)
|
|
76
|
+
]
|
|
77
|
+
utils = [
|
|
78
|
+
# tqdm and tensorboard moved to main dependencies
|
|
79
|
+
]
|
|
80
|
+
jupyter = [
|
|
81
|
+
"jupyter>=1.0.0",
|
|
82
|
+
"ipykernel>=6.13.0,<7.0.0", # Jupyter kernel support
|
|
83
|
+
"notebook>=7.0.0",
|
|
84
|
+
"matplotlib>=3.7.0",
|
|
85
|
+
]
|
|
86
|
+
all = [
|
|
87
|
+
"torchvision>=0.15.0",
|
|
88
|
+
"torchaudio>=2.0.0",
|
|
89
|
+
"matplotlib>=3.7.0",
|
|
90
|
+
"scikit-learn>=1.3.0",
|
|
91
|
+
"scipy>=1.10.0",
|
|
92
|
+
"tqdm>=4.65.0",
|
|
93
|
+
"tensorboard>=2.13.0,<2.20.0",
|
|
94
|
+
"protobuf>=4.25.8,<5.0.0",
|
|
95
|
+
"pillow>=10.0.0,<12.0.0",
|
|
96
|
+
"pandas>=2.0.0",
|
|
97
|
+
"plotly>=5.0.0",
|
|
98
|
+
]
|
|
99
|
+
dev = [
|
|
100
|
+
"twine>=4.0.0",
|
|
101
|
+
"pytest>=7.4.0",
|
|
102
|
+
"pytest-cov>=4.1.0",
|
|
103
|
+
"black>=23.7.0",
|
|
104
|
+
"ruff>=0.1.0",
|
|
105
|
+
"mypy>=1.5.0",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
[tool.setuptools]
|
|
109
|
+
|
|
110
|
+
[tool.setuptools.packages.find]
|
|
111
|
+
where = ["."]
|
|
112
|
+
include = ["spherical_deepkriging*"]
|
|
113
|
+
|
|
114
|
+
[tool.black]
|
|
115
|
+
line-length = 88
|
|
116
|
+
target-version = ['py310']
|
|
117
|
+
|
|
118
|
+
[tool.ruff]
|
|
119
|
+
line-length = 88
|
|
120
|
+
target-version = "py310"
|
|
121
|
+
|
|
122
|
+
[tool.pytest.ini_options]
|
|
123
|
+
testpaths = ["tests"]
|
|
124
|
+
python_files = "test_*.py"
|
|
125
|
+
python_classes = "Test*"
|
|
126
|
+
python_functions = "test_*"
|
|
127
|
+
|
|
128
|
+
[tool.coverage.run]
|
|
129
|
+
source = ["spherical_deepkriging"]
|
|
130
|
+
omit = [
|
|
131
|
+
"*/tests/*",
|
|
132
|
+
"*/__pycache__/*",
|
|
133
|
+
"*/examples/*",
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
[tool.coverage.report]
|
|
137
|
+
exclude_lines = [
|
|
138
|
+
"pragma: no cover",
|
|
139
|
+
"def __repr__",
|
|
140
|
+
"raise AssertionError",
|
|
141
|
+
"raise NotImplementedError",
|
|
142
|
+
"if __name__ == .__main__.:",
|
|
143
|
+
"if TYPE_CHECKING:",
|
|
144
|
+
"^\\.\\.\\.$", # Exclude ellipsis in overload signatures
|
|
145
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from setuptools import Extension, find_packages, setup
|
|
7
|
+
from setuptools.command.build_ext import build_ext
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import tomllib # Python 3.11+
|
|
11
|
+
except ModuleNotFoundError: # pragma: no cover
|
|
12
|
+
import tomli as tomllib # type: ignore[no-redef]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
ROOT = Path(__file__).resolve().parent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CMakeExtension(Extension):
|
|
19
|
+
def __init__(self, name: str, sourcedir: str):
|
|
20
|
+
super().__init__(name=name, sources=[])
|
|
21
|
+
self.sourcedir = str(Path(sourcedir).resolve())
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CMakeBuild(build_ext):
|
|
25
|
+
def build_extension(self, ext: Extension) -> None:
|
|
26
|
+
if not isinstance(ext, CMakeExtension):
|
|
27
|
+
return super().build_extension(ext)
|
|
28
|
+
|
|
29
|
+
extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()
|
|
30
|
+
build_temp = Path(self.build_temp) / ext.name.split(".")[-1]
|
|
31
|
+
build_temp.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
|
|
33
|
+
cfg = "Release"
|
|
34
|
+
python_exe = sys.executable
|
|
35
|
+
jobs = max(1, (os.cpu_count() or 2) - 1)
|
|
36
|
+
|
|
37
|
+
cmake_args = [
|
|
38
|
+
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
|
|
39
|
+
f"-DCMAKE_BUILD_TYPE={cfg}",
|
|
40
|
+
f"-DPython3_EXECUTABLE={python_exe}",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
# Ensure CMake can locate pybind11Config.cmake in isolated build envs.
|
|
44
|
+
try:
|
|
45
|
+
import pybind11 # type: ignore
|
|
46
|
+
|
|
47
|
+
pybind11_dir = Path(pybind11.get_cmake_dir()).resolve()
|
|
48
|
+
cmake_args.append(f"-Dpybind11_DIR={pybind11_dir}")
|
|
49
|
+
except Exception:
|
|
50
|
+
# Let CMake try its default discovery path if pybind11 is unavailable.
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
subprocess.check_call(
|
|
54
|
+
["cmake", "-S", ext.sourcedir, "-B", str(build_temp), *cmake_args]
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# cmake --build passes `--config` only on multi-config generators (Windows)
|
|
58
|
+
if os.name == "nt":
|
|
59
|
+
subprocess.check_call(
|
|
60
|
+
["cmake", "--build", str(build_temp), "--config", cfg]
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
subprocess.check_call(
|
|
64
|
+
[
|
|
65
|
+
"cmake",
|
|
66
|
+
"--build",
|
|
67
|
+
str(build_temp),
|
|
68
|
+
"--config",
|
|
69
|
+
cfg,
|
|
70
|
+
"--",
|
|
71
|
+
f"-j{jobs}",
|
|
72
|
+
]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _load_project_metadata() -> dict:
|
|
77
|
+
with open(ROOT / "pyproject.toml", "rb") as f:
|
|
78
|
+
data = tomllib.load(f)
|
|
79
|
+
return data.get("project", {}) # type: ignore[no-any-return]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
project = _load_project_metadata()
|
|
83
|
+
|
|
84
|
+
name = project.get("name", "spherical-deepkriging")
|
|
85
|
+
version = project.get("version", "0.0.0")
|
|
86
|
+
dependencies = project.get("dependencies", [])
|
|
87
|
+
python_requires = project.get("requires-python", ">=3.10")
|
|
88
|
+
|
|
89
|
+
cpp_ext_spherical_dir = (
|
|
90
|
+
ROOT
|
|
91
|
+
/ "spherical_deepkriging"
|
|
92
|
+
/ "basis_functions"
|
|
93
|
+
/ "mrts_sphere"
|
|
94
|
+
/ "cpp_extensions"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
setup(
|
|
98
|
+
name=name,
|
|
99
|
+
version=version,
|
|
100
|
+
python_requires=python_requires,
|
|
101
|
+
install_requires=dependencies,
|
|
102
|
+
packages=find_packages(include=["spherical_deepkriging*"]),
|
|
103
|
+
include_package_data=True,
|
|
104
|
+
ext_modules=[
|
|
105
|
+
CMakeExtension(
|
|
106
|
+
"spherical_deepkriging.basis_functions.mrts_sphere.cpp_extensions.spherical_basis",
|
|
107
|
+
sourcedir=str(cpp_ext_spherical_dir),
|
|
108
|
+
)
|
|
109
|
+
],
|
|
110
|
+
cmdclass={"build_ext": CMakeBuild},
|
|
111
|
+
zip_safe=False,
|
|
112
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Dict, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jax.scipy.linalg import solve
|
|
5
|
+
|
|
6
|
+
from spherical_deepkriging.basis_functions.mrts.utils import (
|
|
7
|
+
build_extended_matrix,
|
|
8
|
+
compute_h,
|
|
9
|
+
dist,
|
|
10
|
+
predict_rabf,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def mrts0(
|
|
15
|
+
knot: jnp.ndarray, k: int, x: Optional[jnp.ndarray] = None
|
|
16
|
+
) -> Union[Tuple[jnp.ndarray, Dict[str, jnp.ndarray], int], jnp.ndarray]:
|
|
17
|
+
"""Main function for MRTS with static pre-processing."""
|
|
18
|
+
Xu = jnp.asarray(knot)
|
|
19
|
+
ndims = Xu.shape[1]
|
|
20
|
+
|
|
21
|
+
if k < (ndims + 1):
|
|
22
|
+
raise ValueError(f"Invalid k: {k}. It must be >= {ndims + 1}")
|
|
23
|
+
|
|
24
|
+
slice_size = k - ndims - 1
|
|
25
|
+
AHA, UZ = build_extended_matrix(Xu, k=k, ndims=ndims, slice_size=slice_size)
|
|
26
|
+
|
|
27
|
+
BBBH = solve(
|
|
28
|
+
jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]).T
|
|
29
|
+
@ jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]),
|
|
30
|
+
jnp.hstack([jnp.ones((Xu.shape[0], 1)), Xu]).T,
|
|
31
|
+
) @ compute_h(dist(Xu, Xu), ndims)
|
|
32
|
+
obj_attrs = {
|
|
33
|
+
"S": AHA,
|
|
34
|
+
"UZ": UZ,
|
|
35
|
+
"Xu": Xu,
|
|
36
|
+
"BBBH": BBBH,
|
|
37
|
+
"ndims": ndims,
|
|
38
|
+
}
|
|
39
|
+
return predict_rabf(obj_attrs, x, k) if x is not None else (Xu, obj_attrs, k)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Dict, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax import jit, lax
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@jit
|
|
9
|
+
def dist(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
|
|
10
|
+
return jnp.sqrt(
|
|
11
|
+
jnp.sum((A[:, jnp.newaxis, :] - B[jnp.newaxis, :, :]) ** 2, axis=-1)
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_h(d: jnp.ndarray, ndims: int) -> jnp.ndarray:
|
|
16
|
+
"""Computes the H matrix based on the dimension."""
|
|
17
|
+
|
|
18
|
+
def case_ndims_1(_: None) -> jnp.ndarray:
|
|
19
|
+
return (1 / 12) * d**3
|
|
20
|
+
|
|
21
|
+
def case_ndims_2(_: None) -> jnp.ndarray:
|
|
22
|
+
return (1 / 8) * d**2 * jnp.log(d + 1e-8)
|
|
23
|
+
|
|
24
|
+
def case_ndims_3(_: None) -> jnp.ndarray:
|
|
25
|
+
return (-1 / 8) * d
|
|
26
|
+
|
|
27
|
+
cases = [case_ndims_1, case_ndims_2, case_ndims_3]
|
|
28
|
+
return lax.switch(ndims - 1, cases, None)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@partial(jit, static_argnames=["ndims", "slice_size"])
|
|
32
|
+
def build_extended_matrix(
|
|
33
|
+
Xu: jnp.ndarray, k: int, ndims: int, slice_size: int
|
|
34
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
35
|
+
"""Builds the S and UZ matrices."""
|
|
36
|
+
n = Xu.shape[0]
|
|
37
|
+
B = jnp.hstack([jnp.ones((n, 1)), Xu])
|
|
38
|
+
BBB = jnp.linalg.solve(B.T @ B, B.T)
|
|
39
|
+
A = jnp.eye(n) - B @ BBB
|
|
40
|
+
|
|
41
|
+
d = jnp.sqrt(jnp.sum((Xu[:, None, :] - Xu[None, :, :]) ** 2, axis=-1))
|
|
42
|
+
H = compute_h(d, ndims)
|
|
43
|
+
AH = H - (H @ B) @ BBB
|
|
44
|
+
AHA = AH - BBB.T @ (B.T @ AH)
|
|
45
|
+
|
|
46
|
+
gamma0, _, _ = jnp.linalg.svd(AHA, full_matrices=False)
|
|
47
|
+
|
|
48
|
+
gamma0 = lax.dynamic_slice(
|
|
49
|
+
gamma0, start_indices=(0, 0), slice_sizes=(gamma0.shape[0], slice_size)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
trueBS = AHA @ gamma0
|
|
53
|
+
rho = jnp.sqrt(jnp.sum(trueBS**2, axis=0))
|
|
54
|
+
gammas = A @ gamma0 / rho * jnp.sqrt(n)
|
|
55
|
+
|
|
56
|
+
extension_dim = ndims + 1
|
|
57
|
+
|
|
58
|
+
UZ = jnp.vstack(
|
|
59
|
+
[
|
|
60
|
+
jnp.hstack([gammas, jnp.zeros((gammas.shape[0], extension_dim))]),
|
|
61
|
+
jnp.zeros((extension_dim, gammas.shape[1] + extension_dim)),
|
|
62
|
+
]
|
|
63
|
+
)
|
|
64
|
+
pad_size = ndims + 1
|
|
65
|
+
|
|
66
|
+
def compute_valid_indices() -> jnp.ndarray:
|
|
67
|
+
max_valid_index = min(pad_size, Xu.shape[1])
|
|
68
|
+
return jnp.arange(max_valid_index)
|
|
69
|
+
|
|
70
|
+
valid_indices = lax.stop_gradient(compute_valid_indices())
|
|
71
|
+
|
|
72
|
+
updates = 1 / jnp.std(Xu[:, valid_indices], axis=0) / jnp.sqrt((n - 1) / n)
|
|
73
|
+
UZ = UZ.at[n + valid_indices, k - pad_size - 1 + valid_indices].set(updates)
|
|
74
|
+
|
|
75
|
+
return AHA, UZ
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@partial(jit, static_argnames=["k"])
|
|
79
|
+
def predict_rabf(
|
|
80
|
+
obj: Dict[str, jnp.ndarray],
|
|
81
|
+
newx: Optional[jnp.ndarray] = None,
|
|
82
|
+
k: Optional[int] = None,
|
|
83
|
+
) -> jnp.ndarray:
|
|
84
|
+
"""Predicts new values based on the provided model object."""
|
|
85
|
+
if newx is None:
|
|
86
|
+
return obj
|
|
87
|
+
|
|
88
|
+
x0 = newx
|
|
89
|
+
d = dist(x0, obj["Xu"])
|
|
90
|
+
ndims = x0.shape[1] if x0.ndim > 1 else 1
|
|
91
|
+
H = compute_h(d, ndims)
|
|
92
|
+
|
|
93
|
+
kstar = k - ndims - 1
|
|
94
|
+
|
|
95
|
+
def true_branch(_: None) -> jnp.ndarray:
|
|
96
|
+
slice_UZ = lax.dynamic_slice(
|
|
97
|
+
obj["UZ"], start_indices=(0, 0), slice_sizes=(obj["Xu"].shape[0], kstar)
|
|
98
|
+
)
|
|
99
|
+
X1 = H @ slice_UZ
|
|
100
|
+
|
|
101
|
+
B = jnp.hstack(
|
|
102
|
+
[jnp.ones((x0.shape[0], 1)), x0 if x0.ndim > 1 else x0[:, jnp.newaxis]]
|
|
103
|
+
)
|
|
104
|
+
BBBH_UZ = lax.dynamic_slice(
|
|
105
|
+
obj["UZ"], start_indices=(0, 0), slice_sizes=(obj["Xu"].shape[0], kstar)
|
|
106
|
+
)
|
|
107
|
+
X1 -= B @ obj["BBBH"] @ BBBH_UZ
|
|
108
|
+
X1 /= jnp.sqrt(obj["Xu"].shape[0])
|
|
109
|
+
return X1
|
|
110
|
+
|
|
111
|
+
def false_branch(_: None) -> jnp.ndarray:
|
|
112
|
+
return jnp.zeros((x0.shape[0], kstar))
|
|
113
|
+
|
|
114
|
+
X1 = lax.cond(kstar > 0, true_branch, false_branch, operand=None)
|
|
115
|
+
|
|
116
|
+
X2 = jnp.hstack(
|
|
117
|
+
[jnp.ones((x0.shape[0], 1)), x0 if x0.ndim > 1 else x0[:, jnp.newaxis]]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def concat_branch(_: None) -> jnp.ndarray:
|
|
121
|
+
return jnp.hstack([X2, X1])
|
|
122
|
+
|
|
123
|
+
def no_concat_branch(_: None) -> jnp.ndarray:
|
|
124
|
+
padding = jnp.zeros((x0.shape[0], X1.shape[1]))
|
|
125
|
+
return jnp.hstack([X2, padding])
|
|
126
|
+
|
|
127
|
+
return lax.cond(kstar > 0, concat_branch, no_concat_branch, operand=None)
|