eigh 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. eigh-0.1.0/.github/workflows/publish.yml +70 -0
  2. eigh-0.1.0/.gitignore +134 -0
  3. eigh-0.1.0/CHANGES.md +33 -0
  4. eigh-0.1.0/CMakeLists.txt +150 -0
  5. eigh-0.1.0/MANIFEST.in +13 -0
  6. eigh-0.1.0/PKG-INFO +104 -0
  7. eigh-0.1.0/README.md +68 -0
  8. eigh-0.1.0/build.sh +80 -0
  9. eigh-0.1.0/debug/diag_jax.py +29 -0
  10. eigh-0.1.0/debug/test_init.py +40 -0
  11. eigh-0.1.0/docs/BUILD_FIXES.md +125 -0
  12. eigh-0.1.0/docs/DEGENERACY.md +179 -0
  13. eigh-0.1.0/docs/EXTRACTION_NOTES.md +178 -0
  14. eigh-0.1.0/docs/QUICKSTART.md +124 -0
  15. eigh-0.1.0/docs/VERIFICATION.md +196 -0
  16. eigh-0.1.0/example_complex.py +155 -0
  17. eigh-0.1.0/example_simple.py +16 -0
  18. eigh-0.1.0/img/eig.png +0 -0
  19. eigh-0.1.0/include/ffi_helpers.h +67 -0
  20. eigh-0.1.0/include/kernel_nanobind_helpers.h +25 -0
  21. eigh-0.1.0/pyproject.toml +74 -0
  22. eigh-0.1.0/requirements.txt +11 -0
  23. eigh-0.1.0/run_gpu.sh +13 -0
  24. eigh-0.1.0/scripts/install_build_deps.sh +41 -0
  25. eigh-0.1.0/setup_gpu_env.sh +65 -0
  26. eigh-0.1.0/setup_gpu_env_clean.sh +24 -0
  27. eigh-0.1.0/src/cpu/lapack.cc +58 -0
  28. eigh-0.1.0/src/cpu/lapack_kernels.cc +290 -0
  29. eigh-0.1.0/src/cpu/lapack_kernels.h +156 -0
  30. eigh-0.1.0/src/cuda/solver.cc +21 -0
  31. eigh-0.1.0/src/cuda/solver_kernels.cc +161 -0
  32. eigh-0.1.0/src/cuda/solver_kernels.h +79 -0
  33. eigh-0.1.0/src/python/eigh/__init__.py +32 -0
  34. eigh-0.1.0/src/python/eigh/_core.py +406 -0
  35. eigh-0.1.0/src/python/eigh/py.typed +2 -0
  36. eigh-0.1.0/tests/test_eigh.py +343 -0
  37. eigh-0.1.0/tests/test_eigh_compare.py +167 -0
  38. eigh-0.1.0/tests/test_eigh_gen.py +208 -0
  39. eigh-0.1.0/tests/test_eigh_hard.py +277 -0
  40. eigh-0.1.0/tests/test_eigh_jit.py +463 -0
  41. eigh-0.1.0/tests/test_eigh_performance.py +92 -0
