cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,297 @@
1
+ /*
2
+ * Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "../common/common.hpp"
12
+
13
+ #include <algorithm>
14
+ #include <limits>
15
+
16
+ namespace kernelcatcher::tensor_product {
17
+
18
+ struct __attribute__((aligned(16))) tp_data_sizes {
19
+ int64_t batch_size;
20
+ bool shared_a;
21
+ bool shared_b;
22
+ bool shared_w;
23
+ int32_t stride_a;
24
+ int32_t stride_b;
25
+ int32_t stride_w;
26
+ int32_t stride_o;
27
+ }; // struct tp_data_sizes
28
+
29
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
30
+ void fused_tensor_product_fwd(DataOutT* out,
31
+ const DataAT* in_a,
32
+ const DataBT* in_b,
33
+ const DataWeightT* weight,
34
+ ConnectionModeT mode,
35
+ const tp_info<MathT>& info,
36
+ const tp_data_sizes& sizes,
37
+ cudaStream_t stream);
38
+
39
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
40
+ void fused_tensor_product_bwd(DataAT* grad_in_a,
41
+ DataBT* grad_in_b,
42
+ DataWeightT* grad_weight,
43
+ const DataOutT* grad_out,
44
+ const DataAT* in_a,
45
+ const DataBT* in_b,
46
+ const DataWeightT* weight,
47
+ ConnectionModeT mode,
48
+ const tp_info<MathT>& info_bwd_dgrad_a,
49
+ const tp_info<MathT>& info_bwd_dgrad_b,
50
+ const tp_info<MathT>& info_bwd_dgrad_w,
51
+ const tp_data_sizes& sizes,
52
+ cudaStream_t stream);
53
+
54
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
55
+ void fused_tensor_product_bwd_bwd(DataAT* grad_in_a,
56
+ DataBT* grad_in_b,
57
+ DataWeightT* grad_weight,
58
+ DataOutT* grad_grad_out,
59
+ const DataAT* grad_grad_in_a,
60
+ const DataBT* grad_grad_in_b,
61
+ const DataWeightT* grad_grad_weight,
62
+ const DataOutT* grad_out,
63
+ const DataAT* in_a,
64
+ const DataBT* in_b,
65
+ const DataWeightT* weight,
66
+ ConnectionModeT mode,
67
+ const tp_info<MathT>& info_fwd,
68
+ const tp_info<MathT>& info_bwd_dgrad_a,
69
+ const tp_info<MathT>& info_bwd_dgrad_b,
70
+ const tp_info<MathT>& info_bwd_dgrad_w,
71
+ const tp_data_sizes& sizes,
72
+ cudaStream_t stream);
73
+
74
+ extern template void fused_tensor_product_bwd_bwd<float, float, float, float, float>(
75
+ float*,
76
+ float*,
77
+ float*,
78
+ float*,
79
+ const float*,
80
+ const float*,
81
+ const float*,
82
+ const float*,
83
+ const float*,
84
+ const float*,
85
+ const float*,
86
+ ConnectionModeT,
87
+ const tp_info<float>&,
88
+ const tp_info<float>&,
89
+ const tp_info<float>&,
90
+ const tp_info<float>&,
91
+ const tp_data_sizes&,
92
+ cudaStream_t);
93
+
94
+ extern template void fused_tensor_product_bwd_bwd<float, float, float, float, double>(
95
+ float*,
96
+ float*,
97
+ float*,
98
+ float*,
99
+ const float*,
100
+ const float*,
101
+ const float*,
102
+ const float*,
103
+ const float*,
104
+ const float*,
105
+ const float*,
106
+ ConnectionModeT,
107
+ const tp_info<double>&,
108
+ const tp_info<double>&,
109
+ const tp_info<double>&,
110
+ const tp_info<double>&,
111
+ const tp_data_sizes&,
112
+ cudaStream_t);
113
+
114
+ extern template void fused_tensor_product_bwd_bwd<double, double, double, double, double>(
115
+ double*,
116
+ double*,
117
+ double*,
118
+ double*,
119
+ const double*,
120
+ const double*,
121
+ const double*,
122
+ const double*,
123
+ const double*,
124
+ const double*,
125
+ const double*,
126
+ ConnectionModeT,
127
+ const tp_info<double>&,
128
+ const tp_info<double>&,
129
+ const tp_info<double>&,
130
+ const tp_info<double>&,
131
+ const tp_data_sizes&,
132
+ cudaStream_t);
133
+
134
+ extern template void fused_tensor_product_bwd_bwd<__half, __half, __half, __half, float>(
135
+ __half*,
136
+ __half*,
137
+ __half*,
138
+ __half*,
139
+ const __half*,
140
+ const __half*,
141
+ const __half*,
142
+ const __half*,
143
+ const __half*,
144
+ const __half*,
145
+ const __half*,
146
+ ConnectionModeT,
147
+ const tp_info<float>&,
148
+ const tp_info<float>&,
149
+ const tp_info<float>&,
150
+ const tp_info<float>&,
151
+ const tp_data_sizes&,
152
+ cudaStream_t);
153
+ extern template void
154
+ fused_tensor_product_bwd_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
155
+ __nv_bfloat16*,
156
+ __nv_bfloat16*,
157
+ __nv_bfloat16*,
158
+ __nv_bfloat16*,
159
+ const __nv_bfloat16*,
160
+ const __nv_bfloat16*,
161
+ const __nv_bfloat16*,
162
+ const __nv_bfloat16*,
163
+ const __nv_bfloat16*,
164
+ const __nv_bfloat16*,
165
+ const __nv_bfloat16*,
166
+ ConnectionModeT,
167
+ const tp_info<float>&,
168
+ const tp_info<float>&,
169
+ const tp_info<float>&,
170
+ const tp_info<float>&,
171
+ const tp_data_sizes&,
172
+ cudaStream_t);
173
+
174
+ extern template void fused_tensor_product_bwd<float, float, float, float, float>(
175
+ float*,
176
+ float*,
177
+ float*,
178
+ const float*,
179
+ const float*,
180
+ const float*,
181
+ const float*,
182
+ ConnectionModeT,
183
+ const tp_info<float>&,
184
+ const tp_info<float>&,
185
+ const tp_info<float>&,
186
+ const tp_data_sizes&,
187
+ cudaStream_t);
188
+
189
+ extern template void fused_tensor_product_bwd<float, float, float, float, double>(
190
+ float*,
191
+ float*,
192
+ float*,
193
+ const float*,
194
+ const float*,
195
+ const float*,
196
+ const float*,
197
+ ConnectionModeT,
198
+ const tp_info<double>&,
199
+ const tp_info<double>&,
200
+ const tp_info<double>&,
201
+ const tp_data_sizes&,
202
+ cudaStream_t);
203
+
204
+ extern template void fused_tensor_product_bwd<double, double, double, double, double>(
205
+ double*,
206
+ double*,
207
+ double*,
208
+ const double*,
209
+ const double*,
210
+ const double*,
211
+ const double*,
212
+ ConnectionModeT,
213
+ const tp_info<double>&,
214
+ const tp_info<double>&,
215
+ const tp_info<double>&,
216
+ const tp_data_sizes&,
217
+ cudaStream_t);
218
+
219
+ extern template void fused_tensor_product_bwd<__half, __half, __half, __half, float>(
220
+ __half*,
221
+ __half*,
222
+ __half*,
223
+ const __half*,
224
+ const __half*,
225
+ const __half*,
226
+ const __half*,
227
+ ConnectionModeT,
228
+ const tp_info<float>&,
229
+ const tp_info<float>&,
230
+ const tp_info<float>&,
231
+ const tp_data_sizes&,
232
+ cudaStream_t);
233
+ extern template void
234
+ fused_tensor_product_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
235
+ __nv_bfloat16*,
236
+ __nv_bfloat16*,
237
+ __nv_bfloat16*,
238
+ const __nv_bfloat16*,
239
+ const __nv_bfloat16*,
240
+ const __nv_bfloat16*,
241
+ const __nv_bfloat16*,
242
+ ConnectionModeT,
243
+ const tp_info<float>&,
244
+ const tp_info<float>&,
245
+ const tp_info<float>&,
246
+ const tp_data_sizes&,
247
+ cudaStream_t);
248
+
249
+ extern template void fused_tensor_product_fwd<float, float, float, float, float>(
250
+ float*,
251
+ const float*,
252
+ const float*,
253
+ const float*,
254
+ ConnectionModeT,
255
+ const tp_info<float>&,
256
+ const tp_data_sizes&,
257
+ cudaStream_t);
258
+ extern template void fused_tensor_product_fwd<float, float, float, float, double>(
259
+ float*,
260
+ const float*,
261
+ const float*,
262
+ const float*,
263
+ ConnectionModeT,
264
+ const tp_info<double>&,
265
+ const tp_data_sizes&,
266
+ cudaStream_t);
267
+ extern template void fused_tensor_product_fwd<double, double, double, double, double>(
268
+ double*,
269
+ const double*,
270
+ const double*,
271
+ const double*,
272
+ ConnectionModeT,
273
+ const tp_info<double>&,
274
+ const tp_data_sizes&,
275
+ cudaStream_t);
276
+
277
+ extern template void fused_tensor_product_fwd<__half, __half, __half, __half, float>(
278
+ __half*,
279
+ const __half*,
280
+ const __half*,
281
+ const __half*,
282
+ ConnectionModeT,
283
+ const tp_info<float>&,
284
+ const tp_data_sizes&,
285
+ cudaStream_t);
286
+ extern template void
287
+ fused_tensor_product_fwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
288
+ __nv_bfloat16*,
289
+ const __nv_bfloat16*,
290
+ const __nv_bfloat16*,
291
+ const __nv_bfloat16*,
292
+ ConnectionModeT,
293
+ const tp_info<float>&,
294
+ const tp_data_sizes&,
295
+ cudaStream_t);
296
+
297
+ } // namespace kernelcatcher::tensor_product
@@ -0,0 +1,41 @@
1
+ /*
2
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "dtypes.hh" // for Datatype
12
+
13
+ #ifndef CUEQUIVARIANCE_OPS_WITH_CUBLAS
14
+ #pragma message( \
15
+ "indexed_linear functions require cuBLAS support. Set CUEQUIVARIANCE_OPS_WITH_CUBLAS=1 environment variable before building to enable indexed_linear functionality. When cuBLAS is disabled, indexed_linear functions will return error codes.")
16
+ #endif
17
+
18
+ namespace kernelcatcher::equivariance::indexed_linear {
19
+ using namespace kernelcatcher::utils; // for Datatype
20
+
21
+ #define KC_INDEXED_LINEAR_DECL_ARGUMENTS \
22
+ const void *ptr_A, const void *ptr_B, const int *counts, void *ptr_C, Datatype dtype_A, \
23
+ Datatype dtype_B, Datatype dtype_D, int Z, int C, int u, int v, double coefficient, \
24
+ int compute_type, void *workspace, size_t workspace_size, void *stream
25
+
26
+ #define KC_INDEXED_LINEAR_ARGUMENTS \
27
+ ptr_A, ptr_B, counts, ptr_C, dtype_A, dtype_B, dtype_D, Z, C, u, v, coefficient, compute_type, \
28
+ workspace, workspace_size, stream
29
+
30
+ int run_indexed_linear_B( // ptr_A Zu
31
+ // ptr_B Cuv or Cvu
32
+ // ptr_C Zv
33
+ bool transpose_B,
34
+ KC_INDEXED_LINEAR_DECL_ARGUMENTS);
35
+
36
+ int run_indexed_linear_C( // ptr_A Zu
37
+ // ptr_B Zv
38
+ // ptr_C Cuv
39
+ KC_INDEXED_LINEAR_DECL_ARGUMENTS);
40
+
41
+ } // namespace kernelcatcher::equivariance::indexed_linear
@@ -0,0 +1,192 @@
1
+ #ifndef CUDNN_FMHA_RUN_FMHA_H
2
+ #define CUDNN_FMHA_RUN_FMHA_H
3
+
4
+ #include <cstdint> // for uint32_t
5
+ #include <cuda_fp16.h>
6
+ #include <cuda_fp8.h>
7
+
8
+ namespace cudnn_fmha {
9
+ using DataType_TriBias = float;
10
+
11
+ /**
12
+ * @brief Performs Flash Multi-Head Attention computation on GPU using cuDNN
13
+ * @tparam DType Data type for the computation (float, __half, or __nv_bfloat16)
14
+ */
15
+ template <typename DType>
16
+ void run_fmha(DType* q_ptr,
17
+ DType* k_ptr,
18
+ DType* v_ptr,
19
+ DType* o_ptr,
20
+ bool* mask_bias_ptr,
21
+ DataType_TriBias* triangle_bias_ptr,
22
+ float* softmax_lse_ptr,
23
+ float* softmax_max_ptr,
24
+ const uint32_t B,
25
+ const uint32_t I,
26
+ const uint32_t H,
27
+ const uint32_t S_qo,
28
+ const uint32_t S_kv,
29
+ const uint32_t D,
30
+ const float bmm_scale,
31
+ bool use_tf32,
32
+ void* stream = nullptr);
33
+
34
+ /**
35
+ * @brief Performs the backward pass of Flash Multi-Head Attention computation on GPU using cuDNN
36
+ * Note: Backward pass remains in float before fp16/bf16 integration
37
+ */
38
+ template <typename DType>
39
+ void run_fmha_bwd(DType* do_ptr, // [B, N, H, S_qo, D]
40
+ DType* o_ptr, // [B, N, H, S_qo, D]
41
+ float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
42
+ DType* q_ptr, // [B, N, H, S_qo, D]
43
+ DType* k_ptr, // [B, N, H, S_kv, D]
44
+ DType* v_ptr, // [B, N, H, S_kv, D]
45
+ bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
46
+ float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
47
+ DType* dq_ptr, // [B, N, H, S_qo, D] output
48
+ DType* dk_ptr, // [B, N, H, S_kv, D] output
49
+ DType* dv_ptr, // [B, N, H, S_kv, D] output
50
+ float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
51
+ float* do_o_dot_ptr,
52
+ float* dq_fp32_buf, // [B, N, H, S_qo, D] worspace
53
+ const uint32_t B,
54
+ const uint32_t I,
55
+ const uint32_t H,
56
+ const uint32_t S_qo,
57
+ const uint32_t S_kv,
58
+ const uint32_t D,
59
+ const float bmm_scale,
60
+ bool use_tf32,
61
+ void* stream = nullptr);
62
+
63
+ // Explicit template declarations for supported types
64
+ extern template void run_fmha<float>(float*,
65
+ float*,
66
+ float*,
67
+ float*,
68
+ bool*,
69
+ DataType_TriBias*,
70
+ float*,
71
+ float*,
72
+ uint32_t,
73
+ uint32_t,
74
+ uint32_t,
75
+ uint32_t,
76
+ uint32_t,
77
+ uint32_t,
78
+ float,
79
+ bool,
80
+ void*);
81
+
82
+ extern template void run_fmha<__half>(__half*,
83
+ __half*,
84
+ __half*,
85
+ __half*,
86
+ bool*,
87
+ DataType_TriBias*,
88
+ float*,
89
+ float*,
90
+ uint32_t,
91
+ uint32_t,
92
+ uint32_t,
93
+ uint32_t,
94
+ uint32_t,
95
+ uint32_t,
96
+ float,
97
+ bool,
98
+ void*);
99
+
100
+ extern template void run_fmha<__nv_bfloat16>(__nv_bfloat16*,
101
+ __nv_bfloat16*,
102
+ __nv_bfloat16*,
103
+ __nv_bfloat16*,
104
+ bool*,
105
+ DataType_TriBias*,
106
+ float*,
107
+ float*,
108
+ uint32_t,
109
+ uint32_t,
110
+ uint32_t,
111
+ uint32_t,
112
+ uint32_t,
113
+ uint32_t,
114
+ float,
115
+ bool,
116
+ void*);
117
+
118
+ extern template void run_fmha_bwd<__half>(__half* do_ptr,
119
+ __half* o_ptr,
120
+ float* softmax_lse_ptr,
121
+ __half* q_ptr,
122
+ __half* k_ptr,
123
+ __half* v_ptr,
124
+ bool* mask_bias_ptr,
125
+ float* triangle_bias_ptr,
126
+ __half* dq_ptr,
127
+ __half* dk_ptr,
128
+ __half* dv_ptr,
129
+ float* triangle_dbias_ptr,
130
+ float* do_o_dot_ptr,
131
+ float* dq_fp32_buf,
132
+ const uint32_t B,
133
+ const uint32_t I,
134
+ const uint32_t H,
135
+ const uint32_t S_qo,
136
+ const uint32_t S_kv,
137
+ const uint32_t D,
138
+ const float bmm_scale,
139
+ bool use_tf32,
140
+ void* stream);
141
+
142
+ extern template void run_fmha_bwd<__nv_bfloat16>(__nv_bfloat16* do_ptr,
143
+ __nv_bfloat16* o_ptr,
144
+ float* softmax_lse_ptr,
145
+ __nv_bfloat16* q_ptr,
146
+ __nv_bfloat16* k_ptr,
147
+ __nv_bfloat16* v_ptr,
148
+ bool* mask_bias_ptr,
149
+ float* triangle_bias_ptr,
150
+ __nv_bfloat16* dq_ptr,
151
+ __nv_bfloat16* dk_ptr,
152
+ __nv_bfloat16* dv_ptr,
153
+ float* triangle_dbias_ptr,
154
+ float* do_o_dot_ptr,
155
+ float* dq_fp32_buf,
156
+ const uint32_t B,
157
+ const uint32_t I,
158
+ const uint32_t H,
159
+ const uint32_t S_qo,
160
+ const uint32_t S_kv,
161
+ const uint32_t D,
162
+ const float bmm_scale,
163
+ bool use_tf32,
164
+ void* stream);
165
+
166
+ extern template void run_fmha_bwd<float>(float* do_ptr,
167
+ float* o_ptr,
168
+ float* softmax_lse_ptr,
169
+ float* q_ptr,
170
+ float* k_ptr,
171
+ float* v_ptr,
172
+ bool* mask_bias_ptr,
173
+ float* triangle_bias_ptr,
174
+ float* dq_ptr,
175
+ float* dk_ptr,
176
+ float* dv_ptr,
177
+ float* triangle_dbias_ptr,
178
+ float* do_o_dot_ptr,
179
+ float* dq_fp32_buf,
180
+ const uint32_t B,
181
+ const uint32_t I,
182
+ const uint32_t H,
183
+ const uint32_t S_qo,
184
+ const uint32_t S_kv,
185
+ const uint32_t D,
186
+ const float bmm_scale,
187
+ bool use_tf32,
188
+ void* stream);
189
+
190
+ } // namespace cudnn_fmha
191
+
192
+ #endif // CUDNN_FMHA_RUN_FMHA_H