cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/cudart.hpp +286 -0
- cuequivariance_ops/common/error.hpp +66 -0
- cuequivariance_ops/common/error_raft.hpp +323 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
- cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/gpu_timing_kernels.hh +42 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +40 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +336 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
- cuequivariance_ops/triton/pair_bias.py +365 -0
- cuequivariance_ops/triton/tuning_decorator.py +188 -0
- cuequivariance_ops/triton/utils.py +29 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "../common/common.hpp"
|
|
12
|
+
|
|
13
|
+
#include <algorithm>
|
|
14
|
+
#include <limits>
|
|
15
|
+
|
|
16
|
+
namespace kernelcatcher::tensor_product {
|
|
17
|
+
|
|
18
|
+
struct __attribute__((aligned(16))) tp_data_sizes {
|
|
19
|
+
int64_t batch_size;
|
|
20
|
+
bool shared_a;
|
|
21
|
+
bool shared_b;
|
|
22
|
+
bool shared_w;
|
|
23
|
+
int32_t stride_a;
|
|
24
|
+
int32_t stride_b;
|
|
25
|
+
int32_t stride_w;
|
|
26
|
+
int32_t stride_o;
|
|
27
|
+
}; // struct tp_data_sizes
|
|
28
|
+
|
|
29
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
30
|
+
void fused_tensor_product_fwd(DataOutT* out,
|
|
31
|
+
const DataAT* in_a,
|
|
32
|
+
const DataBT* in_b,
|
|
33
|
+
const DataWeightT* weight,
|
|
34
|
+
ConnectionModeT mode,
|
|
35
|
+
const tp_info<MathT>& info,
|
|
36
|
+
const tp_data_sizes& sizes,
|
|
37
|
+
cudaStream_t stream);
|
|
38
|
+
|
|
39
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
40
|
+
void fused_tensor_product_bwd(DataAT* grad_in_a,
|
|
41
|
+
DataBT* grad_in_b,
|
|
42
|
+
DataWeightT* grad_weight,
|
|
43
|
+
const DataOutT* grad_out,
|
|
44
|
+
const DataAT* in_a,
|
|
45
|
+
const DataBT* in_b,
|
|
46
|
+
const DataWeightT* weight,
|
|
47
|
+
ConnectionModeT mode,
|
|
48
|
+
const tp_info<MathT>& info_bwd_dgrad_a,
|
|
49
|
+
const tp_info<MathT>& info_bwd_dgrad_b,
|
|
50
|
+
const tp_info<MathT>& info_bwd_dgrad_w,
|
|
51
|
+
const tp_data_sizes& sizes,
|
|
52
|
+
cudaStream_t stream);
|
|
53
|
+
|
|
54
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
55
|
+
void fused_tensor_product_bwd_bwd(DataAT* grad_in_a,
|
|
56
|
+
DataBT* grad_in_b,
|
|
57
|
+
DataWeightT* grad_weight,
|
|
58
|
+
DataOutT* grad_grad_out,
|
|
59
|
+
const DataAT* grad_grad_in_a,
|
|
60
|
+
const DataBT* grad_grad_in_b,
|
|
61
|
+
const DataWeightT* grad_grad_weight,
|
|
62
|
+
const DataOutT* grad_out,
|
|
63
|
+
const DataAT* in_a,
|
|
64
|
+
const DataBT* in_b,
|
|
65
|
+
const DataWeightT* weight,
|
|
66
|
+
ConnectionModeT mode,
|
|
67
|
+
const tp_info<MathT>& info_fwd,
|
|
68
|
+
const tp_info<MathT>& info_bwd_dgrad_a,
|
|
69
|
+
const tp_info<MathT>& info_bwd_dgrad_b,
|
|
70
|
+
const tp_info<MathT>& info_bwd_dgrad_w,
|
|
71
|
+
const tp_data_sizes& sizes,
|
|
72
|
+
cudaStream_t stream);
|
|
73
|
+
|
|
74
|
+
extern template void fused_tensor_product_bwd_bwd<float, float, float, float, float>(
|
|
75
|
+
float*,
|
|
76
|
+
float*,
|
|
77
|
+
float*,
|
|
78
|
+
float*,
|
|
79
|
+
const float*,
|
|
80
|
+
const float*,
|
|
81
|
+
const float*,
|
|
82
|
+
const float*,
|
|
83
|
+
const float*,
|
|
84
|
+
const float*,
|
|
85
|
+
const float*,
|
|
86
|
+
ConnectionModeT,
|
|
87
|
+
const tp_info<float>&,
|
|
88
|
+
const tp_info<float>&,
|
|
89
|
+
const tp_info<float>&,
|
|
90
|
+
const tp_info<float>&,
|
|
91
|
+
const tp_data_sizes&,
|
|
92
|
+
cudaStream_t);
|
|
93
|
+
|
|
94
|
+
extern template void fused_tensor_product_bwd_bwd<float, float, float, float, double>(
|
|
95
|
+
float*,
|
|
96
|
+
float*,
|
|
97
|
+
float*,
|
|
98
|
+
float*,
|
|
99
|
+
const float*,
|
|
100
|
+
const float*,
|
|
101
|
+
const float*,
|
|
102
|
+
const float*,
|
|
103
|
+
const float*,
|
|
104
|
+
const float*,
|
|
105
|
+
const float*,
|
|
106
|
+
ConnectionModeT,
|
|
107
|
+
const tp_info<double>&,
|
|
108
|
+
const tp_info<double>&,
|
|
109
|
+
const tp_info<double>&,
|
|
110
|
+
const tp_info<double>&,
|
|
111
|
+
const tp_data_sizes&,
|
|
112
|
+
cudaStream_t);
|
|
113
|
+
|
|
114
|
+
extern template void fused_tensor_product_bwd_bwd<double, double, double, double, double>(
|
|
115
|
+
double*,
|
|
116
|
+
double*,
|
|
117
|
+
double*,
|
|
118
|
+
double*,
|
|
119
|
+
const double*,
|
|
120
|
+
const double*,
|
|
121
|
+
const double*,
|
|
122
|
+
const double*,
|
|
123
|
+
const double*,
|
|
124
|
+
const double*,
|
|
125
|
+
const double*,
|
|
126
|
+
ConnectionModeT,
|
|
127
|
+
const tp_info<double>&,
|
|
128
|
+
const tp_info<double>&,
|
|
129
|
+
const tp_info<double>&,
|
|
130
|
+
const tp_info<double>&,
|
|
131
|
+
const tp_data_sizes&,
|
|
132
|
+
cudaStream_t);
|
|
133
|
+
|
|
134
|
+
extern template void fused_tensor_product_bwd_bwd<__half, __half, __half, __half, float>(
|
|
135
|
+
__half*,
|
|
136
|
+
__half*,
|
|
137
|
+
__half*,
|
|
138
|
+
__half*,
|
|
139
|
+
const __half*,
|
|
140
|
+
const __half*,
|
|
141
|
+
const __half*,
|
|
142
|
+
const __half*,
|
|
143
|
+
const __half*,
|
|
144
|
+
const __half*,
|
|
145
|
+
const __half*,
|
|
146
|
+
ConnectionModeT,
|
|
147
|
+
const tp_info<float>&,
|
|
148
|
+
const tp_info<float>&,
|
|
149
|
+
const tp_info<float>&,
|
|
150
|
+
const tp_info<float>&,
|
|
151
|
+
const tp_data_sizes&,
|
|
152
|
+
cudaStream_t);
|
|
153
|
+
extern template void
|
|
154
|
+
fused_tensor_product_bwd_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
155
|
+
__nv_bfloat16*,
|
|
156
|
+
__nv_bfloat16*,
|
|
157
|
+
__nv_bfloat16*,
|
|
158
|
+
__nv_bfloat16*,
|
|
159
|
+
const __nv_bfloat16*,
|
|
160
|
+
const __nv_bfloat16*,
|
|
161
|
+
const __nv_bfloat16*,
|
|
162
|
+
const __nv_bfloat16*,
|
|
163
|
+
const __nv_bfloat16*,
|
|
164
|
+
const __nv_bfloat16*,
|
|
165
|
+
const __nv_bfloat16*,
|
|
166
|
+
ConnectionModeT,
|
|
167
|
+
const tp_info<float>&,
|
|
168
|
+
const tp_info<float>&,
|
|
169
|
+
const tp_info<float>&,
|
|
170
|
+
const tp_info<float>&,
|
|
171
|
+
const tp_data_sizes&,
|
|
172
|
+
cudaStream_t);
|
|
173
|
+
|
|
174
|
+
extern template void fused_tensor_product_bwd<float, float, float, float, float>(
|
|
175
|
+
float*,
|
|
176
|
+
float*,
|
|
177
|
+
float*,
|
|
178
|
+
const float*,
|
|
179
|
+
const float*,
|
|
180
|
+
const float*,
|
|
181
|
+
const float*,
|
|
182
|
+
ConnectionModeT,
|
|
183
|
+
const tp_info<float>&,
|
|
184
|
+
const tp_info<float>&,
|
|
185
|
+
const tp_info<float>&,
|
|
186
|
+
const tp_data_sizes&,
|
|
187
|
+
cudaStream_t);
|
|
188
|
+
|
|
189
|
+
extern template void fused_tensor_product_bwd<float, float, float, float, double>(
|
|
190
|
+
float*,
|
|
191
|
+
float*,
|
|
192
|
+
float*,
|
|
193
|
+
const float*,
|
|
194
|
+
const float*,
|
|
195
|
+
const float*,
|
|
196
|
+
const float*,
|
|
197
|
+
ConnectionModeT,
|
|
198
|
+
const tp_info<double>&,
|
|
199
|
+
const tp_info<double>&,
|
|
200
|
+
const tp_info<double>&,
|
|
201
|
+
const tp_data_sizes&,
|
|
202
|
+
cudaStream_t);
|
|
203
|
+
|
|
204
|
+
extern template void fused_tensor_product_bwd<double, double, double, double, double>(
|
|
205
|
+
double*,
|
|
206
|
+
double*,
|
|
207
|
+
double*,
|
|
208
|
+
const double*,
|
|
209
|
+
const double*,
|
|
210
|
+
const double*,
|
|
211
|
+
const double*,
|
|
212
|
+
ConnectionModeT,
|
|
213
|
+
const tp_info<double>&,
|
|
214
|
+
const tp_info<double>&,
|
|
215
|
+
const tp_info<double>&,
|
|
216
|
+
const tp_data_sizes&,
|
|
217
|
+
cudaStream_t);
|
|
218
|
+
|
|
219
|
+
extern template void fused_tensor_product_bwd<__half, __half, __half, __half, float>(
|
|
220
|
+
__half*,
|
|
221
|
+
__half*,
|
|
222
|
+
__half*,
|
|
223
|
+
const __half*,
|
|
224
|
+
const __half*,
|
|
225
|
+
const __half*,
|
|
226
|
+
const __half*,
|
|
227
|
+
ConnectionModeT,
|
|
228
|
+
const tp_info<float>&,
|
|
229
|
+
const tp_info<float>&,
|
|
230
|
+
const tp_info<float>&,
|
|
231
|
+
const tp_data_sizes&,
|
|
232
|
+
cudaStream_t);
|
|
233
|
+
extern template void
|
|
234
|
+
fused_tensor_product_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
235
|
+
__nv_bfloat16*,
|
|
236
|
+
__nv_bfloat16*,
|
|
237
|
+
__nv_bfloat16*,
|
|
238
|
+
const __nv_bfloat16*,
|
|
239
|
+
const __nv_bfloat16*,
|
|
240
|
+
const __nv_bfloat16*,
|
|
241
|
+
const __nv_bfloat16*,
|
|
242
|
+
ConnectionModeT,
|
|
243
|
+
const tp_info<float>&,
|
|
244
|
+
const tp_info<float>&,
|
|
245
|
+
const tp_info<float>&,
|
|
246
|
+
const tp_data_sizes&,
|
|
247
|
+
cudaStream_t);
|
|
248
|
+
|
|
249
|
+
extern template void fused_tensor_product_fwd<float, float, float, float, float>(
|
|
250
|
+
float*,
|
|
251
|
+
const float*,
|
|
252
|
+
const float*,
|
|
253
|
+
const float*,
|
|
254
|
+
ConnectionModeT,
|
|
255
|
+
const tp_info<float>&,
|
|
256
|
+
const tp_data_sizes&,
|
|
257
|
+
cudaStream_t);
|
|
258
|
+
extern template void fused_tensor_product_fwd<float, float, float, float, double>(
|
|
259
|
+
float*,
|
|
260
|
+
const float*,
|
|
261
|
+
const float*,
|
|
262
|
+
const float*,
|
|
263
|
+
ConnectionModeT,
|
|
264
|
+
const tp_info<double>&,
|
|
265
|
+
const tp_data_sizes&,
|
|
266
|
+
cudaStream_t);
|
|
267
|
+
extern template void fused_tensor_product_fwd<double, double, double, double, double>(
|
|
268
|
+
double*,
|
|
269
|
+
const double*,
|
|
270
|
+
const double*,
|
|
271
|
+
const double*,
|
|
272
|
+
ConnectionModeT,
|
|
273
|
+
const tp_info<double>&,
|
|
274
|
+
const tp_data_sizes&,
|
|
275
|
+
cudaStream_t);
|
|
276
|
+
|
|
277
|
+
extern template void fused_tensor_product_fwd<__half, __half, __half, __half, float>(
|
|
278
|
+
__half*,
|
|
279
|
+
const __half*,
|
|
280
|
+
const __half*,
|
|
281
|
+
const __half*,
|
|
282
|
+
ConnectionModeT,
|
|
283
|
+
const tp_info<float>&,
|
|
284
|
+
const tp_data_sizes&,
|
|
285
|
+
cudaStream_t);
|
|
286
|
+
extern template void
|
|
287
|
+
fused_tensor_product_fwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
288
|
+
__nv_bfloat16*,
|
|
289
|
+
const __nv_bfloat16*,
|
|
290
|
+
const __nv_bfloat16*,
|
|
291
|
+
const __nv_bfloat16*,
|
|
292
|
+
ConnectionModeT,
|
|
293
|
+
const tp_info<float>&,
|
|
294
|
+
const tp_data_sizes&,
|
|
295
|
+
cudaStream_t);
|
|
296
|
+
|
|
297
|
+
} // namespace kernelcatcher::tensor_product
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "dtypes.hh" // for Datatype
|
|
12
|
+
|
|
13
|
+
#ifndef CUEQUIVARIANCE_OPS_WITH_CUBLAS
|
|
14
|
+
#pragma message( \
|
|
15
|
+
"indexed_linear functions require cuBLAS support. Set CUEQUIVARIANCE_OPS_WITH_CUBLAS=1 environment variable before building to enable indexed_linear functionality. When cuBLAS is disabled, indexed_linear functions will return error codes.")
|
|
16
|
+
#endif
|
|
17
|
+
|
|
18
|
+
namespace kernelcatcher::equivariance::indexed_linear {
|
|
19
|
+
using namespace kernelcatcher::utils; // for Datatype
|
|
20
|
+
|
|
21
|
+
#define KC_INDEXED_LINEAR_DECL_ARGUMENTS \
|
|
22
|
+
const void *ptr_A, const void *ptr_B, const int *counts, void *ptr_C, Datatype dtype_A, \
|
|
23
|
+
Datatype dtype_B, Datatype dtype_D, int Z, int C, int u, int v, double coefficient, \
|
|
24
|
+
int compute_type, void *workspace, size_t workspace_size, void *stream
|
|
25
|
+
|
|
26
|
+
#define KC_INDEXED_LINEAR_ARGUMENTS \
|
|
27
|
+
ptr_A, ptr_B, counts, ptr_C, dtype_A, dtype_B, dtype_D, Z, C, u, v, coefficient, compute_type, \
|
|
28
|
+
workspace, workspace_size, stream
|
|
29
|
+
|
|
30
|
+
int run_indexed_linear_B( // ptr_A Zu
|
|
31
|
+
// ptr_B Cuv or Cvu
|
|
32
|
+
// ptr_C Zv
|
|
33
|
+
bool transpose_B,
|
|
34
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
35
|
+
|
|
36
|
+
int run_indexed_linear_C( // ptr_A Zu
|
|
37
|
+
// ptr_B Zv
|
|
38
|
+
// ptr_C Cuv
|
|
39
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
40
|
+
|
|
41
|
+
} // namespace kernelcatcher::equivariance::indexed_linear
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#ifndef CUDNN_FMHA_RUN_FMHA_H
|
|
2
|
+
#define CUDNN_FMHA_RUN_FMHA_H
|
|
3
|
+
|
|
4
|
+
#include <cstdint> // for uint32_t
|
|
5
|
+
#include <cuda_fp16.h>
|
|
6
|
+
#include <cuda_fp8.h>
|
|
7
|
+
|
|
8
|
+
namespace cudnn_fmha {
|
|
9
|
+
using DataType_TriBias = float;
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* @brief Performs Flash Multi-Head Attention computation on GPU using cuDNN
|
|
13
|
+
* @tparam DType Data type for the computation (float, __half, or __nv_bfloat16)
|
|
14
|
+
*/
|
|
15
|
+
template <typename DType>
|
|
16
|
+
void run_fmha(DType* q_ptr,
|
|
17
|
+
DType* k_ptr,
|
|
18
|
+
DType* v_ptr,
|
|
19
|
+
DType* o_ptr,
|
|
20
|
+
bool* mask_bias_ptr,
|
|
21
|
+
DataType_TriBias* triangle_bias_ptr,
|
|
22
|
+
float* softmax_lse_ptr,
|
|
23
|
+
float* softmax_max_ptr,
|
|
24
|
+
const uint32_t B,
|
|
25
|
+
const uint32_t I,
|
|
26
|
+
const uint32_t H,
|
|
27
|
+
const uint32_t S_qo,
|
|
28
|
+
const uint32_t S_kv,
|
|
29
|
+
const uint32_t D,
|
|
30
|
+
const float bmm_scale,
|
|
31
|
+
bool use_tf32,
|
|
32
|
+
void* stream = nullptr);
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* @brief Performs the backward pass of Flash Multi-Head Attention computation on GPU using cuDNN
|
|
36
|
+
* Note: Backward pass remains in float before fp16/bf16 integration
|
|
37
|
+
*/
|
|
38
|
+
template <typename DType>
|
|
39
|
+
void run_fmha_bwd(DType* do_ptr, // [B, N, H, S_qo, D]
|
|
40
|
+
DType* o_ptr, // [B, N, H, S_qo, D]
|
|
41
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
|
|
42
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
43
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
44
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
45
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
46
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
47
|
+
DType* dq_ptr, // [B, N, H, S_qo, D] output
|
|
48
|
+
DType* dk_ptr, // [B, N, H, S_kv, D] output
|
|
49
|
+
DType* dv_ptr, // [B, N, H, S_kv, D] output
|
|
50
|
+
float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
|
|
51
|
+
float* do_o_dot_ptr,
|
|
52
|
+
float* dq_fp32_buf, // [B, N, H, S_qo, D] worspace
|
|
53
|
+
const uint32_t B,
|
|
54
|
+
const uint32_t I,
|
|
55
|
+
const uint32_t H,
|
|
56
|
+
const uint32_t S_qo,
|
|
57
|
+
const uint32_t S_kv,
|
|
58
|
+
const uint32_t D,
|
|
59
|
+
const float bmm_scale,
|
|
60
|
+
bool use_tf32,
|
|
61
|
+
void* stream = nullptr);
|
|
62
|
+
|
|
63
|
+
// Explicit template declarations for supported types
|
|
64
|
+
extern template void run_fmha<float>(float*,
|
|
65
|
+
float*,
|
|
66
|
+
float*,
|
|
67
|
+
float*,
|
|
68
|
+
bool*,
|
|
69
|
+
DataType_TriBias*,
|
|
70
|
+
float*,
|
|
71
|
+
float*,
|
|
72
|
+
uint32_t,
|
|
73
|
+
uint32_t,
|
|
74
|
+
uint32_t,
|
|
75
|
+
uint32_t,
|
|
76
|
+
uint32_t,
|
|
77
|
+
uint32_t,
|
|
78
|
+
float,
|
|
79
|
+
bool,
|
|
80
|
+
void*);
|
|
81
|
+
|
|
82
|
+
extern template void run_fmha<__half>(__half*,
|
|
83
|
+
__half*,
|
|
84
|
+
__half*,
|
|
85
|
+
__half*,
|
|
86
|
+
bool*,
|
|
87
|
+
DataType_TriBias*,
|
|
88
|
+
float*,
|
|
89
|
+
float*,
|
|
90
|
+
uint32_t,
|
|
91
|
+
uint32_t,
|
|
92
|
+
uint32_t,
|
|
93
|
+
uint32_t,
|
|
94
|
+
uint32_t,
|
|
95
|
+
uint32_t,
|
|
96
|
+
float,
|
|
97
|
+
bool,
|
|
98
|
+
void*);
|
|
99
|
+
|
|
100
|
+
extern template void run_fmha<__nv_bfloat16>(__nv_bfloat16*,
|
|
101
|
+
__nv_bfloat16*,
|
|
102
|
+
__nv_bfloat16*,
|
|
103
|
+
__nv_bfloat16*,
|
|
104
|
+
bool*,
|
|
105
|
+
DataType_TriBias*,
|
|
106
|
+
float*,
|
|
107
|
+
float*,
|
|
108
|
+
uint32_t,
|
|
109
|
+
uint32_t,
|
|
110
|
+
uint32_t,
|
|
111
|
+
uint32_t,
|
|
112
|
+
uint32_t,
|
|
113
|
+
uint32_t,
|
|
114
|
+
float,
|
|
115
|
+
bool,
|
|
116
|
+
void*);
|
|
117
|
+
|
|
118
|
+
extern template void run_fmha_bwd<__half>(__half* do_ptr,
|
|
119
|
+
__half* o_ptr,
|
|
120
|
+
float* softmax_lse_ptr,
|
|
121
|
+
__half* q_ptr,
|
|
122
|
+
__half* k_ptr,
|
|
123
|
+
__half* v_ptr,
|
|
124
|
+
bool* mask_bias_ptr,
|
|
125
|
+
float* triangle_bias_ptr,
|
|
126
|
+
__half* dq_ptr,
|
|
127
|
+
__half* dk_ptr,
|
|
128
|
+
__half* dv_ptr,
|
|
129
|
+
float* triangle_dbias_ptr,
|
|
130
|
+
float* do_o_dot_ptr,
|
|
131
|
+
float* dq_fp32_buf,
|
|
132
|
+
const uint32_t B,
|
|
133
|
+
const uint32_t I,
|
|
134
|
+
const uint32_t H,
|
|
135
|
+
const uint32_t S_qo,
|
|
136
|
+
const uint32_t S_kv,
|
|
137
|
+
const uint32_t D,
|
|
138
|
+
const float bmm_scale,
|
|
139
|
+
bool use_tf32,
|
|
140
|
+
void* stream);
|
|
141
|
+
|
|
142
|
+
extern template void run_fmha_bwd<__nv_bfloat16>(__nv_bfloat16* do_ptr,
|
|
143
|
+
__nv_bfloat16* o_ptr,
|
|
144
|
+
float* softmax_lse_ptr,
|
|
145
|
+
__nv_bfloat16* q_ptr,
|
|
146
|
+
__nv_bfloat16* k_ptr,
|
|
147
|
+
__nv_bfloat16* v_ptr,
|
|
148
|
+
bool* mask_bias_ptr,
|
|
149
|
+
float* triangle_bias_ptr,
|
|
150
|
+
__nv_bfloat16* dq_ptr,
|
|
151
|
+
__nv_bfloat16* dk_ptr,
|
|
152
|
+
__nv_bfloat16* dv_ptr,
|
|
153
|
+
float* triangle_dbias_ptr,
|
|
154
|
+
float* do_o_dot_ptr,
|
|
155
|
+
float* dq_fp32_buf,
|
|
156
|
+
const uint32_t B,
|
|
157
|
+
const uint32_t I,
|
|
158
|
+
const uint32_t H,
|
|
159
|
+
const uint32_t S_qo,
|
|
160
|
+
const uint32_t S_kv,
|
|
161
|
+
const uint32_t D,
|
|
162
|
+
const float bmm_scale,
|
|
163
|
+
bool use_tf32,
|
|
164
|
+
void* stream);
|
|
165
|
+
|
|
166
|
+
extern template void run_fmha_bwd<float>(float* do_ptr,
|
|
167
|
+
float* o_ptr,
|
|
168
|
+
float* softmax_lse_ptr,
|
|
169
|
+
float* q_ptr,
|
|
170
|
+
float* k_ptr,
|
|
171
|
+
float* v_ptr,
|
|
172
|
+
bool* mask_bias_ptr,
|
|
173
|
+
float* triangle_bias_ptr,
|
|
174
|
+
float* dq_ptr,
|
|
175
|
+
float* dk_ptr,
|
|
176
|
+
float* dv_ptr,
|
|
177
|
+
float* triangle_dbias_ptr,
|
|
178
|
+
float* do_o_dot_ptr,
|
|
179
|
+
float* dq_fp32_buf,
|
|
180
|
+
const uint32_t B,
|
|
181
|
+
const uint32_t I,
|
|
182
|
+
const uint32_t H,
|
|
183
|
+
const uint32_t S_qo,
|
|
184
|
+
const uint32_t S_kv,
|
|
185
|
+
const uint32_t D,
|
|
186
|
+
const float bmm_scale,
|
|
187
|
+
bool use_tf32,
|
|
188
|
+
void* stream);
|
|
189
|
+
|
|
190
|
+
} // namespace cudnn_fmha
|
|
191
|
+
|
|
192
|
+
#endif // CUDNN_FMHA_RUN_FMHA_H
|