@@ -0,0 +1,70 @@
1
+ name: Build and publish wheels
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ push:
6
+ tags:
7
+ - 'v*'
8
+
9
+ jobs:
10
+ build_wheels:
11
+ name: Build wheels on ${{ matrix.os }}
12
+ runs-on: ${{ matrix.os }}
13
+ strategy:
14
+ matrix:
15
+ os: [ubuntu-latest, macos-latest]
16
+ # os: [ubuntu-latest, macos-latest, windows-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v4
20
+
21
+ - name: Build wheels
22
+ uses: pypa/cibuildwheel@v2.22.0
23
+ env:
24
+ CIBW_ARCHS_MACOS: x86_64 arm64
25
+ # Linux: only x86_64 for now (aarch64 requires QEMU and is slow)
26
+ CIBW_ARCHS_LINUX: x86_64
27
+ # Use manylinux_2_28 (AlmaLinux 8) - CentOS 7 repos are dead
28
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
29
+ # Skip: PyPy (pp*), Python 3.8/3.9 (jaxlib compat), musllinux (Alpine)
30
+ CIBW_SKIP: "pp* cp38-* cp39-* *-musllinux_*"
31
+
32
+ - uses: actions/upload-artifact@v4
33
+ with:
34
+ name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
35
+ path: ./wheelhouse/*.whl
36
+
37
+ build_sdist:
38
+ name: Build source distribution
39
+ runs-on: ubuntu-latest
40
+ steps:
41
+ - uses: actions/checkout@v4
42
+
43
+ - name: Build sdist
44
+ run: pipx run build --sdist
45
+
46
+ - uses: actions/upload-artifact@v4
47
+ with:
48
+ name: cibw-sdist
49
+ path: dist/*.tar.gz
50
+
51
+ upload_pypi:
52
+ needs: [build_wheels, build_sdist]
53
+ runs-on: ubuntu-latest
54
+ environment: pypi
55
+ permissions:
56
+ id-token: write
57
+ if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
58
+ steps:
59
+ - uses: actions/download-artifact@v4
60
+ with:
61
+ pattern: cibw-*
62
+ path: dist
63
+ merge-multiple: true
64
+
65
+ - uses: pypa/gh-action-pypi-publish@release/v1
66
+ with:
67
+ # registry-url: https://test.pypi.org/legacy/
68
+ # To use trusted publishing, remove 'password' and setup trust on PyPI
69
+ # Or use secrets.PYPI_API_TOKEN
70
+ password: ${{ secrets.PYPI_API_TOKEN }}
eigh-0.1.0/.gitignore ADDED
@@ -0,0 +1,134 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ .DS_store
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ #lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ pip-wheel-metadata/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ .python-version
88
+
89
+ # pipenv
90
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
92
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
93
+ # install all needed dependencies.
94
+ #Pipfile.lock
95
+
96
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97
+ __pypackages__/
98
+
99
+ # Celery stuff
100
+ celerybeat-schedule
101
+ celerybeat.pid
102
+
103
+ # SageMath parsed files
104
+ *.sage.py
105
+
106
+ # Environments
107
+ .env
108
+ .venv
109
+ env/
110
+ venv/
111
+ ENV/
112
+ env.bak/
113
+ venv.bak/
114
+
115
+ # Spyder project settings
116
+ .spyderproject
117
+ .spyproject
118
+
119
+ # Rope project settings
120
+ .ropeproject
121
+
122
+ # mkdocs documentation
123
+ /site
124
+
125
+ # mypy
126
+ .mypy_cache/
127
+ .dmypy.json
128
+ dmypy.json
129
+
130
+ # Pyre type checker
131
+ .pyre/
132
+
133
+ # VSCode
134
+ .vscode/*
eigh-0.1.0/CHANGES.md ADDED
@@ -0,0 +1,33 @@
1
+ # Changes Made
2
+
3
+ Renaming to prevent conflicts with JAX's standard `eigh`.
4
+
5
+ | File | Old Name | New Name |
6
+ | :--- | :--- | :--- |
7
+ | `lapack.cc` | `lapack_ssygvd` | `eigh_lapack_ssygvd` |
8
+ | | `lapack_dsygvd` | `eigh_lapack_dsygvd` |
9
+ | | `lapack_chegvd` | `eigh_lapack_chegvd` |
10
+ | | `lapack_zhegvd` | `eigh_lapack_zhegvd` |
11
+ | | `lapack_*_ffi` variants | `eigh_lapack_*_ffi` |
12
+ | `solver.cc` | `cusolver_sygvd_ffi` | `eigh_cusolver_sygvd_ffi` |
13
+ | `_core.py` | Updated `prepare_lapack_call` and CUDA target name | - |
14
+
15
+ The FFI target names are just string identifiers used to look up the registered function pointers.
16
+ The flow is:
17
+
18
+ 1. C++ registers function pointer with name `"eigh_lapack_ssygvd_ffi"`
19
+ 2. Python calls `ffi.register_ffi_target("eigh_lapack_ssygvd_ffi", ...)`
20
+ 3. During JIT lowering, Python requests target `"eigh_lapack_ssygvd_ffi"`
21
+ 4. XLA looks up the function pointer by that name
22
+
23
+ As long as both sides agree on the name (which they now do), everything works. The actual kernel code is unchanged.
24
+
25
+ # To Apply
26
+
27
+ You need to rebuild the package since the C++ files changed:
28
+
29
+ ```bash
30
+ pip install -e .
31
+ ```
32
+
33
+ After rebuilding, both JAX's standard `eigh` and this package's `eigh` can coexist without conflicts.
@@ -0,0 +1,150 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(eigh_standalone LANGUAGES CXX)
3
+
4
+ set(CMAKE_CXX_STANDARD 17)
5
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
6
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
7
+
8
+ # Fix for NVHPC compiler compatibility with nanobind
9
+ if(CMAKE_CXX_COMPILER_ID MATCHES "NVHPC|PGI")
10
+ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1)
11
+ add_compile_definitions(__GXX_ABI_VERSION=1016)
12
+ endif()
13
+
14
+ # Find required packages
15
+ # Use Development.Module instead of Development - manylinux containers
16
+ # don't have full Python development files (libpython), but extension
17
+ # modules don't need them anyway
18
+ find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
19
+
20
+ # Find nanobind
21
+ execute_process(
22
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
23
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR
24
+ )
25
+ list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
26
+ find_package(nanobind CONFIG REQUIRED)
27
+
28
+ # Find BLAS and LAPACK
29
+ # If NVHPC SDK paths are in environment, set hints for CMake
30
+ set(NVHPC_LIB_PATH "/softs/nvidia/hpc_sdk/Linux_x86_64/22.1/compilers/lib")
31
+ if(EXISTS "${NVHPC_LIB_PATH}/libblas.so")
32
+ set(BLA_VENDOR "NVHPC")
33
+ set(BLAS_LIBRARIES "${NVHPC_LIB_PATH}/libblas.so")
34
+ set(LAPACK_LIBRARIES "${NVHPC_LIB_PATH}/liblapack.so")
35
+ set(BLAS_FOUND TRUE)
36
+ set(LAPACK_FOUND TRUE)
37
+ message(STATUS "Using NVHPC BLAS: ${BLAS_LIBRARIES}")
38
+ message(STATUS "Using NVHPC LAPACK: ${LAPACK_LIBRARIES}")
39
+ else()
40
+ find_package(BLAS REQUIRED)
41
+ find_package(LAPACK REQUIRED)
42
+ endif()
43
+
44
+ # XLA headers (from jaxlib)
45
+ execute_process(
46
+ COMMAND "${Python_EXECUTABLE}" -c "import jaxlib; import os; print(os.path.join(jaxlib.__path__[0], 'include'))"
47
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_INCLUDE_DIR
48
+ RESULT_VARIABLE XLA_RESULT
49
+ )
50
+
51
+ if(NOT XLA_RESULT EQUAL 0 OR NOT EXISTS "${XLA_INCLUDE_DIR}")
52
+ message(FATAL_ERROR "Could not find XLA headers. Make sure jaxlib is installed: pip install jaxlib")
53
+ endif()
54
+
55
+ message(STATUS "XLA headers found at: ${XLA_INCLUDE_DIR}")
56
+
57
+ # Include directories
58
+ include_directories(
59
+ ${CMAKE_SOURCE_DIR}/include
60
+ ${CMAKE_SOURCE_DIR}/src/cpu
61
+ ${CMAKE_SOURCE_DIR}/src/cuda
62
+ ${XLA_INCLUDE_DIR}
63
+ )
64
+
65
+ # CPU LAPACK module
66
+ nanobind_add_module(
67
+ eigh_lapack
68
+ STABLE_ABI
69
+ src/cpu/lapack.cc
70
+ src/cpu/lapack_kernels.cc
71
+ )
72
+ target_link_libraries(eigh_lapack PRIVATE ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES})
73
+
74
+ # Add NVHPC Fortran runtime when linking against NVHPC BLAS/LAPACK
75
+ if(EXISTS "${NVHPC_LIB_PATH}/libnvf.so")
76
+ find_library(NVF_LIB NAMES nvf PATHS "${NVHPC_LIB_PATH}" NO_DEFAULT_PATH)
77
+ if(NVF_LIB)
78
+ target_link_libraries(eigh_lapack PRIVATE ${NVF_LIB} rt)
79
+ message(STATUS "Found NVHPC Fortran runtime: ${NVF_LIB}")
80
+ endif()
81
+ endif()
82
+
83
+ set_target_properties(eigh_lapack PROPERTIES
84
+ LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/src/python/eigh
85
+ )
86
+
87
+ # CUDA module (optional - only if CUDA is available)
88
+ include(CheckLanguage)
89
+ check_language(CUDA)
90
+
91
+ if(CMAKE_CUDA_COMPILER)
92
+ enable_language(CUDA)
93
+ set(CMAKE_CUDA_STANDARD 17)
94
+ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
95
+
96
+ find_package(CUDAToolkit REQUIRED)
97
+
98
+ nanobind_add_module(
99
+ eigh_cuda
100
+ STABLE_ABI
101
+ src/cuda/solver.cc
102
+ src/cuda/solver_kernels.cc
103
+ )
104
+
105
+ # Add CUDA include directories
106
+ target_include_directories(eigh_cuda PRIVATE
107
+ ${CUDAToolkit_INCLUDE_DIRS}
108
+ )
109
+
110
+ set_source_files_properties(
111
+ src/cuda/solver_kernels.cc
112
+ PROPERTIES LANGUAGE CUDA
113
+ )
114
+
115
+ # Link CUDA libraries - handle both imported targets and direct paths
116
+ if(TARGET CUDA::cudart AND TARGET CUDA::cusolver AND TARGET CUDA::cublas)
117
+ target_link_libraries(eigh_cuda PRIVATE
118
+ CUDA::cudart
119
+ CUDA::cusolver
120
+ CUDA::cublas
121
+ )
122
+ else()
123
+ # Fallback for NVHPC SDK where imported targets may not be available
124
+ target_link_libraries(eigh_cuda PRIVATE
125
+ ${CUDAToolkit_LIBRARY_DIR}/libcudart.so
126
+ ${CUDAToolkit_LIBRARY_DIR}/libcusolver.so
127
+ ${CUDAToolkit_LIBRARY_DIR}/libcublas.so
128
+ )
129
+ endif()
130
+
131
+ set_target_properties(eigh_cuda PROPERTIES
132
+ LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/src/python/eigh
133
+ CUDA_SEPARABLE_COMPILATION ON
134
+ )
135
+
136
+ message(STATUS "CUDA support enabled")
137
+ else()
138
+ message(STATUS "CUDA not found - GPU support will be disabled")
139
+ endif()
140
+
141
+ # Install target - scikit-build-core will handle the final install location
142
+ install(TARGETS eigh_lapack
143
+ LIBRARY DESTINATION eigh
144
+ )
145
+
146
+ if(TARGET eigh_cuda)
147
+ install(TARGETS eigh_cuda
148
+ LIBRARY DESTINATION eigh
149
+ )
150
+ endif()
eigh-0.1.0/MANIFEST.in ADDED
@@ -0,0 +1,13 @@
1
+ # Include source files for building
2
+ recursive-include include *.h
3
+ recursive-include src/cpu *.cc *.h
4
+ recursive-include src/cuda *.cc *.h
5
+ include CMakeLists.txt
6
+ include README.md
7
+ include requirements.txt
8
+
9
+ # Exclude build artifacts
10
+ global-exclude *.so
11
+ global-exclude *.dylib
12
+ global-exclude __pycache__
13
+ global-exclude *.pyc
eigh-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.2
2
+ Name: eigh
3
+ Version: 0.1.0
4
+ Summary: Differentiable eigenvalue decomposition with JAX (CPU/GPU)
5
+ Keywords: jax,eigenvalue,linear-algebra,gpu,cuda,autodiff
6
+ Author-Email: Xing Zhang <fishjojo@gmail.com>
7
+ License: Apache-2.0
8
+ Classifier: Development Status :: 4 - Beta
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Operating System :: POSIX :: Linux
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering
18
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
19
+ Project-URL: Homepage, https://github.com/fishjojo/pyscfad
20
+ Project-URL: Repository, https://github.com/fishjojo/pyscfad
21
+ Requires-Python: >=3.10
22
+ Requires-Dist: jax>=0.4.0
23
+ Requires-Dist: jaxlib>=0.4.0
24
+ Requires-Dist: numpy>=1.20.0
25
+ Provides-Extra: cuda
26
+ Requires-Dist: jax[cuda12]>=0.4.0; extra == "cuda"
27
+ Provides-Extra: cuda-local
28
+ Requires-Dist: jax[cuda12-local]>=0.4.0; extra == "cuda-local"
29
+ Provides-Extra: test
30
+ Requires-Dist: pytest>=7.0.0; extra == "test"
31
+ Requires-Dist: scipy>=1.7.0; extra == "test"
32
+ Provides-Extra: dev
33
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
34
+ Requires-Dist: scipy>=1.7.0; extra == "dev"
35
+ Description-Content-Type: text/markdown
36
+
37
+ # Differentiable Generalized Eigenvalue Decomposition
38
+
39
+ <img src="img/eig.png" alt="Eigh Logo" width="400">
40
+
41
+ Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).
42
+
43
+ ## Features
44
+ - **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
45
+ - **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
46
+ - **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
47
+ - **Precision**: `float32/64` and `complex64/128`.
48
+ - **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.
49
+
50
+ ## Installation & Quick Start
51
+
52
+ ```bash
53
+ # Install from source
54
+ pip install .
55
+
56
+ # For GPU support in this environment
57
+ pip install .[cuda-local]
58
+ ```
59
+
60
+ ### Usage Example
61
+ ```python
62
+ import jax
63
+ import jax.numpy as jnp
64
+ from eigh import eigh
65
+
66
+ jax.config.update("jax_enable_x64", True)
67
+ A = jnp.array([[2., 1.], [1., 2.]])
68
+ w, v = eigh(A) # Standard
69
+ grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
70
+ ```
71
+
72
+ ## API Reference
73
+ - **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
74
+ Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
75
+ - **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
76
+ Lower-level generalized solver.
77
+
78
+ ## Degenerate Eigenvalues & Gradients
79
+ Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.
80
+
81
+ ## Development & Testing
82
+ - **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
83
+ - **Tests**:
84
+ ```bash
85
+ pytest tests/test_eigh.py # Core functionality
86
+ pytest tests/test_eigh_gen.py # Generalized itypes
87
+ pytest tests/test_eigh_jit.py # JIT & vmap
88
+ ```
89
+ - **GPU Setup**:
90
+ ```bash
91
+ source setup_gpu_env_clean.sh
92
+ ./run_gpu.sh python example_simple.py
93
+ ```
94
+
95
+ ## License & Citation
96
+ Apache License 2.0. If used in research, please cite:
97
+ ```bibtex
98
+ @software{pyscfad,
99
+ author = {Zhang, Xing},
100
+ title = {PySCFad: Automatic Differentiation for PySCF},
101
+ url = {https://github.com/fishjojo/pyscfad},
102
+ year = {2021-2025}
103
+ }
104
+ ```
eigh-0.1.0/README.md ADDED
@@ -0,0 +1,68 @@
1
+ # Differentiable Generalized Eigenvalue Decomposition
2
+
3
+ <img src="img/eig.png" alt="Eigh Logo" width="400">
4
+
5
+ Standalone implementation of differentiable eigenvalue decomposition with CPU (LAPACK) and GPU (cuSOLVER) backends. Extracted from [pyscfad](https://github.com/fishjojo/pyscfad).
6
+
7
+ ## Features
8
+ - **Generalized Problems**: `A @ V = B @ V @ diag(W)`, etc.
9
+ - **JAX Integrated**: Full support for `jit`, `vmap`, `grad`, and `jvp`.
10
+ - **High Performance**: Optimized LAPACK (CPU) and cuSOLVER (GPU) kernels.
11
+ - **Precision**: `float32/64` and `complex64/128`.
12
+ - **Degeneracy Handling**: Configurable `deg_thresh` for stable gradients.
13
+
14
+ ## Installation & Quick Start
15
+
16
+ ```bash
17
+ # Install from source
18
+ pip install .
19
+
20
+ # For GPU support in this environment
21
+ pip install .[cuda-local]
22
+ ```
23
+
24
+ ### Usage Example
25
+ ```python
26
+ import jax
27
+ import jax.numpy as jnp
28
+ from eigh import eigh
29
+
30
+ jax.config.update("jax_enable_x64", True)
31
+ A = jnp.array([[2., 1.], [1., 2.]])
32
+ w, v = eigh(A) # Standard
33
+ grad = jax.grad(lambda A: eigh(A)[0].sum())(A) # Differentiable
34
+ ```
35
+
36
+ ## API Reference
37
+ - **`eigh(a, b=None, *, lower=True, eigvals_only=False, type=1, deg_thresh=1e-9)`**
38
+ Scipy-compatible interface. `type` supports 1: `A@v=B@v@λ`, 2: `A@B@v=v@λ`, 3: `B@A@v=v@λ`.
39
+ - **`eigh_gen(a, b, *, lower=True, itype=1, deg_thresh=1e-9)`**
40
+ Lower-level generalized solver.
41
+
42
+ ## Degenerate Eigenvalues & Gradients
43
+ Individual eigenvalue gradients are ill-defined for degenerate (repeated) eigenvalues. However, symmetric functions (like `sum`, `var`, `trace`) have stable gradients. The `deg_thresh` parameter (default `1e-9`) masks divisions by near-zero gaps to maintain stability.
44
+
45
+ ## Development & Testing
46
+ - **Requirements**: CMake 3.18+, C++17, JAX, NumPy, LAPACK/CUDA.
47
+ - **Tests**:
48
+ ```bash
49
+ pytest tests/test_eigh.py # Core functionality
50
+ pytest tests/test_eigh_gen.py # Generalized itypes
51
+ pytest tests/test_eigh_jit.py # JIT & vmap
52
+ ```
53
+ - **GPU Setup**:
54
+ ```bash
55
+ source setup_gpu_env_clean.sh
56
+ ./run_gpu.sh python example_simple.py
57
+ ```
58
+
59
+ ## License & Citation
60
+ Apache License 2.0. If used in research, please cite:
61
+ ```bibtex
62
+ @software{pyscfad,
63
+ author = {Zhang, Xing},
64
+ title = {PySCFad: Automatic Differentiation for PySCF},
65
+ url = {https://github.com/fishjojo/pyscfad},
66
+ year = {2021-2025}
67
+ }
68
+ ```
eigh-0.1.0/build.sh ADDED
@@ -0,0 +1,80 @@
1
+ #!/bin/bash
2
+ # Build script for eigh_standalone
3
+
4
+ set -e
5
+
6
+ echo "==================================="
7
+ echo "Building eigh_standalone"
8
+ echo "==================================="
9
+
10
+ # Detect Python (prefer .venv if it exists)
11
+ if [ -f "../.venv/bin/python" ]; then
12
+ PYTHON="../.venv/bin/python"
13
+ echo "Using virtual environment Python"
14
+ elif [ -f ".venv/bin/python" ]; then
15
+ PYTHON=".venv/bin/python"
16
+ echo "Using virtual environment Python"
17
+ elif command -v python &> /dev/null; then
18
+ PYTHON="python"
19
+ else
20
+ echo "Error: Python not found"
21
+ exit 1
22
+ fi
23
+
24
+ echo "Python: $($PYTHON --version)"
25
+
26
+ # Check CMake
27
+ if ! command -v cmake &> /dev/null; then
28
+ echo "Error: CMake not found. Please install CMake 3.18+"
29
+ exit 1
30
+ fi
31
+
32
+ echo "CMake: $(cmake --version | head -n1)"
33
+
34
+ # Check CUDA (optional)
35
+ if command -v nvcc &> /dev/null; then
36
+ echo "CUDA: $(nvcc --version | grep release)"
37
+ CUDA_AVAILABLE=1
38
+ else
39
+ echo "CUDA: Not found (GPU support disabled)"
40
+ CUDA_AVAILABLE=0
41
+ fi
42
+
43
+ # Install Python dependencies
44
+ echo ""
45
+ echo "Installing Python dependencies..."
46
+ $PYTHON -m pip install -r requirements.txt
47
+
48
+ # Create build directory
49
+ echo ""
50
+ echo "Creating build directory..."
51
+ rm -rf build
52
+ mkdir build
53
+ cd build
54
+
55
+ # Configure
56
+ echo ""
57
+ echo "Configuring CMake..."
58
+ if [ $CUDA_AVAILABLE -eq 1 ]; then
59
+ cmake -DPython_EXECUTABLE=$PYTHON ..
60
+ else
61
+ cmake -DPython_EXECUTABLE=$PYTHON -DCMAKE_DISABLE_FIND_PACKAGE_CUDA=TRUE ..
62
+ fi
63
+
64
+ # Build
65
+ echo ""
66
+ echo "Building..."
67
+ make -j$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
68
+
69
+ # Check output
70
+ echo ""
71
+ echo "==================================="
72
+ echo "Build complete!"
73
+ echo "==================================="
74
+ echo ""
75
+ echo "Built modules:"
76
+ ls -lh ../src/python/*.so 2>/dev/null || ls -lh ../src/python/*.dylib 2>/dev/null || echo "No modules found"
77
+
78
+ echo ""
79
+ echo "To test, run:"
80
+ echo " python tests/test_eigh.py"
@@ -0,0 +1,29 @@
1
+ import os
2
+ import sys
3
+
4
+ print(f"Python version: {sys.version}")
5
+ print(f"LD_LIBRARY_PATH: {os.environ.get('LD_LIBRARY_PATH', 'Not set')}")
6
+
7
+ try:
8
+ import jax
9
+ print(f"JAX version: {jax.__version__}")
10
+
11
+ # Try to get devices
12
+ try:
13
+ devices = jax.devices()
14
+ print(f"Devices: {devices}")
15
+ except Exception as dev_err:
16
+ print(f"Error getting devices: {dev_err}")
17
+
18
+ # Check backends
19
+ try:
20
+ import jax.extend as jex
21
+ backend = jex.backend.get_backend()
22
+ print(f"Default backend: {backend.platform}")
23
+ except Exception as back_err:
24
+ print(f"Error getting backend: {back_err}")
25
+
26
+ except Exception as e:
27
+ print(f"Main Error: {e}")
28
+ import traceback
29
+ traceback.print_exc()