cuequivariance-ops-jax-cu13 0.7.0__py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ 0.7.0
@@ -0,0 +1,40 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ from ._version import __version__, __git_commit__
12
+ from ._tensor_product_uniform_1d_jit import (
13
+ tensor_product_uniform_1d_jit,
14
+ Operation,
15
+ Path,
16
+ )
17
+ from ._indexed_linear import indexed_linear
18
+ from ._triangle_attention import (
19
+ triangle_attention_cuda_fwd,
20
+ triangle_attention_cuda_bwd,
21
+ triangle_attention_jax_fwd,
22
+ )
23
+ from ._gpu_utilities import noop, sleep, synchronize, event_record, event_elapsed
24
+
25
+ __all__ = [
26
+ "__version__",
27
+ "__git_commit__",
28
+ "tensor_product_uniform_1d_jit",
29
+ "Operation",
30
+ "Path",
31
+ "indexed_linear",
32
+ "triangle_attention_cuda_fwd",
33
+ "triangle_attention_cuda_bwd",
34
+ "triangle_attention_jax_fwd",
35
+ "noop",
36
+ "sleep",
37
+ "synchronize",
38
+ "event_record",
39
+ "event_elapsed",
40
+ ]
@@ -0,0 +1,80 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+ import ctypes
11
+ import importlib.metadata
12
+ import os
13
+ from enum import IntEnum
14
+
15
+ import jax.numpy as jnp
16
+ from jax import ffi
17
+
18
+ import cuequivariance_ops # noqa: F401
19
+
20
+ # Load libcue_ops_jax.so
21
+ try:
22
+ dist = importlib.metadata.distribution("cuequivariance_ops_jax")
23
+ root = dist.locate_file("cuequivariance_ops_jax")
24
+ except Exception:
25
+ # last resort, will fail with writeable install
26
+ root = os.path.dirname(__file__)
27
+
28
+ path = os.path.join(root, "lib/libcue_ops_jax.so")
29
+ library = ctypes.cdll.LoadLibrary(path)
30
+
31
+ # Register the c++ functions with JAX
32
+ CUSTOM_FUNCS = [
33
+ (
34
+ "tensor_product_uniform_1d_jit",
35
+ "tensor_product_uniform_1d_jit",
36
+ "tensor_product_uniform_1d_cpu",
37
+ ),
38
+ ("indexed_linear_B", "indexed_linear_B", None),
39
+ ("indexed_linear_C", "indexed_linear_C", None),
40
+ ("triangle_attention_cuda_fwd", "triangle_attention_cuda_fwd", None),
41
+ ("triangle_attention_cuda_bwd", "triangle_attention_cuda_bwd", None),
42
+ ("noop", "noop_gpu", "noop_cpu"),
43
+ ("sleep", "sleep_gpu", "sleep_cpu"),
44
+ ("synchronize", "synchronize_gpu", "synchronize_cpu"),
45
+ ("event_record", "event_record_gpu", "event_record_cpu"),
46
+ ("event_elapsed", "event_elapsed_gpu", "event_elapsed_cpu"),
47
+ ]
48
+
49
+ for name, cuda_fn, cpu_fn in CUSTOM_FUNCS:
50
+ if cuda_fn is not None:
51
+ ffi.register_ffi_target(
52
+ name=name, fn=ffi.pycapsule(getattr(library, cuda_fn)), platform="CUDA"
53
+ )
54
+ if cpu_fn is not None:
55
+ ffi.register_ffi_target(
56
+ name=name, fn=ffi.pycapsule(getattr(library, cpu_fn)), platform="cpu"
57
+ )
58
+
59
+
60
+ class DataType(IntEnum):
61
+ FLOAT32 = 0
62
+ FLOAT64 = 1
63
+ FLOAT16 = 2
64
+ BFLOAT16 = 3
65
+ INT32 = 4
66
+ INT64 = 5
67
+
68
+
69
+ def _dtype(jax_dtype: jnp.dtype) -> DataType:
70
+ try:
71
+ return {
72
+ jnp.float32: DataType.FLOAT32,
73
+ jnp.float64: DataType.FLOAT64,
74
+ jnp.float16: DataType.FLOAT16,
75
+ jnp.bfloat16: DataType.BFLOAT16,
76
+ jnp.int32: DataType.INT32,
77
+ jnp.int64: DataType.INT64,
78
+ }[jnp.dtype(jax_dtype).type]
79
+ except KeyError:
80
+ raise ValueError(f"Unsupported dtype: {jax_dtype}")
@@ -0,0 +1,32 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ # CUBLAS Compute Types
12
+ # Converted from cublas_api.h cublasComputeType_t enum
13
+
14
+
15
+ class CublasComputeTypes:
16
+ """Class to provide getattr access to CUBLAS compute types."""
17
+
18
+ CUBLAS_COMPUTE_16F = 64
19
+ CUBLAS_COMPUTE_16F_PEDANTIC = 65
20
+ CUBLAS_COMPUTE_32F = 68
21
+ CUBLAS_COMPUTE_32F_PEDANTIC = 69
22
+ CUBLAS_COMPUTE_32F_FAST_16F = 74
23
+ CUBLAS_COMPUTE_32F_FAST_16BF = 75
24
+ CUBLAS_COMPUTE_32F_FAST_TF32 = 77
25
+ CUBLAS_COMPUTE_64F = 70
26
+ CUBLAS_COMPUTE_64F_PEDANTIC = 71
27
+ CUBLAS_COMPUTE_32I = 72
28
+ CUBLAS_COMPUTE_32I_PEDANTIC = 73
29
+
30
+
31
+ # Create instance for getattr access
32
+ cublas_compute_types = CublasComputeTypes()
@@ -0,0 +1,157 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+ """GPU utility functions including no-op, sleep, synchronize, and GPU events.
11
+
12
+ GPU Events Implementation Attribution:
13
+ The GPU events functionality (event_record, event_elapsed) has been adapted from
14
+ JAX's GPU events implementation that was removed in version 0.7.2.
15
+
16
+ Original source: https://github.com/jax-ml/jax/
17
+ License: Apache License 2.0
18
+
19
+ JAX Copyright 2018 The JAX Authors.
20
+ Licensed under the Apache License, Version 2.0.
21
+ """
22
+
23
+ from typing import Any
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from jax import ffi
28
+
29
+
30
+ def _flatten(pytree: Any):
31
+ """Helper to apply FFI function to JAX arrays in a pytree."""
32
+ leaves, treedef = jax.tree.flatten(pytree)
33
+ arrays = [(i, leaf) for i, leaf in enumerate(leaves) if isinstance(leaf, jax.Array)]
34
+ _, values = zip(*arrays)
35
+
36
+ def unflatten(outputs):
37
+ for idx, (leaf_idx, _) in enumerate(arrays):
38
+ leaves[leaf_idx] = outputs[idx]
39
+ return jax.tree.unflatten(treedef, leaves)
40
+
41
+ return values, unflatten
42
+
43
+
44
+ def noop(pytree: Any) -> Any:
45
+ """
46
+ No-op function that returns input pytree unchanged through FFI.
47
+
48
+ Args:
49
+ pytree: Any pytree structure containing JAX arrays
50
+
51
+ Returns:
52
+ The same pytree structure with arrays passed through FFI
53
+ """
54
+ vals, unflatten = _flatten(pytree)
55
+
56
+ vals = ffi.ffi_call(
57
+ "noop",
58
+ [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
59
+ input_output_aliases={i: i for i in range(len(vals))},
60
+ )(*vals)
61
+
62
+ return unflatten(vals)
63
+
64
+
65
+ def sleep(seconds: jax.Array, pytree: Any) -> tuple[jax.Array, Any]:
66
+ """
67
+ Sleep for the specified number of seconds and return input pytree unchanged.
68
+
69
+ Args:
70
+ seconds: Number of seconds to sleep (as a JAX array)
71
+ pytree: Any pytree structure containing JAX arrays
72
+
73
+ Returns:
74
+ A tuple of (elapsed_ticks, pytree) where elapsed_ticks is the number of
75
+ clock ticks that elapsed during the sleep operation
76
+ """
77
+ seconds = jnp.asarray(seconds, dtype=jnp.float32)
78
+ vals, unflatten = _flatten(pytree)
79
+
80
+ outputs = ffi.ffi_call(
81
+ "sleep",
82
+ [jax.ShapeDtypeStruct((), jnp.int64)]
83
+ + [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
84
+ input_output_aliases={i: i for i in range(1, len(vals) + 1)},
85
+ )(seconds, *vals)
86
+ elapsed_ticks, vals = outputs[0], outputs[1:]
87
+
88
+ return elapsed_ticks, unflatten(vals)
89
+
90
+
91
+ def synchronize(pytree: Any) -> tuple[jax.Array, Any]:
92
+ """
93
+ Synchronize the current CUDA stream and return input pytree unchanged.
94
+
95
+ Args:
96
+ pytree: Any pytree structure containing JAX arrays
97
+
98
+ Returns:
99
+ A tuple of (elapsed_seconds, pytree) where elapsed_seconds is the time
100
+ in seconds it took to synchronize the CUDA stream
101
+ """
102
+ vals, unflatten = _flatten(pytree)
103
+
104
+ outputs = ffi.ffi_call(
105
+ "synchronize",
106
+ [jax.ShapeDtypeStruct((), jnp.float32)]
107
+ + [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
108
+ input_output_aliases={i: i + 1 for i in range(len(vals))},
109
+ )(*vals)
110
+ elapsed_seconds, vals = outputs[0], outputs[1:]
111
+
112
+ return elapsed_seconds, unflatten(vals)
113
+
114
+
115
+ def event_record(pytree: Any, *, copy_before: bool = False) -> tuple[jax.Array, Any]:
116
+ """
117
+ Record a GPU event on the current CUDA stream and return the event handle.
118
+
119
+ Args:
120
+ pytree: Any pytree structure containing JAX arrays
121
+ copy_before: If True, copy event handle to device before recording.
122
+ If False, copy after recording (default).
123
+
124
+ Returns:
125
+ A tuple of (event_handle, pytree) where event_handle is a uint64
126
+ representing the CUDA event, and pytree is passed through unchanged.
127
+ """
128
+ vals, unflatten = _flatten(pytree)
129
+
130
+ outputs = ffi.ffi_call(
131
+ "event_record",
132
+ [jax.ShapeDtypeStruct((), jnp.uint64)] # event_handle
133
+ + [jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in vals],
134
+ input_output_aliases={i: i + 1 for i in range(len(vals))},
135
+ )(*vals, copy_before=copy_before)
136
+ event_handle, vals = outputs[0], outputs[1:]
137
+
138
+ return event_handle, unflatten(vals)
139
+
140
+
141
+ def event_elapsed(start_event: jax.Array, end_event: jax.Array) -> jax.Array:
142
+ """
143
+ Calculate elapsed time between two GPU events.
144
+
145
+ Args:
146
+ start_event: uint64 event handle from event_record()
147
+ end_event: uint64 event handle from event_record()
148
+
149
+ Returns:
150
+ Elapsed time in milliseconds as a float32 scalar.
151
+ """
152
+ elapsed_ms = ffi.ffi_call(
153
+ "event_elapsed",
154
+ jax.ShapeDtypeStruct((), jnp.float32),
155
+ )(start_event, end_event)
156
+
157
+ return elapsed_ms
@@ -0,0 +1,148 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+ from __future__ import annotations
11
+
12
+ import jax
13
+ import jax.numpy as jnp
14
+ from jax import ffi
15
+
16
+ from ._common import _dtype
17
+ from ._cublas_enums import cublas_compute_types
18
+
19
+
20
+ def indexed_linear(
21
+ A: jax.Array,
22
+ B: jax.Array,
23
+ D: jax.ShapeDtypeStruct,
24
+ counts: jax.Array,
25
+ u: int,
26
+ v: int,
27
+ C: int,
28
+ Z: int,
29
+ subscripts: tuple[str, str, str],
30
+ coefficient: float,
31
+ math_dtype: str | None = None,
32
+ ) -> jax.Array:
33
+ """
34
+ Performance benchmarks for compatible dtype + math_dtype combinations
35
+ (Problem size: u=512, v=512, Z=10,000):
36
+
37
+ | Rank | Dtype + Math Dtype | Time (ms) | Speedup vs float64 |
38
+ |------|---------------------------------------|-----------|-------------------|
39
+ | 🥇 | bfloat16 + CUBLAS_COMPUTE_32F | 0.0891 | 83.9x |
40
+ | 🥈 | float16 + CUBLAS_COMPUTE_32F | 0.0911 | 82.1x |
41
+ | 🥉 | float32 + CUBLAS_COMPUTE_32F_FAST_TF32| 0.1403 | 53.3x |
42
+ | 4th | float32 + CUBLAS_COMPUTE_32F_PEDANTIC | 0.1987 | 37.6x |
43
+ | 5th | float32 + CUBLAS_COMPUTE_32F | 0.2038 | 36.7x |
44
+ | 6th | float64 + CUBLAS_COMPUTE_64F | 7.4742 | 1.0x (baseline) |
45
+ """
46
+ subscripts = tuple(subscripts)
47
+ original_subscripts = subscripts
48
+ assert len(subscripts) == 3
49
+ swap_u_v = False
50
+ swap_A_B = False
51
+
52
+ dtype = jnp.dtype(A.dtype)
53
+ assert dtype == B.dtype
54
+ assert dtype == D.dtype
55
+
56
+ if math_dtype is None:
57
+ dtype_to_compute_type = {
58
+ jnp.bfloat16: cublas_compute_types.CUBLAS_COMPUTE_32F, # Use 32F compute type for bf16
59
+ jnp.float16: cublas_compute_types.CUBLAS_COMPUTE_32F, # Use 32F compute type for fp16
60
+ jnp.float32: cublas_compute_types.CUBLAS_COMPUTE_32F,
61
+ jnp.float64: cublas_compute_types.CUBLAS_COMPUTE_64F,
62
+ }
63
+ if dtype.type not in dtype_to_compute_type:
64
+ raise ValueError(
65
+ f"For dtype '{dtype}', please specify math_dtype manually and check CUBLAS documentation "
66
+ f"for compatible compute types."
67
+ )
68
+ compute_type = dtype_to_compute_type[dtype.type]
69
+ else:
70
+ if not hasattr(cublas_compute_types, math_dtype):
71
+ supported_types = [
72
+ attr
73
+ for attr in dir(cublas_compute_types)
74
+ if attr.startswith("CUBLAS_COMPUTE_")
75
+ ]
76
+ raise ValueError(
77
+ f"Unsupported math_dtype '{math_dtype}'. "
78
+ f"The supported compute types are: {supported_types}. "
79
+ f"Be careful as they are not compatible with all I/O dtype combinations. "
80
+ f"Have a look at the CUBLAS documentation for compatibility details."
81
+ )
82
+ compute_type = getattr(cublas_compute_types, math_dtype)
83
+
84
+ if subscripts in [("u", "v", "vu"), ("uv", "v", "u"), ("vu", "v", "u")]:
85
+ swap_A_B = True
86
+ swap_u_v = True
87
+ if subscripts in [("v", "uv", "u"), ("v", "vu", "u"), ("v", "u", "vu")]:
88
+ swap_u_v = True
89
+ if subscripts in [("v", "u", "uv"), ("uv", "u", "v"), ("vu", "u", "v")]:
90
+ swap_A_B = True
91
+
92
+ if swap_u_v:
93
+ subscripts = tuple(
94
+ x.replace("u", "q").replace("v", "u").replace("q", "v") for x in subscripts
95
+ )
96
+ u, v = v, u
97
+
98
+ if swap_A_B:
99
+ subscripts = (subscripts[1], subscripts[0], subscripts[2])
100
+ A, B = B, A
101
+
102
+ temp_storage_bytes_cub_ExclusiveSum = 1024 # TODO this seems to be sufficient but we never know if it's enough for all use cases and GPUs
103
+ workspace_size = (
104
+ counts.size * (3 + 1) * jnp.dtype(jnp.int64).itemsize
105
+ + temp_storage_bytes_cub_ExclusiveSum
106
+ )
107
+ workspace = jnp.empty((workspace_size,), dtype=jnp.int8)
108
+
109
+ if subscripts == ("u", "v", "uv"):
110
+ (D, _) = ffi.ffi_call("indexed_linear_C", (D, workspace))(
111
+ A,
112
+ B,
113
+ counts,
114
+ compute_type=compute_type,
115
+ u=u,
116
+ v=v,
117
+ C=C,
118
+ Z=Z,
119
+ coefficient=coefficient,
120
+ dtype_A=_dtype(A.dtype),
121
+ dtype_B=_dtype(B.dtype),
122
+ dtype_D=_dtype(D.dtype),
123
+ )
124
+ return D
125
+
126
+ if subscripts == ("u", "uv", "v"):
127
+ transpose_B = False
128
+ elif subscripts == ("u", "vu", "v"):
129
+ transpose_B = True
130
+ else:
131
+ raise ValueError(f"Invalid subscripts: {original_subscripts}.")
132
+
133
+ (D, _) = ffi.ffi_call("indexed_linear_B", (D, workspace))(
134
+ A,
135
+ B,
136
+ counts,
137
+ compute_type=compute_type,
138
+ u=u,
139
+ v=v,
140
+ C=C,
141
+ Z=Z,
142
+ transpose_B=transpose_B,
143
+ coefficient=coefficient,
144
+ dtype_A=_dtype(A.dtype),
145
+ dtype_B=_dtype(B.dtype),
146
+ dtype_D=_dtype(D.dtype),
147
+ )
148
+ return D
@@ -0,0 +1,210 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+ from __future__ import annotations
11
+
12
+ from enum import IntEnum
13
+ from typing import NamedTuple
14
+
15
+ import jax
16
+ import jax.numpy as jnp
17
+ import numpy as np
18
+ from jax import ffi
19
+
20
+ from ._common import _dtype
21
+
22
+
23
+ class BatchDimension(IntEnum):
24
+ BATCHED = 0
25
+ SHARED = 1
26
+ INDEXED = 2
27
+
28
+
29
+ class SegmentDimension(IntEnum):
30
+ SCALAR = 0
31
+ VECTOR = 1
32
+
33
+
34
+ def _batch_dim(idx: int, size: int, batch_size: int) -> BatchDimension:
35
+ if idx < 0:
36
+ if size == batch_size:
37
+ return BatchDimension.BATCHED
38
+ else:
39
+ return BatchDimension.SHARED
40
+ else:
41
+ return BatchDimension.INDEXED
42
+
43
+
44
+ def _seg_dim(buffer: jax.ShapeDtypeStruct, operand_extent: int) -> SegmentDimension:
45
+ if buffer.shape[-1] == operand_extent:
46
+ return SegmentDimension.VECTOR
47
+ else:
48
+ return SegmentDimension.SCALAR
49
+
50
+
51
+ class Operation(NamedTuple):
52
+ buffers: list[int]
53
+ start_path: int
54
+ num_paths: int
55
+
56
+
57
+ class Path(NamedTuple):
58
+ indices: list[int]
59
+ coefficient: float
60
+
61
+
62
+ def _batch_size(sizes: list[int]) -> int:
63
+ batch_size = 1
64
+ for size in sizes:
65
+ if size != 1:
66
+ assert batch_size in {1, size}
67
+ batch_size = size
68
+ return batch_size
69
+
70
+
71
+ def _operand_extent(
72
+ buffers: list[jax.ShapeDtypeStruct],
73
+ ):
74
+ operand_extent = max(x.shape[-1] for x in buffers)
75
+ for x in buffers:
76
+ assert x.shape[-1] in {1, operand_extent}, x.shape[-1]
77
+ return operand_extent
78
+
79
+
80
+ def _operation_start_indices(
81
+ paths: list[Path], operation_start_paths: np.ndarray
82
+ ) -> np.ndarray:
83
+ path_num_operands = np.array([len(path.indices) for path in paths], dtype=np.int32)
84
+ start_indices = np.append(0, np.cumsum(path_num_operands))
85
+ return start_indices[operation_start_paths].astype(np.int64)
86
+
87
+
88
+ def tensor_product_uniform_1d_jit(
89
+ input_buffers: list[jax.Array], # ndim = num_batch_axes + 2
90
+ output_buffers_shape_dtype: list[jax.ShapeDtypeStruct], # ndim = num_batch_axes + 2
91
+ index_buffers: list[jax.Array], # ndim = num_batch_axes
92
+ buffer_index: list[list[int]], # -1 if not indexed
93
+ *,
94
+ operations: list[Operation],
95
+ paths: list[Path],
96
+ math_dtype: jnp.dtype,
97
+ name: str = "untitled",
98
+ ) -> list[jax.Array]:
99
+ """JIT-compiled CUDA implementation of tensor_product_uniform_1d."""
100
+ input_buffers = list(input_buffers)
101
+ output_buffers_shape_dtype = list(output_buffers_shape_dtype)
102
+ index_buffers = list(index_buffers)
103
+ buffer_index = np.array(buffer_index, dtype=np.int32)
104
+ operations, paths = list(operations), list(paths)
105
+
106
+ io_buffers = input_buffers + output_buffers_shape_dtype
107
+ buffers = io_buffers + index_buffers
108
+ assert buffer_index.shape[0] == len(buffers)
109
+
110
+ # trick: ensure all outputs are "used" by adding dummy operations for unused outputs
111
+ # this ensures that the kernel writes zeros to unused outputs (as expected by the XLA bindings)
112
+ for i in range(len(input_buffers), len(io_buffers)):
113
+ if not any(i in op.buffers for op in operations):
114
+ operations.append(Operation([i], 0, 0))
115
+
116
+ num_batch_axes = buffer_index.shape[1]
117
+ for x in io_buffers:
118
+ assert x.ndim == num_batch_axes + 2
119
+ for i in index_buffers:
120
+ assert i.ndim == num_batch_axes
121
+ assert i.dtype.type in {jnp.int32, jnp.int64}
122
+
123
+ batch_sizes = [
124
+ _batch_size(
125
+ [x.shape[i] for x, idx in zip(buffers, buffer_index[:, i]) if idx < 0],
126
+ )
127
+ for i in range(num_batch_axes)
128
+ ]
129
+
130
+ buffer_batch_dim = np.array(
131
+ [
132
+ [
133
+ _batch_dim(idx, size, batch_size)
134
+ for idx, size, batch_size in zip(idxs, x.shape, batch_sizes)
135
+ ]
136
+ for x, idxs in zip(buffers, buffer_index)
137
+ ]
138
+ )
139
+
140
+ index_extent = [set() for _ in range(len(index_buffers))]
141
+ for x, idxs in zip(buffers, buffer_index):
142
+ for size, idx in zip(x.shape, idxs):
143
+ if idx >= 0:
144
+ index_extent[idx].add(size)
145
+ assert all(len(x) == 1 for x in index_extent), index_extent
146
+ index_extent = [x.pop() for x in index_extent]
147
+
148
+ operand_extent = _operand_extent(io_buffers)
149
+
150
+ math_dtype = jnp.dtype(math_dtype)
151
+ assert math_dtype.type in {jnp.float32, jnp.float64}
152
+
153
+ def ii(items):
154
+ return np.array([i for i in items], dtype=np.int64)
155
+
156
+ buffer_batch_dim = ii(buffer_batch_dim.flatten())
157
+ buffer_num_segments = ii(x.shape[-2] for x in io_buffers)
158
+ buffer_segments_dim = ii(_seg_dim(x, operand_extent) for x in io_buffers)
159
+ buffer_index = ii(buffer_index.flatten())
160
+ index_extent = ii(index_extent)
161
+ buffer_dtype = ii(_dtype(x.dtype) for x in io_buffers + index_buffers)
162
+ operation_num_operands = ii(len(op.buffers) for op in operations)
163
+ operation_buffers = ii(b for op in operations for b in op.buffers)
164
+ operation_num_paths = ii(op.num_paths for op in operations)
165
+ operation_start_coeffs = ii(op.start_path for op in operations)
166
+ operation_start_indices = _operation_start_indices(paths, operation_start_coeffs)
167
+ path_indices = ii(i for path in paths for i in path.indices)
168
+ path_coefficients = np.array([path.coefficient for path in paths], dtype=np.float64)
169
+ batch_sizes = ii(batch_sizes)
170
+
171
+ # print(f"{operand_extent=}")
172
+ # print(f"num_indices={len(index_buffers)}")
173
+ # print(f"{buffer_batch_dim=}")
174
+ # print(f"{buffer_num_segments=}")
175
+ # print(f"{buffer_segments_dim=}")
176
+ # print(f"{buffer_index=}")
177
+ # print(f"{index_extent=}")
178
+ # print(f"{buffer_dtype=}")
179
+ # print(f"{operation_num_operands=}")
180
+ # print(f"{operation_buffers=}")
181
+ # print(f"{operation_num_paths=}")
182
+ # print(f"{operation_start_indices=}")
183
+ # print(f"{operation_start_coeffs=}")
184
+ # print(f"{path_indices=}")
185
+ # print(f"{path_coefficients=}")
186
+ # print(f"{batch_sizes=}", flush=True)
187
+
188
+ call = ffi.ffi_call("tensor_product_uniform_1d_jit", output_buffers_shape_dtype)
189
+ return call(
190
+ *input_buffers,
191
+ *index_buffers,
192
+ name=name,
193
+ math_dtype=_dtype(math_dtype),
194
+ operand_extent=operand_extent,
195
+ num_indices=len(index_buffers),
196
+ buffer_batch_dim=buffer_batch_dim,
197
+ buffer_num_segments=buffer_num_segments,
198
+ buffer_segments_dim=buffer_segments_dim,
199
+ buffer_index=buffer_index,
200
+ index_extent=index_extent,
201
+ buffer_dtype=buffer_dtype,
202
+ operation_num_operands=operation_num_operands,
203
+ operation_buffers=operation_buffers,
204
+ operation_num_paths=operation_num_paths,
205
+ operation_start_indices=operation_start_indices,
206
+ operation_start_coeffs=operation_start_coeffs,
207
+ path_indices=path_indices,
208
+ path_coefficients=path_coefficients.view(np.int64),
209
+ batch_sizes=batch_sizes,
210
+ )