cuequivariance-ops-cu12 0.6.0__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (37) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/nvtx.hpp +29 -0
  6. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  7. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  8. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  9. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  10. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  11. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  12. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  13. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  14. cuequivariance_ops/lib/libcue_ops.so +0 -0
  15. cuequivariance_ops/sleep.hh +18 -0
  16. cuequivariance_ops/triton/__init__.py +66 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  20. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  21. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  27. cuequivariance_ops/triton/cache_manager.py +259 -0
  28. cuequivariance_ops/triton/fused_layer_norm_triton.py +518 -0
  29. cuequivariance_ops/triton/gated_gemm_triton.py +380 -0
  30. cuequivariance_ops/triton/pair_bias.py +324 -0
  31. cuequivariance_ops/triton/tuning_decorator.py +177 -0
  32. cuequivariance_ops/triton/utils.py +28 -0
  33. cuequivariance_ops_cu12-0.6.0.dist-info/METADATA +182 -0
  34. cuequivariance_ops_cu12-0.6.0.dist-info/RECORD +37 -0
  35. cuequivariance_ops_cu12-0.6.0.dist-info/WHEEL +6 -0
  36. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/LICENSE +142 -0
  37. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/Third_party_attr.txt +24 -0
@@ -0,0 +1,192 @@
1
+ #ifndef CUDNN_FMHA_RUN_FMHA_H
2
+ #define CUDNN_FMHA_RUN_FMHA_H
3
+
4
+ #include <cstdint> // for uint32_t
5
+ #include <cuda_fp16.h>
6
+ #include <cuda_fp8.h>
7
+
8
+ namespace cudnn_fmha {
9
+ using DataType_TriBias = float;
10
+
11
+ /**
12
+ * @brief Performs Flash Multi-Head Attention computation on GPU using cuDNN
13
+ * @tparam DType Data type for the computation (float, __half, or __nv_bfloat16)
14
+ */
15
+ template <typename DType>
16
+ void run_fmha(DType* q_ptr,
17
+ DType* k_ptr,
18
+ DType* v_ptr,
19
+ DType* o_ptr,
20
+ bool* mask_bias_ptr,
21
+ DataType_TriBias* triangle_bias_ptr,
22
+ float* softmax_lse_ptr,
23
+ float* softmax_max_ptr,
24
+ const uint32_t B,
25
+ const uint32_t I,
26
+ const uint32_t H,
27
+ const uint32_t S_qo,
28
+ const uint32_t S_kv,
29
+ const uint32_t D,
30
+ const float bmm_scale,
31
+ bool use_tf32,
32
+ void* stream = nullptr);
33
+
34
+ /**
35
+ * @brief Performs the backward pass of Flash Multi-Head Attention computation on GPU using cuDNN
36
+ * Note: Backward pass remains in float before fp16/bf16 integration
37
+ */
38
+ template <typename DType>
39
+ void run_fmha_bwd(DType* do_ptr, // [B, N, H, S_qo, D]
40
+ DType* o_ptr, // [B, N, H, S_qo, D]
41
+ float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
42
+ DType* q_ptr, // [B, N, H, S_qo, D]
43
+ DType* k_ptr, // [B, N, H, S_kv, D]
44
+ DType* v_ptr, // [B, N, H, S_kv, D]
45
+ bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
46
+ float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
47
+ DType* dq_ptr, // [B, N, H, S_qo, D] output
48
+ DType* dk_ptr, // [B, N, H, S_kv, D] output
49
+ DType* dv_ptr, // [B, N, H, S_kv, D] output
50
+ float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
51
+ float* do_o_dot_ptr,
52
+ float* dq_fp32_buf, // [B, N, H, S_qo, D] worspace
53
+ const uint32_t B,
54
+ const uint32_t I,
55
+ const uint32_t H,
56
+ const uint32_t S_qo,
57
+ const uint32_t S_kv,
58
+ const uint32_t D,
59
+ const float bmm_scale,
60
+ bool use_tf32,
61
+ void* stream = nullptr);
62
+
63
+ // Explicit template declarations for supported types
64
+ extern template void run_fmha<float>(float*,
65
+ float*,
66
+ float*,
67
+ float*,
68
+ bool*,
69
+ DataType_TriBias*,
70
+ float*,
71
+ float*,
72
+ uint32_t,
73
+ uint32_t,
74
+ uint32_t,
75
+ uint32_t,
76
+ uint32_t,
77
+ uint32_t,
78
+ float,
79
+ bool,
80
+ void*);
81
+
82
+ extern template void run_fmha<__half>(__half*,
83
+ __half*,
84
+ __half*,
85
+ __half*,
86
+ bool*,
87
+ DataType_TriBias*,
88
+ float*,
89
+ float*,
90
+ uint32_t,
91
+ uint32_t,
92
+ uint32_t,
93
+ uint32_t,
94
+ uint32_t,
95
+ uint32_t,
96
+ float,
97
+ bool,
98
+ void*);
99
+
100
+ extern template void run_fmha<__nv_bfloat16>(__nv_bfloat16*,
101
+ __nv_bfloat16*,
102
+ __nv_bfloat16*,
103
+ __nv_bfloat16*,
104
+ bool*,
105
+ DataType_TriBias*,
106
+ float*,
107
+ float*,
108
+ uint32_t,
109
+ uint32_t,
110
+ uint32_t,
111
+ uint32_t,
112
+ uint32_t,
113
+ uint32_t,
114
+ float,
115
+ bool,
116
+ void*);
117
+
118
+ extern template void run_fmha_bwd<__half>(__half* do_ptr,
119
+ __half* o_ptr,
120
+ float* softmax_lse_ptr,
121
+ __half* q_ptr,
122
+ __half* k_ptr,
123
+ __half* v_ptr,
124
+ bool* mask_bias_ptr,
125
+ float* triangle_bias_ptr,
126
+ __half* dq_ptr,
127
+ __half* dk_ptr,
128
+ __half* dv_ptr,
129
+ float* triangle_dbias_ptr,
130
+ float* do_o_dot_ptr,
131
+ float* dq_fp32_buf,
132
+ const uint32_t B,
133
+ const uint32_t I,
134
+ const uint32_t H,
135
+ const uint32_t S_qo,
136
+ const uint32_t S_kv,
137
+ const uint32_t D,
138
+ const float bmm_scale,
139
+ bool use_tf32,
140
+ void* stream);
141
+
142
+ extern template void run_fmha_bwd<__nv_bfloat16>(__nv_bfloat16* do_ptr,
143
+ __nv_bfloat16* o_ptr,
144
+ float* softmax_lse_ptr,
145
+ __nv_bfloat16* q_ptr,
146
+ __nv_bfloat16* k_ptr,
147
+ __nv_bfloat16* v_ptr,
148
+ bool* mask_bias_ptr,
149
+ float* triangle_bias_ptr,
150
+ __nv_bfloat16* dq_ptr,
151
+ __nv_bfloat16* dk_ptr,
152
+ __nv_bfloat16* dv_ptr,
153
+ float* triangle_dbias_ptr,
154
+ float* do_o_dot_ptr,
155
+ float* dq_fp32_buf,
156
+ const uint32_t B,
157
+ const uint32_t I,
158
+ const uint32_t H,
159
+ const uint32_t S_qo,
160
+ const uint32_t S_kv,
161
+ const uint32_t D,
162
+ const float bmm_scale,
163
+ bool use_tf32,
164
+ void* stream);
165
+
166
+ extern template void run_fmha_bwd<float>(float* do_ptr,
167
+ float* o_ptr,
168
+ float* softmax_lse_ptr,
169
+ float* q_ptr,
170
+ float* k_ptr,
171
+ float* v_ptr,
172
+ bool* mask_bias_ptr,
173
+ float* triangle_bias_ptr,
174
+ float* dq_ptr,
175
+ float* dk_ptr,
176
+ float* dv_ptr,
177
+ float* triangle_dbias_ptr,
178
+ float* do_o_dot_ptr,
179
+ float* dq_fp32_buf,
180
+ const uint32_t B,
181
+ const uint32_t I,
182
+ const uint32_t H,
183
+ const uint32_t S_qo,
184
+ const uint32_t S_kv,
185
+ const uint32_t D,
186
+ const float bmm_scale,
187
+ bool use_tf32,
188
+ void* stream);
189
+
190
+ } // namespace cudnn_fmha
191
+
192
+ #endif // CUDNN_FMHA_RUN_FMHA_H
@@ -0,0 +1,77 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
4
+ *
5
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6
+ * property and proprietary rights in and to this material, related
7
+ * documentation and any modifications thereto. Any use, reproduction,
8
+ * disclosure or distribution of this material and related documentation
9
+ * without an express license agreement from NVIDIA CORPORATION or
10
+ * its affiliates is strictly prohibited.
11
+ */
12
+ #ifndef CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
13
+ #define CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
14
+
15
+ #include <cstdint> // for uint32_t
16
+
17
+ namespace cudnn_fmha {
18
+ using DType = void;
19
+
20
+ enum class Datatype : uint32_t {
21
+ kFloat32 = 0,
22
+ kFloat64 = 1,
23
+ kFloat16 = 2,
24
+ kBFloat16 = 3,
25
+ kInt32 = 4,
26
+ kInt64 = 5
27
+ };
28
+
29
+ __attribute__((visibility("default"))) void run_fmha_for_dtype(
30
+ Datatype dtype,
31
+ DType* q_ptr, // [B, N, H, S_qo, D]
32
+ DType* k_ptr, // [B, N, H, S_kv, D]
33
+ DType* v_ptr, // [B, N, H, S_kv, D]
34
+ DType* o_ptr, // [B, N, H, S_qo, D] output
35
+ bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
36
+ float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
37
+ float* softmax_lse_ptr, // [B, N, H, S_qo, 1] output
38
+ float* softmax_max_ptr, // [B, N, H, S_qo, 1] output
39
+ const uint32_t B,
40
+ const uint32_t I,
41
+ const uint32_t H,
42
+ const uint32_t S_qo,
43
+ const uint32_t S_kv,
44
+ const uint32_t D,
45
+ const float bmm_scale,
46
+ bool use_tf32,
47
+ void* stream = nullptr);
48
+
49
+ __attribute__((visibility("default"))) void run_fmha_bwd_for_dtype(
50
+ Datatype dtype,
51
+ DType* do_ptr, // [B, N, H, S_qo, D]
52
+ DType* o_ptr, // [B, N, H, S_qo, D]
53
+ float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
54
+ DType* q_ptr, // [B, N, H, S_qo, D]
55
+ DType* k_ptr, // [B, N, H, S_kv, D]
56
+ DType* v_ptr, // [B, N, H, S_kv, D]
57
+ bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
58
+ float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
59
+ DType* dq_ptr, // [B, N, H, S_qo, D] output
60
+ DType* dk_ptr, // [B, N, H, S_kv, D] output
61
+ DType* dv_ptr, // [B, N, H, S_kv, D] output
62
+ float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
63
+ float* do_o_dot_ptr, // [B, N, H, S_qo, 1] worspace
64
+ float* dq_fp32_buf_ptr, // [B, N, H, S_qo, D] workspace
65
+ const uint32_t B,
66
+ const uint32_t I,
67
+ const uint32_t H,
68
+ const uint32_t S_qo,
69
+ const uint32_t S_kv,
70
+ const uint32_t D,
71
+ const float bmm_scale,
72
+ bool use_tf32,
73
+ void* stream);
74
+
75
+ } // namespace cudnn_fmha
76
+
77
+ #endif // CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
@@ -0,0 +1,40 @@
1
+ /*
2
+ * Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "../common/common.hpp"
12
+
13
+ namespace kernelcatcher::tensor_product {
14
+
15
+ template <typename DataT>
16
+ void segmented_transpose(DataT* tensor_transpose,
17
+ const DataT* tensor,
18
+ const int32_t* segment_info,
19
+ int32_t num_segments,
20
+ int64_t batch_size,
21
+ int64_t stride,
22
+ bool input_contiguous_as_info,
23
+ cudaStream_t stream);
24
+
25
+ extern template void segmented_transpose<float>(
26
+ float*, const float*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
27
+ extern template void segmented_transpose<double>(
28
+ double*, const double*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
29
+ extern template void segmented_transpose<__nv_bfloat16>(__nv_bfloat16*,
30
+ const __nv_bfloat16*,
31
+ const int32_t*,
32
+ int32_t,
33
+ int64_t,
34
+ int64_t,
35
+ bool,
36
+ cudaStream_t);
37
+ extern template void segmented_transpose<__half>(
38
+ __half*, const __half*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
39
+
40
+ } // namespace kernelcatcher::tensor_product
@@ -0,0 +1,38 @@
1
+ /*
2
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "batch_dimension.hh" // for BatchDimension
12
+ #include "dtypes.hh" // for Datatype
13
+ #include <cstdint>
14
+ #include <string>
15
+ #include <vector>
16
+
17
+ namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit {
18
+ using namespace kernelcatcher::utils;
19
+
20
+ enum class Dimension : int { kScalar = 0, kOneDimensional = 1 };
21
+
22
+ #define KC_UNIFORM_1D_DECL_ARGUMENTS \
23
+ std::string const &name, Datatype math_dtype, int operand_extent, int num_inputs, \
24
+ int num_outputs, int num_index, std::vector<Dimension> const &buffer_dim, \
25
+ std::vector<int> const &buffer_num_segments, \
26
+ std::vector<std::vector<BatchDimension>> const &batch_dim, \
27
+ std::vector<std::vector<int>> const &index_buffer, std::vector<int> const &index_extent, \
28
+ std::vector<Datatype> const &dtypes, std::vector<std::vector<int>> const &operations, \
29
+ std::vector<int> num_paths, std::vector<int> path_indices_start, \
30
+ std::vector<int> path_coefficients_start, std::vector<int> const &path_indices, \
31
+ std::vector<double> const &path_coefficients, std::vector<int> const &batch_sizes, \
32
+ std::vector<void*> const &buffers, std::vector<size_t> const &buffer_bytes, \
33
+ bool zero_output_buffers
34
+
35
+ extern int run_tensor_product_uniform_1d_jit(KC_UNIFORM_1D_DECL_ARGUMENTS, void* stream);
36
+ extern int run_tensor_product_uniform_1d_cpu(KC_UNIFORM_1D_DECL_ARGUMENTS);
37
+
38
+ } // namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit
Binary file
@@ -0,0 +1,18 @@
1
+ /*
2
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <cstdint>
12
+
13
+ namespace kernelcatcher::sleep {
14
+
15
+ int run_sleep(float* seconds, int64_t* elapsed_ticks, void* stream);
16
+ int run_synchronize(float* elapsed_seconds, void* stream);
17
+
18
+ } // namespace kernelcatcher::sleep
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ from .fused_layer_norm_triton import (
12
+ Layout,
13
+ layer_norm_transpose_backward_kernel,
14
+ layer_norm_transpose_backward_single_pass_kernel,
15
+ layer_norm_transpose_forward_kernel,
16
+ layer_norm_transpose_forward_single_pass_kernel,
17
+ )
18
+
19
+ from .gated_gemm_triton import (
20
+ fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel,
21
+ fused_sigmoid_gated_dual_gemm_forward_kernel,
22
+ )
23
+ from .utils import Precision
24
+ from .tuning_decorator import autotune_aot
25
+ from .cache_manager import get_cache_manager
26
+
27
+ cached_kernels = [ "fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel",
28
+ "fused_sigmoid_gated_dual_gemm_forward_kernel",
29
+ ]
30
+
31
+ def init_triton_cache():
32
+ """
33
+ Initializes Triton cache manager by pre-loading cache for all available kernels.
34
+ This function is useful to initialize cache in eager mode before running torch.compile()'d methods
35
+ that cannot handle cache initialization code
36
+ """
37
+ mgr = get_cache_manager()
38
+ for kernel in cached_kernels:
39
+ mgr.load_cache(kernel+'_wrapper')
40
+
41
+ from .utils import (
42
+ Precision,
43
+ )
44
+
45
+ from .pair_bias import (
46
+ pair_bias_norm_linear_mask_forward_kernel,
47
+ pair_bias_linear_mask_forward_kernel,
48
+ pair_bias_mask_forward_kernel,
49
+ )
50
+
51
+
52
+ __all__ = [
53
+ "Precision",
54
+ "fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel",
55
+ "fused_sigmoid_gated_dual_gemm_forward_kernel",
56
+ "layer_norm_transpose_backward_kernel",
57
+ "layer_norm_transpose_backward_single_pass_kernel",
58
+ "layer_norm_transpose_forward_kernel",
59
+ "layer_norm_transpose_forward_single_pass_kernel",
60
+ "pair_bias_norm_linear_mask_forward_kernel",
61
+ "pair_bias_linear_mask_forward_kernel",
62
+ "pair_bias_mask_forward_kernel",
63
+ "autotune_aot",
64
+ "get_cache_manager",
65
+ "init_triton_cache"
66
+ ] + cached_kernels