cuquantum-python-jax-cu12 0.0.5.post0__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cuquantum_python_jax_cu12-0.0.6/MANIFEST.in +44 -0
  2. {cuquantum_python_jax_cu12-0.0.5.post0/cuquantum_python_jax_cu12.egg-info → cuquantum_python_jax_cu12-0.0.6}/PKG-INFO +7 -4
  3. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/README.md +3 -1
  4. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/configure.sh +2 -2
  5. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/__init__.py +10 -1
  6. cuquantum_python_jax_cu12-0.0.6/cuquantum/densitymat/jax/_build_info.py +19 -0
  7. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/cudensitymat.h +993 -97
  8. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp +15 -19
  9. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h +1 -8
  10. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/pybind.cpp +1 -7
  11. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/utils.h +1 -1
  12. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/operator_action.py +31 -54
  13. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/context.py +39 -18
  14. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/elementary_operator.py +24 -10
  15. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/matrix_operator.py +12 -6
  16. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/operator.py +36 -24
  17. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/operator_action_prim.py +80 -63
  18. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/operator_term.py +46 -32
  19. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/utils.py +75 -72
  20. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/__init__.py +35 -0
  21. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/_build_info.py +19 -0
  22. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/cppsrc/CMakeLists.txt +73 -0
  23. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/cppsrc/custabilizer.h +605 -0
  24. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/cppsrc/custabilizer_jax.cpp +96 -0
  25. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/cppsrc/custabilizer_jax.h +13 -0
  26. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/cppsrc/pybind.cpp +27 -0
  27. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/pysrc/__init__.py +3 -0
  28. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/pysrc/_backend.py +36 -0
  29. cuquantum_python_jax_cu12-0.0.6/cuquantum/stabilizer/jax/pysrc/_ffi.py +55 -0
  30. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6/cuquantum_python_jax_cu12.egg-info}/PKG-INFO +7 -4
  31. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum_python_jax_cu12.egg-info/SOURCES.txt +11 -0
  32. cuquantum_python_jax_cu12-0.0.6/cuquantum_python_jax_cu12.egg-info/requires.txt +2 -0
  33. cuquantum_python_jax_cu12-0.0.6/cuquantum_python_jax_cu12.egg-info/top_level.txt +3 -0
  34. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/pyproject.toml +6 -4
  35. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/pyproject.toml.template +4 -2
  36. cuquantum_python_jax_cu12-0.0.6/setup.py +300 -0
  37. cuquantum_python_jax_cu12-0.0.5.post0/MANIFEST.in +0 -9
  38. cuquantum_python_jax_cu12-0.0.5.post0/cuquantum_python_jax_cu12.egg-info/requires.txt +0 -2
  39. cuquantum_python_jax_cu12-0.0.5.post0/cuquantum_python_jax_cu12.egg-info/top_level.txt +0 -2
  40. cuquantum_python_jax_cu12-0.0.5.post0/setup.py +0 -117
  41. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/LICENSE +0 -0
  42. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/NV.LICENSE +0 -0
  43. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/cppsrc/CMakeLists.txt +0 -0
  44. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/__init__.py +0 -0
  45. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum/densitymat/jax/pysrc/base.py +0 -0
  46. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum_python_jax_cu12.egg-info/dependency_links.txt +0 -0
  47. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/cuquantum_python_jax_cu12.egg-info/not-zip-safe +0 -0
  48. {cuquantum_python_jax_cu12-0.0.5.post0 → cuquantum_python_jax_cu12-0.0.6}/setup.cfg +0 -0
@@ -0,0 +1,44 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without
6
+ # modification, are permitted provided that the following conditions are met:
7
+ #
8
+ # 1. Redistributions of source code must retain the above copyright notice, this
9
+ # list of conditions and the following disclaimer.
10
+ #
11
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ # this list of conditions and the following disclaimer in the documentation
13
+ # and/or other materials provided with the distribution.
14
+ #
15
+ # 3. Neither the name of the copyright holder nor the names of its
16
+ # contributors may be used to endorse or promote products derived from
17
+ # this software without specific prior written permission.
18
+ #
19
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ #
30
+
31
+ include pyproject.toml.template
32
+ include configure.sh
33
+ include cuquantum/densitymat/jax/cppsrc/CMakeLists.txt
34
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat.h
35
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.cpp
36
+ include cuquantum/densitymat/jax/cppsrc/cudensitymat_jax.h
37
+ include cuquantum/densitymat/jax/cppsrc/pybind.cpp
38
+ include cuquantum/densitymat/jax/cppsrc/utils.h
39
+ include cuquantum/stabilizer/jax/cppsrc/CMakeLists.txt
40
+ include cuquantum/stabilizer/jax/cppsrc/custabilizer.h
41
+ include cuquantum/stabilizer/jax/cppsrc/custabilizer_jax.cpp
42
+ include cuquantum/stabilizer/jax/cppsrc/custabilizer_jax.h
43
+ include cuquantum/stabilizer/jax/cppsrc/pybind.cpp
44
+ prune tests*
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cuquantum-python-jax-cu12
3
- Version: 0.0.5.post0
3
+ Version: 0.0.6
4
4
  Summary: NVIDIA cuQuantum Python JAX
5
5
  Author-email: NVIDIA Corporation <cuquantum-python@nvidia.com>
6
6
  License-Expression: BSD-3-Clause
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3 :: Only
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3.14
16
17
  Classifier: Programming Language :: Python :: Implementation :: CPython
17
18
  Classifier: Environment :: GPU :: NVIDIA CUDA
18
19
  Classifier: Environment :: GPU :: NVIDIA CUDA :: 12
@@ -20,13 +21,13 @@ Requires-Python: >=3.11.0
20
21
  Description-Content-Type: text/markdown
