cuequivariance-ops-cu12 0.4.0__py3-none-manylinux_2_39_aarch64.whl → 0.5.1__py3-none-manylinux_2_39_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.
- cuequivariance_ops/VERSION +1 -1
- cuequivariance_ops/__init__.py +3 -2
- cuequivariance_ops/equivariance/dtypes.hh +21 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +17 -35
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/triton/__init__.py +29 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +244 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +324 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +340 -0
- cuequivariance_ops/triton/tuning_decorator.py +272 -0
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/METADATA +5 -1
- cuequivariance_ops_cu12-0.5.1.dist-info/RECORD +32 -0
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/WHEEL +1 -1
- cuequivariance_ops_cu12-0.4.0.dist-info/RECORD +0 -13
- {cuequivariance_ops_cu12-0.4.0.dist-info → cuequivariance_ops_cu12-0.5.1.dist-info}/licenses/LICENSE +0 -0
cuequivariance_ops/VERSION
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.
|
|
1
|
+
0.5.1
|
cuequivariance_ops/__init__.py
CHANGED
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
from ._version import __version__, __git_commit__
|
|
12
12
|
import os
|
|
13
|
+
import sys
|
|
13
14
|
import ctypes
|
|
14
15
|
|
|
15
16
|
PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL
|
|
@@ -32,8 +33,8 @@ def load_library():
|
|
|
32
33
|
ctypes.CDLL(
|
|
33
34
|
os.path.join(root_dir(), "lib/libcue_ops.so"), mode=PREFERRED_LOAD_FLAG
|
|
34
35
|
)
|
|
35
|
-
except Exception:
|
|
36
|
-
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"Error while loading libcue_ops.so: {e}", file=sys.stderr)
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
load_library()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
namespace kernelcatcher::utils {
|
|
11
|
+
|
|
12
|
+
enum class Datatype : int {
|
|
13
|
+
kFloat32 = 0,
|
|
14
|
+
kFloat64 = 1,
|
|
15
|
+
kFloat16 = 2,
|
|
16
|
+
kBFloat16 = 3,
|
|
17
|
+
kInt32 = 4,
|
|
18
|
+
kInt64 = 5
|
|
19
|
+
};
|
|
20
|
+
|
|
21
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "dtypes.hh" // for Datatype
|
|
12
|
+
|
|
13
|
+
namespace kernelcatcher::equivariance::indexed_linear {
|
|
14
|
+
using namespace kernelcatcher::utils; // for Datatype
|
|
15
|
+
|
|
16
|
+
#define KC_INDEXED_LINEAR_DECL_ARGUMENTS \
|
|
17
|
+
const void *ptr_A, const void *ptr_B, const int *counts, void *ptr_C, Datatype dtype_A, \
|
|
18
|
+
Datatype dtype_B, Datatype dtype_D, int Z, int C, int u, int v, double coefficient, \
|
|
19
|
+
Datatype math_dtype, void *workspace, size_t workspace_size, void *stream
|
|
20
|
+
|
|
21
|
+
#define KC_INDEXED_LINEAR_ARGUMENTS \
|
|
22
|
+
ptr_A, ptr_B, counts, ptr_C, dtype_A, dtype_B, dtype_D, Z, C, u, v, coefficient, math_dtype, \
|
|
23
|
+
workspace, workspace_size, stream
|
|
24
|
+
|
|
25
|
+
int run_indexed_linear_B( // ptr_A Zu
|
|
26
|
+
// ptr_B Cuv or Cvu
|
|
27
|
+
// ptr_C Zv
|
|
28
|
+
bool transpose_B,
|
|
29
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
30
|
+
|
|
31
|
+
int run_indexed_linear_C( // ptr_A Zu
|
|
32
|
+
// ptr_B Zv
|
|
33
|
+
// ptr_C Cuv
|
|
34
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
35
|
+
|
|
36
|
+
} // namespace kernelcatcher::equivariance::indexed_linear
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
#ifndef CUDNN_FMHA_RUN_FMHA_H
|
|
2
|
+
#define CUDNN_FMHA_RUN_FMHA_H
|
|
3
|
+
|
|
4
|
+
#include <cstdint> // for uint32_t
|
|
5
|
+
#include <cuda_fp16.h>
|
|
6
|
+
#include <cuda_fp8.h>
|
|
7
|
+
|
|
8
|
+
namespace cudnn_fmha {
|
|
9
|
+
using DataType_TriBias = float;
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* @brief Performs Flash Multi-Head Attention computation on GPU using cuDNN
|
|
13
|
+
* @tparam DType Data type for the computation (float, __half, or __nv_bfloat16)
|
|
14
|
+
*/
|
|
15
|
+
template <typename DType>
|
|
16
|
+
void run_fmha(DType* q_ptr,
|
|
17
|
+
DType* k_ptr,
|
|
18
|
+
DType* v_ptr,
|
|
19
|
+
DType* o_ptr,
|
|
20
|
+
bool* mask_bias_ptr,
|
|
21
|
+
DataType_TriBias* triangle_bias_ptr,
|
|
22
|
+
float* softmax_lse_ptr,
|
|
23
|
+
float* softmax_max_ptr,
|
|
24
|
+
const uint32_t B,
|
|
25
|
+
const uint32_t I,
|
|
26
|
+
const uint32_t H,
|
|
27
|
+
const uint32_t S_qo,
|
|
28
|
+
const uint32_t S_kv,
|
|
29
|
+
const uint32_t D,
|
|
30
|
+
const float bmm_scale,
|
|
31
|
+
bool use_tf32,
|
|
32
|
+
void* stream = nullptr);
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* @brief Performs the backward pass of Flash Multi-Head Attention computation on GPU using cuDNN
|
|
36
|
+
* Note: Backward pass remains in float before fp16/bf16 integration
|
|
37
|
+
*/
|
|
38
|
+
template <typename DType>
|
|
39
|
+
void run_fmha_bwd(DType* do_ptr, // [B, N, H, S_qo, D]
|
|
40
|
+
DType* o_ptr, // [B, N, H, S_qo, D]
|
|
41
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
|
|
42
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
43
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
44
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
45
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
46
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
47
|
+
DType* dq_ptr, // [B, N, H, S_qo, D] output
|
|
48
|
+
DType* dk_ptr, // [B, N, H, S_kv, D] output
|
|
49
|
+
DType* dv_ptr, // [B, N, H, S_kv, D] output
|
|
50
|
+
float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
|
|
51
|
+
float* do_o_dot_ptr,
|
|
52
|
+
float* dq_fp32_buf, // [B, N, H, S_qo, D] worspace
|
|
53
|
+
const uint32_t B,
|
|
54
|
+
const uint32_t I,
|
|
55
|
+
const uint32_t H,
|
|
56
|
+
const uint32_t S_qo,
|
|
57
|
+
const uint32_t S_kv,
|
|
58
|
+
const uint32_t D,
|
|
59
|
+
const float bmm_scale,
|
|
60
|
+
bool use_tf32,
|
|
61
|
+
void* stream = nullptr);
|
|
62
|
+
|
|
63
|
+
// Explicit template declarations for supported types
|
|
64
|
+
extern template void run_fmha<float>(float*,
|
|
65
|
+
float*,
|
|
66
|
+
float*,
|
|
67
|
+
float*,
|
|
68
|
+
bool*,
|
|
69
|
+
DataType_TriBias*,
|
|
70
|
+
float*,
|
|
71
|
+
float*,
|
|
72
|
+
uint32_t,
|
|
73
|
+
uint32_t,
|
|
74
|
+
uint32_t,
|
|
75
|
+
uint32_t,
|
|
76
|
+
uint32_t,
|
|
77
|
+
uint32_t,
|
|
78
|
+
float,
|
|
79
|
+
bool,
|
|
80
|
+
void*);
|
|
81
|
+
|
|
82
|
+
extern template void run_fmha<__half>(__half*,
|
|
83
|
+
__half*,
|
|
84
|
+
__half*,
|
|
85
|
+
__half*,
|
|
86
|
+
bool*,
|
|
87
|
+
DataType_TriBias*,
|
|
88
|
+
float*,
|
|
89
|
+
float*,
|
|
90
|
+
uint32_t,
|
|
91
|
+
uint32_t,
|
|
92
|
+
uint32_t,
|
|
93
|
+
uint32_t,
|
|
94
|
+
uint32_t,
|
|
95
|
+
uint32_t,
|
|
96
|
+
float,
|
|
97
|
+
bool,
|
|
98
|
+
void*);
|
|
99
|
+
|
|
100
|
+
extern template void run_fmha<__nv_bfloat16>(__nv_bfloat16*,
|
|
101
|
+
__nv_bfloat16*,
|
|
102
|
+
__nv_bfloat16*,
|
|
103
|
+
__nv_bfloat16*,
|
|
104
|
+
bool*,
|
|
105
|
+
DataType_TriBias*,
|
|
106
|
+
float*,
|
|
107
|
+
float*,
|
|
108
|
+
uint32_t,
|
|
109
|
+
uint32_t,
|
|
110
|
+
uint32_t,
|
|
111
|
+
uint32_t,
|
|
112
|
+
uint32_t,
|
|
113
|
+
uint32_t,
|
|
114
|
+
float,
|
|
115
|
+
bool,
|
|
116
|
+
void*);
|
|
117
|
+
|
|
118
|
+
extern template void run_fmha_bwd<__half>(__half* do_ptr,
|
|
119
|
+
__half* o_ptr,
|
|
120
|
+
float* softmax_lse_ptr,
|
|
121
|
+
__half* q_ptr,
|
|
122
|
+
__half* k_ptr,
|
|
123
|
+
__half* v_ptr,
|
|
124
|
+
bool* mask_bias_ptr,
|
|
125
|
+
float* triangle_bias_ptr,
|
|
126
|
+
__half* dq_ptr,
|
|
127
|
+
__half* dk_ptr,
|
|
128
|
+
__half* dv_ptr,
|
|
129
|
+
float* triangle_dbias_ptr,
|
|
130
|
+
float* do_o_dot_ptr,
|
|
131
|
+
float* dq_fp32_buf,
|
|
132
|
+
const uint32_t B,
|
|
133
|
+
const uint32_t I,
|
|
134
|
+
const uint32_t H,
|
|
135
|
+
const uint32_t S_qo,
|
|
136
|
+
const uint32_t S_kv,
|
|
137
|
+
const uint32_t D,
|
|
138
|
+
const float bmm_scale,
|
|
139
|
+
bool use_tf32,
|
|
140
|
+
void* stream);
|
|
141
|
+
|
|
142
|
+
extern template void run_fmha_bwd<__nv_bfloat16>(__nv_bfloat16* do_ptr,
|
|
143
|
+
__nv_bfloat16* o_ptr,
|
|
144
|
+
float* softmax_lse_ptr,
|
|
145
|
+
__nv_bfloat16* q_ptr,
|
|
146
|
+
__nv_bfloat16* k_ptr,
|
|
147
|
+
__nv_bfloat16* v_ptr,
|
|
148
|
+
bool* mask_bias_ptr,
|
|
149
|
+
float* triangle_bias_ptr,
|
|
150
|
+
__nv_bfloat16* dq_ptr,
|
|
151
|
+
__nv_bfloat16* dk_ptr,
|
|
152
|
+
__nv_bfloat16* dv_ptr,
|
|
153
|
+
float* triangle_dbias_ptr,
|
|
154
|
+
float* do_o_dot_ptr,
|
|
155
|
+
float* dq_fp32_buf,
|
|
156
|
+
const uint32_t B,
|
|
157
|
+
const uint32_t I,
|
|
158
|
+
const uint32_t H,
|
|
159
|
+
const uint32_t S_qo,
|
|
160
|
+
const uint32_t S_kv,
|
|
161
|
+
const uint32_t D,
|
|
162
|
+
const float bmm_scale,
|
|
163
|
+
bool use_tf32,
|
|
164
|
+
void* stream);
|
|
165
|
+
|
|
166
|
+
extern template void run_fmha_bwd<float>(float* do_ptr,
|
|
167
|
+
float* o_ptr,
|
|
168
|
+
float* softmax_lse_ptr,
|
|
169
|
+
float* q_ptr,
|
|
170
|
+
float* k_ptr,
|
|
171
|
+
float* v_ptr,
|
|
172
|
+
bool* mask_bias_ptr,
|
|
173
|
+
float* triangle_bias_ptr,
|
|
174
|
+
float* dq_ptr,
|
|
175
|
+
float* dk_ptr,
|
|
176
|
+
float* dv_ptr,
|
|
177
|
+
float* triangle_dbias_ptr,
|
|
178
|
+
float* do_o_dot_ptr,
|
|
179
|
+
float* dq_fp32_buf,
|
|
180
|
+
const uint32_t B,
|
|
181
|
+
const uint32_t I,
|
|
182
|
+
const uint32_t H,
|
|
183
|
+
const uint32_t S_qo,
|
|
184
|
+
const uint32_t S_kv,
|
|
185
|
+
const uint32_t D,
|
|
186
|
+
const float bmm_scale,
|
|
187
|
+
bool use_tf32,
|
|
188
|
+
void* stream);
|
|
189
|
+
|
|
190
|
+
} // namespace cudnn_fmha
|
|
191
|
+
|
|
192
|
+
#endif // CUDNN_FMHA_RUN_FMHA_H
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
4
|
+
*
|
|
5
|
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
6
|
+
* property and proprietary rights in and to this material, related
|
|
7
|
+
* documentation and any modifications thereto. Any use, reproduction,
|
|
8
|
+
* disclosure or distribution of this material and related documentation
|
|
9
|
+
* without an express license agreement from NVIDIA CORPORATION or
|
|
10
|
+
* its affiliates is strictly prohibited.
|
|
11
|
+
*/
|
|
12
|
+
#ifndef CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
13
|
+
#define CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
14
|
+
|
|
15
|
+
#include <cstdint> // for uint32_t
|
|
16
|
+
|
|
17
|
+
namespace cudnn_fmha {
|
|
18
|
+
using DType = void;
|
|
19
|
+
|
|
20
|
+
enum class Datatype : uint32_t {
|
|
21
|
+
kFloat32 = 0,
|
|
22
|
+
kFloat64 = 1,
|
|
23
|
+
kFloat16 = 2,
|
|
24
|
+
kBFloat16 = 3,
|
|
25
|
+
kInt32 = 4,
|
|
26
|
+
kInt64 = 5
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
__attribute__((visibility("default"))) void run_fmha_for_dtype(
|
|
30
|
+
Datatype dtype,
|
|
31
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
32
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
33
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
34
|
+
DType* o_ptr, // [B, N, H, S_qo, D] output
|
|
35
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
36
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
37
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1] output
|
|
38
|
+
float* softmax_max_ptr, // [B, N, H, S_qo, 1] output
|
|
39
|
+
const uint32_t B,
|
|
40
|
+
const uint32_t I,
|
|
41
|
+
const uint32_t H,
|
|
42
|
+
const uint32_t S_qo,
|
|
43
|
+
const uint32_t S_kv,
|
|
44
|
+
const uint32_t D,
|
|
45
|
+
const float bmm_scale,
|
|
46
|
+
bool use_tf32,
|
|
47
|
+
void* stream = nullptr);
|
|
48
|
+
|
|
49
|
+
__attribute__((visibility("default"))) void run_fmha_bwd_for_dtype(
|
|
50
|
+
Datatype dtype,
|
|
51
|
+
DType* do_ptr, // [B, N, H, S_qo, D]
|
|
52
|
+
DType* o_ptr, // [B, N, H, S_qo, D]
|
|
53
|
+
float* softmax_lse_ptr, // [B, N, H, S_qo, 1]
|
|
54
|
+
DType* q_ptr, // [B, N, H, S_qo, D]
|
|
55
|
+
DType* k_ptr, // [B, N, H, S_kv, D]
|
|
56
|
+
DType* v_ptr, // [B, N, H, S_kv, D]
|
|
57
|
+
bool* mask_bias_ptr, // [B, N, 1, 1, S_kv]
|
|
58
|
+
float* triangle_bias_ptr, // [B, 1, H, S_qo, S_kv]
|
|
59
|
+
DType* dq_ptr, // [B, N, H, S_qo, D] output
|
|
60
|
+
DType* dk_ptr, // [B, N, H, S_kv, D] output
|
|
61
|
+
DType* dv_ptr, // [B, N, H, S_kv, D] output
|
|
62
|
+
float* triangle_dbias_ptr, // [B, 1, H, S_qo, S_kv] output
|
|
63
|
+
float* do_o_dot_ptr, // [B, N, H, S_qo, 1] worspace
|
|
64
|
+
float* dq_fp32_buf_ptr, // [B, N, H, S_qo, D] workspace
|
|
65
|
+
const uint32_t B,
|
|
66
|
+
const uint32_t I,
|
|
67
|
+
const uint32_t H,
|
|
68
|
+
const uint32_t S_qo,
|
|
69
|
+
const uint32_t S_kv,
|
|
70
|
+
const uint32_t D,
|
|
71
|
+
const float bmm_scale,
|
|
72
|
+
bool use_tf32,
|
|
73
|
+
void* stream);
|
|
74
|
+
|
|
75
|
+
} // namespace cudnn_fmha
|
|
76
|
+
|
|
77
|
+
#endif // CUDNN_FMHA_RUN_FMHA_CUDAFREE_H
|
|
@@ -8,49 +8,31 @@
|
|
|
8
8
|
|
|
9
9
|
#pragma once
|
|
10
10
|
|
|
11
|
+
#include "dtypes.hh" // for Datatype
|
|
11
12
|
#include <cstdint>
|
|
12
13
|
#include <string>
|
|
13
14
|
#include <vector>
|
|
14
15
|
|
|
15
16
|
namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit {
|
|
17
|
+
using namespace kernelcatcher::utils;
|
|
16
18
|
|
|
17
|
-
enum class Datatype : int {
|
|
18
|
-
kFloat32 = 0,
|
|
19
|
-
kFloat64 = 1,
|
|
20
|
-
kFloat16 = 2,
|
|
21
|
-
kBFloat16 = 3,
|
|
22
|
-
kInt32 = 4,
|
|
23
|
-
kInt64 = 5
|
|
24
|
-
};
|
|
25
19
|
enum class Dimension : int { kScalar = 0, kOneDimensional = 1 };
|
|
26
20
|
enum class BatchDimension : int { kBatched = 0, kShared = 1, kIndexed = 2 };
|
|
27
21
|
|
|
28
|
-
|
|
29
|
-
std::string const&
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
std::vector<Datatype> const& dtypes, // num_inputs + num_outputs + num_index
|
|
44
|
-
std::vector<std::vector<int>> const& operations,
|
|
45
|
-
std::vector<int> num_paths, // num_operations
|
|
46
|
-
std::vector<int> path_indices_start, // num_operations
|
|
47
|
-
std::vector<int> path_coefficients_start, // num_operations
|
|
48
|
-
std::vector<int> const& path_indices,
|
|
49
|
-
std::vector<double> const& path_coefficients,
|
|
50
|
-
std::vector<int> const& batch_sizes, // num_batch_axes
|
|
51
|
-
std::vector<void*> const& buffers,
|
|
52
|
-
std::vector<size_t> const& buffer_bytes,
|
|
53
|
-
bool zero_output_buffers,
|
|
54
|
-
void* stream);
|
|
22
|
+
#define KC_UNIFORM_1D_DECL_ARGUMENTS \
|
|
23
|
+
std::string const &name, Datatype math_dtype, int operand_extent, int num_inputs, \
|
|
24
|
+
int num_outputs, int num_index, std::vector<Dimension> const &buffer_dim, \
|
|
25
|
+
std::vector<int> const &buffer_num_segments, \
|
|
26
|
+
std::vector<std::vector<BatchDimension>> const &batch_dim, \
|
|
27
|
+
std::vector<std::vector<int>> const &index_buffer, std::vector<int> const &index_extent, \
|
|
28
|
+
std::vector<Datatype> const &dtypes, std::vector<std::vector<int>> const &operations, \
|
|
29
|
+
std::vector<int> num_paths, std::vector<int> path_indices_start, \
|
|
30
|
+
std::vector<int> path_coefficients_start, std::vector<int> const &path_indices, \
|
|
31
|
+
std::vector<double> const &path_coefficients, std::vector<int> const &batch_sizes, \
|
|
32
|
+
std::vector<void*> const &buffers, std::vector<size_t> const &buffer_bytes, \
|
|
33
|
+
bool zero_output_buffers
|
|
34
|
+
|
|
35
|
+
extern int run_tensor_product_uniform_1d_jit(KC_UNIFORM_1D_DECL_ARGUMENTS, void* stream);
|
|
36
|
+
extern int run_tensor_product_uniform_1d_cpu(KC_UNIFORM_1D_DECL_ARGUMENTS);
|
|
55
37
|
|
|
56
38
|
} // namespace kernelcatcher::equivariance::tensor_product_uniform_1d_jit
|
|
Binary file
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
from .fused_layer_norm_triton import (
|
|
12
|
+
Layout,
|
|
13
|
+
layer_norm_transpose_backward_kernel,
|
|
14
|
+
layer_norm_transpose_forward_kernel,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .gated_gemm_triton import (
|
|
18
|
+
Precision,
|
|
19
|
+
fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel,
|
|
20
|
+
fused_sigmoid_gated_dual_gemm_forward_kernel,
|
|
21
|
+
)
|
|
22
|
+
from .tuning_decorator import autotune_aot
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"Precision",
|
|
26
|
+
"fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel",
|
|
27
|
+
"fused_sigmoid_gated_dual_gemm_forward_kernel",
|
|
28
|
+
"autotune_aot",
|
|
29
|
+
]
|