cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/cudart.hpp +286 -0
- cuequivariance_ops/common/error.hpp +66 -0
- cuequivariance_ops/common/error_raft.hpp +323 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
- cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/gpu_timing_kernels.hh +42 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +40 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +336 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
- cuequivariance_ops/triton/pair_bias.py +365 -0
- cuequivariance_ops/triton/tuning_decorator.py +188 -0
- cuequivariance_ops/triton/utils.py +29 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef __RAFT_RT_ERROR
|
|
18
|
+
#define __RAFT_RT_ERROR
|
|
19
|
+
|
|
20
|
+
#pragma once
|
|
21
|
+
|
|
22
|
+
#if defined(__GNUC__) && __has_include(<cxxabi.h>) && __has_include(<execinfo.h>)
|
|
23
|
+
#define ENABLE_COLLECT_CALLSTACK
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
#include <cstdio>
|
|
27
|
+
#include <iostream>
|
|
28
|
+
#include <memory>
|
|
29
|
+
#include <stdexcept>
|
|
30
|
+
#include <string>
|
|
31
|
+
#include <vector>
|
|
32
|
+
|
|
33
|
+
#ifdef ENABLE_COLLECT_CALLSTACK
|
|
34
|
+
#include <cxxabi.h>
|
|
35
|
+
#include <execinfo.h>
|
|
36
|
+
|
|
37
|
+
#include <sstream>
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
namespace raft {
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* @defgroup error_handling Exceptions & Error Handling
|
|
44
|
+
* @{
|
|
45
|
+
*/
|
|
46
|
+
|
|
47
|
+
/** base exception class for the whole of raft */
|
|
48
|
+
class exception : public std::exception {
|
|
49
|
+
public:
|
|
50
|
+
/** default ctor */
|
|
51
|
+
explicit exception() noexcept : std::exception(), msg_() {}
|
|
52
|
+
|
|
53
|
+
/** copy ctor */
|
|
54
|
+
exception(exception const& src) noexcept : std::exception(), msg_(src.what())
|
|
55
|
+
{
|
|
56
|
+
collect_call_stack();
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/** ctor from an input message */
|
|
60
|
+
explicit exception(std::string const msg) noexcept : std::exception(), msg_(std::move(msg))
|
|
61
|
+
{
|
|
62
|
+
collect_call_stack();
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
/** get the message associated with this exception */
|
|
66
|
+
char const* what() const noexcept override { return msg_.c_str(); }
|
|
67
|
+
|
|
68
|
+
private:
|
|
69
|
+
/** message associated with this exception */
|
|
70
|
+
std::string msg_;
|
|
71
|
+
|
|
72
|
+
/** append call stack info to this exception's message for ease of debug */
|
|
73
|
+
// Courtesy: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html
|
|
74
|
+
void collect_call_stack() noexcept
|
|
75
|
+
{
|
|
76
|
+
#ifdef ENABLE_COLLECT_CALLSTACK
|
|
77
|
+
constexpr int kSkipFrames = 1;
|
|
78
|
+
constexpr int kMaxStackDepth = 64;
|
|
79
|
+
void* stack[kMaxStackDepth]; // NOLINT
|
|
80
|
+
auto depth = backtrace(stack, kMaxStackDepth);
|
|
81
|
+
std::ostringstream oss;
|
|
82
|
+
oss << std::endl << "Obtained " << (depth - kSkipFrames) << " stack frames" << std::endl;
|
|
83
|
+
char** strings = backtrace_symbols(stack, depth);
|
|
84
|
+
if (strings == nullptr) {
|
|
85
|
+
oss << "But no stack trace could be found!" << std::endl;
|
|
86
|
+
msg_ += oss.str();
|
|
87
|
+
return;
|
|
88
|
+
}
|
|
89
|
+
// Courtesy: https://panthema.net/2008/0901-stacktrace-demangled/
|
|
90
|
+
for (int i = kSkipFrames; i < depth; i++) {
|
|
91
|
+
oss << "#" << i << " in "; // beginning of the backtrace line
|
|
92
|
+
|
|
93
|
+
char* mangled_name = nullptr;
|
|
94
|
+
char* offset_begin = nullptr;
|
|
95
|
+
char* offset_end = nullptr;
|
|
96
|
+
auto backtrace_line = strings[i];
|
|
97
|
+
|
|
98
|
+
// Find parentheses and +address offset surrounding mangled name
|
|
99
|
+
// e.g. ./module(function+0x15c) [0x8048a6d]
|
|
100
|
+
for (char* p = backtrace_line; *p != 0; p++) {
|
|
101
|
+
if (*p == '(') {
|
|
102
|
+
mangled_name = p;
|
|
103
|
+
} else if (*p == '+') {
|
|
104
|
+
offset_begin = p;
|
|
105
|
+
} else if (*p == ')') {
|
|
106
|
+
offset_end = p;
|
|
107
|
+
break;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// Attempt to demangle the symbol
|
|
112
|
+
if (mangled_name != nullptr && offset_begin != nullptr && offset_end != nullptr &&
|
|
113
|
+
mangled_name + 1 < offset_begin) {
|
|
114
|
+
// Split the backtrace_line
|
|
115
|
+
*mangled_name++ = 0;
|
|
116
|
+
*offset_begin++ = 0;
|
|
117
|
+
*offset_end++ = 0;
|
|
118
|
+
|
|
119
|
+
// Demangle the name part
|
|
120
|
+
int status = 0;
|
|
121
|
+
char* real_name = abi::__cxa_demangle(mangled_name, nullptr, nullptr, &status);
|
|
122
|
+
|
|
123
|
+
if (status == 0) { // Success: substitute the real name
|
|
124
|
+
oss << backtrace_line << ": " << real_name << " +" << offset_begin << offset_end;
|
|
125
|
+
} else { // Couldn't demangle
|
|
126
|
+
oss << backtrace_line << ": " << mangled_name << " +" << offset_begin << offset_end;
|
|
127
|
+
}
|
|
128
|
+
free(real_name);
|
|
129
|
+
} else { // Couldn't match the symbol name
|
|
130
|
+
oss << backtrace_line;
|
|
131
|
+
}
|
|
132
|
+
oss << std::endl;
|
|
133
|
+
}
|
|
134
|
+
free(strings);
|
|
135
|
+
msg_ += oss.str();
|
|
136
|
+
#endif
|
|
137
|
+
}
|
|
138
|
+
};
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* @brief Exception thrown when logical precondition is violated.
|
|
142
|
+
*
|
|
143
|
+
* This exception should not be thrown directly and is instead thrown by the
|
|
144
|
+
* RAFT_EXPECTS and RAFT_FAIL macros.
|
|
145
|
+
*
|
|
146
|
+
*/
|
|
147
|
+
struct logic_error : public raft::exception {
|
|
148
|
+
explicit logic_error(char const* const message) : raft::exception(message) {}
|
|
149
|
+
explicit logic_error(std::string const& message) : raft::exception(message) {}
|
|
150
|
+
};
|
|
151
|
+
|
|
152
|
+
/**
|
|
153
|
+
* @brief Exception thrown when attempting to use CUDA features from a non-CUDA
|
|
154
|
+
* build
|
|
155
|
+
*
|
|
156
|
+
*/
|
|
157
|
+
struct non_cuda_build_error : public raft::exception {
|
|
158
|
+
explicit non_cuda_build_error(char const* const message) : raft::exception(message) {}
|
|
159
|
+
explicit non_cuda_build_error(std::string const& message) : raft::exception(message) {}
|
|
160
|
+
};
|
|
161
|
+
|
|
162
|
+
/**
|
|
163
|
+
* @brief Exception thrown when a CUDA error is encountered.
|
|
164
|
+
*/
|
|
165
|
+
struct cuda_error : public raft::exception {
|
|
166
|
+
explicit cuda_error(char const* const message) : raft::exception(message) {}
|
|
167
|
+
explicit cuda_error(std::string const& message) : raft::exception(message) {}
|
|
168
|
+
};
|
|
169
|
+
|
|
170
|
+
/**
|
|
171
|
+
* @}
|
|
172
|
+
*/
|
|
173
|
+
|
|
174
|
+
} // namespace raft
|
|
175
|
+
|
|
176
|
+
// FIXME: Need to be replaced with RAFT_FAIL
|
|
177
|
+
/** macro to throw a runtime error */
|
|
178
|
+
#define THROW(fmt, ...) \
|
|
179
|
+
do { \
|
|
180
|
+
int size1 = \
|
|
181
|
+
std::snprintf(nullptr, 0, "exception occurred! file=%s line=%d: ", __FILE__, __LINE__); \
|
|
182
|
+
int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
|
|
183
|
+
if (size1 < 0 || size2 < 0) \
|
|
184
|
+
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
|
|
185
|
+
auto size = size1 + size2 + 1; /* +1 for final '\0' */ \
|
|
186
|
+
auto buf = std::make_unique<char[]>(size_t(size)); \
|
|
187
|
+
std::snprintf(buf.get(), \
|
|
188
|
+
size1 + 1 /* +1 for '\0' */, \
|
|
189
|
+
"exception occurred! file=%s line=%d: ", \
|
|
190
|
+
__FILE__, \
|
|
191
|
+
__LINE__); \
|
|
192
|
+
std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
|
|
193
|
+
std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
|
|
194
|
+
throw raft::exception(msg); \
|
|
195
|
+
} while (0)
|
|
196
|
+
|
|
197
|
+
// FIXME: Need to be replaced with RAFT_EXPECTS
|
|
198
|
+
/** macro to check for a conditional and assert on failure */
|
|
199
|
+
#define ASSERT(check, fmt, ...) \
|
|
200
|
+
do { \
|
|
201
|
+
if (!(check)) THROW(fmt, ##__VA_ARGS__); \
|
|
202
|
+
} while (0)
|
|
203
|
+
|
|
204
|
+
/**
|
|
205
|
+
* Macro to append error message to first argument.
|
|
206
|
+
* This should only be called in contexts where it is OK to throw exceptions!
|
|
207
|
+
*/
|
|
208
|
+
#define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
|
|
209
|
+
do { \
|
|
210
|
+
int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \
|
|
211
|
+
int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \
|
|
212
|
+
int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
|
|
213
|
+
if (size1 < 0 || size2 < 0 || size3 < 0) \
|
|
214
|
+
throw raft::exception("Error in snprintf, cannot handle raft exception."); \
|
|
215
|
+
auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \
|
|
216
|
+
std::vector<char> buf(size); \
|
|
217
|
+
std::snprintf(buf.data(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \
|
|
218
|
+
std::snprintf( \
|
|
219
|
+
buf.data() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \
|
|
220
|
+
std::snprintf(buf.data() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
|
|
221
|
+
msg += std::string(buf.data(), buf.data() + size - 1); /* -1 to remove final '\0' */ \
|
|
222
|
+
} while (0)
|
|
223
|
+
|
|
224
|
+
/**
|
|
225
|
+
* @defgroup assertion Assertion and error macros
|
|
226
|
+
* @{
|
|
227
|
+
*/
|
|
228
|
+
|
|
229
|
+
/**
|
|
230
|
+
* @brief Macro for checking (pre-)conditions that throws an exception when a condition is false
|
|
231
|
+
*
|
|
232
|
+
* @param[in] cond Expression that evaluates to true or false
|
|
233
|
+
* @param[in] fmt String literal description of the reason that cond is expected to be true with
|
|
234
|
+
* optional format tagas
|
|
235
|
+
* @throw raft::logic_error if the condition evaluates to false.
|
|
236
|
+
*/
|
|
237
|
+
#define RAFT_EXPECTS(cond, fmt, ...) \
|
|
238
|
+
do { \
|
|
239
|
+
if (!(cond)) { \
|
|
240
|
+
std::string msg{}; \
|
|
241
|
+
SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \
|
|
242
|
+
throw raft::logic_error(msg); \
|
|
243
|
+
} \
|
|
244
|
+
} while (0)
|
|
245
|
+
|
|
246
|
+
/**
|
|
247
|
+
* @brief Indicates that an erroneous code path has been taken.
|
|
248
|
+
*
|
|
249
|
+
* @param[in] fmt String literal description of the reason that this code path is erroneous with
|
|
250
|
+
* optional format tagas
|
|
251
|
+
* @throw always throws raft::logic_error
|
|
252
|
+
*/
|
|
253
|
+
#define RAFT_FAIL(fmt, ...) \
|
|
254
|
+
do { \
|
|
255
|
+
std::string msg{}; \
|
|
256
|
+
SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \
|
|
257
|
+
throw raft::logic_error(msg); \
|
|
258
|
+
} while (0)
|
|
259
|
+
|
|
260
|
+
/**
|
|
261
|
+
* @}
|
|
262
|
+
*/
|
|
263
|
+
|
|
264
|
+
#endif
|
|
265
|
+
|
|
266
|
+
/**
|
|
267
|
+
* @brief Error checking macro for CUDA runtime API functions.
|
|
268
|
+
*
|
|
269
|
+
* Invokes a CUDA runtime API function call, if the call does not return
|
|
270
|
+
* cudaSuccess, invokes cudaGetLastError() to clear the error and throws an
|
|
271
|
+
* exception detailing the CUDA error that occurred
|
|
272
|
+
*
|
|
273
|
+
*/
|
|
274
|
+
#define RAFT_CUDA_TRY(call) \
|
|
275
|
+
do { \
|
|
276
|
+
cudaError_t const status = call; \
|
|
277
|
+
if (status != cudaSuccess) { \
|
|
278
|
+
cudaGetLastError(); \
|
|
279
|
+
std::string msg{}; \
|
|
280
|
+
SET_ERROR_MSG(msg, \
|
|
281
|
+
"CUDA error encountered at: ", \
|
|
282
|
+
"call='%s', Reason=%s:%s", \
|
|
283
|
+
#call, \
|
|
284
|
+
cudaGetErrorName(status), \
|
|
285
|
+
cudaGetErrorString(status)); \
|
|
286
|
+
throw raft::cuda_error(msg); \
|
|
287
|
+
} \
|
|
288
|
+
} while (0)
|
|
289
|
+
|
|
290
|
+
/**
|
|
291
|
+
* @brief Debug macro to check for CUDA errors
|
|
292
|
+
*
|
|
293
|
+
* In a non-release build, this macro will synchronize the specified stream
|
|
294
|
+
* before error checking. In both release and non-release builds, this macro
|
|
295
|
+
* checks for any pending CUDA errors from previous calls. If an error is
|
|
296
|
+
* reported, an exception is thrown detailing the CUDA error that occurred.
|
|
297
|
+
*
|
|
298
|
+
* The intent of this macro is to provide a mechanism for synchronous and
|
|
299
|
+
* deterministic execution for debugging asynchronous CUDA execution. It should
|
|
300
|
+
* be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an
|
|
301
|
+
* asynchronous kernel launch.
|
|
302
|
+
*/
|
|
303
|
+
#ifndef NDEBUG
|
|
304
|
+
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
|
|
305
|
+
#else
|
|
306
|
+
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaPeekAtLastError());
|
|
307
|
+
#endif
|
|
308
|
+
|
|
309
|
+
// /**
|
|
310
|
+
// * @brief check for cuda runtime API errors but log error instead of raising
|
|
311
|
+
// * exception.
|
|
312
|
+
// */
|
|
313
|
+
#define RAFT_CUDA_TRY_NO_THROW(call) \
|
|
314
|
+
do { \
|
|
315
|
+
cudaError_t const status = call; \
|
|
316
|
+
if (cudaSuccess != status) { \
|
|
317
|
+
printf("CUDA call='%s' at file=%s line=%d failed with %s\n", \
|
|
318
|
+
#call, \
|
|
319
|
+
__FILE__, \
|
|
320
|
+
__LINE__, \
|
|
321
|
+
cudaGetErrorString(status)); \
|
|
322
|
+
} \
|
|
323
|
+
} while (0)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace kernelcatcher::utils {
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* @brief Push a named nvtx range
|
|
15
|
+
* @param name range name
|
|
16
|
+
*/
|
|
17
|
+
void push_range(const char* name);
|
|
18
|
+
|
|
19
|
+
/** Pop the latest range */
|
|
20
|
+
void pop_range();
|
|
21
|
+
|
|
22
|
+
struct range_guard {
|
|
23
|
+
range_guard(const char* name) { push_range(name); }
|
|
24
|
+
~range_guard() { pop_range(); }
|
|
25
|
+
range_guard(range_guard const&) = delete;
|
|
26
|
+
range_guard& operator=(range_guard const&) = delete;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace kernelcatcher::utils {
|
|
12
|
+
|
|
13
|
+
enum class BatchDimension : int { kBatched = 0, kShared = 1, kIndexed = 2 };
|
|
14
|
+
|
|
15
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <iostream>
|
|
12
|
+
|
|
13
|
+
namespace kernelcatcher::utils {
|
|
14
|
+
|
|
15
|
+
enum class Datatype : int {
|
|
16
|
+
kFloat32 = 0,
|
|
17
|
+
kFloat64 = 1,
|
|
18
|
+
kFloat16 = 2,
|
|
19
|
+
kBFloat16 = 3,
|
|
20
|
+
kInt32 = 4,
|
|
21
|
+
kInt64 = 5
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
inline int size_of(Datatype dtype)
|
|
25
|
+
{
|
|
26
|
+
switch (dtype) {
|
|
27
|
+
case Datatype::kFloat32: return 4;
|
|
28
|
+
case Datatype::kFloat64: return 8;
|
|
29
|
+
case Datatype::kFloat16: return 2;
|
|
30
|
+
case Datatype::kBFloat16: return 2;
|
|
31
|
+
case Datatype::kInt32: return 4;
|
|
32
|
+
case Datatype::kInt64: return 8;
|
|
33
|
+
default: return -1;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
inline std::ostream& operator<<(std::ostream& s, Datatype const& d)
|
|
38
|
+
{
|
|
39
|
+
switch (d) {
|
|
40
|
+
case Datatype::kFloat32: return s << "float";
|
|
41
|
+
case Datatype::kFloat64: return s << "double";
|
|
42
|
+
case Datatype::kFloat16: return s << "k_fp16";
|
|
43
|
+
case Datatype::kBFloat16: return s << "k_bf16";
|
|
44
|
+
case Datatype::kInt32: return s << "kc_int32";
|
|
45
|
+
case Datatype::kInt64: return s << "kc_int64";
|
|
46
|
+
}
|
|
47
|
+
return s << "unknown_datatype";
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
inline bool is_real(Datatype const& d)
|
|
51
|
+
{
|
|
52
|
+
switch (d) {
|
|
53
|
+
case Datatype::kFloat32:
|
|
54
|
+
case Datatype::kFloat64:
|
|
55
|
+
case Datatype::kFloat16:
|
|
56
|
+
case Datatype::kBFloat16: return true;
|
|
57
|
+
case Datatype::kInt32:
|
|
58
|
+
case Datatype::kInt64: return false;
|
|
59
|
+
}
|
|
60
|
+
return false; // Default case, should not be reached
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
inline bool is_integral(Datatype const& d) { return !is_real(d); }
|
|
64
|
+
|
|
65
|
+
} // namespace kernelcatcher::utils
|