cuequivariance-ops-jax-cu13 0.7.0__py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops_jax/VERSION +1 -0
- cuequivariance_ops_jax/__init__.py +40 -0
- cuequivariance_ops_jax/_common.py +80 -0
- cuequivariance_ops_jax/_cublas_enums.py +32 -0
- cuequivariance_ops_jax/_gpu_utilities.py +157 -0
- cuequivariance_ops_jax/_indexed_linear.py +148 -0
- cuequivariance_ops_jax/_tensor_product_uniform_1d_jit.py +210 -0
- cuequivariance_ops_jax/_triangle_attention.py +205 -0
- cuequivariance_ops_jax/_version.py +20 -0
- cuequivariance_ops_jax/lib/libcue_ops_jax.so +0 -0
- cuequivariance_ops_jax_cu13-0.7.0.dist-info/METADATA +229 -0
- cuequivariance_ops_jax_cu13-0.7.0.dist-info/RECORD +15 -0
- cuequivariance_ops_jax_cu13-0.7.0.dist-info/WHEEL +6 -0
- cuequivariance_ops_jax_cu13-0.7.0.dist-info/licenses/LICENSE +173 -0
- cuequivariance_ops_jax_cu13-0.7.0.dist-info/licenses/Third_party_attr.txt +409 -0
@@ -0,0 +1 @@
|
|
1
|
+
0.7.0
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
|
11
|
+
from ._version import __version__, __git_commit__
|
12
|
+
from ._tensor_product_uniform_1d_jit import (
|
13
|
+
tensor_product_uniform_1d_jit,
|
14
|
+
Operation,
|
15
|
+
Path,
|
16
|
+
)
|
17
|
+
from ._indexed_linear import indexed_linear
|
18
|
+
from ._triangle_attention import (
|
19
|
+
triangle_attention_cuda_fwd,
|
20
|
+
triangle_attention_cuda_bwd,
|
21
|
+
triangle_attention_jax_fwd,
|
22
|
+
)
|
23
|
+
from ._gpu_utilities import noop, sleep, synchronize, event_record, event_elapsed
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"__version__",
|
27
|
+
"__git_commit__",
|
28
|
+
"tensor_product_uniform_1d_jit",
|
29
|
+
"Operation",
|
30
|
+
"Path",
|
31
|
+
"indexed_linear",
|
32
|
+
"triangle_attention_cuda_fwd",
|
33
|
+
"triangle_attention_cuda_bwd",
|
34
|
+
"triangle_attention_jax_fwd",
|
35
|
+
"noop",
|
36
|
+
"sleep",
|
37
|
+
"synchronize",
|
38
|
+
"event_record",
|
39
|
+
"event_elapsed",
|
40
|
+
]
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
import ctypes
|
11
|
+
import importlib.metadata
|
12
|
+
import os
|
13
|
+
from enum import IntEnum
|
14
|
+
|
15
|
+
import jax.numpy as jnp
|
16
|
+
from jax import ffi
|
17
|
+
|
18
|
+
import cuequivariance_ops # noqa: F401
|
19
|
+
|
20
|
+
# Load libcue_ops_jax.so
|
21
|
+
try:
|
22
|
+
dist = importlib.metadata.distribution("cuequivariance_ops_jax")
|
23
|
+
root = dist.locate_file("cuequivariance_ops_jax")
|
24
|
+
except Exception:
|
25
|
+
# last resort, will fail with writeable install
|
26
|
+
root = os.path.dirname(__file__)
|
27
|
+
|
28
|
+
path = os.path.join(root, "lib/libcue_ops_jax.so")
|
29
|
+
library = ctypes.cdll.LoadLibrary(path)
|
30
|
+
|
31
|
+
# Register the c++ functions with JAX
|
32
|
+
CUSTOM_FUNCS = [
|
33
|
+
(
|
34
|
+
"tensor_product_uniform_1d_jit",
|
35
|
+
"tensor_product_uniform_1d_jit",
|
36
|
+
"tensor_product_uniform_1d_cpu",
|
37
|
+
),
|
38
|
+
("indexed_linear_B", "indexed_linear_B", None),
|
39
|
+
("indexed_linear_C", "indexed_linear_C", None),
|
40
|
+
("triangle_attention_cuda_fwd", "triangle_attention_cuda_fwd", None),
|
41
|
+
("triangle_attention_cuda_bwd", "triangle_attention_cuda_bwd", None),
|
42
|
+
("noop", "noop_gpu", "noop_cpu"),
|
43
|
+
("sleep", "sleep_gpu", "sleep_cpu"),
|
44
|
+
("synchronize", "synchronize_gpu", "synchronize_cpu"),
|
45
|
+
("event_record", "event_record_gpu", "event_record_cpu"),
|
46
|
+
("event_elapsed", "event_elapsed_gpu", "event_elapsed_cpu"),
|
47
|
+
]
|
48
|
+
|
49
|
+
for name, cuda_fn, cpu_fn in CUSTOM_FUNCS:
|
50
|
+
if cuda_fn is not None:
|
51
|
+
ffi.register_ffi_target(
|
52
|
+
name=name, fn=ffi.pycapsule(getattr(library, cuda_fn)), platform="CUDA"
|
53
|
+
)
|
54
|
+
if cpu_fn is not None:
|
55
|
+
ffi.register_ffi_target(
|
56
|
+
name=name, fn=ffi.pycapsule(getattr(library, cpu_fn)), platform="cpu"
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
class DataType(IntEnum):
|
61
|
+
FLOAT32 = 0
|
62
|
+
FLOAT64 = 1
|
63
|
+
FLOAT16 = 2
|
64
|
+
BFLOAT16 = 3
|
65
|
+
INT32 = 4
|
66
|
+
INT64 = 5
|
67
|
+
|
68
|
+
|
69
|
+
def _dtype(jax_dtype: jnp.dtype) -> DataType:
|
70
|
+
try:
|
71
|
+
return {
|
72
|
+
jnp.float32: DataType.FLOAT32,
|
73
|
+
jnp.float64: DataType.FLOAT64,
|
74
|
+
jnp.float16: DataType.FLOAT16,
|
75
|
+
jnp.bfloat16: DataType.BFLOAT16,
|
76
|
+
jnp.int32: DataType.INT32,
|
77
|
+
jnp.int64: DataType.INT64,
|
78
|
+
}[jnp.dtype(jax_dtype).type]
|
79
|
+
except KeyError:
|
80
|
+
raise ValueError(f"Unsupported dtype: {jax_dtype}")
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
|
11
|
+
# CUBLAS Compute Types
|
12
|
+
# Converted from cublas_api.h cublasComputeType_t enum
|
13
|
+
|
14
|
+
|
15
|
+
class CublasComputeTypes:
|
16
|
+
"""Class to provide getattr access to CUBLAS compute types."""
|
17
|
+
|
18
|
+
CUBLAS_COMPUTE_16F = 64
|
19
|
+
CUBLAS_COMPUTE_16F_PEDANTIC = 65
|
20
|
+
CUBLAS_COMPUTE_32F = 68
|
21
|
+
CUBLAS_COMPUTE_32F_PEDANTIC = 69
|
22
|
+
CUBLAS_COMPUTE_32F_FAST_16F = 74
|
23
|
+
CUBLAS_COMPUTE_32F_FAST_16BF = 75
|
24
|
+
CUBLAS_COMPUTE_32F_FAST_TF32 = 77
|
25
|
+
CUBLAS_COMPUTE_64F = 70
|
26
|
+
CUBLAS_COMPUTE_64F_PEDANTIC = 71
|
27
|
+
CUBLAS_COMPUTE_32I = 72
|
28
|
+
CUBLAS_COMPUTE_32I_PEDANTIC = 73
|
29
|
+
|
30
|
+
|
31
|
+
# Create instance for getattr access
|
32
|
+
cublas_compute_types = CublasComputeTypes()
|
@@ -0,0 +1,157 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
"""GPU utility functions including no-op, sleep, synchronize, and GPU events.
|
11
|
+
|
12
|
+
GPU Events Implementation Attribution:
|
13
|
+
The GPU events functionality (event_record, event_elapsed) has been adapted from
|
14
|
+
JAX's GPU events implementation that was removed in version 0.7.2.
|
15
|
+
|
16
|
+
Original source: https://github.com/jax-ml/jax/
|
17
|
+
License: Apache License 2.0
|
18
|
+
|
19
|
+
JAX Copyright 2018 The JAX Authors.
|
20
|
+
Licensed under the Apache License, Version 2.0.
|
21
|
+
"""
|
22
|
+
|
23
|
+
from typing import Any
|
24
|
+
|
25
|
+
import jax
|
26
|
+
import jax.numpy as jnp
|
27
|
+
from jax import ffi
|
28
|
+
|
29
|
+
|
30
|
+
def _flatten(pytree: Any):
|
31
|
+
"""Helper to apply FFI function to JAX arrays in a pytree."""
|
32
|
+
leaves, treedef = jax.tree.flatten(pytree)
|
33
|
+
arrays = [(i, leaf) for i, leaf in enumerate(leaves) if isinstance(leaf, jax.Array)]
|
34
|
+
_, values = zip(*arrays)
|
35
|
+
|
36
|
+
def unflatten(outputs):
|
37
|
+
for idx, (leaf_idx, _) in enumerate(arrays):
|
38
|
+
leaves[leaf_idx] = outputs[idx]
|
39
|
+
return jax.tree.unflatten(treedef, leaves)
|
40
|
+
|
41
|
+
return values, unflatten
|
42
|
+
|
43
|
+
|
44
|
+
def noop(pytree: Any) -> Any:
|
45
|
+
"""
|
46
|
+
No-op function that returns input pytree unchanged through FFI.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
pytree: Any pytree structure containing JAX arrays
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The same pytree structure with arrays passed through FFI
|
53
|
+
"""
|
54
|
+
vals, unflatten = _flatten(pytree)
|
55
|
+
|
56
|
+
vals = ffi.ffi_call(
|
57
|
+
"noop",
|
58
|
+
[jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
|
59
|
+
input_output_aliases={i: i for i in range(len(vals))},
|
60
|
+
)(*vals)
|
61
|
+
|
62
|
+
return unflatten(vals)
|
63
|
+
|
64
|
+
|
65
|
+
def sleep(seconds: jax.Array, pytree: Any) -> tuple[jax.Array, Any]:
|
66
|
+
"""
|
67
|
+
Sleep for the specified number of seconds and return input pytree unchanged.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
seconds: Number of seconds to sleep (as a JAX array)
|
71
|
+
pytree: Any pytree structure containing JAX arrays
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
A tuple of (elapsed_ticks, pytree) where elapsed_ticks is the number of
|
75
|
+
clock ticks that elapsed during the sleep operation
|
76
|
+
"""
|
77
|
+
seconds = jnp.asarray(seconds, dtype=jnp.float32)
|
78
|
+
vals, unflatten = _flatten(pytree)
|
79
|
+
|
80
|
+
outputs = ffi.ffi_call(
|
81
|
+
"sleep",
|
82
|
+
[jax.ShapeDtypeStruct((), jnp.int64)]
|
83
|
+
+ [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
|
84
|
+
input_output_aliases={i: i for i in range(1, len(vals) + 1)},
|
85
|
+
)(seconds, *vals)
|
86
|
+
elapsed_ticks, vals = outputs[0], outputs[1:]
|
87
|
+
|
88
|
+
return elapsed_ticks, unflatten(vals)
|
89
|
+
|
90
|
+
|
91
|
+
def synchronize(pytree: Any) -> tuple[jax.Array, Any]:
|
92
|
+
"""
|
93
|
+
Synchronize the current CUDA stream and return input pytree unchanged.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
pytree: Any pytree structure containing JAX arrays
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
A tuple of (elapsed_seconds, pytree) where elapsed_seconds is the time
|
100
|
+
in seconds it took to synchronize the CUDA stream
|
101
|
+
"""
|
102
|
+
vals, unflatten = _flatten(pytree)
|
103
|
+
|
104
|
+
outputs = ffi.ffi_call(
|
105
|
+
"synchronize",
|
106
|
+
[jax.ShapeDtypeStruct((), jnp.float32)]
|
107
|
+
+ [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
|
108
|
+
input_output_aliases={i: i + 1 for i in range(len(vals))},
|
109
|
+
)(*vals)
|
110
|
+
elapsed_seconds, vals = outputs[0], outputs[1:]
|
111
|
+
|
112
|
+
return elapsed_seconds, unflatten(vals)
|
113
|
+
|
114
|
+
|
115
|
+
def event_record(pytree: Any, *, copy_before: bool = False) -> tuple[jax.Array, Any]:
|
116
|
+
"""
|
117
|
+
Record a GPU event on the current CUDA stream and return the event handle.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
pytree: Any pytree structure containing JAX arrays
|
121
|
+
copy_before: If True, copy event handle to device before recording.
|
122
|
+
If False, copy after recording (default).
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
A tuple of (event_handle, pytree) where event_handle is a uint64
|
126
|
+
representing the CUDA event, and pytree is passed through unchanged.
|
127
|
+
"""
|
128
|
+
vals, unflatten = _flatten(pytree)
|
129
|
+
|
130
|
+
outputs = ffi.ffi_call(
|
131
|
+
"event_record",
|
132
|
+
[jax.ShapeDtypeStruct((), jnp.uint64)] # event_handle
|
133
|
+
+ [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
|
134
|
+
input_output_aliases={i: i + 1 for i in range(len(vals))},
|
135
|
+
)(*vals, copy_before=copy_before)
|
136
|
+
event_handle, vals = outputs[0], outputs[1:]
|
137
|
+
|
138
|
+
return event_handle, unflatten(vals)
|
139
|
+
|
140
|
+
|
141
|
+
def event_elapsed(start_event: jax.Array, end_event: jax.Array) -> jax.Array:
|
142
|
+
"""
|
143
|
+
Calculate elapsed time between two GPU events.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
start_event: uint64 event handle from event_record()
|
147
|
+
end_event: uint64 event handle from event_record()
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
Elapsed time in milliseconds as a float32 scalar.
|
151
|
+
"""
|
152
|
+
elapsed_ms = ffi.ffi_call(
|
153
|
+
"event_elapsed",
|
154
|
+
jax.ShapeDtypeStruct((), jnp.float32),
|
155
|
+
)(start_event, end_event)
|
156
|
+
|
157
|
+
return elapsed_ms
|
@@ -0,0 +1,148 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
from __future__ import annotations
|
11
|
+
|
12
|
+
import jax
|
13
|
+
import jax.numpy as jnp
|
14
|
+
from jax import ffi
|
15
|
+
|
16
|
+
from ._common import _dtype
|
17
|
+
from ._cublas_enums import cublas_compute_types
|
18
|
+
|
19
|
+
|
20
|
+
def indexed_linear(
|
21
|
+
A: jax.Array,
|
22
|
+
B: jax.Array,
|
23
|
+
D: jax.ShapeDtypeStruct,
|
24
|
+
counts: jax.Array,
|
25
|
+
u: int,
|
26
|
+
v: int,
|
27
|
+
C: int,
|
28
|
+
Z: int,
|
29
|
+
subscripts: tuple[str, str, str],
|
30
|
+
coefficient: float,
|
31
|
+
math_dtype: str | None = None,
|
32
|
+
) -> jax.Array:
|
33
|
+
"""
|
34
|
+
Performance benchmarks for compatible dtype + math_dtype combinations
|
35
|
+
(Problem size: u=512, v=512, Z=10,000):
|
36
|
+
|
37
|
+
| Rank | Dtype + Math Dtype | Time (ms) | Speedup vs float64 |
|
38
|
+
|------|---------------------------------------|-----------|-------------------|
|
39
|
+
| 🥇 | bfloat16 + CUBLAS_COMPUTE_32F | 0.0891 | 83.9x |
|
40
|
+
| 🥈 | float16 + CUBLAS_COMPUTE_32F | 0.0911 | 82.1x |
|
41
|
+
| 🥉 | float32 + CUBLAS_COMPUTE_32F_FAST_TF32| 0.1403 | 53.3x |
|
42
|
+
| 4th | float32 + CUBLAS_COMPUTE_32F_PEDANTIC | 0.1987 | 37.6x |
|
43
|
+
| 5th | float32 + CUBLAS_COMPUTE_32F | 0.2038 | 36.7x |
|
44
|
+
| 6th | float64 + CUBLAS_COMPUTE_64F | 7.4742 | 1.0x (baseline) |
|
45
|
+
"""
|
46
|
+
subscripts = tuple(subscripts)
|
47
|
+
original_subscripts = subscripts
|
48
|
+
assert len(subscripts) == 3
|
49
|
+
swap_u_v = False
|
50
|
+
swap_A_B = False
|
51
|
+
|
52
|
+
dtype = jnp.dtype(A.dtype)
|
53
|
+
assert dtype == B.dtype
|
54
|
+
assert dtype == D.dtype
|
55
|
+
|
56
|
+
if math_dtype is None:
|
57
|
+
dtype_to_compute_type = {
|
58
|
+
jnp.bfloat16: cublas_compute_types.CUBLAS_COMPUTE_32F, # Use 32F compute type for bf16
|
59
|
+
jnp.float16: cublas_compute_types.CUBLAS_COMPUTE_32F, # Use 32F compute type for fp16
|
60
|
+
jnp.float32: cublas_compute_types.CUBLAS_COMPUTE_32F,
|
61
|
+
jnp.float64: cublas_compute_types.CUBLAS_COMPUTE_64F,
|
62
|
+
}
|
63
|
+
if dtype.type not in dtype_to_compute_type:
|
64
|
+
raise ValueError(
|
65
|
+
f"For dtype '{dtype}', please specify math_dtype manually and check CUBLAS documentation "
|
66
|
+
f"for compatible compute types."
|
67
|
+
)
|
68
|
+
compute_type = dtype_to_compute_type[dtype.type]
|
69
|
+
else:
|
70
|
+
if not hasattr(cublas_compute_types, math_dtype):
|
71
|
+
supported_types = [
|
72
|
+
attr
|
73
|
+
for attr in dir(cublas_compute_types)
|
74
|
+
if attr.startswith("CUBLAS_COMPUTE_")
|
75
|
+
]
|
76
|
+
raise ValueError(
|
77
|
+
f"Unsupported math_dtype '{math_dtype}'. "
|
78
|
+
f"The supported compute types are: {supported_types}. "
|
79
|
+
f"Be careful as they are not compatible with all I/O dtype combinations. "
|
80
|
+
f"Have a look at the CUBLAS documentation for compatibility details."
|
81
|
+
)
|
82
|
+
compute_type = getattr(cublas_compute_types, math_dtype)
|
83
|
+
|
84
|
+
if subscripts in [("u", "v", "vu"), ("uv", "v", "u"), ("vu", "v", "u")]:
|
85
|
+
swap_A_B = True
|
86
|
+
swap_u_v = True
|
87
|
+
if subscripts in [("v", "uv", "u"), ("v", "vu", "u"), ("v", "u", "vu")]:
|
88
|
+
swap_u_v = True
|
89
|
+
if subscripts in [("v", "u", "uv"), ("uv", "u", "v"), ("vu", "u", "v")]:
|
90
|
+
swap_A_B = True
|
91
|
+
|
92
|
+
if swap_u_v:
|
93
|
+
subscripts = tuple(
|
94
|
+
x.replace("u", "q").replace("v", "u").replace("q", "v") for x in subscripts
|
95
|
+
)
|
96
|
+
u, v = v, u
|
97
|
+
|
98
|
+
if swap_A_B:
|
99
|
+
subscripts = (subscripts[1], subscripts[0], subscripts[2])
|
100
|
+
A, B = B, A
|
101
|
+
|
102
|
+
temp_storage_bytes_cub_ExclusiveSum = 1024 # TODO this seems to be sufficient but we never know if it's enough for all use cases and GPUs
|
103
|
+
workspace_size = (
|
104
|
+
counts.size * (3 + 1) * jnp.dtype(jnp.int64).itemsize
|
105
|
+
+ temp_storage_bytes_cub_ExclusiveSum
|
106
|
+
)
|
107
|
+
workspace = jnp.empty((workspace_size,), dtype=jnp.int8)
|
108
|
+
|
109
|
+
if subscripts == ("u", "v", "uv"):
|
110
|
+
(D, _) = ffi.ffi_call("indexed_linear_C", (D, workspace))(
|
111
|
+
A,
|
112
|
+
B,
|
113
|
+
counts,
|
114
|
+
compute_type=compute_type,
|
115
|
+
u=u,
|
116
|
+
v=v,
|
117
|
+
C=C,
|
118
|
+
Z=Z,
|
119
|
+
coefficient=coefficient,
|
120
|
+
dtype_A=_dtype(A.dtype),
|
121
|
+
dtype_B=_dtype(B.dtype),
|
122
|
+
dtype_D=_dtype(D.dtype),
|
123
|
+
)
|
124
|
+
return D
|
125
|
+
|
126
|
+
if subscripts == ("u", "uv", "v"):
|
127
|
+
transpose_B = False
|
128
|
+
elif subscripts == ("u", "vu", "v"):
|
129
|
+
transpose_B = True
|
130
|
+
else:
|
131
|
+
raise ValueError(f"Invalid subscripts: {original_subscripts}.")
|
132
|
+
|
133
|
+
(D, _) = ffi.ffi_call("indexed_linear_B", (D, workspace))(
|
134
|
+
A,
|
135
|
+
B,
|
136
|
+
counts,
|
137
|
+
compute_type=compute_type,
|
138
|
+
u=u,
|
139
|
+
v=v,
|
140
|
+
C=C,
|
141
|
+
Z=Z,
|
142
|
+
transpose_B=transpose_B,
|
143
|
+
coefficient=coefficient,
|
144
|
+
dtype_A=_dtype(A.dtype),
|
145
|
+
dtype_B=_dtype(B.dtype),
|
146
|
+
dtype_D=_dtype(D.dtype),
|
147
|
+
)
|
148
|
+
return D
|
@@ -0,0 +1,210 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
3
|
+
#
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5
|
+
# property and proprietary rights in and to this material, related
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
7
|
+
# disclosure or distribution of this material and related documentation
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
9
|
+
# its affiliates is strictly prohibited.
|
10
|
+
from __future__ import annotations
|
11
|
+
|
12
|
+
from enum import IntEnum
|
13
|
+
from typing import NamedTuple
|
14
|
+
|
15
|
+
import jax
|
16
|
+
import jax.numpy as jnp
|
17
|
+
import numpy as np
|
18
|
+
from jax import ffi
|
19
|
+
|
20
|
+
from ._common import _dtype
|
21
|
+
|
22
|
+
|
23
|
+
class BatchDimension(IntEnum):
|
24
|
+
BATCHED = 0
|
25
|
+
SHARED = 1
|
26
|
+
INDEXED = 2
|
27
|
+
|
28
|
+
|
29
|
+
class SegmentDimension(IntEnum):
|
30
|
+
SCALAR = 0
|
31
|
+
VECTOR = 1
|
32
|
+
|
33
|
+
|
34
|
+
def _batch_dim(idx: int, size: int, batch_size: int) -> BatchDimension:
|
35
|
+
if idx < 0:
|
36
|
+
if size == batch_size:
|
37
|
+
return BatchDimension.BATCHED
|
38
|
+
else:
|
39
|
+
return BatchDimension.SHARED
|
40
|
+
else:
|
41
|
+
return BatchDimension.INDEXED
|
42
|
+
|
43
|
+
|
44
|
+
def _seg_dim(buffer: jax.ShapeDtypeStruct, operand_extent: int) -> SegmentDimension:
|
45
|
+
if buffer.shape[-1] == operand_extent:
|
46
|
+
return SegmentDimension.VECTOR
|
47
|
+
else:
|
48
|
+
return SegmentDimension.SCALAR
|
49
|
+
|
50
|
+
|
51
|
+
class Operation(NamedTuple):
|
52
|
+
buffers: list[int]
|
53
|
+
start_path: int
|
54
|
+
num_paths: int
|
55
|
+
|
56
|
+
|
57
|
+
class Path(NamedTuple):
|
58
|
+
indices: list[int]
|
59
|
+
coefficient: float
|
60
|
+
|
61
|
+
|
62
|
+
def _batch_size(sizes: list[int]) -> int:
|
63
|
+
batch_size = 1
|
64
|
+
for size in sizes:
|
65
|
+
if size != 1:
|
66
|
+
assert batch_size in {1, size}
|
67
|
+
batch_size = size
|
68
|
+
return batch_size
|
69
|
+
|
70
|
+
|
71
|
+
def _operand_extent(
|
72
|
+
buffers: list[jax.ShapeDtypeStruct],
|
73
|
+
):
|
74
|
+
operand_extent = max(x.shape[-1] for x in buffers)
|
75
|
+
for x in buffers:
|
76
|
+
assert x.shape[-1] in {1, operand_extent}, x.shape[-1]
|
77
|
+
return operand_extent
|
78
|
+
|
79
|
+
|
80
|
+
def _operation_start_indices(
|
81
|
+
paths: list[Path], operation_start_paths: np.ndarray
|
82
|
+
) -> np.ndarray:
|
83
|
+
path_num_operands = np.array([len(path.indices) for path in paths], dtype=np.int32)
|
84
|
+
start_indices = np.append(0, np.cumsum(path_num_operands))
|
85
|
+
return start_indices[operation_start_paths].astype(np.int64)
|
86
|
+
|
87
|
+
|
88
|
+
def tensor_product_uniform_1d_jit(
|
89
|
+
input_buffers: list[jax.Array], # ndim = num_batch_axes + 2
|
90
|
+
output_buffers_shape_dtype: list[jax.ShapeDtypeStruct], # ndim = num_batch_axes + 2
|
91
|
+
index_buffers: list[jax.Array], # ndim = num_batch_axes
|
92
|
+
buffer_index: list[list[int]], # -1 if not indexed
|
93
|
+
*,
|
94
|
+
operations: list[Operation],
|
95
|
+
paths: list[Path],
|
96
|
+
math_dtype: jnp.dtype,
|
97
|
+
name: str = "untitled",
|
98
|
+
) -> list[jax.Array]:
|
99
|
+
"""JIT-compiled CUDA implementation of tensor_product_uniform_1d."""
|
100
|
+
input_buffers = list(input_buffers)
|
101
|
+
output_buffers_shape_dtype = list(output_buffers_shape_dtype)
|
102
|
+
index_buffers = list(index_buffers)
|
103
|
+
buffer_index = np.array(buffer_index, dtype=np.int32)
|
104
|
+
operations, paths = list(operations), list(paths)
|
105
|
+
|
106
|
+
io_buffers = input_buffers + output_buffers_shape_dtype
|
107
|
+
buffers = io_buffers + index_buffers
|
108
|
+
assert buffer_index.shape[0] == len(buffers)
|
109
|
+
|
110
|
+
# trick: ensure all outputs are "used" by adding dummy operations for unused outputs
|
111
|
+
# this ensures that the kernel writes zeros to unused outputs (as expected by the XLA bindings)
|
112
|
+
for i in range(len(input_buffers), len(io_buffers)):
|
113
|
+
if not any(i in op.buffers for op in operations):
|
114
|
+
operations.append(Operation([i], 0, 0))
|
115
|
+
|
116
|
+
num_batch_axes = buffer_index.shape[1]
|
117
|
+
for x in io_buffers:
|
118
|
+
assert x.ndim == num_batch_axes + 2
|
119
|
+
for i in index_buffers:
|
120
|
+
assert i.ndim == num_batch_axes
|
121
|
+
assert i.dtype.type in {jnp.int32, jnp.int64}
|
122
|
+
|
123
|
+
batch_sizes = [
|
124
|
+
_batch_size(
|
125
|
+
[x.shape[i] for x, idx in zip(buffers, buffer_index[:, i]) if idx < 0],
|
126
|
+
)
|
127
|
+
for i in range(num_batch_axes)
|
128
|
+
]
|
129
|
+
|
130
|
+
buffer_batch_dim = np.array(
|
131
|
+
[
|
132
|
+
[
|
133
|
+
_batch_dim(idx, size, batch_size)
|
134
|
+
for idx, size, batch_size in zip(idxs, x.shape, batch_sizes)
|
135
|
+
]
|
136
|
+
for x, idxs in zip(buffers, buffer_index)
|
137
|
+
]
|
138
|
+
)
|
139
|
+
|
140
|
+
index_extent = [set() for _ in range(len(index_buffers))]
|
141
|
+
for x, idxs in zip(buffers, buffer_index):
|
142
|
+
for size, idx in zip(x.shape, idxs):
|
143
|
+
if idx >= 0:
|
144
|
+
index_extent[idx].add(size)
|
145
|
+
assert all(len(x) == 1 for x in index_extent), index_extent
|
146
|
+
index_extent = [x.pop() for x in index_extent]
|
147
|
+
|
148
|
+
operand_extent = _operand_extent(io_buffers)
|
149
|
+
|
150
|
+
math_dtype = jnp.dtype(math_dtype)
|
151
|
+
assert math_dtype.type in {jnp.float32, jnp.float64}
|
152
|
+
|
153
|
+
def ii(items):
|
154
|
+
return np.array([i for i in items], dtype=np.int64)
|
155
|
+
|
156
|
+
buffer_batch_dim = ii(buffer_batch_dim.flatten())
|
157
|
+
buffer_num_segments = ii(x.shape[-2] for x in io_buffers)
|
158
|
+
buffer_segments_dim = ii(_seg_dim(x, operand_extent) for x in io_buffers)
|
159
|
+
buffer_index = ii(buffer_index.flatten())
|
160
|
+
index_extent = ii(index_extent)
|
161
|
+
buffer_dtype = ii(_dtype(x.dtype) for x in io_buffers + index_buffers)
|
162
|
+
operation_num_operands = ii(len(op.buffers) for op in operations)
|
163
|
+
operation_buffers = ii(b for op in operations for b in op.buffers)
|
164
|
+
operation_num_paths = ii(op.num_paths for op in operations)
|
165
|
+
operation_start_coeffs = ii(op.start_path for op in operations)
|
166
|
+
operation_start_indices = _operation_start_indices(paths, operation_start_coeffs)
|
167
|
+
path_indices = ii(i for path in paths for i in path.indices)
|
168
|
+
path_coefficients = np.array([path.coefficient for path in paths], dtype=np.float64)
|
169
|
+
batch_sizes = ii(batch_sizes)
|
170
|
+
|
171
|
+
# print(f"{operand_extent=}")
|
172
|
+
# print(f"num_indices={len(index_buffers)}")
|
173
|
+
# print(f"{buffer_batch_dim=}")
|
174
|
+
# print(f"{buffer_num_segments=}")
|
175
|
+
# print(f"{buffer_segments_dim=}")
|
176
|
+
# print(f"{buffer_index=}")
|
177
|
+
# print(f"{index_extent=}")
|
178
|
+
# print(f"{buffer_dtype=}")
|
179
|
+
# print(f"{operation_num_operands=}")
|
180
|
+
# print(f"{operation_buffers=}")
|
181
|
+
# print(f"{operation_num_paths=}")
|
182
|
+
# print(f"{operation_start_indices=}")
|
183
|
+
# print(f"{operation_start_coeffs=}")
|
184
|
+
# print(f"{path_indices=}")
|
185
|
+
# print(f"{path_coefficients=}")
|
186
|
+
# print(f"{batch_sizes=}", flush=True)
|
187
|
+
|
188
|
+
call = ffi.ffi_call("tensor_product_uniform_1d_jit", output_buffers_shape_dtype)
|
189
|
+
return call(
|
190
|
+
*input_buffers,
|
191
|
+
*index_buffers,
|
192
|
+
name=name,
|
193
|
+
math_dtype=_dtype(math_dtype),
|
194
|
+
operand_extent=operand_extent,
|
195
|
+
num_indices=len(index_buffers),
|
196
|
+
buffer_batch_dim=buffer_batch_dim,
|
197
|
+
buffer_num_segments=buffer_num_segments,
|
198
|
+
buffer_segments_dim=buffer_segments_dim,
|
199
|
+
buffer_index=buffer_index,
|
200
|
+
index_extent=index_extent,
|
201
|
+
buffer_dtype=buffer_dtype,
|
202
|
+
operation_num_operands=operation_num_operands,
|
203
|
+
operation_buffers=operation_buffers,
|
204
|
+
operation_num_paths=operation_num_paths,
|
205
|
+
operation_start_indices=operation_start_indices,
|
206
|
+
operation_start_coeffs=operation_start_coeffs,
|
207
|
+
path_indices=path_indices,
|
208
|
+
path_coefficients=path_coefficients.view(np.int64),
|
209
|
+
batch_sizes=batch_sizes,
|
210
|
+
)
|