cuequivariance-ops-cu12 0.6.0__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (37) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/nvtx.hpp +29 -0
  6. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  7. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  8. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  9. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  10. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  11. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  12. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  13. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  14. cuequivariance_ops/lib/libcue_ops.so +0 -0
  15. cuequivariance_ops/sleep.hh +18 -0
  16. cuequivariance_ops/triton/__init__.py +66 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  20. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  21. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  27. cuequivariance_ops/triton/cache_manager.py +259 -0
  28. cuequivariance_ops/triton/fused_layer_norm_triton.py +518 -0
  29. cuequivariance_ops/triton/gated_gemm_triton.py +380 -0
  30. cuequivariance_ops/triton/pair_bias.py +324 -0
  31. cuequivariance_ops/triton/tuning_decorator.py +177 -0
  32. cuequivariance_ops/triton/utils.py +28 -0
  33. cuequivariance_ops_cu12-0.6.0.dist-info/METADATA +182 -0
  34. cuequivariance_ops_cu12-0.6.0.dist-info/RECORD +37 -0
  35. cuequivariance_ops_cu12-0.6.0.dist-info/WHEEL +6 -0
  36. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/LICENSE +142 -0
  37. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/Third_party_attr.txt +24 -0
@@ -0,0 +1 @@
1
+ 0.6.0
@@ -0,0 +1,42 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ from ._version import __version__, __git_commit__
12
+ import os
13
+ import sys
14
+ import ctypes
15
+
16
+ PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL
17
+
18
+
19
+ def root_dir():
20
+ try:
21
+ import importlib.metadata
22
+
23
+ dist = importlib.metadata.distribution("cuequivariance_ops")
24
+ root = dist.locate_file("cuequivariance_ops")
25
+ except Exception:
26
+ # last resort, will fail with writeable install
27
+ root = os.path.dirname(__file__)
28
+ return root
29
+
30
+
31
+ def load_library():
32
+ try:
33
+ ctypes.CDLL(
34
+ os.path.join(root_dir(), "lib/libcue_ops.so"), mode=PREFERRED_LOAD_FLAG
35
+ )
36
+ except Exception as e:
37
+ print(f"Error while loading libcue_ops.so: {e}", file=sys.stderr)
38
+
39
+
40
+ load_library()
41
+
42
+ __all__ = ["__version__", "__git_commit__", "root_dir", "load_library"]
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+
12
+ import importlib.resources
13
+
14
+ __version__ = (
15
+ importlib.resources.files("cuequivariance_ops")
16
+ .joinpath("VERSION")
17
+ .read_text()
18
+ .strip()
19
+ )
20
+ __git_commit__ = "release"
@@ -0,0 +1,98 @@
1
+ /*
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <cstdint>
12
+ #include <cuda_bf16.h>
13
+ #include <cuda_fp16.h>
14
+ #include <vector>
15
+
16
+ namespace kernelcatcher {
17
+
18
+ namespace tensor_product {
19
+ /**
20
+ * @brief a wrapper struct containing informations
21
+ * about tensor-product paths
22
+ */
23
+ template <typename MathT>
24
+ struct __attribute__((aligned(16))) tp_info {
25
+ /** offsets into `path_offsets_and_dims` for each "target" */
26
+ const int32_t* __restrict__ path_csr_offsets{nullptr};
27
+ /** "sources" of all paths and their offsets and dimensions */
28
+ const int32_t* __restrict__ path_offsets_and_dims{nullptr};
29
+ /** clebsch-gordan values for each path */
30
+ const MathT* __restrict__ path_cg_values{nullptr};
31
+ /** number of "target" segments */
32
+ int32_t num_target_segments{0};
33
+ /** number of path (i.e. all paths between segments) */
34
+ int32_t num_paths{0};
35
+ }; // struct tp_info
36
+
37
+ enum class ConnectionModeT : uint8_t {
38
+ kUVW = 0,
39
+ // UVW with U spherical harmonic
40
+ k1VW, // NOLINT
41
+ // UVW with V spherical harmonic
42
+ kU1W, // NOLINT
43
+ kUVU,
44
+ kUVV,
45
+ kUUW,
46
+ kUUU,
47
+ // FullTP, no weight
48
+ kUVUV,
49
+ // FullTP, U spherical harmonic
50
+ k1V1V,
51
+ // FullTP, V spherical harmonic
52
+ kU1U1,
53
+ // Linear
54
+ kUUVV,
55
+ };
56
+ } // namespace tensor_product
57
+
58
+ namespace symmetric_tensor_contraction {
59
+ /**
60
+ * @brief a wrapper struct containing informations
61
+ * about tensor-product paths
62
+ */
63
+ template <typename DataT>
64
+ struct __attribute__((aligned(16))) clebsch_gordan_tensor {
65
+ const DataT* __restrict__ cg_values{nullptr};
66
+ const int16_t* __restrict__ cg_indices{nullptr};
67
+ const int32_t* __restrict__ cg_offsets{nullptr};
68
+ int32_t total_output_irreps{0};
69
+ }; // struct clebsch_gordan_tensor
70
+ } // namespace symmetric_tensor_contraction
71
+
72
+ namespace batch_linear {
73
+ struct __attribute__((aligned(8))) MatrixLayout {
74
+ int32_t size_row; // uncontracted mode
75
+ int32_t size_col; // contracted mode
76
+ };
77
+
78
+ struct __attribute__((aligned(8))) IndexOffset {
79
+ int32_t start;
80
+ int32_t end;
81
+ };
82
+
83
+ enum class GemvModeT : std::uint8_t { kUVV = 0, kUUV = 1 };
84
+ enum class WeightSharedModeT : std::int32_t { kShared = 0, kIndexed = 1, kBatched = 2 };
85
+ /**
86
+ * @brief a wrapper struct containing informations
87
+ * about tensor-product paths
88
+ */
89
+ template <typename DataT>
90
+ struct __attribute__((aligned(16))) batch_linear_info {
91
+ const MatrixLayout* __restrict__ layouts{nullptr};
92
+ const IndexOffset* __restrict__ index_offsets{nullptr};
93
+ const int32_t* __restrict__ indices{nullptr};
94
+ const DataT* __restrict__ alpha{nullptr};
95
+ }; // struct batch_linear_info
96
+
97
+ } // namespace batch_linear
98
+ } // namespace kernelcatcher
@@ -0,0 +1,29 @@
1
+ /*
2
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ namespace kernelcatcher::utils {
12
+
13
+ /**
14
+ * @brief Push a named nvtx range
15
+ * @param name range name
16
+ */
17
+ void push_range(const char* name);
18
+
19
+ /** Pop the latest range */
20
+ void pop_range();
21
+
22
+ struct range_guard {
23
+ range_guard(const char* name) { push_range(name); }
24
+ ~range_guard() { pop_range(); }
25
+ range_guard(range_guard const&) = delete;
26
+ range_guard& operator=(range_guard const&) = delete;
27
+ };
28
+
29
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,15 @@
1
+ /*
2
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ namespace kernelcatcher::utils {
12
+
13
+ enum class BatchDimension : int { kBatched = 0, kShared = 1, kIndexed = 2 };
14
+
15
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,65 @@
1
+ /*
2
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <iostream>
12
+
13
+ namespace kernelcatcher::utils {
14
+
15
+ enum class Datatype : int {
16
+ kFloat32 = 0,
17
+ kFloat64 = 1,
18
+ kFloat16 = 2,
19
+ kBFloat16 = 3,
20
+ kInt32 = 4,
21
+ kInt64 = 5
22
+ };
23
+
24
+ inline int size_of(Datatype dtype)
25
+ {
26
+ switch (dtype) {
27
+ case Datatype::kFloat32: return 4;
28
+ case Datatype::kFloat64: return 8;
29
+ case Datatype::kFloat16: return 2;
30
+ case Datatype::kBFloat16: return 2;
31
+ case Datatype::kInt32: return 4;
32
+ case Datatype::kInt64: return 8;
33
+ default: return -1;
34
+ }
35
+ }
36
+
37
+ inline std::ostream& operator<<(std::ostream& s, Datatype const& d)
38
+ {
39
+ switch (d) {
40
+ case Datatype::kFloat32: return s << "float";
41
+ case Datatype::kFloat64: return s << "double";
42
+ case Datatype::kFloat16: return s << "k_fp16";
43
+ case Datatype::kBFloat16: return s << "k_bf16";
44
+ case Datatype::kInt32: return s << "kc_int32";
45
+ case Datatype::kInt64: return s << "kc_int64";
46
+ }
47
+ return s << "unknown_datatype";
48
+ }
49
+
50
+ inline bool is_real(Datatype const& d)
51
+ {
52
+ switch (d) {
53
+ case Datatype::kFloat32:
54
+ case Datatype::kFloat64:
55
+ case Datatype::kFloat16:
56
+ case Datatype::kBFloat16: return true;
57
+ case Datatype::kInt32:
58
+ case Datatype::kInt64: return false;
59
+ }
60
+ return false; // Default case, should not be reached
61
+ }
62
+
63
+ inline bool is_integral(Datatype const& d) { return !is_real(d); }
64
+
65
+ } // namespace kernelcatcher::utils
@@ -0,0 +1,297 @@
1
+ /*
2
+ * Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "../common/common.hpp"
12
+
13
+ #include <algorithm>
14
+ #include <limits>
15
+
16
+ namespace kernelcatcher::tensor_product {
17
+
18
+ struct __attribute__((aligned(16))) tp_data_sizes {
19
+ int64_t batch_size;
20
+ bool shared_a;
21
+ bool shared_b;
22
+ bool shared_w;
23
+ int32_t stride_a;
24
+ int32_t stride_b;
25
+ int32_t stride_w;
26
+ int32_t stride_o;
27
+ }; // struct tp_data_sizes
28
+
29
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
30
+ void fused_tensor_product_fwd(DataOutT* out,
31
+ const DataAT* in_a,
32
+ const DataBT* in_b,
33
+ const DataWeightT* weight,
34
+ ConnectionModeT mode,
35
+ const tp_info<MathT>& info,
36
+ const tp_data_sizes& sizes,
37
+ cudaStream_t stream);
38
+
39
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
40
+ void fused_tensor_product_bwd(DataAT* grad_in_a,
41
+ DataBT* grad_in_b,
42
+ DataWeightT* grad_weight,
43
+ const DataOutT* grad_out,
44
+ const DataAT* in_a,
45
+ const DataBT* in_b,
46
+ const DataWeightT* weight,
47
+ ConnectionModeT mode,
48
+ const tp_info<MathT>& info_bwd_dgrad_a,
49
+ const tp_info<MathT>& info_bwd_dgrad_b,
50
+ const tp_info<MathT>& info_bwd_dgrad_w,
51
+ const tp_data_sizes& sizes,
52
+ cudaStream_t stream);
53
+
54
+ template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
55
+ void fused_tensor_product_bwd_bwd(DataAT* grad_in_a,
56
+ DataBT* grad_in_b,
57
+ DataWeightT* grad_weight,
58
+ DataOutT* grad_grad_out,
59
+ const DataAT* grad_grad_in_a,
60
+ const DataBT* grad_grad_in_b,
61
+ const DataWeightT* grad_grad_weight,
62
+ const DataOutT* grad_out,
63
+ const DataAT* in_a,
64
+ const DataBT* in_b,
65
+ const DataWeightT* weight,
66
+ ConnectionModeT mode,
67
+ const tp_info<MathT>& info_fwd,
68
+ const tp_info<MathT>& info_bwd_dgrad_a,
69
+ const tp_info<MathT>& info_bwd_dgrad_b,
70
+ const tp_info<MathT>& info_bwd_dgrad_w,
71
+ const tp_data_sizes& sizes,
72
+ cudaStream_t stream);
73
+
74
+ extern template void fused_tensor_product_bwd_bwd<float, float, float, float, float>(
75
+ float*,
76
+ float*,
77
+ float*,
78
+ float*,
79
+ const float*,
80
+ const float*,
81
+ const float*,
82
+ const float*,
83
+ const float*,
84
+ const float*,
85
+ const float*,
86
+ ConnectionModeT,
87
+ const tp_info<float>&,
88
+ const tp_info<float>&,
89
+ const tp_info<float>&,
90
+ const tp_info<float>&,
91
+ const tp_data_sizes&,
92
+ cudaStream_t);
93
+
94
+ extern template void fused_tensor_product_bwd_bwd<float, float, float, float, double>(
95
+ float*,
96
+ float*,
97
+ float*,
98
+ float*,
99
+ const float*,
100
+ const float*,
101
+ const float*,
102
+ const float*,
103
+ const float*,
104
+ const float*,
105
+ const float*,
106
+ ConnectionModeT,
107
+ const tp_info<double>&,
108
+ const tp_info<double>&,
109
+ const tp_info<double>&,
110
+ const tp_info<double>&,
111
+ const tp_data_sizes&,
112
+ cudaStream_t);
113
+
114
+ extern template void fused_tensor_product_bwd_bwd<double, double, double, double, double>(
115
+ double*,
116
+ double*,
117
+ double*,
118
+ double*,
119
+ const double*,
120
+ const double*,
121
+ const double*,
122
+ const double*,
123
+ const double*,
124
+ const double*,
125
+ const double*,
126
+ ConnectionModeT,
127
+ const tp_info<double>&,
128
+ const tp_info<double>&,
129
+ const tp_info<double>&,
130
+ const tp_info<double>&,
131
+ const tp_data_sizes&,
132
+ cudaStream_t);
133
+
134
+ extern template void fused_tensor_product_bwd_bwd<__half, __half, __half, __half, float>(
135
+ __half*,
136
+ __half*,
137
+ __half*,
138
+ __half*,
139
+ const __half*,
140
+ const __half*,
141
+ const __half*,
142
+ const __half*,
143
+ const __half*,
144
+ const __half*,
145
+ const __half*,
146
+ ConnectionModeT,
147
+ const tp_info<float>&,
148
+ const tp_info<float>&,
149
+ const tp_info<float>&,
150
+ const tp_info<float>&,
151
+ const tp_data_sizes&,
152
+ cudaStream_t);
153
+ extern template void
154
+ fused_tensor_product_bwd_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
155
+ __nv_bfloat16*,
156
+ __nv_bfloat16*,
157
+ __nv_bfloat16*,
158
+ __nv_bfloat16*,
159
+ const __nv_bfloat16*,
160
+ const __nv_bfloat16*,
161
+ const __nv_bfloat16*,
162
+ const __nv_bfloat16*,
163
+ const __nv_bfloat16*,
164
+ const __nv_bfloat16*,
165
+ const __nv_bfloat16*,
166
+ ConnectionModeT,
167
+ const tp_info<float>&,
168
+ const tp_info<float>&,
169
+ const tp_info<float>&,
170
+ const tp_info<float>&,
171
+ const tp_data_sizes&,
172
+ cudaStream_t);
173
+
174
+ extern template void fused_tensor_product_bwd<float, float, float, float, float>(
175
+ float*,
176
+ float*,
177
+ float*,
178
+ const float*,
179
+ const float*,
180
+ const float*,
181
+ const float*,
182
+ ConnectionModeT,
183
+ const tp_info<float>&,
184
+ const tp_info<float>&,
185
+ const tp_info<float>&,
186
+ const tp_data_sizes&,
187
+ cudaStream_t);
188
+
189
+ extern template void fused_tensor_product_bwd<float, float, float, float, double>(
190
+ float*,
191
+ float*,
192
+ float*,
193
+ const float*,
194
+ const float*,
195
+ const float*,
196
+ const float*,
197
+ ConnectionModeT,
198
+ const tp_info<double>&,
199
+ const tp_info<double>&,
200
+ const tp_info<double>&,
201
+ const tp_data_sizes&,
202
+ cudaStream_t);
203
+
204
+ extern template void fused_tensor_product_bwd<double, double, double, double, double>(
205
+ double*,
206
+ double*,
207
+ double*,
208
+ const double*,
209
+ const double*,
210
+ const double*,
211
+ const double*,
212
+ ConnectionModeT,
213
+ const tp_info<double>&,
214
+ const tp_info<double>&,
215
+ const tp_info<double>&,
216
+ const tp_data_sizes&,
217
+ cudaStream_t);
218
+
219
+ extern template void fused_tensor_product_bwd<__half, __half, __half, __half, float>(
220
+ __half*,
221
+ __half*,
222
+ __half*,
223
+ const __half*,
224
+ const __half*,
225
+ const __half*,
226
+ const __half*,
227
+ ConnectionModeT,
228
+ const tp_info<float>&,
229
+ const tp_info<float>&,
230
+ const tp_info<float>&,
231
+ const tp_data_sizes&,
232
+ cudaStream_t);
233
+ extern template void
234
+ fused_tensor_product_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
235
+ __nv_bfloat16*,
236
+ __nv_bfloat16*,
237
+ __nv_bfloat16*,
238
+ const __nv_bfloat16*,
239
+ const __nv_bfloat16*,
240
+ const __nv_bfloat16*,
241
+ const __nv_bfloat16*,
242
+ ConnectionModeT,
243
+ const tp_info<float>&,
244
+ const tp_info<float>&,
245
+ const tp_info<float>&,
246
+ const tp_data_sizes&,
247
+ cudaStream_t);
248
+
249
+ extern template void fused_tensor_product_fwd<float, float, float, float, float>(
250
+ float*,
251
+ const float*,
252
+ const float*,
253
+ const float*,
254
+ ConnectionModeT,
255
+ const tp_info<float>&,
256
+ const tp_data_sizes&,
257
+ cudaStream_t);
258
+ extern template void fused_tensor_product_fwd<float, float, float, float, double>(
259
+ float*,
260
+ const float*,
261
+ const float*,
262
+ const float*,
263
+ ConnectionModeT,
264
+ const tp_info<double>&,
265
+ const tp_data_sizes&,
266
+ cudaStream_t);
267
+ extern template void fused_tensor_product_fwd<double, double, double, double, double>(
268
+ double*,
269
+ const double*,
270
+ const double*,
271
+ const double*,
272
+ ConnectionModeT,
273
+ const tp_info<double>&,
274
+ const tp_data_sizes&,
275
+ cudaStream_t);
276
+
277
+ extern template void fused_tensor_product_fwd<__half, __half, __half, __half, float>(
278
+ __half*,
279
+ const __half*,
280
+ const __half*,
281
+ const __half*,
282
+ ConnectionModeT,
283
+ const tp_info<float>&,
284
+ const tp_data_sizes&,
285
+ cudaStream_t);
286
+ extern template void
287
+ fused_tensor_product_fwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
288
+ __nv_bfloat16*,
289
+ const __nv_bfloat16*,
290
+ const __nv_bfloat16*,
291
+ const __nv_bfloat16*,
292
+ ConnectionModeT,
293
+ const tp_info<float>&,
294
+ const tp_data_sizes&,
295
+ cudaStream_t);
296
+
297
+ } // namespace kernelcatcher::tensor_product
@@ -0,0 +1,36 @@
1
+ /*
2
+ * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * This source code and/or documentation ("Licensed Deliverables") are
5
+ * subject to NVIDIA intellectual property rights under U.S. and
6
+ * international Copyright laws.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include "dtypes.hh" // for Datatype
12
+
13
+ namespace kernelcatcher::equivariance::indexed_linear {
14
+ using namespace kernelcatcher::utils; // for Datatype
15
+
16
+ #define KC_INDEXED_LINEAR_DECL_ARGUMENTS \
17
+ const void *ptr_A, const void *ptr_B, const int *counts, void *ptr_C, Datatype dtype_A, \
18
+ Datatype dtype_B, Datatype dtype_D, int Z, int C, int u, int v, double coefficient, \
19
+ Datatype math_dtype, void *workspace, size_t workspace_size, void *stream
20
+
21
+ #define KC_INDEXED_LINEAR_ARGUMENTS \
22
+ ptr_A, ptr_B, counts, ptr_C, dtype_A, dtype_B, dtype_D, Z, C, u, v, coefficient, math_dtype, \
23
+ workspace, workspace_size, stream
24
+
25
+ int run_indexed_linear_B( // ptr_A Zu
26
+ // ptr_B Cuv or Cvu
27
+ // ptr_C Zv
28
+ bool transpose_B,
29
+ KC_INDEXED_LINEAR_DECL_ARGUMENTS);
30
+
31
+ int run_indexed_linear_C( // ptr_A Zu
32
+ // ptr_B Zv
33
+ // ptr_C Cuv
34
+ KC_INDEXED_LINEAR_DECL_ARGUMENTS);
35
+
36
+ } // namespace kernelcatcher::equivariance::indexed_linear