cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/cudart.hpp +286 -0
- cuequivariance_ops/common/error.hpp +66 -0
- cuequivariance_ops/common/error_raft.hpp +323 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
- cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/gpu_timing_kernels.hh +42 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +40 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +336 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
- cuequivariance_ops/triton/pair_bias.py +365 -0
- cuequivariance_ops/triton/tuning_decorator.py +188 -0
- cuequivariance_ops/triton/utils.py +29 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.8.1
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
from ._version import __version__, __git_commit__
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import ctypes
|
|
15
|
+
|
|
16
|
+
PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def root_dir():
|
|
20
|
+
try:
|
|
21
|
+
import importlib.metadata
|
|
22
|
+
|
|
23
|
+
dist = importlib.metadata.distribution("cuequivariance_ops")
|
|
24
|
+
root = dist.locate_file("cuequivariance_ops")
|
|
25
|
+
except Exception:
|
|
26
|
+
# last resort, will fail with writeable install
|
|
27
|
+
root = os.path.dirname(__file__)
|
|
28
|
+
return root
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_library():
|
|
32
|
+
try:
|
|
33
|
+
ctypes.CDLL(
|
|
34
|
+
os.path.join(root_dir(), "lib/libcue_ops.so"), mode=PREFERRED_LOAD_FLAG
|
|
35
|
+
)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"Error while loading libcue_ops.so: {e}", file=sys.stderr)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
load_library()
|
|
41
|
+
|
|
42
|
+
__all__ = ["__version__", "__git_commit__", "root_dir", "load_library"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import importlib.resources
|
|
13
|
+
|
|
14
|
+
__version__ = (
|
|
15
|
+
importlib.resources.files("cuequivariance_ops")
|
|
16
|
+
.joinpath("VERSION")
|
|
17
|
+
.read_text()
|
|
18
|
+
.strip()
|
|
19
|
+
)
|
|
20
|
+
__git_commit__ = "release"
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <cstdint>
|
|
12
|
+
#include <cuda_bf16.h>
|
|
13
|
+
#include <cuda_fp16.h>
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
namespace kernelcatcher {
|
|
17
|
+
|
|
18
|
+
namespace tensor_product {
|
|
19
|
+
/**
|
|
20
|
+
* @brief a wrapper struct containing informations
|
|
21
|
+
* about tensor-product paths
|
|
22
|
+
*/
|
|
23
|
+
template <typename MathT>
|
|
24
|
+
struct __attribute__((aligned(16))) tp_info {
|
|
25
|
+
/** offsets into `path_offsets_and_dims` for each "target" */
|
|
26
|
+
const int32_t* __restrict__ path_csr_offsets{nullptr};
|
|
27
|
+
/** "sources" of all paths and their offsets and dimensions */
|
|
28
|
+
const int32_t* __restrict__ path_offsets_and_dims{nullptr};
|
|
29
|
+
/** clebsch-gordan values for each path */
|
|
30
|
+
const MathT* __restrict__ path_cg_values{nullptr};
|
|
31
|
+
/** number of "target" segments */
|
|
32
|
+
int32_t num_target_segments{0};
|
|
33
|
+
/** number of path (i.e. all paths between segments) */
|
|
34
|
+
int32_t num_paths{0};
|
|
35
|
+
}; // struct tp_info
|
|
36
|
+
|
|
37
|
+
enum class ConnectionModeT : uint8_t {
|
|
38
|
+
kUVW = 0,
|
|
39
|
+
// UVW with U spherical harmonic
|
|
40
|
+
k1VW, // NOLINT
|
|
41
|
+
// UVW with V spherical harmonic
|
|
42
|
+
kU1W, // NOLINT
|
|
43
|
+
kUVU,
|
|
44
|
+
kUVV,
|
|
45
|
+
kUUW,
|
|
46
|
+
kUUU,
|
|
47
|
+
// FullTP, no weight
|
|
48
|
+
kUVUV,
|
|
49
|
+
// FullTP, U spherical harmonic
|
|
50
|
+
k1V1V,
|
|
51
|
+
// FullTP, V spherical harmonic
|
|
52
|
+
kU1U1,
|
|
53
|
+
// Linear
|
|
54
|
+
kUUVV,
|
|
55
|
+
};
|
|
56
|
+
} // namespace tensor_product
|
|
57
|
+
|
|
58
|
+
namespace symmetric_tensor_contraction {
|
|
59
|
+
/**
|
|
60
|
+
* @brief a wrapper struct containing informations
|
|
61
|
+
* about tensor-product paths
|
|
62
|
+
*/
|
|
63
|
+
template <typename DataT>
|
|
64
|
+
struct __attribute__((aligned(16))) clebsch_gordan_tensor {
|
|
65
|
+
const DataT* __restrict__ cg_values{nullptr};
|
|
66
|
+
const int16_t* __restrict__ cg_indices{nullptr};
|
|
67
|
+
const int32_t* __restrict__ cg_offsets{nullptr};
|
|
68
|
+
int32_t total_output_irreps{0};
|
|
69
|
+
}; // struct clebsch_gordan_tensor
|
|
70
|
+
} // namespace symmetric_tensor_contraction
|
|
71
|
+
|
|
72
|
+
namespace batch_linear {
|
|
73
|
+
struct __attribute__((aligned(8))) MatrixLayout {
|
|
74
|
+
int32_t size_row; // uncontracted mode
|
|
75
|
+
int32_t size_col; // contracted mode
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
struct __attribute__((aligned(8))) IndexOffset {
|
|
79
|
+
int32_t start;
|
|
80
|
+
int32_t end;
|
|
81
|
+
};
|
|
82
|
+
|
|
83
|
+
enum class GemvModeT : std::uint8_t { kUVV = 0, kUUV = 1 };
|
|
84
|
+
enum class WeightSharedModeT : std::int32_t { kShared = 0, kIndexed = 1, kBatched = 2 };
|
|
85
|
+
/**
|
|
86
|
+
* @brief a wrapper struct containing informations
|
|
87
|
+
* about tensor-product paths
|
|
88
|
+
*/
|
|
89
|
+
template <typename DataT>
|
|
90
|
+
struct __attribute__((aligned(16))) batch_linear_info {
|
|
91
|
+
const MatrixLayout* __restrict__ layouts{nullptr};
|
|
92
|
+
const IndexOffset* __restrict__ index_offsets{nullptr};
|
|
93
|
+
const int32_t* __restrict__ indices{nullptr};
|
|
94
|
+
const DataT* __restrict__ alpha{nullptr};
|
|
95
|
+
}; // struct batch_linear_info
|
|
96
|
+
|
|
97
|
+
} // namespace batch_linear
|
|
98
|
+
} // namespace kernelcatcher
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <cuda_bf16.h>
|
|
12
|
+
#include <cuda_fp16.h>
|
|
13
|
+
|
|
14
|
+
#include "error.hpp"
|
|
15
|
+
|
|
16
|
+
#include <cuda.h>
|
|
17
|
+
#include <cuda_runtime.h>
|
|
18
|
+
#include <cuda_runtime_api.h>
|
|
19
|
+
|
|
20
|
+
#include <cstdint>
|
|
21
|
+
#include <initializer_list>
|
|
22
|
+
#include <iomanip>
|
|
23
|
+
#include <iostream>
|
|
24
|
+
#include <stdexcept>
|
|
25
|
+
#include <string>
|
|
26
|
+
|
|
27
|
+
namespace kernelcatcher::utils {
|
|
28
|
+
|
|
29
|
+
__device__ constexpr int sm_arch()
|
|
30
|
+
{
|
|
31
|
+
#ifdef __CUDA_ARCH__
|
|
32
|
+
return __CUDA_ARCH__;
|
|
33
|
+
#else
|
|
34
|
+
return -1;
|
|
35
|
+
#endif
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
template <typename DataT>
|
|
39
|
+
__device__ constexpr bool valid_data_type_for_arch()
|
|
40
|
+
{
|
|
41
|
+
// we only support sm_arch >= 700 anyways, atomics cause issues <= 600 for double, too
|
|
42
|
+
// so guard against that
|
|
43
|
+
if (std::is_same<DataT, double>::value && (sm_arch() < 700)) { return false; }
|
|
44
|
+
if (std::is_same<DataT, __half>::value && (sm_arch() < 700)) { return false; }
|
|
45
|
+
if (std::is_same<DataT, __nv_bfloat16>::value && (sm_arch() < 800)) { return false; }
|
|
46
|
+
return true;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
template <typename DataT>
|
|
50
|
+
__host__ __device__ constexpr int32_t get_native_veclen()
|
|
51
|
+
{
|
|
52
|
+
// simplified alignment checks: use VECLEN for packed types of reduced precision
|
|
53
|
+
// otherwise 1 (e.g. FP16 would have up to 2 FP8 could have up to 4)
|
|
54
|
+
// we usually already use rather many registers and don't have too many dimensions
|
|
55
|
+
// to iterative over to begin with, so this should simplify things for now
|
|
56
|
+
if (static_cast<int32_t>(sizeof(DataT)) >= 4) { return int32_t{1}; }
|
|
57
|
+
return static_cast<int32_t>(4 / static_cast<int32_t>(sizeof(DataT)));
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
template <typename DataT>
|
|
61
|
+
void copy(DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind)
|
|
62
|
+
{
|
|
63
|
+
RAFT_CUDA_TRY(cudaMemcpy(out, in, sizeof(DataT) * len, kind));
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
template <typename DataT>
|
|
67
|
+
void copy_async(DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind, cudaStream_t stream)
|
|
68
|
+
{
|
|
69
|
+
RAFT_CUDA_TRY(cudaMemcpyAsync(out, in, sizeof(DataT) * len, kind, stream));
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
template <typename DataT>
|
|
73
|
+
void copy_async_no_throw(
|
|
74
|
+
DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind, cudaStream_t stream) noexcept
|
|
75
|
+
{
|
|
76
|
+
RAFT_CUDA_TRY_NO_THROW(cudaMemcpyAsync(out, in, sizeof(DataT) * len, kind, stream));
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
inline void sync(cudaStream_t stream = nullptr) { RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); }
|
|
80
|
+
|
|
81
|
+
inline void sync_no_throw(cudaStream_t stream = nullptr) noexcept
|
|
82
|
+
{
|
|
83
|
+
RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream));
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
template <typename DataT>
|
|
87
|
+
inline void memset(DataT* out, size_t len, uint8_t byte_value = 0)
|
|
88
|
+
{
|
|
89
|
+
RAFT_CUDA_TRY(cudaMemset(out, byte_value, len * sizeof(DataT)));
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
template <typename DataT>
|
|
93
|
+
inline void memset_async(DataT* out, size_t len, cudaStream_t stream, uint8_t byte_value = 0)
|
|
94
|
+
{
|
|
95
|
+
RAFT_CUDA_TRY(cudaMemsetAsync(out, byte_value, len * sizeof(DataT), stream));
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
template <typename DataT>
|
|
99
|
+
inline void memset_async_no_throw(DataT* out,
|
|
100
|
+
size_t len,
|
|
101
|
+
cudaStream_t stream,
|
|
102
|
+
uint8_t byte_value = 0) noexcept
|
|
103
|
+
{
|
|
104
|
+
RAFT_CUDA_TRY_NO_THROW(cudaMemsetAsync(out, byte_value, len * sizeof(DataT), stream));
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
inline int get_sm_count()
|
|
108
|
+
{
|
|
109
|
+
int dev_id;
|
|
110
|
+
RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
|
|
111
|
+
int mp_count;
|
|
112
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&mp_count, cudaDevAttrMultiProcessorCount, dev_id));
|
|
113
|
+
return mp_count;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
template <class FuncT>
|
|
117
|
+
inline int get_max_blocks_per_sm(FuncT func, int block_size, size_t dynamic_smem_size)
|
|
118
|
+
{
|
|
119
|
+
int nblks;
|
|
120
|
+
RAFT_CUDA_TRY(
|
|
121
|
+
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&nblks, func, block_size, dynamic_smem_size));
|
|
122
|
+
return nblks;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
template <class FuncT>
|
|
126
|
+
inline int get_max_grid_blocks(FuncT func, int block_size, size_t dynamic_smem_size)
|
|
127
|
+
{
|
|
128
|
+
return get_max_blocks_per_sm(func, block_size, dynamic_smem_size) * get_sm_count();
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
template <class FuncT>
|
|
132
|
+
inline size_t get_static_smem_size(FuncT func)
|
|
133
|
+
{
|
|
134
|
+
cudaFuncAttributes attrs;
|
|
135
|
+
RAFT_CUDA_TRY(cudaFuncGetAttributes(&attrs, func));
|
|
136
|
+
return attrs.sharedSizeBytes;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
inline void get_max_smem_sizes(size_t& smem_block, size_t& smem_optin)
|
|
140
|
+
{
|
|
141
|
+
int dev_id, smem_blk, smem_max;
|
|
142
|
+
RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
|
|
143
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smem_blk, cudaDevAttrMaxSharedMemoryPerBlock, dev_id));
|
|
144
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id));
|
|
145
|
+
smem_block = static_cast<size_t>(smem_blk);
|
|
146
|
+
smem_optin = static_cast<size_t>(smem_max);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
inline int get_max_smem_per_block_optin()
|
|
150
|
+
{
|
|
151
|
+
int dev_id;
|
|
152
|
+
cudaGetDevice(&dev_id);
|
|
153
|
+
int available_smem;
|
|
154
|
+
cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id);
|
|
155
|
+
return available_smem;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
inline int get_max_smem_per_sm()
|
|
159
|
+
{
|
|
160
|
+
int dev_id;
|
|
161
|
+
cudaGetDevice(&dev_id);
|
|
162
|
+
int available_smem;
|
|
163
|
+
cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id);
|
|
164
|
+
return available_smem;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
inline int get_max_smem_per_block()
|
|
168
|
+
{
|
|
169
|
+
int dev_id;
|
|
170
|
+
cudaGetDevice(&dev_id);
|
|
171
|
+
int available_smem;
|
|
172
|
+
cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerBlock, dev_id);
|
|
173
|
+
return available_smem;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
template <typename FuncT>
|
|
177
|
+
inline void set_smem_optin(int32_t required_size, FuncT func)
|
|
178
|
+
{
|
|
179
|
+
// opt-in with actual required size
|
|
180
|
+
RAFT_CUDA_TRY(
|
|
181
|
+
cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, required_size));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
template <typename DataT, int ALIGN>
|
|
185
|
+
inline bool is_aligned(std::initializer_list<const void*> ptrs,
|
|
186
|
+
std::initializer_list<size_t> sizes = {})
|
|
187
|
+
{
|
|
188
|
+
bool ret = ALIGN % sizeof(DataT) == 0;
|
|
189
|
+
for (const auto* p : ptrs)
|
|
190
|
+
ret = ret && reinterpret_cast<uintptr_t>(p) % ALIGN == 0;
|
|
191
|
+
for (auto s : sizes)
|
|
192
|
+
ret = ret && (s * sizeof(DataT)) % ALIGN == 0;
|
|
193
|
+
return ret;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
/**
|
|
197
|
+
* @brief Get the device ID associated with a CUDA stream.
|
|
198
|
+
*
|
|
199
|
+
* This function retrieves the device ID for the given CUDA stream, using the most
|
|
200
|
+
* appropriate API based on the CUDA runtime version:
|
|
201
|
+
* - For CUDA 12.8+: Uses cudaStreamGetDevice() from the Runtime API
|
|
202
|
+
* - For older versions: Falls back to CUDA Driver API methods
|
|
203
|
+
*
|
|
204
|
+
* @param stream The CUDA stream to query. Can be a user-created stream or nullptr
|
|
205
|
+
* for the default stream.
|
|
206
|
+
*
|
|
207
|
+
* @return The device ID (0-based index) associated with the stream.
|
|
208
|
+
*
|
|
209
|
+
* @throws std::runtime_error If any CUDA Driver API calls fail during fallback.
|
|
210
|
+
* The exception message includes the specific operation
|
|
211
|
+
* that failed and the CUDA error code.
|
|
212
|
+
*
|
|
213
|
+
* @note This function is compatible across CUDA versions and automatically
|
|
214
|
+
* selects the best available method for stream-to-device mapping.
|
|
215
|
+
*/
|
|
216
|
+
inline int getDeviceFromStream(cudaStream_t stream)
|
|
217
|
+
{
|
|
218
|
+
int runtimeVersion;
|
|
219
|
+
RAFT_CUDA_TRY(cudaRuntimeGetVersion(&runtimeVersion));
|
|
220
|
+
|
|
221
|
+
// CUDA 12.8 corresponds to version 12080
|
|
222
|
+
if (runtimeVersion >= 12080) {
|
|
223
|
+
// Use the new cudaStreamGetDevice function
|
|
224
|
+
int deviceId;
|
|
225
|
+
RAFT_CUDA_TRY(cudaStreamGetDevice(stream, &deviceId));
|
|
226
|
+
return deviceId;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
// Fallback to Driver API method for older CUDA versions or if Runtime API fails
|
|
230
|
+
CUstream cuStream = static_cast<CUstream>(stream);
|
|
231
|
+
CUcontext context;
|
|
232
|
+
CUdevice device;
|
|
233
|
+
|
|
234
|
+
CUresult result = cuStreamGetCtx(cuStream, &context);
|
|
235
|
+
if (result != CUDA_SUCCESS) {
|
|
236
|
+
throw std::runtime_error("Failed to get context from stream: " + std::to_string(result));
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
result = cuCtxPushCurrent(context);
|
|
240
|
+
if (result != CUDA_SUCCESS) {
|
|
241
|
+
throw std::runtime_error("Failed to push context: " + std::to_string(result));
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
result = cuCtxGetDevice(&device);
|
|
245
|
+
if (result != CUDA_SUCCESS) {
|
|
246
|
+
cuCtxPopCurrent(&context); // Clean up before throwing
|
|
247
|
+
throw std::runtime_error("Failed to get device from context: " + std::to_string(result));
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
result = cuCtxPopCurrent(&context);
|
|
251
|
+
if (result != CUDA_SUCCESS) {
|
|
252
|
+
throw std::runtime_error("Failed to pop context: " + std::to_string(result));
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
return static_cast<int>(device);
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
/**
|
|
259
|
+
* @brief Get the device properties for the device associated with a CUDA stream.
|
|
260
|
+
*
|
|
261
|
+
* This function retrieves the device properties for the device that owns the given
|
|
262
|
+
* CUDA stream. It combines getDeviceFromStream() and cudaGetDeviceProperties() to
|
|
263
|
+
* provide a convenient way to query device capabilities based on a stream handle.
|
|
264
|
+
*
|
|
265
|
+
* @param stream The CUDA stream to query. Can be a user-created stream or nullptr
|
|
266
|
+
* for the default stream.
|
|
267
|
+
*
|
|
268
|
+
* @return cudaDeviceProp structure containing the device properties for the device
|
|
269
|
+
* associated with the stream.
|
|
270
|
+
*
|
|
271
|
+
* @throws std::runtime_error If getDeviceFromStream() fails (CUDA Driver API errors)
|
|
272
|
+
* or if cudaGetDeviceProperties() fails (via RAFT_CUDA_TRY).
|
|
273
|
+
*
|
|
274
|
+
* @note This function is useful when you have a stream handle and need to query
|
|
275
|
+
* device-specific capabilities like shared memory size, compute capability,
|
|
276
|
+
* multiprocessor count, etc.
|
|
277
|
+
*/
|
|
278
|
+
inline cudaDeviceProp getDevicePropFromStream(cudaStream_t stream)
|
|
279
|
+
{
|
|
280
|
+
int deviceId = getDeviceFromStream(stream);
|
|
281
|
+
cudaDeviceProp prop;
|
|
282
|
+
RAFT_CUDA_TRY(cudaGetDeviceProperties(&prop, deviceId));
|
|
283
|
+
return prop;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "error_raft.hpp"
|
|
12
|
+
|
|
13
|
+
#include <cuda_bf16.h>
|
|
14
|
+
#include <cuda_fp16.h>
|
|
15
|
+
|
|
16
|
+
namespace kernelcatcher::utils {
|
|
17
|
+
|
|
18
|
+
template <typename DataT>
|
|
19
|
+
void inline assert_data_type_support()
|
|
20
|
+
{
|
|
21
|
+
// defaut data type is always supported
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
template <>
|
|
25
|
+
void inline assert_data_type_support<__half>()
|
|
26
|
+
{
|
|
27
|
+
int device, major;
|
|
28
|
+
RAFT_CUDA_TRY(cudaGetDevice(&device));
|
|
29
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
|
|
30
|
+
ASSERT(major >= 7,
|
|
31
|
+
"Detected compute capability < 7, however requested DataType __half requires >= 7.");
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template <>
|
|
35
|
+
void inline assert_data_type_support<__half2>()
|
|
36
|
+
{
|
|
37
|
+
int device, major;
|
|
38
|
+
RAFT_CUDA_TRY(cudaGetDevice(&device));
|
|
39
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
|
|
40
|
+
ASSERT(major >= 7,
|
|
41
|
+
"Detected compute capability < 7, however requested DataType __half2 requires >= 7.");
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
template <>
|
|
45
|
+
void inline assert_data_type_support<__nv_bfloat16>()
|
|
46
|
+
{
|
|
47
|
+
int device, major;
|
|
48
|
+
RAFT_CUDA_TRY(cudaGetDevice(&device));
|
|
49
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
|
|
50
|
+
ASSERT(
|
|
51
|
+
major >= 8,
|
|
52
|
+
"Detected compute capability < 8, however requested DataType __nv_bfloat16 requires >= 8.");
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
template <>
|
|
56
|
+
void inline assert_data_type_support<__nv_bfloat162>()
|
|
57
|
+
{
|
|
58
|
+
int device, major;
|
|
59
|
+
RAFT_CUDA_TRY(cudaGetDevice(&device));
|
|
60
|
+
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
|
|
61
|
+
ASSERT(
|
|
62
|
+
major >= 8,
|
|
63
|
+
"Detected compute capability < 8, however requested DataType __nv_bfloat162 requires >= 8.");
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
} // namespace kernelcatcher::utils
|