21
22
  License-File: LICENSE
22
23
  License-File: NV.LICENSE
23
- Requires-Dist: cuquantum-python-cu12~=26.3.0
24
- Requires-Dist: jax[cuda12-local]<0.7,>=0.5
24
+ Requires-Dist: cuquantum-python-cu12~=26.6.0
25
+ Requires-Dist: jax[cuda12-local]>=0.8
25
26
  Dynamic: license-file
26
27
 
27
28
  # cuQuantum Python JAX
28
29
 
29
- cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
30
+ cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library and the GF(2) sparse-dense matrix multiply from the cuStabilizer library.
30
31
 
31
32
  ## Documentation
32
33
 
@@ -38,6 +39,7 @@ Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com
38
39
 
39
40
  The build-time dependencies of the cuQuantum Python JAX package include:
40
41
 
42
+ * CUDA Toolkit 12.x or 13.x
41
43
  * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
42
44
  * pybind11
43
45
  * wheel
@@ -100,6 +102,7 @@ The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will
100
102
  Runtime dependencies of the cuQuantum Python JAX package include:
101
103
 
102
104
  * An NVIDIA GPU with compute capability 7.5+
105
+ * CUDA Toolkit 12.x or 13.x
103
106
  * cuquantum-python-cu12~=26.3.0 for CUDA 12 or cuquantum-python-cu13~=26.3.0 for CUDA 13
104
107
  * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
105
108
 
@@ -1,6 +1,6 @@
1
1
  # cuQuantum Python JAX
2
2
 
3
- cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library.
3
+ cuQuantum Python JAX provides a JAX extension for cuQuantum Python. It exposes selected functionality of cuQuantum SDK in a JAX-compatible way that enables JAX frameworks to directly interface with the exposed cuQuantum API. In the current release, cuQuantum JAX exposes a JAX interface to the Operator Action API from the cuDensityMat library and the GF(2) sparse-dense matrix multiply from the cuStabilizer library.
4
4
 
5
5
  ## Documentation
6
6
 
@@ -12,6 +12,7 @@ Please visit the [NVIDIA cuQuantum Python documentation](https://docs.nvidia.com
12
12
 
13
13
  The build-time dependencies of the cuQuantum Python JAX package include:
14
14
 
15
+ * CUDA Toolkit 12.x or 13.x
15
16
  * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
16
17
  * pybind11
17
18
  * wheel
@@ -74,6 +75,7 @@ The CUDA version is detected automatically from `$CUDA_PATH` and the wheel will
74
75
  Runtime dependencies of the cuQuantum Python JAX package include:
75
76
 
76
77
  * An NVIDIA GPU with compute capability 7.5+
78
+ * CUDA Toolkit 12.x or 13.x
77
79
  * cuquantum-python-cu12~=26.3.0 for CUDA 12 or cuquantum-python-cu13~=26.3.0 for CUDA 13
78
80
  * jax[cuda12-local]>=0.5,<0.7 for CUDA 12 or jax[cuda13-local]>=0.8,<0.9 for CUDA 13
79
81
 
@@ -73,11 +73,11 @@ echo "Detected CUDA_VERSION=${CUDA_VERSION_RAW} (major=${CUDA_MAJOR}) from ${CUD
73
73
  # -------------------------------------------------------------------
74
74
  case "${CUDA_MAJOR}" in
75
75
  12)
76
- JAX_VERSION_SPEC=">=0.5,<0.7"
76
+ JAX_VERSION_SPEC=">=0.8"
77
77
  CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 12"
78
78
  ;;
79
79
  13)
80
- JAX_VERSION_SPEC=">=0.8,<0.9"
80
+ JAX_VERSION_SPEC=">=0.8"
81
81
  CUDA_CLASSIFIER="Environment :: GPU :: NVIDIA CUDA :: 13"
82
82
  ;;
83
83
  *)
@@ -6,8 +6,17 @@ from cuquantum.bindings._internal import cudensitymat as _cudm
6
6
  _cudm._inspect_function_pointers() # for loading libcudensitymat.so
7
7
 
8
8
  import jax
9
+
10
+ try:
11
+ from ._build_info import check_jax_abi as _check_jax_abi
12
+ except ImportError:
13
+ pass
14
+ else:
15
+ _check_jax_abi()
16
+ del _check_jax_abi
17
+
9
18
  if not jax.config.jax_enable_x64:
10
- raise RuntimeError(f"jax_enable_x64 must be set to True to use cuQuantum Python JAX")
19
+ raise RuntimeError("jax_enable_x64 must be set to True to use cuQuantum Python JAX")
11
20
 
12
21
  from .operator_action import operator_action
13
22
  from .pysrc import (
@@ -0,0 +1,19 @@
1
+ # Auto-generated by setup.py at build time; do not edit.
2
+ JAX_BUILD_VERSION = ""
3
+ PROJECT_NAME = "cuquantum-python-jax-cu12"
4
+
5
+ def check_jax_abi() -> None:
6
+ """Warn whenever the runtime jax version differs from the
7
+ build-time jax version. JAX uses effort-based versioning, so
8
+ any difference -- patch bumps included -- may shift the FFI
9
+ capsule layout these bindings were compiled against."""
10
+ if not JAX_BUILD_VERSION:
11
+ return
12
+ import jax, warnings
13
+ if jax.__version__ != JAX_BUILD_VERSION:
14
+ warnings.warn(
15
+ f"JAX version mismatch: {PROJECT_NAME} was built using "
16
+ f"JAX {JAX_BUILD_VERSION}, current runtime is JAX "
17
+ f"{jax.__version__}. Rebuild against the runtime jax if "
18
+ f"you see issues.",
19
+ RuntimeWarning, stacklevel=3)