cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/cudart.hpp +286 -0
- cuequivariance_ops/common/error.hpp +66 -0
- cuequivariance_ops/common/error_raft.hpp +323 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
- cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/gpu_timing_kernels.hh +42 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +40 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +336 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
- cuequivariance_ops/triton/pair_bias.py +365 -0
- cuequivariance_ops/triton/tuning_decorator.py +188 -0
- cuequivariance_ops/triton/utils.py +29 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
4
|
+
*
|
|
5
|
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
6
|
+
* property and proprietary rights in and to this material, related
|
|
7
|
+
* documentation and any modifications thereto. Any use, reproduction,
|
|
8
|
+
* disclosure or distribution of this material and related documentation
|
|
9
|
+
* without an express license agreement from NVIDIA CORPORATION or
|
|
10
|
+
* its affiliates is strictly prohibited.
|
|
11
|
+
*/
|
|
12
|
+
#ifndef CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
13
|
+
#define CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
14
|
+
|
|
15
|
+
#include <cstdint> // for uint32_t
|
|
16
|
+
#include <optional>
|
|
17
|
+
|
|
18
|
+
namespace cudnn_fmha {
|
|
19
|
+
using DType = void;
|
|
20
|
+
|
|
21
|
+
enum class Datatype : uint32_t {
|
|
22
|
+
kFloat32 = 0,
|
|
23
|
+
kFloat64 = 1,
|
|
24
|
+
kFloat16 = 2,
|
|
25
|
+
kBFloat16 = 3,
|
|
26
|
+
kInt32 = 4,
|
|
27
|
+
kInt64 = 5
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
__attribute__((visibility("default"))) void run_fmha_for_dtype(
|
|
31
|
+
Datatype dtype,
|
|
32
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
33
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
34
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
35
|
+
DType* o_ptr, // [B, N, H, S_qo, D] output
|
|
36
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
37
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
38
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1] output
|
|
39
|
+
float* softmax_max_ptr, // [B, N, H, S_qo, 1] output
|
|
40
|
+
const uint32_t B,
|
|
41
|
+
const uint32_t I,
|
|
42
|
+
const uint32_t H,
|
|
43
|
+
const uint32_t S_qo,
|
|
44
|
+
const uint32_t S_kv,
|
|
45
|
+
const uint32_t D,
|
|
46
|
+
const float bmm_scale,
|
|
47
|
+
bool use_tf32,
|
|
48
|
+
void* stream = nullptr);
|
|
49
|
+
|
|
50
|
+
__attribute__((visibility("default"))) void run_fmha_bwd_for_dtype(
|
|
51
|
+
Datatype dtype,
|
|
52
|
+
DType* do_ptr, // [B, N, H, S_qo, D]
|
|
53
|
+
DType* o_ptr, // [B, N, H, S_qo, D]
|
|
54
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
|
|
55
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
56
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
57
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
58
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
59
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
60
|
+
DType* dq_ptr, // [B, N, H, S_qo, D] output
|
|
61
|
+
DType* dk_ptr, // [B, N, H, S_kv, D] output
|
|
62
|
+
DType* dv_ptr, // [B, N, H, S_kv, D] output
|
|
63
|
+
float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
|
|
64
|
+
float* do_o_dot_ptr, // [B, N, H, S_qo, 1] worspace
|
|
65
|
+
float* dq_fp32_buf_ptr, // [B, N, H, S_qo, D] workspace
|
|
66
|
+
const uint32_t B,
|
|
67
|
+
const uint32_t I,
|
|
68
|
+
const uint32_t H,
|
|
69
|
+
const uint32_t S_qo,
|
|
70
|
+
const uint32_t S_kv,
|
|
71
|
+
const uint32_t D,
|
|
72
|
+
const float bmm_scale,
|
|
73
|
+
bool use_tf32,
|
|
74
|
+
void* stream,
|
|
75
|
+
bool zero_init_dbias_dq_buf = true);
|
|
76
|
+
|
|
77
|
+
// Shared sanity checking functions for kernel requirements
|
|
78
|
+
// These can be called from bindings before invoking kernel functions
|
|
79
|
+
__attribute__((visibility("default"))) void validate_fmha_fwd_params(
|
|
80
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
81
|
+
uint32_t D, // Head dimension
|
|
82
|
+
bool use_tf32); // Whether TF32 is requested
|
|
83
|
+
|
|
84
|
+
__attribute__((visibility("default"))) void validate_fmha_bwd_params(
|
|
85
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
86
|
+
bool use_tf32); // Whether TF32 is requested
|
|
87
|
+
|
|
88
|
+
// Shared kernel selection helpers
|
|
89
|
+
// Determine whether to use SM100 kernels based on conditions
|
|
90
|
+
__attribute__((visibility("default"))) bool should_use_sm100f_fwd(
|
|
91
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
92
|
+
bool has_triangle_bias_same_type, // Whether triangle_bias has same dtype as q/k/v
|
|
93
|
+
uint32_t D, // Head dimension
|
|
94
|
+
uint32_t S_kv, // Key/value sequence length
|
|
95
|
+
bool mask_consistent, // Whether mask_bias_ptr and actual_s_kv_ptr are both null or both non-null
|
|
96
|
+
void* stream, // CUDA stream (for device capability check)
|
|
97
|
+
const std::optional<std::vector<int>> device_cc =
|
|
98
|
+
std::nullopt); // Device capability (compute capability version)
|
|
99
|
+
|
|
100
|
+
__attribute__((visibility("default"))) bool should_use_sm100f_bwd(
|
|
101
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
102
|
+
bool has_triangle_bias_same_type, // Whether triangle_bias has same dtype as q/k/v
|
|
103
|
+
uint32_t D, // Head dimension
|
|
104
|
+
bool has_dbias_fp32_buf, // Whether dbias_fp32_buf_ptr is provided
|
|
105
|
+
void* stream, // CUDA stream (for device capability check)
|
|
106
|
+
const std::optional<std::vector<int>> device_cc =
|
|
107
|
+
std::nullopt); // Device capability (compute capability version)
|
|
108
|
+
|
|
109
|
+
// CUDA-free wrappers for SM100 kernels (can be called from non-CUDA code)
|
|
110
|
+
__attribute__((visibility("default"))) void run_fmha_sm100_for_dtype(
|
|
111
|
+
Datatype dtype,
|
|
112
|
+
void* q_ptr, // [B, N, H, S_qo, D]
|
|
113
|
+
void* k_ptr, // [B, N, H, S_kv, D]
|
|
114
|
+
void* v_ptr, // [B, N, H, S_kv, D]
|
|
115
|
+
void* o_ptr, // [B, N, H, S_qo, D] output
|
|
116
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv] (can be nullptr)
|
|
117
|
+
int* actual_s_kv_ptr, // [B, N] (can be nullptr, must match mask_bias_ptr)
|
|
118
|
+
void* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv] (same dtype as q/k/v)
|
|
119
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1] output
|
|
120
|
+
float* softmax_max_ptr, // [B, N, H, S_qo, 1] output
|
|
121
|
+
uint32_t B,
|
|
122
|
+
uint32_t I,
|
|
123
|
+
uint32_t H,
|
|
124
|
+
uint32_t S_qo,
|
|
125
|
+
uint32_t S_kv,
|
|
126
|
+
uint32_t D,
|
|
127
|
+
float bmm_scale,
|
|
128
|
+
void* stream);
|
|
129
|
+
|
|
130
|
+
__attribute__((visibility("default"))) void run_fmha_bwd_sm100_for_dtype(
|
|
131
|
+
Datatype dtype,
|
|
132
|
+
void* q_ptr, // [B, N, H, S_qo, D]
|
|
133
|
+
void* k_ptr, // [B, N, H, S_kv, D]
|
|
134
|
+
void* v_ptr, // [B, N, H, S_kv, D]
|
|
135
|
+
void* o_ptr, // [B, N, H, S_qo, D]
|
|
136
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv] (can be nullptr)
|
|
137
|
+
void* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv] (same dtype as q/k/v)
|
|
138
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
|
|
139
|
+
void* do_ptr, // [B, N, H, S_qo, D]
|
|
140
|
+
void* dq_ptr, // [B, N, H, S_qo, D] output
|
|
141
|
+
void* dk_ptr, // [B, N, H, S_kv, D] output
|
|
142
|
+
void* dv_ptr, // [B, N, H, S_kv, D] output
|
|
143
|
+
void* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output (same dtype as q/k/v)
|
|
144
|
+
float* do_o_dot_ptr, // [B, N, H, S_qo, 1] workspace
|
|
145
|
+
float* dq_fp32_buf_ptr, // [B, N, H, S_qo, D] workspace
|
|
146
|
+
float* dbias_fp32_buf_ptr, // [B, 1, H, padded_S_qo, padded_S_kv] workspace
|
|
147
|
+
uint32_t B,
|
|
148
|
+
uint32_t I,
|
|
149
|
+
uint32_t H,
|
|
150
|
+
uint32_t S_qo,
|
|
151
|
+
uint32_t S_kv,
|
|
152
|
+
uint32_t D,
|
|
153
|
+
float bmm_scale,
|
|
154
|
+
bool zero_init_dbias_dq_buf, // Whether to zero initialize dq_fp32_buf and dbias_fp32_buf
|
|
155
|
+
void* stream);
|
|
156
|
+
|
|
157
|
+
// Returns true if FP32 bias is needed for forward pass, false if same-dtype bias is used (SM100
|
|
158
|
+
// case) Takes minimal arguments needed to determine bias dtype requirement for forward pass
|
|
159
|
+
__attribute__((visibility("default"))) bool needs_fp32_bias_fwd(
|
|
160
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
161
|
+
uint32_t D, // Head dimension
|
|
162
|
+
uint32_t S_kv, // Key/value sequence length (for forward pass constraint)
|
|
163
|
+
void* stream); // CUDA stream (for device capability check)
|
|
164
|
+
|
|
165
|
+
// Returns true if FP32 bias input is needed for backward pass, false if same-dtype bias is used
|
|
166
|
+
// (SM100 case) Takes minimal arguments needed to determine bias input dtype requirement for
|
|
167
|
+
// backward pass.
|
|
168
|
+
__attribute__((visibility("default"))) bool needs_fp32_bias_bwd(
|
|
169
|
+
int bits, // Bit width of dtype (16 for FP16/BF16, 32 for FP32, 64 for FP64)
|
|
170
|
+
uint32_t D, // Head dimension
|
|
171
|
+
bool has_dbias_fp32_buf, // Whether dbias FP32 workspace buffer is provided (required for SM100)
|
|
172
|
+
void* stream); // CUDA stream (for device capability check)
|
|
173
|
+
|
|
174
|
+
} // namespace cudnn_fmha
|
|
175
|
+
|
|
176
|
+
#endif // CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cuda_fp16.h>
|
|
4
|
+
|
|
5
|
+
namespace cudnn_fmha {
|
|
6
|
+
|
|
7
|
+
template <typename DType>
|
|
8
|
+
void run_fmha_sm100(DType* q_ptr,
|
|
9
|
+
DType* k_ptr,
|
|
10
|
+
DType* v_ptr,
|
|
11
|
+
DType* o_ptr,
|
|
12
|
+
bool* mask_bias_ptr,
|
|
13
|
+
int* actual_s_kv_ptr,
|
|
14
|
+
DType* triangle_bias_ptr,
|
|
15
|
+
float* softmax_lse_ptr,
|
|
16
|
+
float* softmax_max_ptr,
|
|
17
|
+
const uint32_t B,
|
|
18
|
+
const uint32_t I,
|
|
19
|
+
const uint32_t H,
|
|
20
|
+
const uint32_t S_qo,
|
|
21
|
+
const uint32_t S_kv,
|
|
22
|
+
const uint32_t D,
|
|
23
|
+
const float bmm_scale,
|
|
24
|
+
void* stream_ = nullptr);
|
|
25
|
+
|
|
26
|
+
extern template void run_fmha_sm100<__nv_bfloat16>(__nv_bfloat16*,
|
|
27
|
+
__nv_bfloat16*,
|
|
28
|
+
__nv_bfloat16*,
|
|
29
|
+
__nv_bfloat16*,
|
|
30
|
+
bool*,
|
|
31
|
+
int*,
|
|
32
|
+
__nv_bfloat16*,
|
|
33
|
+
float*,
|
|
34
|
+
float*,
|
|
35
|
+
const uint32_t,
|
|
36
|
+
const uint32_t,
|
|
37
|
+
const uint32_t,
|
|
38
|
+
const uint32_t,
|
|
39
|
+
const uint32_t,
|
|
40
|
+
const uint32_t,
|
|
41
|
+
const float,
|
|
42
|
+
void*);
|
|
43
|
+
|
|
44
|
+
extern template void run_fmha_sm100<__half>(__half*,
|
|
45
|
+
__half*,
|
|
46
|
+
__half*,
|
|
47
|
+
__half*,
|
|
48
|
+
bool*,
|
|
49
|
+
int*,
|
|
50
|
+
__half*,
|
|
51
|
+
float*,
|
|
52
|
+
float*,
|
|
53
|
+
const uint32_t,
|
|
54
|
+
const uint32_t,
|
|
55
|
+
const uint32_t,
|
|
56
|
+
const uint32_t,
|
|
57
|
+
const uint32_t,
|
|
58
|
+
const uint32_t,
|
|
59
|
+
const float,
|
|
60
|
+
void*);
|
|
61
|
+
|
|
62
|
+
template <typename DType>
|
|
63
|
+
void run_fmha_bwd_sm100(DType* q_ptr,
|
|
64
|
+
DType* k_ptr,
|
|
65
|
+
DType* v_ptr,
|
|
66
|
+
DType* o_ptr,
|
|
67
|
+
bool* mask_bias_ptr,
|
|
68
|
+
DType* triangle_bias_ptr,
|
|
69
|
+
float* softmax_lse_ptr,
|
|
70
|
+
DType* do_ptr,
|
|
71
|
+
DType* dq_ptr,
|
|
72
|
+
DType* dk_ptr,
|
|
73
|
+
DType* dv_ptr,
|
|
74
|
+
DType* dtriangle_bias_ptr,
|
|
75
|
+
float* do_o_dot_tmp_ptr,
|
|
76
|
+
float* dq_tmp_ptr, // must be zero initialized
|
|
77
|
+
float* dtriangle_bias_tmp_ptr, // must be zero initialized
|
|
78
|
+
const uint32_t B,
|
|
79
|
+
const uint32_t I,
|
|
80
|
+
const uint32_t H,
|
|
81
|
+
const uint32_t S_qo,
|
|
82
|
+
const uint32_t S_kv,
|
|
83
|
+
const uint32_t D,
|
|
84
|
+
const float bmm_scale,
|
|
85
|
+
void* stream_ = nullptr);
|
|
86
|
+
|
|
87
|
+
extern template void run_fmha_bwd_sm100<__nv_bfloat16>(__nv_bfloat16*,
|
|
88
|
+
__nv_bfloat16*,
|
|
89
|
+
__nv_bfloat16*,
|
|
90
|
+
__nv_bfloat16*,
|
|
91
|
+
bool*,
|
|
92
|
+
__nv_bfloat16*,
|
|
93
|
+
float*,
|
|
94
|
+
__nv_bfloat16*,
|
|
95
|
+
__nv_bfloat16*,
|
|
96
|
+
__nv_bfloat16*,
|
|
97
|
+
__nv_bfloat16*,
|
|
98
|
+
__nv_bfloat16*,
|
|
99
|
+
float*,
|
|
100
|
+
float*,
|
|
101
|
+
float*,
|
|
102
|
+
const uint32_t,
|
|
103
|
+
const uint32_t,
|
|
104
|
+
const uint32_t,
|
|
105
|
+
const uint32_t,
|
|
106
|
+
const uint32_t,
|
|
107
|
+
const uint32_t,
|
|
108
|
+
const float,
|
|
109
|
+
void* stream_);
|
|
110
|
+
|
|
111
|
+
extern template void run_fmha_bwd_sm100<__half>(__half*,
|
|
112
|
+
__half*,
|
|
113
|
+
__half*,
|
|
114
|
+
__half*,
|
|
115
|
+
bool*,
|
|
116
|
+
__half*,
|
|
117
|
+
float*,
|
|
118
|
+
__half*,
|
|
119
|
+
__half*,
|
|
120
|
+
__half*,
|
|
121
|
+
__half*,
|
|
122
|
+
__half*,
|
|
123
|
+
float*,
|
|
124
|
+
float*,
|
|
125
|
+
float*,
|
|
126
|
+
const uint32_t,
|
|
127
|
+
const uint32_t,
|
|
128
|
+
const uint32_t,
|
|
129
|
+
const uint32_t,
|
|
130
|
+
const uint32_t,
|
|
131
|
+
const uint32_t,
|
|
132
|
+
const float,
|
|
133
|
+
void* stream_);
|
|
134
|
+
|
|
135
|
+
} // namespace cudnn_fmha
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "../common/common.hpp"
|
|
12
|
+
|
|
13
|
+
namespace kernelcatcher::tensor_product {
|
|
14
|
+
|
|
15
|
+
template <typename DataT>
|
|
16
|
+
void segmented_transpose(DataT* tensor_transpose,
|
|
17
|
+
const DataT* tensor,
|
|
18
|
+
const int32_t* segment_info,
|
|
19
|
+
int32_t num_segments,
|
|
20
|
+
int64_t batch_size,
|
|
21
|
+
int64_t stride,
|
|
22
|
+
bool input_contiguous_as_info,
|
|
23
|
+
cudaStream_t stream);
|
|
24
|
+
|
|
25
|
+
extern template void segmented_transpose<float>(
|
|
26
|
+
float*, const float*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
|
|
27
|
+
extern template void segmented_transpose<double>(
|
|
28
|
+
double*, const double*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
|
|
29
|
+
extern template void segmented_transpose<__nv_bfloat16>(__nv_bfloat16*,
|
|
30
|
+
const __nv_bfloat16*,
|
|
31
|
+
const int32_t*,
|
|
32
|
+
int32_t,
|
|
33
|
+
int64_t,
|
|
34
|
+
int64_t,
|
|
35
|
+
bool,
|
|
36
|
+
cudaStream_t);
|
|
37
|
+
extern template void segmented_transpose<__half>(
|
|
38
|
+
__half*, const __half*, const int32_t*, int32_t, int64_t, int64_t, bool, cudaStream_t);
|
|
39
|
+
|
|
40
|
+
} // namespace kernelcatcher::tensor_product
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "batch_dimension.hh" // for BatchDimension
|
|
12
|
+
#include "dtypes.hh" // for Datatype
|
|
13
|
+
#include <cstdint>
|
|
14
|
+
#include <string>
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit {
|
|
18
|
+
using namespace kernelcatcher::utils;
|
|
19
|
+
|
|
20
|
+
enum class Dimension : int { kScalar = 0, kOneDimensional = 1 };
|
|
21
|
+
|
|
22
|
+
#define KC_UNIFORM_1D_DECL_ARGUMENTS \
|
|
23
|
+
std::string const &name, Datatype math_dtype, int operand_extent, int num_inputs, \
|
|
24
|
+
int num_outputs, int num_index, std::vector<Dimension> const &buffer_dim, \
|
|
25
|
+
std::vector<int> const &buffer_num_segments, \
|
|
26
|
+
std::vector<std::vector<BatchDimension>> const &batch_dim, \
|
|
27
|
+
std::vector<std::vector<int>> const &index_buffer, std::vector<int> const &index_extent, \
|
|
28
|
+
std::vector<Datatype> const &dtypes, std::vector<std::vector<int>> const &operations, \
|
|
29
|
+
std::vector<int> const &num_paths, std::vector<int> const &path_indices_start, \
|
|
30
|
+
std::vector<int> const &path_coefficients_start, std::vector<int> const &path_indices, \
|
|
31
|
+
std::vector<double> const &path_coefficients, std::vector<int> const &batch_sizes, \
|
|
32
|
+
std::vector<void*> const &buffers, std::vector<size_t> const &buffer_bytes, \
|
|
33
|
+
bool zero_output_buffers
|
|
34
|
+
|
|
35
|
+
extern int run_tensor_product_uniform_1d_jit(KC_UNIFORM_1D_DECL_ARGUMENTS, void* stream);
|
|
36
|
+
extern int run_tensor_product_uniform_1d_cpu(KC_UNIFORM_1D_DECL_ARGUMENTS);
|
|
37
|
+
|
|
38
|
+
} // namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
/*
|
|
10
|
+
* GPU Events Implementation Attribution
|
|
11
|
+
*
|
|
12
|
+
* The CUDA event recording and timing functionality (record_event, event_elapsed)
|
|
13
|
+
* has been adapted from JAX's GPU events implementation that was removed in version 0.7.2.
|
|
14
|
+
*
|
|
15
|
+
* Original source: https://github.com/jax-ml/jax/
|
|
16
|
+
* License: Apache License 2.0
|
|
17
|
+
*
|
|
18
|
+
* JAX Copyright 2018 The JAX Authors.
|
|
19
|
+
* Licensed under the Apache License, Version 2.0.
|
|
20
|
+
*/
|
|
21
|
+
|
|
22
|
+
#pragma once
|
|
23
|
+
|
|
24
|
+
#include <cstdint>
|
|
25
|
+
|
|
26
|
+
namespace kernelcatcher::gpu_utilities {
|
|
27
|
+
|
|
28
|
+
int run_sleep(float* seconds, int64_t* elapsed_ticks, void* stream);
|
|
29
|
+
int run_synchronize(float* elapsed_seconds, void* stream);
|
|
30
|
+
|
|
31
|
+
// Record a CUDA event on the given stream
|
|
32
|
+
// Returns 0 on success, non-zero on error
|
|
33
|
+
int record_event(uint64_t* event_handle, void* stream, bool copy_before);
|
|
34
|
+
|
|
35
|
+
// Calculate elapsed time between two events
|
|
36
|
+
// Returns 0 on success, non-zero on error
|
|
37
|
+
int event_elapsed(const uint64_t* start_event_handle,
|
|
38
|
+
const uint64_t* end_event_handle,
|
|
39
|
+
float* elapsed_ms,
|
|
40
|
+
void* stream);
|
|
41
|
+
|
|
42
|
+
} // namespace kernelcatcher::gpu_utilities
|
|
Binary file
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
/*
|
|
12
|
+
* DEPRECATED HEADER: This header has been renamed to gpu_timing_kernels.hh
|
|
13
|
+
*
|
|
14
|
+
* This compatibility header provides backward compatibility for code that includes
|
|
15
|
+
* the old sleep.hh header name. The header has been renamed to better reflect
|
|
16
|
+
* its content (GPU timing and synchronization kernels).
|
|
17
|
+
*
|
|
18
|
+
* For new code, please use:
|
|
19
|
+
* #include "gpu_timing_kernels.hh"
|
|
20
|
+
*
|
|
21
|
+
* OLD (deprecated but still works):
|
|
22
|
+
* #include "sleep.hh"
|
|
23
|
+
*
|
|
24
|
+
* NEW (recommended):
|
|
25
|
+
* #include "gpu_timing_kernels.hh"
|
|
26
|
+
*
|
|
27
|
+
* This compatibility header will be removed in a future version.
|
|
28
|
+
*/
|
|
29
|
+
|
|
30
|
+
#warning \
|
|
31
|
+
"The 'sleep.hh' header has been renamed to 'gpu_timing_kernels.hh'. Please update your #include statements. This compatibility header will be removed in a future version."
|
|
32
|
+
|
|
33
|
+
// Include the new header to maintain compatibility
|
|
34
|
+
#include "gpu_timing_kernels.hh"
|
|
35
|
+
|
|
36
|
+
// Provide backward compatibility aliases
|
|
37
|
+
namespace kernelcatcher::sleep {
|
|
38
|
+
using kernelcatcher::gpu_utilities::run_sleep;
|
|
39
|
+
using kernelcatcher::gpu_utilities::run_synchronize;
|
|
40
|
+
} // namespace kernelcatcher::sleep
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
from .fused_layer_norm_triton import (
|
|
12
|
+
Layout,
|
|
13
|
+
layer_norm_transpose_backward_kernel,
|
|
14
|
+
layer_norm_transpose_backward_single_pass_kernel,
|
|
15
|
+
layer_norm_transpose_forward_kernel,
|
|
16
|
+
layer_norm_transpose_forward_single_pass_kernel,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .gated_gemm_triton import (
|
|
20
|
+
fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel,
|
|
21
|
+
fused_sigmoid_gated_dual_gemm_forward_kernel,
|
|
22
|
+
)
|
|
23
|
+
from .utils import Precision
|
|
24
|
+
from .tuning_decorator import autotune_aot
|
|
25
|
+
from .cache_manager import get_cache_manager
|
|
26
|
+
|
|
27
|
+
cached_kernels = [ "fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel",
|
|
28
|
+
"fused_sigmoid_gated_dual_gemm_forward_kernel",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
def init_triton_cache():
|
|
32
|
+
"""
|
|
33
|
+
Initializes Triton cache manager by pre-loading cache for all available kernels.
|
|
34
|
+
This function is useful to initialize cache in eager mode before running torch.compile()'d methods
|
|
35
|
+
that cannot handle cache initialization code
|
|
36
|
+
"""
|
|
37
|
+
mgr = get_cache_manager()
|
|
38
|
+
for kernel in cached_kernels:
|
|
39
|
+
mgr.load_cache(kernel+'_wrapper')
|
|
40
|
+
|
|
41
|
+
from .utils import (
|
|
42
|
+
Precision,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
from .pair_bias import (
|
|
46
|
+
pair_bias_norm_linear_mask_forward_kernel,
|
|
47
|
+
pair_bias_linear_mask_forward_kernel,
|
|
48
|
+
pair_bias_mask_forward_kernel,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"Precision",
|
|
54
|
+
"fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel",
|
|
55
|
+
"fused_sigmoid_gated_dual_gemm_forward_kernel",
|
|
56
|
+
"layer_norm_transpose_backward_kernel",
|
|
57
|
+
"layer_norm_transpose_backward_single_pass_kernel",
|
|
58
|
+
"layer_norm_transpose_forward_kernel",
|
|
59
|
+
"layer_norm_transpose_forward_single_pass_kernel",
|
|
60
|
+
"pair_bias_norm_linear_mask_forward_kernel",
|
|
61
|
+
"pair_bias_linear_mask_forward_kernel",
|
|
62
|
+
"pair_bias_mask_forward_kernel",
|
|
63
|
+
"autotune_aot",
|
|
64
|
+
"get_cache_manager",
|
|
65
|
+
"init_triton_cache"
|
|
66
|
+
] + cached_kernels
|