eigh 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eigh-0.1.0/.github/workflows/publish.yml +70 -0
- eigh-0.1.0/.gitignore +134 -0
- eigh-0.1.0/CHANGES.md +33 -0
- eigh-0.1.0/CMakeLists.txt +150 -0
- eigh-0.1.0/MANIFEST.in +13 -0
- eigh-0.1.0/PKG-INFO +104 -0
- eigh-0.1.0/README.md +68 -0
- eigh-0.1.0/build.sh +80 -0
- eigh-0.1.0/debug/diag_jax.py +29 -0
- eigh-0.1.0/debug/test_init.py +40 -0
- eigh-0.1.0/docs/BUILD_FIXES.md +125 -0
- eigh-0.1.0/docs/DEGENERACY.md +179 -0
- eigh-0.1.0/docs/EXTRACTION_NOTES.md +178 -0
- eigh-0.1.0/docs/QUICKSTART.md +124 -0
- eigh-0.1.0/docs/VERIFICATION.md +196 -0
- eigh-0.1.0/example_complex.py +155 -0
- eigh-0.1.0/example_simple.py +16 -0
- eigh-0.1.0/img/eig.png +0 -0
- eigh-0.1.0/include/ffi_helpers.h +67 -0
- eigh-0.1.0/include/kernel_nanobind_helpers.h +25 -0
- eigh-0.1.0/pyproject.toml +74 -0
- eigh-0.1.0/requirements.txt +11 -0
- eigh-0.1.0/run_gpu.sh +13 -0
- eigh-0.1.0/scripts/install_build_deps.sh +41 -0
- eigh-0.1.0/setup_gpu_env.sh +65 -0
- eigh-0.1.0/setup_gpu_env_clean.sh +24 -0
- eigh-0.1.0/src/cpu/lapack.cc +58 -0
- eigh-0.1.0/src/cpu/lapack_kernels.cc +290 -0
- eigh-0.1.0/src/cpu/lapack_kernels.h +156 -0
- eigh-0.1.0/src/cuda/solver.cc +21 -0
- eigh-0.1.0/src/cuda/solver_kernels.cc +161 -0
- eigh-0.1.0/src/cuda/solver_kernels.h +79 -0
- eigh-0.1.0/src/python/eigh/__init__.py +32 -0
- eigh-0.1.0/src/python/eigh/_core.py +406 -0
- eigh-0.1.0/src/python/eigh/py.typed +2 -0
- eigh-0.1.0/tests/test_eigh.py +343 -0
- eigh-0.1.0/tests/test_eigh_compare.py +167 -0
- eigh-0.1.0/tests/test_eigh_gen.py +208 -0
- eigh-0.1.0/tests/test_eigh_hard.py +277 -0
- eigh-0.1.0/tests/test_eigh_jit.py +463 -0
- eigh-0.1.0/tests/test_eigh_performance.py +92 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
name: Build and publish wheels
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
push:
|
|
6
|
+
tags:
|
|
7
|
+
- 'v*'
|
|
8
|
+
|
|
9
|
+
jobs:
|
|
10
|
+
build_wheels:
|
|
11
|
+
name: Build wheels on ${{ matrix.os }}
|
|
12
|
+
runs-on: ${{ matrix.os }}
|
|
13
|
+
strategy:
|
|
14
|
+
matrix:
|
|
15
|
+
os: [ubuntu-latest, macos-latest]
|
|
16
|
+
# os: [ubuntu-latest, macos-latest, windows-latest]
|
|
17
|
+
|
|
18
|
+
steps:
|
|
19
|
+
- uses: actions/checkout@v4
|
|
20
|
+
|
|
21
|
+
- name: Build wheels
|
|
22
|
+
uses: pypa/cibuildwheel@v2.22.0
|
|
23
|
+
env:
|
|
24
|
+
CIBW_ARCHS_MACOS: x86_64 arm64
|
|
25
|
+
# Linux: only x86_64 for now (aarch64 requires QEMU and is slow)
|
|
26
|
+
CIBW_ARCHS_LINUX: x86_64
|
|
27
|
+
# Use manylinux_2_28 (AlmaLinux 8) - CentOS 7 repos are dead
|
|
28
|
+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
|
|
29
|
+
# Skip: PyPy (pp*), Python 3.8/3.9 (jaxlib compat), musllinux (Alpine)
|
|
30
|
+
CIBW_SKIP: "pp* cp38-* cp39-* *-musllinux_*"
|
|
31
|
+
|
|
32
|
+
- uses: actions/upload-artifact@v4
|
|
33
|
+
with:
|
|
34
|
+
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
|
|
35
|
+
path: ./wheelhouse/*.whl
|
|
36
|
+
|
|
37
|
+
build_sdist:
|
|
38
|
+
name: Build source distribution
|
|
39
|
+
runs-on: ubuntu-latest
|
|
40
|
+
steps:
|
|
41
|
+
- uses: actions/checkout@v4
|
|
42
|
+
|
|
43
|
+
- name: Build sdist
|
|
44
|
+
run: pipx run build --sdist
|
|
45
|
+
|
|
46
|
+
- uses: actions/upload-artifact@v4
|
|
47
|
+
with:
|
|
48
|
+
name: cibw-sdist
|
|
49
|
+
path: dist/*.tar.gz
|
|
50
|
+
|
|
51
|
+
upload_pypi:
|
|
52
|
+
needs: [build_wheels, build_sdist]
|
|
53
|
+
runs-on: ubuntu-latest
|
|
54
|
+
environment: pypi
|
|
55
|
+
permissions:
|
|
56
|
+
id-token: write
|
|
57
|
+
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
|
|
58
|
+
steps:
|
|
59
|
+
- uses: actions/download-artifact@v4
|
|
60
|
+
with:
|
|
61
|
+
pattern: cibw-*
|
|
62
|
+
path: dist
|
|
63
|
+
merge-multiple: true
|
|
64
|
+
|
|
65
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
66
|
+
with:
|
|
67
|
+
# registry-url: https://test.pypi.org/legacy/
|
|
68
|
+
# To use trusted publishing, remove 'password' and setup trust on PyPI
|
|
69
|
+
# Or use secrets.PYPI_API_TOKEN
|
|
70
|
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
eigh-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
.DS_store
|
|
7
|
+
|
|
8
|
+
# C extensions
|
|
9
|
+
*.so
|
|
10
|
+
|
|
11
|
+
# Distribution / packaging
|
|
12
|
+
.Python
|
|
13
|
+
build/
|
|
14
|
+
develop-eggs/
|
|
15
|
+
dist/
|
|
16
|
+
downloads/
|
|
17
|
+
eggs/
|
|
18
|
+
.eggs/
|
|
19
|
+
#lib/
|
|
20
|
+
lib64/
|
|
21
|
+
parts/
|
|
22
|
+
sdist/
|
|
23
|
+
var/
|
|
24
|
+
wheels/
|
|
25
|
+
pip-wheel-metadata/
|
|
26
|
+
share/python-wheels/
|
|
27
|
+
*.egg-info/
|
|
28
|
+
.installed.cfg
|
|
29
|
+
*.egg
|
|
30
|
+
MANIFEST
|
|
31
|
+
|
|
32
|
+
# PyInstaller
|
|
33
|
+
# Usually these files are written by a python script from a template
|
|
34
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
35
|
+
*.manifest
|
|
36
|
+
*.spec
|
|
37
|
+
|
|
38
|
+
# Installer logs
|
|
39
|
+
pip-log.txt
|
|
40
|
+
pip-delete-this-directory.txt
|
|
41
|
+
|
|
42
|
+
# Unit test / coverage reports
|
|
43
|
+
htmlcov/
|
|
44
|
+
.tox/
|
|
45
|
+
.nox/
|
|
46
|
+
.coverage
|
|
47
|
+
.coverage.*
|
|
48
|
+
.cache
|
|
49
|
+
nosetests.xml
|
|
50
|
+
coverage.xml
|
|
51
|
+
*.cover
|
|
52
|
+
*.py,cover
|
|
53
|
+
.hypothesis/
|
|
54
|
+
.pytest_cache/
|
|
55
|
+
|
|
56
|
+
# Translations
|
|
57
|
+
*.mo
|
|
58
|
+
*.pot
|
|
59
|
+
|
|
60
|
+
# Django stuff:
|
|
61
|
+
*.log
|
|
62
|
+
local_settings.py
|
|
63
|
+
db.sqlite3
|
|
64
|
+
db.sqlite3-journal
|
|
65
|
+
|
|
66
|
+
# Flask stuff:
|
|
67
|
+
instance/
|
|
68
|
+
.webassets-cache
|
|
69
|
+
|
|
70
|
+
# Scrapy stuff:
|
|
71
|
+
.scrapy
|
|
72
|
+
|
|
73
|
+
# Sphinx documentation
|
|
74
|
+
docs/_build/
|
|
75
|
+
|
|
76
|
+
# PyBuilder
|
|
77
|
+
target/
|
|
78
|
+
|
|
79
|
+
# Jupyter Notebook
|
|
80
|
+
.ipynb_checkpoints
|
|
81
|
+
|
|
82
|
+
# IPython
|
|
83
|
+
profile_default/
|
|
84
|
+
ipython_config.py
|
|
85
|
+
|
|
86
|
+
# pyenv
|
|
87
|
+
.python-version
|
|
88
|
+
|
|
89
|
+
# pipenv
|
|
90
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
91
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
92
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
93
|
+
# install all needed dependencies.
|
|
94
|
+
#Pipfile.lock
|
|
95
|
+
|
|
96
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
97
|
+
__pypackages__/
|
|
98
|
+
|
|
99
|
+
# Celery stuff
|
|
100
|
+
celerybeat-schedule
|
|
101
|
+
celerybeat.pid
|
|
102
|
+
|
|
103
|
+
# SageMath parsed files
|
|
104
|
+
*.sage.py
|
|
105
|
+
|
|
106
|
+
# Environments
|
|
107
|
+
.env
|
|
108
|
+
.venv
|
|
109
|
+
env/
|
|
110
|
+
venv/
|
|
111
|
+
ENV/
|
|
112
|
+
env.bak/
|
|
113
|
+
venv.bak/
|
|
114
|
+
|
|
115
|
+
# Spyder project settings
|
|
116
|
+
.spyderproject
|
|
117
|
+
.spyproject
|
|
118
|
+
|
|
119
|
+
# Rope project settings
|
|
120
|
+
.ropeproject
|
|
121
|
+
|
|
122
|
+
# mkdocs documentation
|
|
123
|
+
/site
|
|
124
|
+
|
|
125
|
+
# mypy
|
|
126
|
+
.mypy_cache/
|
|
127
|
+
.dmypy.json
|
|
128
|
+
dmypy.json
|
|
129
|
+
|
|
130
|
+
# Pyre type checker
|
|
131
|
+
.pyre/
|
|
132
|
+
|
|
133
|
+
# VSCode
|
|
134
|
+
.vscode/*
|
eigh-0.1.0/CHANGES.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Changes Made
|
|
2
|
+
|
|
3
|
+
Renaming to prevent conflicts with JAX's standard `eigh`.
|
|
4
|
+
|
|
5
|
+
| File | Old Name | New Name |
|
|
6
|
+
| :--- | :--- | :--- |
|
|
7
|
+
| `lapack.cc` | `lapack_ssygvd` | `eigh_lapack_ssygvd` |
|
|
8
|
+
| | `lapack_dsygvd` | `eigh_lapack_dsygvd` |
|
|
9
|
+
| | `lapack_chegvd` | `eigh_lapack_chegvd` |
|
|
10
|
+
| | `lapack_zhegvd` | `eigh_lapack_zhegvd` |
|
|
11
|
+
| | `lapack_*_ffi` variants | `eigh_lapack_*_ffi` |
|
|
12
|
+
| `solver.cc` | `cusolver_sygvd_ffi` | `eigh_cusolver_sygvd_ffi` |
|
|
13
|
+
| `_core.py` | Updated `prepare_lapack_call` and CUDA target name | - |
|
|
14
|
+
|
|
15
|
+
The FFI target names are just string identifiers used to look up the registered function pointers.
|
|
16
|
+
The flow is:
|
|
17
|
+
|
|
18
|
+
1. C++ registers function pointer with name `"eigh_lapack_ssygvd_ffi"`
|
|
19
|
+
2. Python calls `ffi.register_ffi_target("eigh_lapack_ssygvd_ffi", ...)`
|
|
20
|
+
3. During JIT lowering, Python requests target `"eigh_lapack_ssygvd_ffi"`
|
|
21
|
+
4. XLA looks up the function pointer by that name
|
|
22
|
+
|
|
23
|
+
As long as both sides agree on the name (which they now do), everything works. The actual kernel code is unchanged.
|
|
24
|
+
|
|
25
|
+
# To Apply
|
|
26
|
+
|
|
27
|
+
You need to rebuild the package since the C++ files changed:
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install -e .
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
After rebuilding, both JAX's standard `eigh` and this package's `eigh` can coexist without conflicts.
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.18)
|
|
2
|
+
project(eigh_standalone LANGUAGES CXX)
|
|
3
|
+
|
|
4
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
5
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
6
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
7
|
+
|
|
8
|
+
# Fix for NVHPC compiler compatibility with nanobind
|
|
9
|
+
if(CMAKE_CXX_COMPILER_ID MATCHES "NVHPC|PGI")
|
|
10
|
+
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
|
|
11
|
+
add_compile_definitions(__GXX_ABI_VERSION=1016)
|
|
12
|
+
endif()
|
|
13
|
+
|
|
14
|
+
# Find required packages
|
|
15
|
+
# Use Development.Module instead of Development - manylinux containers
|
|
16
|
+
# don't have full Python development files (libpython), but extension
|
|
17
|
+
# modules don't need them anyway
|
|
18
|
+
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
|
|
19
|
+
|
|
20
|
+
# Find nanobind
|
|
21
|
+
execute_process(
|
|
22
|
+
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
|
23
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR
|
|
24
|
+
)
|
|
25
|
+
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
26
|
+
find_package(nanobind CONFIG REQUIRED)
|
|
27
|
+
|
|
28
|
+
# Find BLAS and LAPACK
|
|
29
|
+
# If NVHPC SDK paths are in environment, set hints for CMake
|
|
30
|
+
set(NVHPC_LIB_PATH "/softs/nvidia/hpc_sdk/Linux_x86_64/22.1/compilers/lib")
|
|
31
|
+
if(EXISTS "${NVHPC_LIB_PATH}/libblas.so")
|
|
32
|
+
set(BLA_VENDOR "NVHPC")
|
|
33
|
+
set(BLAS_LIBRARIES "${NVHPC_LIB_PATH}/libblas.so")
|
|
34
|
+
set(LAPACK_LIBRARIES "${NVHPC_LIB_PATH}/liblapack.so")
|
|
35
|
+
set(BLAS_FOUND TRUE)
|
|
36
|
+
set(LAPACK_FOUND TRUE)
|
|
37
|
+
message(STATUS "Using NVHPC BLAS: ${BLAS_LIBRARIES}")
|
|
38
|
+
message(STATUS "Using NVHPC LAPACK: ${LAPACK_LIBRARIES}")
|
|
39
|
+
else()
|
|
40
|
+
find_package(BLAS REQUIRED)
|
|
41
|
+
find_package(LAPACK REQUIRED)
|
|
42
|
+
endif()
|
|
43
|
+
|
|
44
|
+
# XLA headers (from jaxlib)
|
|
45
|
+
execute_process(
|
|
46
|
+
COMMAND "${Python_EXECUTABLE}" -c "import jaxlib; import os; print(os.path.join(jaxlib.__path__[0], 'include'))"
|
|
47
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_INCLUDE_DIR
|
|
48
|
+
RESULT_VARIABLE XLA_RESULT
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if(NOT XLA_RESULT EQUAL 0 OR NOT EXISTS "${XLA_INCLUDE_DIR}")
|
|
52
|
+
message(FATAL_ERROR "Could not find XLA headers. Make sure jaxlib is installed: pip install jaxlib")
|
|
53
|
+
endif()
|
|
54
|
+
|
|
55
|
+
message(STATUS "XLA headers found at: ${XLA_INCLUDE_DIR}")
|
|
56
|
+
|
|
57
|
+
# Include directories
|
|
58
|
+
include_directories(
|
|
59
|
+
${CMAKE_SOURCE_DIR}/include
|
|
60
|
+
${CMAKE_SOURCE_DIR}/src/cpu
|
|
61
|
+
${CMAKE_SOURCE_DIR}/src/cuda
|
|
62
|
+
${XLA_INCLUDE_DIR}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# CPU LAPACK module
|
|
66
|
+
nanobind_add_module(
|
|
67
|
+
eigh_lapack
|
|
68
|
+
STABLE_ABI
|
|
69
|
+
src/cpu/lapack.cc
|
|
70
|
+
src/cpu/lapack_kernels.cc
|
|
71
|
+
)
|
|
72
|
+
target_link_libraries(eigh_lapack PRIVATE ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES})
|
|
73
|
+
|
|
74
|
+
# Add NVHPC Fortran runtime when linking against NVHPC BLAS/LAPACK
|
|
75
|
+
if(EXISTS "${NVHPC_LIB_PATH}/libnvf.so")
|
|
76
|
+
find_library(NVF_LIB NAMES nvf PATHS "${NVHPC_LIB_PATH}" NO_DEFAULT_PATH)
|
|
77
|
+
if(NVF_LIB)
|
|
78
|
+
target_link_libraries(eigh_lapack PRIVATE ${NVF_LIB} rt)
|
|
79
|
+
message(STATUS "Found NVHPC Fortran runtime: ${NVF_LIB}")
|
|
80
|
+
endif()
|
|
81
|
+
endif()
|
|
82
|
+
|
|
83
|
+
set_target_properties(eigh_lapack PROPERTIES
|
|
84
|
+
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/src/python/eigh
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# CUDA module (optional - only if CUDA is available)
|
|
88
|
+
include(CheckLanguage)
|
|
89
|
+
check_language(CUDA)
|
|
90
|
+
|
|
91
|
+
if(CMAKE_CUDA_COMPILER)
|
|
92
|
+
enable_language(CUDA)
|
|
93
|
+
set(CMAKE_CUDA_STANDARD 17)
|
|
94
|
+
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
|
95
|
+
|
|
96
|
+
find_package(CUDAToolkit REQUIRED)
|
|
97
|
+
|
|
98
|
+
nanobind_add_module(
|
|
99
|
+
eigh_cuda
|
|
100
|
+
STABLE_ABI
|
|
101
|
+
src/cuda/solver.cc
|
|
102
|
+
src/cuda/solver_kernels.cc
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Add CUDA include directories
|
|
106
|
+
target_include_directories(eigh_cuda PRIVATE
|
|
107
|
+
${CUDAToolkit_INCLUDE_DIRS}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
set_source_files_properties(
|
|
111
|
+
src/cuda/solver_kernels.cc
|
|
112
|
+
PROPERTIES LANGUAGE CUDA
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Link CUDA libraries - handle both imported targets and direct paths
|
|
116
|
+
if(TARGET CUDA::cudart AND TARGET CUDA::cusolver AND TARGET CUDA::cublas)
|
|
117
|
+
target_link_libraries(eigh_cuda PRIVATE
|
|
118
|
+
CUDA::cudart
|
|
119
|
+
CUDA::cusolver
|
|
120
|
+
CUDA::cublas
|
|
121
|
+
)
|
|
122
|
+
else()
|
|
123
|
+
# Fallback for NVHPC SDK where imported targets may not be available
|
|
124
|
+
target_link_libraries(eigh_cuda PRIVATE
|
|
125
|
+
${CUDAToolkit_LIBRARY_DIR}/libcudart.so
|
|
126
|
+
${CUDAToolkit_LIBRARY_DIR}/libcusolver.so
|
|
127
|
+
${CUDAToolkit_LIBRARY_DIR}/libcublas.so
|
|
128
|
+
)
|
|
129
|
+
endif()
|
|
130
|
+
|
|
131
|
+
set_target_properties(eigh_cuda PROPERTIES
|
|
132
|
+
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/src/python/eigh
|
|
133
|
+
CUDA_SEPARABLE_COMPILATION ON
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
message(STATUS "CUDA support enabled")
|
|
137
|
+
else()
|
|
138
|
+
message(STATUS "CUDA not found - GPU support will be disabled")
|
|
139
|
+
endif()
|
|
140
|
+
|
|
141
|
+
# Install target - scikit-build-core will handle the final install location
|
|
142
|
+
install(TARGETS eigh_lapack
|
|
143
|
+
LIBRARY DESTINATION eigh
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if(TARGET eigh_cuda)
|
|
147
|
+
install(TARGETS eigh_cuda
|
|
148
|
+
LIBRARY DESTINATION eigh
|
|
149
|
+
)
|
|
150
|
+
endif()
|
eigh-0.1.0/MANIFEST.in
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Include source files for building
|
|
2
|
+
recursive-include include *.h
|
|
3
|
+
recursive-include src/cpu *.cc *.h
|
|
4
|
+
recursive-include src/cuda *.cc *.h
|
|
5
|
+
include CMakeLists.txt
|
|
6
|
+
include README.md
|
|
7
|
+
include requirements.txt
|
|
8
|
+
|
|
9
|
+
# Exclude build artifacts
|
|
10
|
+
global-exclude *.so
|
|
11
|
+
global-exclude *.dylib
|
|
12
|
+
global-exclude __pycache__
|
|
13
|
+
global-exclude *.pyc
|
eigh-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: eigh
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Differentiable eigenvalue decomposition with JAX (CPU/GPU)
|
|
5
|
+
Keywords: jax,eigenvalue,linear-algebra,gpu,cuda,autodiff
|
|
6
|
+
Author-Email: Xing Zhang <fishjojo@gmail.com>
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
19
|
+
Project-URL: Homepage, https://github.com/fishjojo/pyscfad
|
|
20
|
+
Project-URL: Repository, https://github.com/fishjojo/pyscfad
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Requires-Dist: jax>=0.4.0
|
|
23
|
+
Requires-Dist: jaxlib>=0.4.0
|
|
24
|
+
Requires-Dist: numpy>=1.20.0
|
|
25
|
+
Provides-Extra: cuda
|
|
26
|
+
Requires-Dist: jax[cuda12]>=0.4.0; extra == "cuda"
|
|
27
|
+
Provides-Extra: cuda-local
|
|
28
|
+
Requires-Dist: jax[cuda12-local]>=0.4.0; extra == "cuda-local"
|
|
29
|
+
Provides-Extra: test
|
|
30
|
+
Requires-Dist: pytest>=7.0.0; extra == "test"
|
|
31
|
+
Requires-Dist: scipy>=1.7.0; extra == "test"
|
|
32
|
+
Provides-Extra: dev
|
|
33
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
34
|
+
Requires-Dist: scipy>=1.7.0; extra == "dev"
|
|
35
|
+
Description-Content-Type: text/markdown
|
|
36
|
+
|
|
37
|
+
# Differentiable Generalized Eigenvalue Decomposition
|
|
38
|
+
|
|
39
|
+
<img src="img/eig.png" alt="Eigh Logo" width="400">
|
|
40
|
+
|
|
41
|
+
Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).
|
|
42
|
+
|
|
43
|
+
## Features
|
|
44
|
+
- **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
|
|
45
|
+
- **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
|
|
46
|
+
- **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
|
|
47
|
+
- **Precision**: `float32/64` and `complex64/128`.
|
|
48
|
+
- **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.
|
|
49
|
+
|
|
50
|
+
## Installation & Quick Start
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
# Install from source
|
|
54
|
+
pip install .
|
|
55
|
+
|
|
56
|
+
# For GPU support in this environment
|
|
57
|
+
pip install .[cuda-local]
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
### Usage Example
|
|
61
|
+
```python
|
|
62
|
+
import jax
|
|
63
|
+
import jax.numpy as jnp
|
|
64
|
+
from eigh import eigh
|
|
65
|
+
|
|
66
|
+
jax.config.update("jax_enable_x64", True)
|
|
67
|
+
A = jnp.array([[2., 1.], [1., 2.]])
|
|
68
|
+
w, v = eigh(A) # Standard
|
|
69
|
+
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
## API Reference
|
|
73
|
+
- **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
|
|
74
|
+
Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
|
|
75
|
+
- **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
|
|
76
|
+
Lower-level generalized solver.
|
|
77
|
+
|
|
78
|
+
## Degenerate Eigenvalues & Gradients
|
|
79
|
+
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.
|
|
80
|
+
|
|
81
|
+
## Development & Testing
|
|
82
|
+
- **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
|
|
83
|
+
- **Tests**:
|
|
84
|
+
```bash
|
|
85
|
+
pytest tests/test_eigh.py # Core functionality
|
|
86
|
+
pytest tests/test_eigh_gen.py # Generalized itypes
|
|
87
|
+
pytest tests/test_eigh_jit.py # JIT & vmap
|
|
88
|
+
```
|
|
89
|
+
- **GPU Setup**:
|
|
90
|
+
```bash
|
|
91
|
+
source setup_gpu_env_clean.sh
|
|
92
|
+
./run_gpu.sh python example_simple.py
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
## License & Citation
|
|
96
|
+
Apache License 2.0. If used in research, please cite:
|
|
97
|
+
```bibtex
|
|
98
|
+
@software{pyscfad,
|
|
99
|
+
author = {Zhang, Xing},
|
|
100
|
+
title = {PySCFad: Automatic Differentiation for PySCF},
|
|
101
|
+
url = {https://github.com/fishjojo/pyscfad},
|
|
102
|
+
year = {2021-2025}
|
|
103
|
+
}
|
|
104
|
+
```
|
eigh-0.1.0/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Differentiable Generalized Eigenvalue Decomposition
|
|
2
|
+
|
|
3
|
+
<img src="img/eig.png" alt="Eigh Logo" width="400">
|
|
4
|
+
|
|
5
|
+
Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
- **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
|
|
9
|
+
- **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
|
|
10
|
+
- **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
|
|
11
|
+
- **Precision**: `float32/64` and `complex64/128`.
|
|
12
|
+
- **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.
|
|
13
|
+
|
|
14
|
+
## Installation & Quick Start
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
# Install from source
|
|
18
|
+
pip install .
|
|
19
|
+
|
|
20
|
+
# For GPU support in this environment
|
|
21
|
+
pip install .[cuda-local]
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
### Usage Example
|
|
25
|
+
```python
|
|
26
|
+
import jax
|
|
27
|
+
import jax.numpy as jnp
|
|
28
|
+
from eigh import eigh
|
|
29
|
+
|
|
30
|
+
jax.config.update("jax_enable_x64", True)
|
|
31
|
+
A = jnp.array([[2., 1.], [1., 2.]])
|
|
32
|
+
w, v = eigh(A) # Standard
|
|
33
|
+
grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## API Reference
|
|
37
|
+
- **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
|
|
38
|
+
Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
|
|
39
|
+
- **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
|
|
40
|
+
Lower-level generalized solver.
|
|
41
|
+
|
|
42
|
+
## Degenerate Eigenvalues & Gradients
|
|
43
|
+
Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.
|
|
44
|
+
|
|
45
|
+
## Development & Testing
|
|
46
|
+
- **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
|
|
47
|
+
- **Tests**:
|
|
48
|
+
```bash
|
|
49
|
+
pytest tests/test_eigh.py # Core functionality
|
|
50
|
+
pytest tests/test_eigh_gen.py # Generalized itypes
|
|
51
|
+
pytest tests/test_eigh_jit.py # JIT & vmap
|
|
52
|
+
```
|
|
53
|
+
- **GPU Setup**:
|
|
54
|
+
```bash
|
|
55
|
+
source setup_gpu_env_clean.sh
|
|
56
|
+
./run_gpu.sh python example_simple.py
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## License & Citation
|
|
60
|
+
Apache License 2.0. If used in research, please cite:
|
|
61
|
+
```bibtex
|
|
62
|
+
@software{pyscfad,
|
|
63
|
+
author = {Zhang, Xing},
|
|
64
|
+
title = {PySCFad: Automatic Differentiation for PySCF},
|
|
65
|
+
url = {https://github.com/fishjojo/pyscfad},
|
|
66
|
+
year = {2021-2025}
|
|
67
|
+
}
|
|
68
|
+
```
|
eigh-0.1.0/build.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Build script for eigh_standalone
|
|
3
|
+
|
|
4
|
+
set -e
|
|
5
|
+
|
|
6
|
+
echo "==================================="
|
|
7
|
+
echo "Building eigh_standalone"
|
|
8
|
+
echo "==================================="
|
|
9
|
+
|
|
10
|
+
# Detect Python (prefer .venv if it exists)
|
|
11
|
+
if [ -f "../.venv/bin/python" ]; then
|
|
12
|
+
PYTHON="../.venv/bin/python"
|
|
13
|
+
echo "Using virtual environment Python"
|
|
14
|
+
elif [ -f ".venv/bin/python" ]; then
|
|
15
|
+
PYTHON=".venv/bin/python"
|
|
16
|
+
echo "Using virtual environment Python"
|
|
17
|
+
elif command -v python &> /dev/null; then
|
|
18
|
+
PYTHON="python"
|
|
19
|
+
else
|
|
20
|
+
echo "Error: Python not found"
|
|
21
|
+
exit 1
|
|
22
|
+
fi
|
|
23
|
+
|
|
24
|
+
echo "Python: $($PYTHON --version)"
|
|
25
|
+
|
|
26
|
+
# Check CMake
|
|
27
|
+
if ! command -v cmake &> /dev/null; then
|
|
28
|
+
echo "Error: CMake not found. Please install CMake 3.18+"
|
|
29
|
+
exit 1
|
|
30
|
+
fi
|
|
31
|
+
|
|
32
|
+
echo "CMake: $(cmake --version | head -n1)"
|
|
33
|
+
|
|
34
|
+
# Check CUDA (optional)
|
|
35
|
+
if command -v nvcc &> /dev/null; then
|
|
36
|
+
echo "CUDA: $(nvcc --version | grep release)"
|
|
37
|
+
CUDA_AVAILABLE=1
|
|
38
|
+
else
|
|
39
|
+
echo "CUDA: Not found (GPU support disabled)"
|
|
40
|
+
CUDA_AVAILABLE=0
|
|
41
|
+
fi
|
|
42
|
+
|
|
43
|
+
# Install Python dependencies
|
|
44
|
+
echo ""
|
|
45
|
+
echo "Installing Python dependencies..."
|
|
46
|
+
$PYTHON -m pip install -r requirements.txt
|
|
47
|
+
|
|
48
|
+
# Create build directory
|
|
49
|
+
echo ""
|
|
50
|
+
echo "Creating build directory..."
|
|
51
|
+
rm -rf build
|
|
52
|
+
mkdir build
|
|
53
|
+
cd build
|
|
54
|
+
|
|
55
|
+
# Configure
|
|
56
|
+
echo ""
|
|
57
|
+
echo "Configuring CMake..."
|
|
58
|
+
if [ $CUDA_AVAILABLE -eq 1 ]; then
|
|
59
|
+
cmake -DPython_EXECUTABLE=$PYTHON ..
|
|
60
|
+
else
|
|
61
|
+
cmake -DPython_EXECUTABLE=$PYTHON -DCMAKE_DISABLE_FIND_PACKAGE_CUDA=TRUE ..
|
|
62
|
+
fi
|
|
63
|
+
|
|
64
|
+
# Build
|
|
65
|
+
echo ""
|
|
66
|
+
echo "Building..."
|
|
67
|
+
make -j$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
|
|
68
|
+
|
|
69
|
+
# Check output
|
|
70
|
+
echo ""
|
|
71
|
+
echo "==================================="
|
|
72
|
+
echo "Build complete!"
|
|
73
|
+
echo "==================================="
|
|
74
|
+
echo ""
|
|
75
|
+
echo "Built modules:"
|
|
76
|
+
ls -lh ../src/python/*.so 2>/dev/null || ls -lh ../src/python/*.dylib 2>/dev/null || echo "No modules found"
|
|
77
|
+
|
|
78
|
+
echo ""
|
|
79
|
+
echo "To test, run:"
|
|
80
|
+
echo " python tests/test_eigh.py"
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
print(f"Python version: {sys.version}")
|
|
5
|
+
print(f"LD_LIBRARY_PATH: {os.environ.get('LD_LIBRARY_PATH', 'Not set')}")
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import jax
|
|
9
|
+
print(f"JAX version: {jax.__version__}")
|
|
10
|
+
|
|
11
|
+
# Try to get devices
|
|
12
|
+
try:
|
|
13
|
+
devices = jax.devices()
|
|
14
|
+
print(f"Devices: {devices}")
|
|
15
|
+
except Exception as dev_err:
|
|
16
|
+
print(f"Error getting devices: {dev_err}")
|
|
17
|
+
|
|
18
|
+
# Check backends
|
|
19
|
+
try:
|
|
20
|
+
import jax.extend as jex
|
|
21
|
+
backend = jex.backend.get_backend()
|
|
22
|
+
print(f"Default backend: {backend.platform}")
|
|
23
|
+
except Exception as back_err:
|
|
24
|
+
print(f"Error getting backend: {back_err}")
|
|
25
|
+
|
|
26
|
+
except Exception as e:
|
|
27
|
+
print(f"Main Error: {e}")
|
|
28
|
+
import traceback
|
|
29
|
+
traceback.print_exc()
|