cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1 @@
1
+ 0.8.1
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ from ._version import __version__, __git_commit__
12
+ import os
13
+ import sys
14
+ import ctypes
15
+
16
+ PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL
17
+
18
+
19
+ def root_dir():
20
+ try:
21
+ import importlib.metadata
22
+
23
+ dist = importlib.metadata.distribution("cuequivariance_ops")
24
+ root = dist.locate_file("cuequivariance_ops")
25
+ except Exception:
26
+ # last resort, will fail with writeable install
27
+ root = os.path.dirname(__file__)
28
+ return root
29
+
30
+
31
+ def load_library():
32
+ try:
33
+ ctypes.CDLL(
34
+ os.path.join(root_dir(), "lib/libcue_ops.so"), mode=PREFERRED_LOAD_FLAG
35
+ )
36
+ except Exception as e:
37
+ print(f"Error while loading libcue_ops.so: {e}", file=sys.stderr)
38
+
39
+
40
+ load_library()
41
+
42
+ __all__ = ["__version__", "__git_commit__", "root_dir", "load_library"]
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+
12
+ import importlib.resources
13
+
14
+ __version__ = (
15
+ importlib.resources.files("cuequivariance_ops")
16
+ .joinpath("VERSION")
17
+ .read_text()
18
+ .strip()
19
+ )
20
+ __git_commit__ = "release"
@@ -0,0 +1,98 @@
1
+ /*
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <cstdint>
12
+ #include <cuda_bf16.h>
13
+ #include <cuda_fp16.h>
14
+ #include <vector>
15
+
16
+ namespace kernelcatcher {
17
+
18
+ namespace tensor_product {
19
+ /**
20
+ * @brief a wrapper struct containing informations
21
+ * about tensor-product paths
22
+ */
23
+ template <typename MathT>
24
+ struct __attribute__((aligned(16))) tp_info {
25
+ /** offsets into `path_offsets_and_dims` for each "target" */
26
+ const int32_t* __restrict__ path_csr_offsets{nullptr};
27
+ /** "sources" of all paths and their offsets and dimensions */
28
+ const int32_t* __restrict__ path_offsets_and_dims{nullptr};
29
+ /** clebsch-gordan values for each path */
30
+ const MathT* __restrict__ path_cg_values{nullptr};
31
+ /** number of "target" segments */
32
+ int32_t num_target_segments{0};
33
+ /** number of path (i.e. all paths between segments) */
34
+ int32_t num_paths{0};
35
+ }; // struct tp_info
36
+
37
+ enum class ConnectionModeT : uint8_t {
38
+ kUVW = 0,
39
+ // UVW with U spherical harmonic
40
+ k1VW, // NOLINT
41
+ // UVW with V spherical harmonic
42
+ kU1W, // NOLINT
43
+ kUVU,
44
+ kUVV,
45
+ kUUW,
46
+ kUUU,
47
+ // FullTP, no weight
48
+ kUVUV,
49
+ // FullTP, U spherical harmonic
50
+ k1V1V,
51
+ // FullTP, V spherical harmonic
52
+ kU1U1,
53
+ // Linear
54
+ kUUVV,
55
+ };
56
+ } // namespace tensor_product
57
+
58
+ namespace symmetric_tensor_contraction {
59
+ /**
60
+ * @brief a wrapper struct containing informations
61
+ * about tensor-product paths
62
+ */
63
+ template <typename DataT>
64
+ struct __attribute__((aligned(16))) clebsch_gordan_tensor {
65
+ const DataT* __restrict__ cg_values{nullptr};
66
+ const int16_t* __restrict__ cg_indices{nullptr};
67
+ const int32_t* __restrict__ cg_offsets{nullptr};
68
+ int32_t total_output_irreps{0};
69
+ }; // struct clebsch_gordan_tensor
70
+ } // namespace symmetric_tensor_contraction
71
+
72
+ namespace batch_linear {
73
+ struct __attribute__((aligned(8))) MatrixLayout {
74
+ int32_t size_row; // uncontracted mode
75
+ int32_t size_col; // contracted mode
76
+ };
77
+
78
+ struct __attribute__((aligned(8))) IndexOffset {
79
+ int32_t start;
80
+ int32_t end;
81
+ };
82
+
83
+ enum class GemvModeT : std::uint8_t { kUVV = 0, kUUV = 1 };
84
+ enum class WeightSharedModeT : std::int32_t { kShared = 0, kIndexed = 1, kBatched = 2 };
85
+ /**
86
+ * @brief a wrapper struct containing informations
87
+ * about tensor-product paths
88
+ */
89
+ template <typename DataT>
90
+ struct __attribute__((aligned(16))) batch_linear_info {
91
+ const MatrixLayout* __restrict__ layouts{nullptr};
92
+ const IndexOffset* __restrict__ index_offsets{nullptr};
93
+ const int32_t* __restrict__ indices{nullptr};
94
+ const DataT* __restrict__ alpha{nullptr};
95
+ }; // struct batch_linear_info
96
+
97
+ } // namespace batch_linear
98
+ } // namespace kernelcatcher
@@ -0,0 +1,286 @@
1
+ /*
2
+ * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <cuda_bf16.h>
12
+ #include <cuda_fp16.h>
13
+
14
+ #include "error.hpp"
15
+
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+ #include <cuda_runtime_api.h>
19
+
20
+ #include <cstdint>
21
+ #include <initializer_list>
22
+ #include <iomanip>
23
+ #include <iostream>
24
+ #include <stdexcept>
25
+ #include <string>
26
+
27
+ namespace kernelcatcher::utils {
28
+
29
+ __device__ constexpr int sm_arch()
30
+ {
31
+ #ifdef __CUDA_ARCH__
32
+ return __CUDA_ARCH__;
33
+ #else
34
+ return -1;
35
+ #endif
36
+ }
37
+
38
+ template <typename DataT>
39
+ __device__ constexpr bool valid_data_type_for_arch()
40
+ {
41
+ // we only support sm_arch >= 700 anyways, atomics cause issues <= 600 for double, too
42
+ // so guard against that
43
+ if (std::is_same<DataT, double>::value && (sm_arch() < 700)) { return false; }
44
+ if (std::is_same<DataT, __half>::value && (sm_arch() < 700)) { return false; }
45
+ if (std::is_same<DataT, __nv_bfloat16>::value && (sm_arch() < 800)) { return false; }
46
+ return true;
47
+ }
48
+
49
+ template <typename DataT>
50
+ __host__ __device__ constexpr int32_t get_native_veclen()
51
+ {
52
+ // simplified alignment checks: use VECLEN for packed types of reduced precision
53
+ // otherwise 1 (e.g. FP16 would have up to 2 FP8 could have up to 4)
54
+ // we usually already use rather many registers and don't have too many dimensions
55
+ // to iterative over to begin with, so this should simplify things for now
56
+ if (static_cast<int32_t>(sizeof(DataT)) >= 4) { return int32_t{1}; }
57
+ return static_cast<int32_t>(4 / static_cast<int32_t>(sizeof(DataT)));
58
+ }
59
+
60
+ template <typename DataT>
61
+ void copy(DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind)
62
+ {
63
+ RAFT_CUDA_TRY(cudaMemcpy(out, in, sizeof(DataT) * len, kind));
64
+ }
65
+
66
+ template <typename DataT>
67
+ void copy_async(DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind, cudaStream_t stream)
68
+ {
69
+ RAFT_CUDA_TRY(cudaMemcpyAsync(out, in, sizeof(DataT) * len, kind, stream));
70
+ }
71
+
72
+ template <typename DataT>
73
+ void copy_async_no_throw(
74
+ DataT* out, const DataT* in, size_t len, cudaMemcpyKind kind, cudaStream_t stream) noexcept
75
+ {
76
+ RAFT_CUDA_TRY_NO_THROW(cudaMemcpyAsync(out, in, sizeof(DataT) * len, kind, stream));
77
+ }
78
+
79
+ inline void sync(cudaStream_t stream = nullptr) { RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); }
80
+
81
+ inline void sync_no_throw(cudaStream_t stream = nullptr) noexcept
82
+ {
83
+ RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream));
84
+ }
85
+
86
+ template <typename DataT>
87
+ inline void memset(DataT* out, size_t len, uint8_t byte_value = 0)
88
+ {
89
+ RAFT_CUDA_TRY(cudaMemset(out, byte_value, len * sizeof(DataT)));
90
+ }
91
+
92
+ template <typename DataT>
93
+ inline void memset_async(DataT* out, size_t len, cudaStream_t stream, uint8_t byte_value = 0)
94
+ {
95
+ RAFT_CUDA_TRY(cudaMemsetAsync(out, byte_value, len * sizeof(DataT), stream));
96
+ }
97
+
98
+ template <typename DataT>
99
+ inline void memset_async_no_throw(DataT* out,
100
+ size_t len,
101
+ cudaStream_t stream,
102
+ uint8_t byte_value = 0) noexcept
103
+ {
104
+ RAFT_CUDA_TRY_NO_THROW(cudaMemsetAsync(out, byte_value, len * sizeof(DataT), stream));
105
+ }
106
+
107
+ inline int get_sm_count()
108
+ {
109
+ int dev_id;
110
+ RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
111
+ int mp_count;
112
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&mp_count, cudaDevAttrMultiProcessorCount, dev_id));
113
+ return mp_count;
114
+ }
115
+
116
+ template <class FuncT>
117
+ inline int get_max_blocks_per_sm(FuncT func, int block_size, size_t dynamic_smem_size)
118
+ {
119
+ int nblks;
120
+ RAFT_CUDA_TRY(
121
+ cudaOccupancyMaxActiveBlocksPerMultiprocessor(&nblks, func, block_size, dynamic_smem_size));
122
+ return nblks;
123
+ }
124
+
125
+ template <class FuncT>
126
+ inline int get_max_grid_blocks(FuncT func, int block_size, size_t dynamic_smem_size)
127
+ {
128
+ return get_max_blocks_per_sm(func, block_size, dynamic_smem_size) * get_sm_count();
129
+ }
130
+
131
+ template <class FuncT>
132
+ inline size_t get_static_smem_size(FuncT func)
133
+ {
134
+ cudaFuncAttributes attrs;
135
+ RAFT_CUDA_TRY(cudaFuncGetAttributes(&attrs, func));
136
+ return attrs.sharedSizeBytes;
137
+ }
138
+
139
+ inline void get_max_smem_sizes(size_t& smem_block, size_t& smem_optin)
140
+ {
141
+ int dev_id, smem_blk, smem_max;
142
+ RAFT_CUDA_TRY(cudaGetDevice(&dev_id));
143
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smem_blk, cudaDevAttrMaxSharedMemoryPerBlock, dev_id));
144
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smem_max, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id));
145
+ smem_block = static_cast<size_t>(smem_blk);
146
+ smem_optin = static_cast<size_t>(smem_max);
147
+ }
148
+
149
+ inline int get_max_smem_per_block_optin()
150
+ {
151
+ int dev_id;
152
+ cudaGetDevice(&dev_id);
153
+ int available_smem;
154
+ cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id);
155
+ return available_smem;
156
+ }
157
+
158
+ inline int get_max_smem_per_sm()
159
+ {
160
+ int dev_id;
161
+ cudaGetDevice(&dev_id);
162
+ int available_smem;
163
+ cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id);
164
+ return available_smem;
165
+ }
166
+
167
+ inline int get_max_smem_per_block()
168
+ {
169
+ int dev_id;
170
+ cudaGetDevice(&dev_id);
171
+ int available_smem;
172
+ cudaDeviceGetAttribute(&available_smem, cudaDevAttrMaxSharedMemoryPerBlock, dev_id);
173
+ return available_smem;
174
+ }
175
+
176
+ template <typename FuncT>
177
+ inline void set_smem_optin(int32_t required_size, FuncT func)
178
+ {
179
+ // opt-in with actual required size
180
+ RAFT_CUDA_TRY(
181
+ cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, required_size));
182
+ }
183
+
184
+ template <typename DataT, int ALIGN>
185
+ inline bool is_aligned(std::initializer_list<const void*> ptrs,
186
+ std::initializer_list<size_t> sizes = {})
187
+ {
188
+ bool ret = ALIGN % sizeof(DataT) == 0;
189
+ for (const auto* p : ptrs)
190
+ ret = ret && reinterpret_cast<uintptr_t>(p) % ALIGN == 0;
191
+ for (auto s : sizes)
192
+ ret = ret && (s * sizeof(DataT)) % ALIGN == 0;
193
+ return ret;
194
+ }
195
+
196
+ /**
197
+ * @brief Get the device ID associated with a CUDA stream.
198
+ *
199
+ * This function retrieves the device ID for the given CUDA stream, using the most
200
+ * appropriate API based on the CUDA runtime version:
201
+ * - For CUDA 12.8+: Uses cudaStreamGetDevice() from the Runtime API
202
+ * - For older versions: Falls back to CUDA Driver API methods
203
+ *
204
+ * @param stream The CUDA stream to query. Can be a user-created stream or nullptr
205
+ * for the default stream.
206
+ *
207
+ * @return The device ID (0-based index) associated with the stream.
208
+ *
209
+ * @throws std::runtime_error If any CUDA Driver API calls fail during fallback.
210
+ * The exception message includes the specific operation
211
+ * that failed and the CUDA error code.
212
+ *
213
+ * @note This function is compatible across CUDA versions and automatically
214
+ * selects the best available method for stream-to-device mapping.
215
+ */
216
+ inline int getDeviceFromStream(cudaStream_t stream)
217
+ {
218
+ int runtimeVersion;
219
+ RAFT_CUDA_TRY(cudaRuntimeGetVersion(&runtimeVersion));
220
+
221
+ // CUDA 12.8 corresponds to version 12080
222
+ if (runtimeVersion >= 12080) {
223
+ // Use the new cudaStreamGetDevice function
224
+ int deviceId;
225
+ RAFT_CUDA_TRY(cudaStreamGetDevice(stream, &deviceId));
226
+ return deviceId;
227
+ }
228
+
229
+ // Fallback to Driver API method for older CUDA versions or if Runtime API fails
230
+ CUstream cuStream = static_cast<CUstream>(stream);
231
+ CUcontext context;
232
+ CUdevice device;
233
+
234
+ CUresult result = cuStreamGetCtx(cuStream, &context);
235
+ if (result != CUDA_SUCCESS) {
236
+ throw std::runtime_error("Failed to get context from stream: " + std::to_string(result));
237
+ }
238
+
239
+ result = cuCtxPushCurrent(context);
240
+ if (result != CUDA_SUCCESS) {
241
+ throw std::runtime_error("Failed to push context: " + std::to_string(result));
242
+ }
243
+
244
+ result = cuCtxGetDevice(&device);
245
+ if (result != CUDA_SUCCESS) {
246
+ cuCtxPopCurrent(&context); // Clean up before throwing
247
+ throw std::runtime_error("Failed to get device from context: " + std::to_string(result));
248
+ }
249
+
250
+ result = cuCtxPopCurrent(&context);
251
+ if (result != CUDA_SUCCESS) {
252
+ throw std::runtime_error("Failed to pop context: " + std::to_string(result));
253
+ }
254
+
255
+ return static_cast<int>(device);
256
+ }
257
+
258
+ /**
259
+ * @brief Get the device properties for the device associated with a CUDA stream.
260
+ *
261
+ * This function retrieves the device properties for the device that owns the given
262
+ * CUDA stream. It combines getDeviceFromStream() and cudaGetDeviceProperties() to
263
+ * provide a convenient way to query device capabilities based on a stream handle.
264
+ *
265
+ * @param stream The CUDA stream to query. Can be a user-created stream or nullptr
266
+ * for the default stream.
267
+ *
268
+ * @return cudaDeviceProp structure containing the device properties for the device
269
+ * associated with the stream.
270
+ *
271
+ * @throws std::runtime_error If getDeviceFromStream() fails (CUDA Driver API errors)
272
+ * or if cudaGetDeviceProperties() fails (via RAFT_CUDA_TRY).
273
+ *
274
+ * @note This function is useful when you have a stream handle and need to query
275
+ * device-specific capabilities like shared memory size, compute capability,
276
+ * multiprocessor count, etc.
277
+ */
278
+ inline cudaDeviceProp getDevicePropFromStream(cudaStream_t stream)
279
+ {
280
+ int deviceId = getDeviceFromStream(stream);
281
+ cudaDeviceProp prop;
282
+ RAFT_CUDA_TRY(cudaGetDeviceProperties(&prop, deviceId));
283
+ return prop;
284
+ }
285
+
286
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,66 @@
1
+ /*
2
+ * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "error_raft.hpp"
12
+
13
+ #include <cuda_bf16.h>
14
+ #include <cuda_fp16.h>
15
+
16
+ namespace kernelcatcher::utils {
17
+
18
+ template <typename DataT>
19
+ void inline assert_data_type_support()
20
+ {
21
+ // defaut data type is always supported
22
+ }
23
+
24
+ template <>
25
+ void inline assert_data_type_support<__half>()
26
+ {
27
+ int device, major;
28
+ RAFT_CUDA_TRY(cudaGetDevice(&device));
29
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
30
+ ASSERT(major >= 7,
31
+ "Detected compute capability < 7, however requested DataType __half requires >= 7.");
32
+ }
33
+
34
+ template <>
35
+ void inline assert_data_type_support<__half2>()
36
+ {
37
+ int device, major;
38
+ RAFT_CUDA_TRY(cudaGetDevice(&device));
39
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
40
+ ASSERT(major >= 7,
41
+ "Detected compute capability < 7, however requested DataType __half2 requires >= 7.");
42
+ }
43
+
44
+ template <>
45
+ void inline assert_data_type_support<__nv_bfloat16>()
46
+ {
47
+ int device, major;
48
+ RAFT_CUDA_TRY(cudaGetDevice(&device));
49
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
50
+ ASSERT(
51
+ major >= 8,
52
+ "Detected compute capability < 8, however requested DataType __nv_bfloat16 requires >= 8.");
53
+ }
54
+
55
+ template <>
56
+ void inline assert_data_type_support<__nv_bfloat162>()
57
+ {
58
+ int device, major;
59
+ RAFT_CUDA_TRY(cudaGetDevice(&device));
60
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
61
+ ASSERT(
62
+ major >= 8,
63
+ "Detected compute capability < 8, however requested DataType __nv_bfloat162 requires >= 8.");
64
+ }
65
+
66
+ } // namespace kernelcatcher::utils