cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,323 @@
1
+ /*
2
+ * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef __RAFT_RT_ERROR
18
+ #define __RAFT_RT_ERROR
19
+
20
+ #pragma once
21
+
22
+ #if defined(__GNUC__) && __has_include(<cxxabi.h>) && __has_include(<execinfo.h>)
23
+ #define ENABLE_COLLECT_CALLSTACK
24
+ #endif
25
+
26
+ #include <cstdio>
27
+ #include <iostream>
28
+ #include <memory>
29
+ #include <stdexcept>
30
+ #include <string>
31
+ #include <vector>
32
+
33
+ #ifdef ENABLE_COLLECT_CALLSTACK
34
+ #include <cxxabi.h>
35
+ #include <execinfo.h>
36
+
37
+ #include <sstream>
38
+ #endif
39
+
40
+ namespace raft {
41
+
42
+ /**
43
+ * @defgroup error_handling Exceptions & Error Handling
44
+ * @{
45
+ */
46
+
47
+ /** base exception class for the whole of raft */
48
+ class exception : public std::exception {
49
+ public:
50
+ /** default ctor */
51
+ explicit exception() noexcept : std::exception(), msg_() {}
52
+
53
+ /** copy ctor */
54
+ exception(exception const& src) noexcept : std::exception(), msg_(src.what())
55
+ {
56
+ collect_call_stack();
57
+ }
58
+
59
+ /** ctor from an input message */
60
+ explicit exception(std::string const msg) noexcept : std::exception(), msg_(std::move(msg))
61
+ {
62
+ collect_call_stack();
63
+ }
64
+
65
+ /** get the message associated with this exception */
66
+ char const* what() const noexcept override { return msg_.c_str(); }
67
+
68
+ private:
69
+ /** message associated with this exception */
70
+ std::string msg_;
71
+
72
+ /** append call stack info to this exception's message for ease of debug */
73
+ // Courtesy: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html
74
+ void collect_call_stack() noexcept
75
+ {
76
+ #ifdef ENABLE_COLLECT_CALLSTACK
77
+ constexpr int kSkipFrames = 1;
78
+ constexpr int kMaxStackDepth = 64;
79
+ void* stack[kMaxStackDepth]; // NOLINT
80
+ auto depth = backtrace(stack, kMaxStackDepth);
81
+ std::ostringstream oss;
82
+ oss << std::endl << "Obtained " << (depth - kSkipFrames) << " stack frames" << std::endl;
83
+ char** strings = backtrace_symbols(stack, depth);
84
+ if (strings == nullptr) {
85
+ oss << "But no stack trace could be found!" << std::endl;
86
+ msg_ += oss.str();
87
+ return;
88
+ }
89
+ // Courtesy: https://panthema.net/2008/0901-stacktrace-demangled/
90
+ for (int i = kSkipFrames; i < depth; i++) {
91
+ oss << "#" << i << " in "; // beginning of the backtrace line
92
+
93
+ char* mangled_name = nullptr;
94
+ char* offset_begin = nullptr;
95
+ char* offset_end = nullptr;
96
+ auto backtrace_line = strings[i];
97
+
98
+ // Find parentheses and +address offset surrounding mangled name
99
+ // e.g. ./module(function+0x15c) [0x8048a6d]
100
+ for (char* p = backtrace_line; *p != 0; p++) {
101
+ if (*p == '(') {
102
+ mangled_name = p;
103
+ } else if (*p == '+') {
104
+ offset_begin = p;
105
+ } else if (*p == ')') {
106
+ offset_end = p;
107
+ break;
108
+ }
109
+ }
110
+
111
+ // Attempt to demangle the symbol
112
+ if (mangled_name != nullptr && offset_begin != nullptr && offset_end != nullptr &&
113
+ mangled_name + 1 < offset_begin) {
114
+ // Split the backtrace_line
115
+ *mangled_name++ = 0;
116
+ *offset_begin++ = 0;
117
+ *offset_end++ = 0;
118
+
119
+ // Demangle the name part
120
+ int status = 0;
121
+ char* real_name = abi::__cxa_demangle(mangled_name, nullptr, nullptr, &status);
122
+
123
+ if (status == 0) { // Success: substitute the real name
124
+ oss << backtrace_line << ": " << real_name << " +" << offset_begin << offset_end;
125
+ } else { // Couldn't demangle
126
+ oss << backtrace_line << ": " << mangled_name << " +" << offset_begin << offset_end;
127
+ }
128
+ free(real_name);
129
+ } else { // Couldn't match the symbol name
130
+ oss << backtrace_line;
131
+ }
132
+ oss << std::endl;
133
+ }
134
+ free(strings);
135
+ msg_ += oss.str();
136
+ #endif
137
+ }
138
+ };
139
+
140
+ /**
141
+ * @brief Exception thrown when logical precondition is violated.
142
+ *
143
+ * This exception should not be thrown directly and is instead thrown by the
144
+ * RAFT_EXPECTS and RAFT_FAIL macros.
145
+ *
146
+ */
147
+ struct logic_error : public raft::exception {
148
+ explicit logic_error(char const* const message) : raft::exception(message) {}
149
+ explicit logic_error(std::string const& message) : raft::exception(message) {}
150
+ };
151
+
152
+ /**
153
+ * @brief Exception thrown when attempting to use CUDA features from a non-CUDA
154
+ * build
155
+ *
156
+ */
157
+ struct non_cuda_build_error : public raft::exception {
158
+ explicit non_cuda_build_error(char const* const message) : raft::exception(message) {}
159
+ explicit non_cuda_build_error(std::string const& message) : raft::exception(message) {}
160
+ };
161
+
162
+ /**
163
+ * @brief Exception thrown when a CUDA error is encountered.
164
+ */
165
+ struct cuda_error : public raft::exception {
166
+ explicit cuda_error(char const* const message) : raft::exception(message) {}
167
+ explicit cuda_error(std::string const& message) : raft::exception(message) {}
168
+ };
169
+
170
+ /**
171
+ * @}
172
+ */
173
+
174
+ } // namespace raft
175
+
176
+ // FIXME: Need to be replaced with RAFT_FAIL
177
+ /** macro to throw a runtime error */
178
+ #define THROW(fmt, ...) \
179
+ do { \
180
+ int size1 = \
181
+ std::snprintf(nullptr, 0, "exception occurred! file=%s line=%d: ", __FILE__, __LINE__); \
182
+ int size2 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
183
+ if (size1 < 0 || size2 < 0) \
184
+ throw raft::exception("Error in snprintf, cannot handle raft exception."); \
185
+ auto size = size1 + size2 + 1; /* +1 for final '\0' */ \
186
+ auto buf = std::make_unique<char[]>(size_t(size)); \
187
+ std::snprintf(buf.get(), \
188
+ size1 + 1 /* +1 for '\0' */, \
189
+ "exception occurred! file=%s line=%d: ", \
190
+ __FILE__, \
191
+ __LINE__); \
192
+ std::snprintf(buf.get() + size1, size2 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
193
+ std::string msg(buf.get(), buf.get() + size - 1); /* -1 to remove final '\0' */ \
194
+ throw raft::exception(msg); \
195
+ } while (0)
196
+
197
+ // FIXME: Need to be replaced with RAFT_EXPECTS
198
+ /** macro to check for a conditional and assert on failure */
199
+ #define ASSERT(check, fmt, ...) \
200
+ do { \
201
+ if (!(check)) THROW(fmt, ##__VA_ARGS__); \
202
+ } while (0)
203
+
204
+ /**
205
+ * Macro to append error message to first argument.
206
+ * This should only be called in contexts where it is OK to throw exceptions!
207
+ */
208
+ #define SET_ERROR_MSG(msg, location_prefix, fmt, ...) \
209
+ do { \
210
+ int size1 = std::snprintf(nullptr, 0, "%s", location_prefix); \
211
+ int size2 = std::snprintf(nullptr, 0, "file=%s line=%d: ", __FILE__, __LINE__); \
212
+ int size3 = std::snprintf(nullptr, 0, fmt, ##__VA_ARGS__); \
213
+ if (size1 < 0 || size2 < 0 || size3 < 0) \
214
+ throw raft::exception("Error in snprintf, cannot handle raft exception."); \
215
+ auto size = size1 + size2 + size3 + 1; /* +1 for final '\0' */ \
216
+ std::vector<char> buf(size); \
217
+ std::snprintf(buf.data(), size1 + 1 /* +1 for '\0' */, "%s", location_prefix); \
218
+ std::snprintf( \
219
+ buf.data() + size1, size2 + 1 /* +1 for '\0' */, "file=%s line=%d: ", __FILE__, __LINE__); \
220
+ std::snprintf(buf.data() + size1 + size2, size3 + 1 /* +1 for '\0' */, fmt, ##__VA_ARGS__); \
221
+ msg += std::string(buf.data(), buf.data() + size - 1); /* -1 to remove final '\0' */ \
222
+ } while (0)
223
+
224
+ /**
225
+ * @defgroup assertion Assertion and error macros
226
+ * @{
227
+ */
228
+
229
+ /**
230
+ * @brief Macro for checking (pre-)conditions that throws an exception when a condition is false
231
+ *
232
+ * @param[in] cond Expression that evaluates to true or false
233
+ * @param[in] fmt String literal description of the reason that cond is expected to be true with
234
+ * optional format tagas
235
+ * @throw raft::logic_error if the condition evaluates to false.
236
+ */
237
+ #define RAFT_EXPECTS(cond, fmt, ...) \
238
+ do { \
239
+ if (!(cond)) { \
240
+ std::string msg{}; \
241
+ SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \
242
+ throw raft::logic_error(msg); \
243
+ } \
244
+ } while (0)
245
+
246
+ /**
247
+ * @brief Indicates that an erroneous code path has been taken.
248
+ *
249
+ * @param[in] fmt String literal description of the reason that this code path is erroneous with
250
+ * optional format tagas
251
+ * @throw always throws raft::logic_error
252
+ */
253
+ #define RAFT_FAIL(fmt, ...) \
254
+ do { \
255
+ std::string msg{}; \
256
+ SET_ERROR_MSG(msg, "RAFT failure at ", fmt, ##__VA_ARGS__); \
257
+ throw raft::logic_error(msg); \
258
+ } while (0)
259
+
260
+ /**
261
+ * @}
262
+ */
263
+
264
+ #endif
265
+
266
+ /**
267
+ * @brief Error checking macro for CUDA runtime API functions.
268
+ *
269
+ * Invokes a CUDA runtime API function call, if the call does not return
270
+ * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an
271
+ * exception detailing the CUDA error that occurred
272
+ *
273
+ */
274
+ #define RAFT_CUDA_TRY(call) \
275
+ do { \
276
+ cudaError_t const status = call; \
277
+ if (status != cudaSuccess) { \
278
+ cudaGetLastError(); \
279
+ std::string msg{}; \
280
+ SET_ERROR_MSG(msg, \
281
+ "CUDA error encountered at: ", \
282
+ "call='%s', Reason=%s:%s", \
283
+ #call, \
284
+ cudaGetErrorName(status), \
285
+ cudaGetErrorString(status)); \
286
+ throw raft::cuda_error(msg); \
287
+ } \
288
+ } while (0)
289
+
290
+ /**
291
+ * @brief Debug macro to check for CUDA errors
292
+ *
293
+ * In a non-release build, this macro will synchronize the specified stream
294
+ * before error checking. In both release and non-release builds, this macro
295
+ * checks for any pending CUDA errors from previous calls. If an error is
296
+ * reported, an exception is thrown detailing the CUDA error that occurred.
297
+ *
298
+ * The intent of this macro is to provide a mechanism for synchronous and
299
+ * deterministic execution for debugging asynchronous CUDA execution. It should
300
+ * be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an
301
+ * asynchronous kernel launch.
302
+ */
303
+ #ifndef NDEBUG
304
+ #define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
305
+ #else
306
+ #define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaPeekAtLastError());
307
+ #endif
308
+
309
+ // /**
310
+ // * @brief check for cuda runtime API errors but log error instead of raising
311
+ // * exception.
312
+ // */
313
+ #define RAFT_CUDA_TRY_NO_THROW(call) \
314
+ do { \
315
+ cudaError_t const status = call; \
316
+ if (cudaSuccess != status) { \
317
+ printf("CUDA call='%s' at file=%s line=%d failed with %s\n", \
318
+ #call, \
319
+ __FILE__, \
320
+ __LINE__, \
321
+ cudaGetErrorString(status)); \
322
+ } \
323
+ } while (0)
@@ -0,0 +1,29 @@
1
+ /*
2
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ namespace kernelcatcher::utils {
12
+
13
+ /**
14
+ * @brief Push a named nvtx range
15
+ * @param name range name
16
+ */
17
+ void push_range(const char* name);
18
+
19
+ /** Pop the latest range */
20
+ void pop_range();
21
+
22
+ struct range_guard {
23
+ range_guard(const char* name) { push_range(name); }
24
+ ~range_guard() { pop_range(); }
25
+ range_guard(range_guard const&) = delete;
26
+ range_guard& operator=(range_guard const&) = delete;
27
+ };
28
+
29
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,15 @@
1
+ /*
2
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ namespace kernelcatcher::utils {
12
+
13
+ enum class BatchDimension : int { kBatched = 0, kShared = 1, kIndexed = 2 };
14
+
15
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,65 @@
1
+ /*
2
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <iostream>
12
+
13
+ namespace kernelcatcher::utils {
14
+
15
+ enum class Datatype : int {
16
+ kFloat32 = 0,
17
+ kFloat64 = 1,
18
+ kFloat16 = 2,
19
+ kBFloat16 = 3,
20
+ kInt32 = 4,
21
+ kInt64 = 5
22
+ };
23
+
24
+ inline int size_of(Datatype dtype)
25
+ {
26
+ switch (dtype) {
27
+ case Datatype::kFloat32: return 4;
28
+ case Datatype::kFloat64: return 8;
29
+ case Datatype::kFloat16: return 2;
30
+ case Datatype::kBFloat16: return 2;
31
+ case Datatype::kInt32: return 4;
32
+ case Datatype::kInt64: return 8;
33
+ default: return -1;
34
+ }
35
+ }
36
+
37
+ inline std::ostream& operator<<(std::ostream& s, Datatype const& d)
38
+ {
39
+ switch (d) {
40
+ case Datatype::kFloat32: return s << "float";
41
+ case Datatype::kFloat64: return s << "double";
42
+ case Datatype::kFloat16: return s << "k_fp16";
43
+ case Datatype::kBFloat16: return s << "k_bf16";
44
+ case Datatype::kInt32: return s << "kc_int32";
45
+ case Datatype::kInt64: return s << "kc_int64";
46
+ }
47
+ return s << "unknown_datatype";
48
+ }
49
+
50
+ inline bool is_real(Datatype const& d)
51
+ {
52
+ switch (d) {
53
+ case Datatype::kFloat32:
54
+ case Datatype::kFloat64:
55
+ case Datatype::kFloat16:
56
+ case Datatype::kBFloat16: return true;
57
+ case Datatype::kInt32:
58
+ case Datatype::kInt64: return false;
59
+ }
60
+ return false; // Default case, should not be reached
61
+ }
62
+
63
+ inline bool is_integral(Datatype const& d) { return !is_real(d); }
64
+
65
+ } // namespace kernelcatcher::utils