cuquantum-python-jax-cu12 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuquantum_python_jax_cu12-0.0.5/LICENSE +28 -0
- cuquantum_python_jax_cu12-0.0.5/MANIFEST.in +9 -0
- cuquantum_python_jax_cu12-0.0.5/PKG-INFO +113 -0
- cuquantum_python_jax_cu12-0.0.5/README.md +87 -0
- cuquantum_python_jax_cu12-0.0.5/configure.sh +122 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/__init__.py +18 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/CMakeLists.txt +68 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat.h +3109 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp +437 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h +30 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/pybind.cpp +51 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/utils.h +31 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/operator_action.py +427 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/__init__.py +8 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/base.py +98 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/context.py +332 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/elementary_operator.py +245 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/matrix_operator.py +194 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator.py +347 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator_action_prim.py +549 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator_term.py +473 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/utils.py +396 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/PKG-INFO +113 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/SOURCES.txt +30 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/dependency_links.txt +1 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/not-zip-safe +1 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/requires.txt +3 -0
- cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/top_level.txt +2 -0
- cuquantum_python_jax_cu12-0.0.5/pyproject.toml +49 -0
- cuquantum_python_jax_cu12-0.0.5/pyproject.toml.template +49 -0
- cuquantum_python_jax_cu12-0.0.5/setup.cfg +4 -0
- cuquantum_python_jax_cu12-0.0.5/setup.py +117 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
|
|
3
|
+
BSD-3-Clause
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
include pyproject.toml.template
|
|
2
|
+
include configure.sh
|
|
3
|
+
include cuquantum/densitymat/jax/cppsrc/CMakeLists.txt
|
|
4
|
+
include cuquantum/densitymat/jax/cppsrc/cudensitymat.h
|
|
5
|
+
include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp
|
|
6
|
+
include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h
|
|
7
|
+
include cuquantum/densitymat/jax/cppsrc/pybind.cpp
|
|
8
|
+
include cuquantum/densitymat/jax/cppsrc/utils.h
|
|
9
|
+
prune tests*
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cuquantum-python-jax-cu12
|
|
3
|
+
Version: 0.0.5
|
|
4
|
+
Summary: NVIDIA cuQuantum Python JAX
|
|
5
|
+
Author-email: NVIDIA Corporation <cuquantum-python@nvidia.com>
|
|
6
|
+
License-Expression: BSD-3-Clause
|
|
7
|
+
Project-URL: Homepage, https://developer.nvidia.com/cuquantum-sdk
|
|
8
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
9
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
10
|
+
Classifier: Topic :: Education
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering
|
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
17
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
18
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA :: 12
|
|
19
|
+
Requires-Python: >=3.11.0
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: pybind11
|
|
23
|
+
Requires-Dist: cuquantum-python-cu12~=26.1.0
|
|
24
|
+
Requires-Dist: jax[cuda12-local]<0.7,>=0.5
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
# cuQuantum Python JAX
|
|
28
|
+
|
|
29
|
+
cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
|
|
30
|
+
|
|
31
|
+
## Documentation
|
|
32
|
+
|
|
33
|
+
Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com/cuda/cuquantum/latest/python).
|
|
34
|
+
|
|
35
|
+
## Building and installing cuQuantum Python JAX
|
|
36
|
+
|
|
37
|
+
### Requirements
|
|
38
|
+
|
|
39
|
+
The install-time dependencies of the cuQuantum Python JAX package include:
|
|
40
|
+
|
|
41
|
+
* cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
|
|
42
|
+
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
|
|
43
|
+
* pybind11
|
|
44
|
+
* setuptools>=77.0.3
|
|
45
|
+
|
|
46
|
+
Note:
|
|
47
|
+
1. cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.
|
|
48
|
+
2. cuQuantum Python JAX installation does not support build isolation. The user needs to pass in `--no-build-isolation` to `pip` when installing cuQuantum Python JAX.
|
|
49
|
+
3. cuQuantum Python JAX wheels are CUDA-versioned: `cuquantum-python-jax-cu12` for CUDA 12 and `cuquantum-python-jax-cu13` for CUDA 13.
|
|
50
|
+
|
|
51
|
+
#### Installation using `jax[cudaXX-local]`
|
|
52
|
+
|
|
53
|
+
`cuquantum-python-jax-cu12` (or `cuquantum-python-jax-cu13`) depends explicitly on `jax[cudaXX-local]`. Installing the package will also install `jax[cudaXX-local]`.
|
|
54
|
+
|
|
55
|
+
Using `jax[cudaXX-local]` assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify `LD_LIBRARY_PATH`, including the library folders containing `libcudnn.so` and `libcupti.so`.
|
|
56
|
+
|
|
57
|
+
`libcupti.so` is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under `/usr/local/cuda`, `libcupti.so` is located under `/usr/local/cuda/extras/CUPTI/lib64` and `LD_LIBRARY_PATH` should contain this path.
|
|
58
|
+
|
|
59
|
+
`libcudnn.so` is installed separately from the CUDA Toolkit. The default installation location is `/usr/local/cuda/lib64`, and `LD_LIBRARY_PATH` should contain this path.
|
|
60
|
+
|
|
61
|
+
Both `libcudnn.so` and `libcupti.so` are installable with pip:
|
|
62
|
+
|
|
63
|
+
```
|
|
64
|
+
pip install nvidia-cudnn-cu12
|
|
65
|
+
pip install nvidia-cuda-cupti-cu12
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
After installing cuDNN and cuPTI, the user may install cuQuantum Python JAX with `pip` using either:
|
|
69
|
+
|
|
70
|
+
```
|
|
71
|
+
pip install --no-build-isolation cuquantum-python-jax-cu12 # for CUDA 12
|
|
72
|
+
pip install --no-build-isolation cuquantum-python-jax-cu13 # for CUDA 13
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
or one of
|
|
76
|
+
|
|
77
|
+
```
|
|
78
|
+
pip install --no-build-isolation cuquantum-python-cu12[jax]
|
|
79
|
+
pip install --no-build-isolation cuquantum-python-cu13[jax]
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
where the CUDA version is explicitly specified on cuquantum-python.
|
|
83
|
+
|
|
84
|
+
Note:
|
|
85
|
+
1. If cuDNN and cuPTI are installed with `pip`, the user does not need to specify library folders in `LD_LIBRARY_PATH`.
|
|
86
|
+
2. When the latter command `pip install --no-build-isolation cuquantum-python-cu12[jax]`/`pip install --no-build-isolation cuquantum-python-cu13[jax]` is used, `--no-build-isolation` applies to both cuquantum-python and cuquantum-python-jax. The user needs to ensure cuquantum-python's build dependencies are installed before the installation.
|
|
87
|
+
|
|
88
|
+
#### Installing from source
|
|
89
|
+
|
|
90
|
+
To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the [instructions on GitHub](https://github.com/NVIDIA/cuQuantum/blob/main/python/README.md). Once complete, navigate to `python/extensions`, then:
|
|
91
|
+
|
|
92
|
+
```
|
|
93
|
+
pip install .
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will be named accordingly (`cuquantum-python-jax-cu12` or `cuquantum-python-jax-cu13`).
|
|
97
|
+
|
|
98
|
+
## Running
|
|
99
|
+
|
|
100
|
+
### Requirements
|
|
101
|
+
|
|
102
|
+
Runtime dependencies of the cuQuantum Python JAX package include:
|
|
103
|
+
|
|
104
|
+
* An NVIDIA GPU with compute capability 7.5+
|
|
105
|
+
* cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
|
|
106
|
+
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
|
|
107
|
+
* pybind11
|
|
108
|
+
|
|
109
|
+
## Developer Notes
|
|
110
|
+
|
|
111
|
+
* cuQuantum Python JAX does not support editable installation.
|
|
112
|
+
* Both cuQuantum Python and cuQuantum Python JAX need to be installed into `site-packages` for proper import of the library.
|
|
113
|
+
* cuQuantum Python JAX assumes cuQuantum Python will be available under the current `site-packages` directory.
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# cuQuantum Python JAX
|
|
2
|
+
|
|
3
|
+
cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
|
|
4
|
+
|
|
5
|
+
## Documentation
|
|
6
|
+
|
|
7
|
+
Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com/cuda/cuquantum/latest/python).
|
|
8
|
+
|
|
9
|
+
## Building and installing cuQuantum Python JAX
|
|
10
|
+
|
|
11
|
+
### Requirements
|
|
12
|
+
|
|
13
|
+
The install-time dependencies of the cuQuantum Python JAX package include:
|
|
14
|
+
|
|
15
|
+
* cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
|
|
16
|
+
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
|
|
17
|
+
* pybind11
|
|
18
|
+
* setuptools>=77.0.3
|
|
19
|
+
|
|
20
|
+
Note:
|
|
21
|
+
1. cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.
|
|
22
|
+
2. cuQuantum Python JAX installation does not support build isolation. The user needs to pass in `--no-build-isolation` to `pip` when installing cuQuantum Python JAX.
|
|
23
|
+
3. cuQuantum Python JAX wheels are CUDA-versioned: `cuquantum-python-jax-cu12` for CUDA 12 and `cuquantum-python-jax-cu13` for CUDA 13.
|
|
24
|
+
|
|
25
|
+
#### Installation using `jax[cudaXX-local]`
|
|
26
|
+
|
|
27
|
+
`cuquantum-python-jax-cu12` (or `cuquantum-python-jax-cu13`) depends explicitly on `jax[cudaXX-local]`. Installing the package will also install `jax[cudaXX-local]`.
|
|
28
|
+
|
|
29
|
+
Using `jax[cudaXX-local]` assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify `LD_LIBRARY_PATH`, including the library folders containing `libcudnn.so` and `libcupti.so`.
|
|
30
|
+
|
|
31
|
+
`libcupti.so` is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under `/usr/local/cuda`, `libcupti.so` is located under `/usr/local/cuda/extras/CUPTI/lib64` and `LD_LIBRARY_PATH` should contain this path.
|
|
32
|
+
|
|
33
|
+
`libcudnn.so` is installed separately from the CUDA Toolkit. The default installation location is `/usr/local/cuda/lib64`, and `LD_LIBRARY_PATH` should contain this path.
|
|
34
|
+
|
|
35
|
+
Both `libcudnn.so` and `libcupti.so` are installable with pip:
|
|
36
|
+
|
|
37
|
+
```
|
|
38
|
+
pip install nvidia-cudnn-cu12
|
|
39
|
+
pip install nvidia-cuda-cupti-cu12
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
After installing cuDNN and cuPTI, the user may install cuQuantum Python JAX with `pip` using either:
|
|
43
|
+
|
|
44
|
+
```
|
|
45
|
+
pip install --no-build-isolation cuquantum-python-jax-cu12 # for CUDA 12
|
|
46
|
+
pip install --no-build-isolation cuquantum-python-jax-cu13 # for CUDA 13
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
or one of
|
|
50
|
+
|
|
51
|
+
```
|
|
52
|
+
pip install --no-build-isolation cuquantum-python-cu12[jax]
|
|
53
|
+
pip install --no-build-isolation cuquantum-python-cu13[jax]
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
where the CUDA version is explicitly specified on cuquantum-python.
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
1. If cuDNN and cuPTI are installed with `pip`, the user does not need to specify library folders in `LD_LIBRARY_PATH`.
|
|
60
|
+
2. When the latter command `pip install --no-build-isolation cuquantum-python-cu12[jax]`/`pip install --no-build-isolation cuquantum-python-cu13[jax]` is used, `--no-build-isolation` applies to both cuquantum-python and cuquantum-python-jax. The user needs to ensure cuquantum-python's build dependencies are installed before the installation.
|
|
61
|
+
|
|
62
|
+
#### Installing from source
|
|
63
|
+
|
|
64
|
+
To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the [instructions on GitHub](https://github.com/NVIDIA/cuQuantum/blob/main/python/README.md). Once complete, navigate to `python/extensions`, then:
|
|
65
|
+
|
|
66
|
+
```
|
|
67
|
+
pip install .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will be named accordingly (`cuquantum-python-jax-cu12` or `cuquantum-python-jax-cu13`).
|
|
71
|
+
|
|
72
|
+
## Running
|
|
73
|
+
|
|
74
|
+
### Requirements
|
|
75
|
+
|
|
76
|
+
Runtime dependencies of the cuQuantum Python JAX package include:
|
|
77
|
+
|
|
78
|
+
* An NVIDIA GPU with compute capability 7.5+
|
|
79
|
+
* cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
|
|
80
|
+
* jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
|
|
81
|
+
* pybind11
|
|
82
|
+
|
|
83
|
+
## Developer Notes
|
|
84
|
+
|
|
85
|
+
* cuQuantum Python JAX does not support editable installation.
|
|
86
|
+
* Both cuQuantum Python and cuQuantum Python JAX need to be installed into `site-packages` for proper import of the library.
|
|
87
|
+
* cuQuantum Python JAX assumes cuQuantum Python will be available under the current `site-packages` directory.
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES
|
|
3
|
+
#
|
|
4
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
5
|
+
|
|
6
|
+
# Generate pyproject.toml from pyproject.toml.template by detecting
|
|
7
|
+
# CUDA version and substituting template placeholders.
|
|
8
|
+
#
|
|
9
|
+
# Usage:
|
|
10
|
+
# CUDA_PATH=/usr/local/cuda bash configure.sh
|
|
11
|
+
|
|
12
|
+
set -euo pipefail
|
|
13
|
+
|
|
14
|
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
15
|
+
TEMPLATE="${SCRIPT_DIR}/pyproject.toml.template"
|
|
16
|
+
OUTPUT="${SCRIPT_DIR}/pyproject.toml"
|
|
17
|
+
|
|
18
|
+
# -------------------------------------------------------------------
|
|
19
|
+
# 1. Validate template exists (early, before doing any real work)
|
|
20
|
+
# -------------------------------------------------------------------
|
|
21
|
+
if [[ ! -f "${TEMPLATE}" ]]; then
|
|
22
|
+
echo "ERROR: Template not found at ${TEMPLATE}" >&2
|
|
23
|
+
exit 1
|
|
24
|
+
fi
|
|
25
|
+
|
|
26
|
+
# -------------------------------------------------------------------
|
|
27
|
+
# 2. Validate CUDA_PATH
|
|
28
|
+
# -------------------------------------------------------------------
|
|
29
|
+
if [[ -z "${CUDA_PATH:-}" ]]; then
|
|
30
|
+
echo "ERROR: CUDA_PATH is not set. Please set it to your CUDA toolkit root." >&2
|
|
31
|
+
exit 1
|
|
32
|
+
fi
|
|
33
|
+
|
|
34
|
+
CUDA_H="${CUDA_PATH}/include/cuda.h"
|
|
35
|
+
if [[ ! -f "${CUDA_H}" ]]; then
|
|
36
|
+
echo "ERROR: Cannot find ${CUDA_H}. Is CUDA_PATH set correctly?" >&2
|
|
37
|
+
exit 1
|
|
38
|
+
fi
|
|
39
|
+
|
|
40
|
+
# -------------------------------------------------------------------
|
|
41
|
+
# 3. Parse CUDA_VERSION from cuda.h (mirrors setup.py logic)
|
|
42
|
+
# Example line: #define CUDA_VERSION 12020
|
|
43
|
+
# 12020 => major = 12020 / 1000 = 12
|
|
44
|
+
# -------------------------------------------------------------------
|
|
45
|
+
# Read grep output into a variable to avoid a grep|head|sed pipeline,
|
|
46
|
+
# which can behave inconsistently with pipefail across bash versions
|
|
47
|
+
# (head closes the pipe early, causing grep to receive SIGPIPE).
|
|
48
|
+
CUDA_H_MATCH=$(grep -E -m1 '^\s*#define\s+CUDA_VERSION\s+[0-9]+' "${CUDA_H}" || true)
|
|
49
|
+
|
|
50
|
+
if [[ -z "${CUDA_H_MATCH}" ]]; then
|
|
51
|
+
echo "ERROR: Could not find CUDA_VERSION definition in ${CUDA_H}" >&2
|
|
52
|
+
exit 1
|
|
53
|
+
fi
|
|
54
|
+
|
|
55
|
+
CUDA_VERSION_RAW=$(echo "${CUDA_H_MATCH}" | sed -E 's/.*#define\s+CUDA_VERSION\s+([0-9]+).*/\1/')
|
|
56
|
+
|
|
57
|
+
if [[ -z "${CUDA_VERSION_RAW}" ]]; then
|
|
58
|
+
echo "ERROR: Could not parse CUDA_VERSION from ${CUDA_H}" >&2
|
|
59
|
+
exit 1
|
|
60
|
+
fi
|
|
61
|
+
|
|
62
|
+
if (( CUDA_VERSION_RAW < 1000 )); then
|
|
63
|
+
echo "ERROR: CUDA_VERSION ${CUDA_VERSION_RAW} from ${CUDA_H} is unexpectedly small (< 1000)" >&2
|
|
64
|
+
exit 1
|
|
65
|
+
fi
|
|
66
|
+
|
|
67
|
+
CUDA_MAJOR=$(( CUDA_VERSION_RAW / 1000 ))
|
|
68
|
+
|
|
69
|
+
echo "Detected CUDA_VERSION=${CUDA_VERSION_RAW} (major=${CUDA_MAJOR}) from ${CUDA_H}"
|
|
70
|
+
|
|
71
|
+
# -------------------------------------------------------------------
|
|
72
|
+
# 4. Map CUDA major version to template variable values
|
|
73
|
+
# -------------------------------------------------------------------
|
|
74
|
+
case "${CUDA_MAJOR}" in
|
|
75
|
+
12)
|
|
76
|
+
JAX_VERSION_SPEC=">=0.5,<0.7"
|
|
77
|
+
CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 12"
|
|
78
|
+
;;
|
|
79
|
+
13)
|
|
80
|
+
JAX_VERSION_SPEC=">=0.8,<0.9"
|
|
81
|
+
CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 13"
|
|
82
|
+
;;
|
|
83
|
+
*)
|
|
84
|
+
echo "ERROR: Unsupported CUDA major version: ${CUDA_MAJOR}" >&2
|
|
85
|
+
exit 1
|
|
86
|
+
;;
|
|
87
|
+
esac
|
|
88
|
+
|
|
89
|
+
# -------------------------------------------------------------------
|
|
90
|
+
# 5. Escape sed replacement-special characters (& and \) in values
|
|
91
|
+
# -------------------------------------------------------------------
|
|
92
|
+
sed_escape() {
|
|
93
|
+
printf '%s' "$1" | sed -e 's/[&\\/]/\\&/g'
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
CUDA_MAJOR_ESC=$(sed_escape "${CUDA_MAJOR}")
|
|
97
|
+
JAX_VERSION_SPEC_ESC=$(sed_escape "${JAX_VERSION_SPEC}")
|
|
98
|
+
CUDA_CLASSIFIER_ESC=$(sed_escape "${CUDA_CLASSIFIER}")
|
|
99
|
+
|
|
100
|
+
# -------------------------------------------------------------------
|
|
101
|
+
# 6. Generate pyproject.toml via sed substitution
|
|
102
|
+
# Write to a temp file then atomically move into place so a partial
|
|
103
|
+
# write (e.g. Ctrl+C, disk full) never leaves a broken pyproject.toml.
|
|
104
|
+
# -------------------------------------------------------------------
|
|
105
|
+
TMPFILE=$(mktemp "${OUTPUT}.XXXXXX")
|
|
106
|
+
trap 'rm -f "${TMPFILE}"' EXIT
|
|
107
|
+
|
|
108
|
+
sed -e "s|@CUDA_MAJOR_VER@|${CUDA_MAJOR_ESC}|g" \
|
|
109
|
+
-e "s|@JAX_VERSION_SPEC@|${JAX_VERSION_SPEC_ESC}|g" \
|
|
110
|
+
-e "s|@CUDA_CLASSIFIER@|${CUDA_CLASSIFIER_ESC}|g" \
|
|
111
|
+
-- "${TEMPLATE}" > "${TMPFILE}"
|
|
112
|
+
|
|
113
|
+
# Skip overwrite if output already exists and is identical (idempotency).
|
|
114
|
+
if [[ -f "${OUTPUT}" ]] && cmp -s "${TMPFILE}" "${OUTPUT}"; then
|
|
115
|
+
echo "pyproject.toml is already up to date for CUDA ${CUDA_MAJOR}"
|
|
116
|
+
exit 0
|
|
117
|
+
fi
|
|
118
|
+
|
|
119
|
+
mv -f "${TMPFILE}" "${OUTPUT}"
|
|
120
|
+
trap - EXIT
|
|
121
|
+
|
|
122
|
+
echo "Generated ${OUTPUT} for CUDA ${CUDA_MAJOR}"
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
4
|
+
|
|
5
|
+
from cuquantum.bindings._internal import cudensitymat as _cudm
|
|
6
|
+
_cudm._inspect_function_pointers() # for loading libcudensitymat.so
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
if not jax.config.jax_enable_x64:
|
|
10
|
+
raise RuntimeError(f"jax_enable_x64 must be set to True to use cuQuantum Python JAX")
|
|
11
|
+
|
|
12
|
+
from .operator_action import operator_action
|
|
13
|
+
from .pysrc import (
|
|
14
|
+
ElementaryOperator,
|
|
15
|
+
MatrixOperator,
|
|
16
|
+
OperatorTerm,
|
|
17
|
+
Operator
|
|
18
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
4
|
+
|
|
5
|
+
cmake_minimum_required(VERSION 3.22)
|
|
6
|
+
project(cudensitymat_jax LANGUAGES CXX)
|
|
7
|
+
|
|
8
|
+
set(CMAKE_CXX_STANDARD 17)
|
|
9
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
10
|
+
|
|
11
|
+
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
|
|
12
|
+
message(STATUS "Python executable: ${Python3_EXECUTABLE}")
|
|
13
|
+
|
|
14
|
+
find_package(CUDAToolkit REQUIRED)
|
|
15
|
+
message(STATUS "CUDA toolkit directory: ${CUDAToolkit_INCLUDE_DIRS}")
|
|
16
|
+
|
|
17
|
+
# Find XLA directory
|
|
18
|
+
execute_process(
|
|
19
|
+
COMMAND ${Python3_EXECUTABLE} -c "import jax; print(jax.ffi.include_dir())"
|
|
20
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
21
|
+
OUTPUT_VARIABLE XLA_DIR
|
|
22
|
+
)
|
|
23
|
+
if(NOT XLA_DIR)
|
|
24
|
+
message(FATAL_ERROR "XLA directory not found")
|
|
25
|
+
else()
|
|
26
|
+
message(STATUS "XLA directory: ${XLA_DIR}")
|
|
27
|
+
endif()
|
|
28
|
+
|
|
29
|
+
# Find pybind11 directory
|
|
30
|
+
execute_process(
|
|
31
|
+
COMMAND ${Python3_EXECUTABLE} -c "import pybind11; print(pybind11.get_include())"
|
|
32
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
33
|
+
OUTPUT_VARIABLE pybind11_INCLUDE_DIR
|
|
34
|
+
)
|
|
35
|
+
if(NOT pybind11_INCLUDE_DIR)
|
|
36
|
+
message(FATAL_ERROR "Pybind11 include directory not found")
|
|
37
|
+
else()
|
|
38
|
+
message(STATUS "Pybind11 include directory: ${pybind11_INCLUDE_DIR}")
|
|
39
|
+
endif()
|
|
40
|
+
|
|
41
|
+
set(pybind11_DIR ${pybind11_INCLUDE_DIR}/../share/cmake/pybind11)
|
|
42
|
+
find_package(pybind11 REQUIRED)
|
|
43
|
+
|
|
44
|
+
pybind11_add_module(
|
|
45
|
+
${PROJECT_NAME}
|
|
46
|
+
cudensitymat_jax.cpp
|
|
47
|
+
pybind.cpp
|
|
48
|
+
)
|
|
49
|
+
target_include_directories(
|
|
50
|
+
${PROJECT_NAME}
|
|
51
|
+
PUBLIC
|
|
52
|
+
${CUDAToolkit_INCLUDE_DIRS}
|
|
53
|
+
${XLA_DIR}
|
|
54
|
+
${pybind11_INCLUDE_DIR}
|
|
55
|
+
${CMAKE_CURRENT_SOURCE_DIR}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
set_target_properties(
|
|
59
|
+
${PROJECT_NAME}
|
|
60
|
+
PROPERTIES
|
|
61
|
+
BUILD_RPATH "$ORIGIN"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
target_link_libraries(
|
|
65
|
+
${PROJECT_NAME}
|
|
66
|
+
PRIVATE
|
|
67
|
+
CUDA::cudart_static
|
|
68
|
+
)
|