cuquantum-python-jax-cu12 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. cuquantum_python_jax_cu12-0.0.5/LICENSE +28 -0
  2. cuquantum_python_jax_cu12-0.0.5/MANIFEST.in +9 -0
  3. cuquantum_python_jax_cu12-0.0.5/PKG-INFO +113 -0
  4. cuquantum_python_jax_cu12-0.0.5/README.md +87 -0
  5. cuquantum_python_jax_cu12-0.0.5/configure.sh +122 -0
  6. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/__init__.py +18 -0
  7. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/CMakeLists.txt +68 -0
  8. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat.h +3109 -0
  9. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp +437 -0
  10. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h +30 -0
  11. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/pybind.cpp +51 -0
  12. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/cppsrc/utils.h +31 -0
  13. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/operator_action.py +427 -0
  14. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/__init__.py +8 -0
  15. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/base.py +98 -0
  16. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/context.py +332 -0
  17. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/elementary_operator.py +245 -0
  18. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/matrix_operator.py +194 -0
  19. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator.py +347 -0
  20. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator_action_prim.py +549 -0
  21. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/pysrc/operator_term.py +473 -0
  22. cuquantum_python_jax_cu12-0.0.5/cuquantum/densitymat/jax/utils.py +396 -0
  23. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/PKG-INFO +113 -0
  24. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/SOURCES.txt +30 -0
  25. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/dependency_links.txt +1 -0
  26. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/not-zip-safe +1 -0
  27. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/requires.txt +3 -0
  28. cuquantum_python_jax_cu12-0.0.5/cuquantum_python_jax_cu12.egg-info/top_level.txt +2 -0
  29. cuquantum_python_jax_cu12-0.0.5/pyproject.toml +49 -0
  30. cuquantum_python_jax_cu12-0.0.5/pyproject.toml.template +49 -0
  31. cuquantum_python_jax_cu12-0.0.5/setup.cfg +4 -0
  32. cuquantum_python_jax_cu12-0.0.5/setup.py +117 -0
@@ -0,0 +1,28 @@
1
+ Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+ BSD-3-Clause
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,9 @@
1
+ include pyproject.toml.template
2
+ include configure.sh
3
+ include cuquantum/densitymat/jax/cppsrc/CMakeLists.txt
4
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat.h
5
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp
6
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h
7
+ include cuquantum/densitymat/jax/cppsrc/pybind.cpp
8
+ include cuquantum/densitymat/jax/cppsrc/utils.h
9
+ prune tests*
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: cuquantum-python-jax-cu12
3
+ Version: 0.0.5
4
+ Summary: NVIDIA cuQuantum Python JAX
5
+ Author-email: NVIDIA Corporation <cuquantum-python@nvidia.com>
6
+ License-Expression: BSD-3-Clause
7
+ Project-URL: Homepage, https://developer.nvidia.com/cuquantum-sdk
8
+ Classifier: Development Status :: 5 - Production/Stable
9
+ Classifier: Operating System :: POSIX :: Linux
10
+ Classifier: Topic :: Education
11
+ Classifier: Topic :: Scientific/Engineering
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: Implementation :: CPython
17
+ Classifier: Environment :: GPU :: NVIDIA CUDA
18
+ Classifier: Environment :: GPU :: NVIDIA CUDA :: 12
19
+ Requires-Python: >=3.11.0
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: pybind11
23
+ Requires-Dist: cuquantum-python-cu12~=26.1.0
24
+ Requires-Dist: jax[cuda12-local]<0.7,>=0.5
25
+ Dynamic: license-file
26
+
27
+ # cuQuantum Python JAX
28
+
29
+ cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
30
+
31
+ ## Documentation
32
+
33
+ Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com/cuda/cuquantum/latest/python).
34
+
35
+ ## Building and installing cuQuantum Python JAX
36
+
37
+ ### Requirements
38
+
39
+ The install-time dependencies of the cuQuantum Python JAX package include:
40
+
41
+ * cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
42
+ * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
43
+ * pybind11
44
+ * setuptools>=77.0.3
45
+
46
+ Note:
47
+ 1. cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.
48
+ 2. cuQuantum Python JAX installation does not support build isolation. The user needs to pass in `--no-build-isolation` to `pip` when installing cuQuantum Python JAX.
49
+ 3. cuQuantum Python JAX wheels are CUDA-versioned: `cuquantum-python-jax-cu12` for CUDA 12 and `cuquantum-python-jax-cu13` for CUDA 13.
50
+
51
+ #### Installation using `jax[cudaXX-local]`
52
+
53
+ `cuquantum-python-jax-cu12` (or `cuquantum-python-jax-cu13`) depends explicitly on `jax[cudaXX-local]`. Installing the package will also install `jax[cudaXX-local]`.
54
+
55
+ Using `jax[cudaXX-local]` assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify `LD_LIBRARY_PATH`, including the library folders containing `libcudnn.so` and `libcupti.so`.
56
+
57
+ `libcupti.so` is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under `/usr/local/cuda`, `libcupti.so` is located under `/usr/local/cuda/extras/CUPTI/lib64` and `LD_LIBRARY_PATH` should contain this path.
58
+
59
+ `libcudnn.so` is installed separately from the CUDA Toolkit. The default installation location is `/usr/local/cuda/lib64`, and `LD_LIBRARY_PATH` should contain this path.
60
+
61
+ Both `libcudnn.so` and `libcupti.so` are installable with pip:
62
+
63
+ ```
64
+ pip install nvidia-cudnn-cu12
65
+ pip install nvidia-cuda-cupti-cu12
66
+ ```
67
+
68
+ After installing cuDNN and cuPTI, the user may install cuQuantum Python JAX with `pip` using either:
69
+
70
+ ```
71
+ pip install --no-build-isolation cuquantum-python-jax-cu12 # for CUDA 12
72
+ pip install --no-build-isolation cuquantum-python-jax-cu13 # for CUDA 13
73
+ ```
74
+
75
+ or one of
76
+
77
+ ```
78
+ pip install --no-build-isolation cuquantum-python-cu12[jax]
79
+ pip install --no-build-isolation cuquantum-python-cu13[jax]
80
+ ```
81
+
82
+ where the CUDA version is explicitly specified on cuquantum-python.
83
+
84
+ Note:
85
+ 1. If cuDNN and cuPTI are installed with `pip`, the user does not need to specify library folders in `LD_LIBRARY_PATH`.
86
+ 2. When the latter command `pip install --no-build-isolation cuquantum-python-cu12[jax]`/`pip install --no-build-isolation cuquantum-python-cu13[jax]` is used, `--no-build-isolation` applies to both cuquantum-python and cuquantum-python-jax. The user needs to ensure cuquantum-python's build dependencies are installed before the installation.
87
+
88
+ #### Installing from source
89
+
90
+ To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the [instructions on GitHub](https://github.com/NVIDIA/cuQuantum/blob/main/python/README.md). Once complete, navigate to `python/extensions`, then:
91
+
92
+ ```
93
+ pip install .
94
+ ```
95
+
96
+ The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will be named accordingly (`cuquantum-python-jax-cu12` or `cuquantum-python-jax-cu13`).
97
+
98
+ ## Running
99
+
100
+ ### Requirements
101
+
102
+ Runtime dependencies of the cuQuantum Python JAX package include:
103
+
104
+ * An NVIDIA GPU with compute capability 7.5+
105
+ * cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
106
+ * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
107
+ * pybind11
108
+
109
+ ## Developer Notes
110
+
111
+ * cuQuantum Python JAX does not support editable installation.
112
+ * Both cuQuantum Python and cuQuantum Python JAX need to be installed into `site-packages` for proper import of the library.
113
+ * cuQuantum Python JAX assumes cuQuantum Python will be available under the current `site-packages` directory.
@@ -0,0 +1,87 @@
1
+ # cuQuantum Python JAX
2
+
3
+ cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
4
+
5
+ ## Documentation
6
+
7
+ Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com/cuda/cuquantum/latest/python).
8
+
9
+ ## Building and installing cuQuantum Python JAX
10
+
11
+ ### Requirements
12
+
13
+ The install-time dependencies of the cuQuantum Python JAX package include:
14
+
15
+ * cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
16
+ * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
17
+ * pybind11
18
+ * setuptools>=77.0.3
19
+
20
+ Note:
21
+ 1. cuQuantum Python JAX is only supported with CUDA 12 and CUDA 13.
22
+ 2. cuQuantum Python JAX installation does not support build isolation. The user needs to pass in `--no-build-isolation` to `pip` when installing cuQuantum Python JAX.
23
+ 3. cuQuantum Python JAX wheels are CUDA-versioned: `cuquantum-python-jax-cu12` for CUDA 12 and `cuquantum-python-jax-cu13` for CUDA 13.
24
+
25
+ #### Installation using `jax[cudaXX-local]`
26
+
27
+ `cuquantum-python-jax-cu12` (or `cuquantum-python-jax-cu13`) depends explicitly on `jax[cudaXX-local]`. Installing the package will also install `jax[cudaXX-local]`.
28
+
29
+ Using `jax[cudaXX-local]` assumes the user provides both cuDNN and the CUDA Toolkit. cuDNN is not a part of the CUDA Toolkit and requires an additional installation. The user must also specify `LD_LIBRARY_PATH`, including the library folders containing `libcudnn.so` and `libcupti.so`.
30
+
31
+ `libcupti.so` is provided by the CUDA Toolkit. If the CUDA Toolkit is installed under `/usr/local/cuda`, `libcupti.so` is located under `/usr/local/cuda/extras/CUPTI/lib64` and `LD_LIBRARY_PATH` should contain this path.
32
+
33
+ `libcudnn.so` is installed separately from the CUDA Toolkit. The default installation location is `/usr/local/cuda/lib64`, and `LD_LIBRARY_PATH` should contain this path.
34
+
35
+ Both `libcudnn.so` and `libcupti.so` are installable with pip:
36
+
37
+ ```
38
+ pip install nvidia-cudnn-cu12
39
+ pip install nvidia-cuda-cupti-cu12
40
+ ```
41
+
42
+ After installing cuDNN and cuPTI, the user may install cuQuantum Python JAX with `pip` using either:
43
+
44
+ ```
45
+ pip install --no-build-isolation cuquantum-python-jax-cu12 # for CUDA 12
46
+ pip install --no-build-isolation cuquantum-python-jax-cu13 # for CUDA 13
47
+ ```
48
+
49
+ or one of
50
+
51
+ ```
52
+ pip install --no-build-isolation cuquantum-python-cu12[jax]
53
+ pip install --no-build-isolation cuquantum-python-cu13[jax]
54
+ ```
55
+
56
+ where the CUDA version is explicitly specified on cuquantum-python.
57
+
58
+ Note:
59
+ 1. If cuDNN and cuPTI are installed with `pip`, the user does not need to specify library folders in `LD_LIBRARY_PATH`.
60
+ 2. When the latter command `pip install --no-build-isolation cuquantum-python-cu12[jax]`/`pip install --no-build-isolation cuquantum-python-cu13[jax]` is used, `--no-build-isolation` applies to both cuquantum-python and cuquantum-python-jax. The user needs to ensure cuquantum-python's build dependencies are installed before the installation.
61
+
62
+ #### Installing from source
63
+
64
+ To install cuQuantum Python JAX from source, first compile cuQuantum Python from source using the [instructions on GitHub](https://github.com/NVIDIA/cuQuantum/blob/main/python/README.md). Once complete, navigate to `python/extensions`, then:
65
+
66
+ ```
67
+ pip install .
68
+ ```
69
+
70
+ The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will be named accordingly (`cuquantum-python-jax-cu12` or `cuquantum-python-jax-cu13`).
71
+
72
+ ## Running
73
+
74
+ ### Requirements
75
+
76
+ Runtime dependencies of the cuQuantum Python JAX package include:
77
+
78
+ * An NVIDIA GPU with compute capability 7.5+
79
+ * cuquantum-python-cu12~=26.1.0 for CUDA 12 or cuquantum-python-cu13~=26.1.0 for CUDA 13
80
+ * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
81
+ * pybind11
82
+
83
+ ## Developer Notes
84
+
85
+ * cuQuantum Python JAX does not support editable installation.
86
+ * Both cuQuantum Python and cuQuantum Python JAX need to be installed into `site-packages` for proper import of the library.
87
+ * cuQuantum Python JAX assumes cuQuantum Python will be available under the current `site-packages` directory.
@@ -0,0 +1,122 @@
1
+ #!/usr/bin/env bash
2
+ # Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES
3
+ #
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+
6
+ # Generate pyproject.toml from pyproject.toml.template by detecting
7
+ # CUDA version and substituting template placeholders.
8
+ #
9
+ # Usage:
10
+ # CUDA_PATH=/usr/local/cuda bash configure.sh
11
+
12
+ set -euo pipefail
13
+
14
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
15
+ TEMPLATE="${SCRIPT_DIR}/pyproject.toml.template"
16
+ OUTPUT="${SCRIPT_DIR}/pyproject.toml"
17
+
18
+ # -------------------------------------------------------------------
19
+ # 1. Validate template exists (early, before doing any real work)
20
+ # -------------------------------------------------------------------
21
+ if [[ ! -f "${TEMPLATE}" ]]; then
22
+ echo "ERROR: Template not found at ${TEMPLATE}" >&2
23
+ exit 1
24
+ fi
25
+
26
+ # -------------------------------------------------------------------
27
+ # 2. Validate CUDA_PATH
28
+ # -------------------------------------------------------------------
29
+ if [[ -z "${CUDA_PATH:-}" ]]; then
30
+ echo "ERROR: CUDA_PATH is not set. Please set it to your CUDA toolkit root." >&2
31
+ exit 1
32
+ fi
33
+
34
+ CUDA_H="${CUDA_PATH}/include/cuda.h"
35
+ if [[ ! -f "${CUDA_H}" ]]; then
36
+ echo "ERROR: Cannot find ${CUDA_H}. Is CUDA_PATH set correctly?" >&2
37
+ exit 1
38
+ fi
39
+
40
+ # -------------------------------------------------------------------
41
+ # 3. Parse CUDA_VERSION from cuda.h (mirrors setup.py logic)
42
+ # Example line: #define CUDA_VERSION 12020
43
+ # 12020 => major = 12020 / 1000 = 12
44
+ # -------------------------------------------------------------------
45
+ # Read grep output into a variable to avoid a grep|head|sed pipeline,
46
+ # which can behave inconsistently with pipefail across bash versions
47
+ # (head closes the pipe early, causing grep to receive SIGPIPE).
48
+ CUDA_H_MATCH=$(grep -E -m1 '^\s*#define\s+CUDA_VERSION\s+[0-9]+' "${CUDA_H}" || true)
49
+
50
+ if [[ -z "${CUDA_H_MATCH}" ]]; then
51
+ echo "ERROR: Could not find CUDA_VERSION definition in ${CUDA_H}" >&2
52
+ exit 1
53
+ fi
54
+
55
+ CUDA_VERSION_RAW=$(echo "${CUDA_H_MATCH}" | sed -E 's/.*#define\s+CUDA_VERSION\s+([0-9]+).*/\1/')
56
+
57
+ if [[ -z "${CUDA_VERSION_RAW}" ]]; then
58
+ echo "ERROR: Could not parse CUDA_VERSION from ${CUDA_H}" >&2
59
+ exit 1
60
+ fi
61
+
62
+ if (( CUDA_VERSION_RAW < 1000 )); then
63
+ echo "ERROR: CUDA_VERSION ${CUDA_VERSION_RAW} from ${CUDA_H} is unexpectedly small (< 1000)" >&2
64
+ exit 1
65
+ fi
66
+
67
+ CUDA_MAJOR=$(( CUDA_VERSION_RAW / 1000 ))
68
+
69
+ echo "Detected CUDA_VERSION=${CUDA_VERSION_RAW} (major=${CUDA_MAJOR}) from ${CUDA_H}"
70
+
71
+ # -------------------------------------------------------------------
72
+ # 4. Map CUDA major version to template variable values
73
+ # -------------------------------------------------------------------
74
+ case "${CUDA_MAJOR}" in
75
+ 12)
76
+ JAX_VERSION_SPEC=">=0.5,<0.7"
77
+ CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 12"
78
+ ;;
79
+ 13)
80
+ JAX_VERSION_SPEC=">=0.8,<0.9"
81
+ CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 13"
82
+ ;;
83
+ *)
84
+ echo "ERROR: Unsupported CUDA major version: ${CUDA_MAJOR}" >&2
85
+ exit 1
86
+ ;;
87
+ esac
88
+
89
+ # -------------------------------------------------------------------
90
+ # 5. Escape sed replacement-special characters (& and \) in values
91
+ # -------------------------------------------------------------------
92
+ sed_escape() {
93
+ printf '%s' "$1" | sed -e 's/[&\\/]/\\&/g'
94
+ }
95
+
96
+ CUDA_MAJOR_ESC=$(sed_escape "${CUDA_MAJOR}")
97
+ JAX_VERSION_SPEC_ESC=$(sed_escape "${JAX_VERSION_SPEC}")
98
+ CUDA_CLASSIFIER_ESC=$(sed_escape "${CUDA_CLASSIFIER}")
99
+
100
+ # -------------------------------------------------------------------
101
+ # 6. Generate pyproject.toml via sed substitution
102
+ # Write to a temp file then atomically move into place so a partial
103
+ # write (e.g. Ctrl+C, disk full) never leaves a broken pyproject.toml.
104
+ # -------------------------------------------------------------------
105
+ TMPFILE=$(mktemp "${OUTPUT}.XXXXXX")
106
+ trap 'rm -f "${TMPFILE}"' EXIT
107
+
108
+ sed -e "s|@CUDA_MAJOR_VER@|${CUDA_MAJOR_ESC}|g" \
109
+ -e "s|@JAX_VERSION_SPEC@|${JAX_VERSION_SPEC_ESC}|g" \
110
+ -e "s|@CUDA_CLASSIFIER@|${CUDA_CLASSIFIER_ESC}|g" \
111
+ -- "${TEMPLATE}" > "${TMPFILE}"
112
+
113
+ # Skip overwrite if output already exists and is identical (idempotency).
114
+ if [[ -f "${OUTPUT}" ]] && cmp -s "${TMPFILE}" "${OUTPUT}"; then
115
+ echo "pyproject.toml is already up to date for CUDA ${CUDA_MAJOR}"
116
+ exit 0
117
+ fi
118
+
119
+ mv -f "${TMPFILE}" "${OUTPUT}"
120
+ trap - EXIT
121
+
122
+ echo "Generated ${OUTPUT} for CUDA ${CUDA_MAJOR}"
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+
5
+ from cuquantum.bindings._internal import cudensitymat as _cudm
6
+ _cudm._inspect_function_pointers() # for loading libcudensitymat.so
7
+
8
+ import jax
9
+ if not jax.config.jax_enable_x64:
10
+ raise RuntimeError(f"jax_enable_x64 must be set to True to use cuQuantum Python JAX")
11
+
12
+ from .operator_action import operator_action
13
+ from .pysrc import (
14
+ ElementaryOperator,
15
+ MatrixOperator,
16
+ OperatorTerm,
17
+ Operator
18
+ )
@@ -0,0 +1,68 @@
1
+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+
5
+ cmake_minimum_required(VERSION 3.22)
6
+ project(cudensitymat_jax LANGUAGES CXX)
7
+
8
+ set(CMAKE_CXX_STANDARD 17)
9
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
10
+
11
+ find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
12
+ message(STATUS "Python executable: ${Python3_EXECUTABLE}")
13
+
14
+ find_package(CUDAToolkit REQUIRED)
15
+ message(STATUS "CUDA toolkit directory: ${CUDAToolkit_INCLUDE_DIRS}")
16
+
17
+ # Find XLA directory
18
+ execute_process(
19
+ COMMAND ${Python3_EXECUTABLE} -c "import jax; print(jax.ffi.include_dir())"
20
+ OUTPUT_STRIP_TRAILING_WHITESPACE
21
+ OUTPUT_VARIABLE XLA_DIR
22
+ )
23
+ if(NOT XLA_DIR)
24
+ message(FATAL_ERROR "XLA directory not found")
25
+ else()
26
+ message(STATUS "XLA directory: ${XLA_DIR}")
27
+ endif()
28
+
29
+ # Find pybind11 directory
30
+ execute_process(
31
+ COMMAND ${Python3_EXECUTABLE} -c "import pybind11; print(pybind11.get_include())"
32
+ OUTPUT_STRIP_TRAILING_WHITESPACE
33
+ OUTPUT_VARIABLE pybind11_INCLUDE_DIR
34
+ )
35
+ if(NOT pybind11_INCLUDE_DIR)
36
+ message(FATAL_ERROR "Pybind11 include directory not found")
37
+ else()
38
+ message(STATUS "Pybind11 include directory: ${pybind11_INCLUDE_DIR}")
39
+ endif()
40
+
41
+ set(pybind11_DIR ${pybind11_INCLUDE_DIR}/../share/cmake/pybind11)
42
+ find_package(pybind11 REQUIRED)
43
+
44
+ pybind11_add_module(
45
+ ${PROJECT_NAME}
46
+ cudensitymat_jax.cpp
47
+ pybind.cpp
48
+ )
49
+ target_include_directories(
50
+ ${PROJECT_NAME}
51
+ PUBLIC
52
+ ${CUDAToolkit_INCLUDE_DIRS}
53
+ ${XLA_DIR}
54
+ ${pybind11_INCLUDE_DIR}
55
+ ${CMAKE_CURRENT_SOURCE_DIR}
56
+ )
57
+
58
+ set_target_properties(
59
+ ${PROJECT_NAME}
60
+ PROPERTIES
61
+ BUILD_RPATH "$ORIGIN"
62
+ )
63
+
64
+ target_link_libraries(
65
+ ${PROJECT_NAME}
66
+ PRIVATE
67
+ CUDA::cudart_static
68
+ )