cuequivariance-ops-cu12 0.6.0__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +18 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +259 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +518 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +380 -0
- cuequivariance_ops/triton/pair_bias.py +324 -0
- cuequivariance_ops/triton/tuning_decorator.py +177 -0
- cuequivariance_ops/triton/utils.py +28 -0
- cuequivariance_ops_cu12-0.6.0.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.6.0.dist-info/RECORD +37 -0
- cuequivariance_ops_cu12-0.6.0.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.6.0.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.6.0.dist-info/licenses/Third_party_attr.txt +24 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.6.0
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
from ._version import __version__, __git_commit__
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import ctypes
|
|
15
|
+
|
|
16
|
+
PREFERRED_LOAD_FLAG = ctypes.RTLD_LOCAL
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def root_dir():
|
|
20
|
+
try:
|
|
21
|
+
import importlib.metadata
|
|
22
|
+
|
|
23
|
+
dist = importlib.metadata.distribution("cuequivariance_ops")
|
|
24
|
+
root = dist.locate_file("cuequivariance_ops")
|
|
25
|
+
except Exception:
|
|
26
|
+
# last resort, will fail with writeable install
|
|
27
|
+
root = os.path.dirname(__file__)
|
|
28
|
+
return root
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_library():
|
|
32
|
+
try:
|
|
33
|
+
ctypes.CDLL(
|
|
34
|
+
os.path.join(root_dir(), "lib/libcue_ops.so"), mode=PREFERRED_LOAD_FLAG
|
|
35
|
+
)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"Error while loading libcue_ops.so: {e}", file=sys.stderr)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
load_library()
|
|
41
|
+
|
|
42
|
+
__all__ = ["__version__", "__git_commit__", "root_dir", "load_library"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import importlib.resources
|
|
13
|
+
|
|
14
|
+
__version__ = (
|
|
15
|
+
importlib.resources.files("cuequivariance_ops")
|
|
16
|
+
.joinpath("VERSION")
|
|
17
|
+
.read_text()
|
|
18
|
+
.strip()
|
|
19
|
+
)
|
|
20
|
+
__git_commit__ = "release"
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <cstdint>
|
|
12
|
+
#include <cuda_bf16.h>
|
|
13
|
+
#include <cuda_fp16.h>
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
namespace kernelcatcher {
|
|
17
|
+
|
|
18
|
+
namespace tensor_product {
|
|
19
|
+
/**
|
|
20
|
+
* @brief a wrapper struct containing informations
|
|
21
|
+
* about tensor-product paths
|
|
22
|
+
*/
|
|
23
|
+
template <typename MathT>
|
|
24
|
+
struct __attribute__((aligned(16))) tp_info {
|
|
25
|
+
/** offsets into `path_offsets_and_dims` for each "target" */
|
|
26
|
+
const int32_t* __restrict__ path_csr_offsets{nullptr};
|
|
27
|
+
/** "sources" of all paths and their offsets and dimensions */
|
|
28
|
+
const int32_t* __restrict__ path_offsets_and_dims{nullptr};
|
|
29
|
+
/** clebsch-gordan values for each path */
|
|
30
|
+
const MathT* __restrict__ path_cg_values{nullptr};
|
|
31
|
+
/** number of "target" segments */
|
|
32
|
+
int32_t num_target_segments{0};
|
|
33
|
+
/** number of path (i.e. all paths between segments) */
|
|
34
|
+
int32_t num_paths{0};
|
|
35
|
+
}; // struct tp_info
|
|
36
|
+
|
|
37
|
+
enum class ConnectionModeT : uint8_t {
|
|
38
|
+
kUVW = 0,
|
|
39
|
+
// UVW with U spherical harmonic
|
|
40
|
+
k1VW, // NOLINT
|
|
41
|
+
// UVW with V spherical harmonic
|
|
42
|
+
kU1W, // NOLINT
|
|
43
|
+
kUVU,
|
|
44
|
+
kUVV,
|
|
45
|
+
kUUW,
|
|
46
|
+
kUUU,
|
|
47
|
+
// FullTP, no weight
|
|
48
|
+
kUVUV,
|
|
49
|
+
// FullTP, U spherical harmonic
|
|
50
|
+
k1V1V,
|
|
51
|
+
// FullTP, V spherical harmonic
|
|
52
|
+
kU1U1,
|
|
53
|
+
// Linear
|
|
54
|
+
kUUVV,
|
|
55
|
+
};
|
|
56
|
+
} // namespace tensor_product
|
|
57
|
+
|
|
58
|
+
namespace symmetric_tensor_contraction {
|
|
59
|
+
/**
|
|
60
|
+
* @brief a wrapper struct containing informations
|
|
61
|
+
* about tensor-product paths
|
|
62
|
+
*/
|
|
63
|
+
template <typename DataT>
|
|
64
|
+
struct __attribute__((aligned(16))) clebsch_gordan_tensor {
|
|
65
|
+
const DataT* __restrict__ cg_values{nullptr};
|
|
66
|
+
const int16_t* __restrict__ cg_indices{nullptr};
|
|
67
|
+
const int32_t* __restrict__ cg_offsets{nullptr};
|
|
68
|
+
int32_t total_output_irreps{0};
|
|
69
|
+
}; // struct clebsch_gordan_tensor
|
|
70
|
+
} // namespace symmetric_tensor_contraction
|
|
71
|
+
|
|
72
|
+
namespace batch_linear {
|
|
73
|
+
struct __attribute__((aligned(8))) MatrixLayout {
|
|
74
|
+
int32_t size_row; // uncontracted mode
|
|
75
|
+
int32_t size_col; // contracted mode
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
struct __attribute__((aligned(8))) IndexOffset {
|
|
79
|
+
int32_t start;
|
|
80
|
+
int32_t end;
|
|
81
|
+
};
|
|
82
|
+
|
|
83
|
+
enum class GemvModeT : std::uint8_t { kUVV = 0, kUUV = 1 };
|
|
84
|
+
enum class WeightSharedModeT : std::int32_t { kShared = 0, kIndexed = 1, kBatched = 2 };
|
|
85
|
+
/**
|
|
86
|
+
* @brief a wrapper struct containing informations
|
|
87
|
+
* about tensor-product paths
|
|
88
|
+
*/
|
|
89
|
+
template <typename DataT>
|
|
90
|
+
struct __attribute__((aligned(16))) batch_linear_info {
|
|
91
|
+
const MatrixLayout* __restrict__ layouts{nullptr};
|
|
92
|
+
const IndexOffset* __restrict__ index_offsets{nullptr};
|
|
93
|
+
const int32_t* __restrict__ indices{nullptr};
|
|
94
|
+
const DataT* __restrict__ alpha{nullptr};
|
|
95
|
+
}; // struct batch_linear_info
|
|
96
|
+
|
|
97
|
+
} // namespace batch_linear
|
|
98
|
+
} // namespace kernelcatcher
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace kernelcatcher::utils {
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* @brief Push a named nvtx range
|
|
15
|
+
* @param name range name
|
|
16
|
+
*/
|
|
17
|
+
void push_range(const char* name);
|
|
18
|
+
|
|
19
|
+
/** Pop the latest range */
|
|
20
|
+
void pop_range();
|
|
21
|
+
|
|
22
|
+
struct range_guard {
|
|
23
|
+
range_guard(const char* name) { push_range(name); }
|
|
24
|
+
~range_guard() { pop_range(); }
|
|
25
|
+
range_guard(range_guard const&) = delete;
|
|
26
|
+
range_guard& operator=(range_guard const&) = delete;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace kernelcatcher::utils {
|
|
12
|
+
|
|
13
|
+
enum class BatchDimension : int { kBatched = 0, kShared = 1, kIndexed = 2 };
|
|
14
|
+
|
|
15
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <iostream>
|
|
12
|
+
|
|
13
|
+
namespace kernelcatcher::utils {
|
|
14
|
+
|
|
15
|
+
enum class Datatype : int {
|
|
16
|
+
kFloat32 = 0,
|
|
17
|
+
kFloat64 = 1,
|
|
18
|
+
kFloat16 = 2,
|
|
19
|
+
kBFloat16 = 3,
|
|
20
|
+
kInt32 = 4,
|
|
21
|
+
kInt64 = 5
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
inline int size_of(Datatype dtype)
|
|
25
|
+
{
|
|
26
|
+
switch (dtype) {
|
|
27
|
+
case Datatype::kFloat32: return 4;
|
|
28
|
+
case Datatype::kFloat64: return 8;
|
|
29
|
+
case Datatype::kFloat16: return 2;
|
|
30
|
+
case Datatype::kBFloat16: return 2;
|
|
31
|
+
case Datatype::kInt32: return 4;
|
|
32
|
+
case Datatype::kInt64: return 8;
|
|
33
|
+
default: return -1;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
inline std::ostream& operator<<(std::ostream& s, Datatype const& d)
|
|
38
|
+
{
|
|
39
|
+
switch (d) {
|
|
40
|
+
case Datatype::kFloat32: return s << "float";
|
|
41
|
+
case Datatype::kFloat64: return s << "double";
|
|
42
|
+
case Datatype::kFloat16: return s << "k_fp16";
|
|
43
|
+
case Datatype::kBFloat16: return s << "k_bf16";
|
|
44
|
+
case Datatype::kInt32: return s << "kc_int32";
|
|
45
|
+
case Datatype::kInt64: return s << "kc_int64";
|
|
46
|
+
}
|
|
47
|
+
return s << "unknown_datatype";
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
inline bool is_real(Datatype const& d)
|
|
51
|
+
{
|
|
52
|
+
switch (d) {
|
|
53
|
+
case Datatype::kFloat32:
|
|
54
|
+
case Datatype::kFloat64:
|
|
55
|
+
case Datatype::kFloat16:
|
|
56
|
+
case Datatype::kBFloat16: return true;
|
|
57
|
+
case Datatype::kInt32:
|
|
58
|
+
case Datatype::kInt64: return false;
|
|
59
|
+
}
|
|
60
|
+
return false; // Default case, should not be reached
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
inline bool is_integral(Datatype const& d) { return !is_real(d); }
|
|
64
|
+
|
|
65
|
+
} // namespace kernelcatcher::utils
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "../common/common.hpp"
|
|
12
|
+
|
|
13
|
+
#include <algorithm>
|
|
14
|
+
#include <limits>
|
|
15
|
+
|
|
16
|
+
namespace kernelcatcher::tensor_product {
|
|
17
|
+
|
|
18
|
+
struct __attribute__((aligned(16))) tp_data_sizes {
|
|
19
|
+
int64_t batch_size;
|
|
20
|
+
bool shared_a;
|
|
21
|
+
bool shared_b;
|
|
22
|
+
bool shared_w;
|
|
23
|
+
int32_t stride_a;
|
|
24
|
+
int32_t stride_b;
|
|
25
|
+
int32_t stride_w;
|
|
26
|
+
int32_t stride_o;
|
|
27
|
+
}; // struct tp_data_sizes
|
|
28
|
+
|
|
29
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
30
|
+
void fused_tensor_product_fwd(DataOutT* out,
|
|
31
|
+
const DataAT* in_a,
|
|
32
|
+
const DataBT* in_b,
|
|
33
|
+
const DataWeightT* weight,
|
|
34
|
+
ConnectionModeT mode,
|
|
35
|
+
const tp_info<MathT>& info,
|
|
36
|
+
const tp_data_sizes& sizes,
|
|
37
|
+
cudaStream_t stream);
|
|
38
|
+
|
|
39
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
40
|
+
void fused_tensor_product_bwd(DataAT* grad_in_a,
|
|
41
|
+
DataBT* grad_in_b,
|
|
42
|
+
DataWeightT* grad_weight,
|
|
43
|
+
const DataOutT* grad_out,
|
|
44
|
+
const DataAT* in_a,
|
|
45
|
+
const DataBT* in_b,
|
|
46
|
+
const DataWeightT* weight,
|
|
47
|
+
ConnectionModeT mode,
|
|
48
|
+
const tp_info<MathT>& info_bwd_dgrad_a,
|
|
49
|
+
const tp_info<MathT>& info_bwd_dgrad_b,
|
|
50
|
+
const tp_info<MathT>& info_bwd_dgrad_w,
|
|
51
|
+
const tp_data_sizes& sizes,
|
|
52
|
+
cudaStream_t stream);
|
|
53
|
+
|
|
54
|
+
template <typename DataAT, typename DataBT, typename DataWeightT, typename DataOutT, typename MathT>
|
|
55
|
+
void fused_tensor_product_bwd_bwd(DataAT* grad_in_a,
|
|
56
|
+
DataBT* grad_in_b,
|
|
57
|
+
DataWeightT* grad_weight,
|
|
58
|
+
DataOutT* grad_grad_out,
|
|
59
|
+
const DataAT* grad_grad_in_a,
|
|
60
|
+
const DataBT* grad_grad_in_b,
|
|
61
|
+
const DataWeightT* grad_grad_weight,
|
|
62
|
+
const DataOutT* grad_out,
|
|
63
|
+
const DataAT* in_a,
|
|
64
|
+
const DataBT* in_b,
|
|
65
|
+
const DataWeightT* weight,
|
|
66
|
+
ConnectionModeT mode,
|
|
67
|
+
const tp_info<MathT>& info_fwd,
|
|
68
|
+
const tp_info<MathT>& info_bwd_dgrad_a,
|
|
69
|
+
const tp_info<MathT>& info_bwd_dgrad_b,
|
|
70
|
+
const tp_info<MathT>& info_bwd_dgrad_w,
|
|
71
|
+
const tp_data_sizes& sizes,
|
|
72
|
+
cudaStream_t stream);
|
|
73
|
+
|
|
74
|
+
extern template void fused_tensor_product_bwd_bwd<float, float, float, float, float>(
|
|
75
|
+
float*,
|
|
76
|
+
float*,
|
|
77
|
+
float*,
|
|
78
|
+
float*,
|
|
79
|
+
const float*,
|
|
80
|
+
const float*,
|
|
81
|
+
const float*,
|
|
82
|
+
const float*,
|
|
83
|
+
const float*,
|
|
84
|
+
const float*,
|
|
85
|
+
const float*,
|
|
86
|
+
ConnectionModeT,
|
|
87
|
+
const tp_info<float>&,
|
|
88
|
+
const tp_info<float>&,
|
|
89
|
+
const tp_info<float>&,
|
|
90
|
+
const tp_info<float>&,
|
|
91
|
+
const tp_data_sizes&,
|
|
92
|
+
cudaStream_t);
|
|
93
|
+
|
|
94
|
+
extern template void fused_tensor_product_bwd_bwd<float, float, float, float, double>(
|
|
95
|
+
float*,
|
|
96
|
+
float*,
|
|
97
|
+
float*,
|
|
98
|
+
float*,
|
|
99
|
+
const float*,
|
|
100
|
+
const float*,
|
|
101
|
+
const float*,
|
|
102
|
+
const float*,
|
|
103
|
+
const float*,
|
|
104
|
+
const float*,
|
|
105
|
+
const float*,
|
|
106
|
+
ConnectionModeT,
|
|
107
|
+
const tp_info<double>&,
|
|
108
|
+
const tp_info<double>&,
|
|
109
|
+
const tp_info<double>&,
|
|
110
|
+
const tp_info<double>&,
|
|
111
|
+
const tp_data_sizes&,
|
|
112
|
+
cudaStream_t);
|
|
113
|
+
|
|
114
|
+
extern template void fused_tensor_product_bwd_bwd<double, double, double, double, double>(
|
|
115
|
+
double*,
|
|
116
|
+
double*,
|
|
117
|
+
double*,
|
|
118
|
+
double*,
|
|
119
|
+
const double*,
|
|
120
|
+
const double*,
|
|
121
|
+
const double*,
|
|
122
|
+
const double*,
|
|
123
|
+
const double*,
|
|
124
|
+
const double*,
|
|
125
|
+
const double*,
|
|
126
|
+
ConnectionModeT,
|
|
127
|
+
const tp_info<double>&,
|
|
128
|
+
const tp_info<double>&,
|
|
129
|
+
const tp_info<double>&,
|
|
130
|
+
const tp_info<double>&,
|
|
131
|
+
const tp_data_sizes&,
|
|
132
|
+
cudaStream_t);
|
|
133
|
+
|
|
134
|
+
extern template void fused_tensor_product_bwd_bwd<__half, __half, __half, __half, float>(
|
|
135
|
+
__half*,
|
|
136
|
+
__half*,
|
|
137
|
+
__half*,
|
|
138
|
+
__half*,
|
|
139
|
+
const __half*,
|
|
140
|
+
const __half*,
|
|
141
|
+
const __half*,
|
|
142
|
+
const __half*,
|
|
143
|
+
const __half*,
|
|
144
|
+
const __half*,
|
|
145
|
+
const __half*,
|
|
146
|
+
ConnectionModeT,
|
|
147
|
+
const tp_info<float>&,
|
|
148
|
+
const tp_info<float>&,
|
|
149
|
+
const tp_info<float>&,
|
|
150
|
+
const tp_info<float>&,
|
|
151
|
+
const tp_data_sizes&,
|
|
152
|
+
cudaStream_t);
|
|
153
|
+
extern template void
|
|
154
|
+
fused_tensor_product_bwd_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
155
|
+
__nv_bfloat16*,
|
|
156
|
+
__nv_bfloat16*,
|
|
157
|
+
__nv_bfloat16*,
|
|
158
|
+
__nv_bfloat16*,
|
|
159
|
+
const __nv_bfloat16*,
|
|
160
|
+
const __nv_bfloat16*,
|
|
161
|
+
const __nv_bfloat16*,
|
|
162
|
+
const __nv_bfloat16*,
|
|
163
|
+
const __nv_bfloat16*,
|
|
164
|
+
const __nv_bfloat16*,
|
|
165
|
+
const __nv_bfloat16*,
|
|
166
|
+
ConnectionModeT,
|
|
167
|
+
const tp_info<float>&,
|
|
168
|
+
const tp_info<float>&,
|
|
169
|
+
const tp_info<float>&,
|
|
170
|
+
const tp_info<float>&,
|
|
171
|
+
const tp_data_sizes&,
|
|
172
|
+
cudaStream_t);
|
|
173
|
+
|
|
174
|
+
extern template void fused_tensor_product_bwd<float, float, float, float, float>(
|
|
175
|
+
float*,
|
|
176
|
+
float*,
|
|
177
|
+
float*,
|
|
178
|
+
const float*,
|
|
179
|
+
const float*,
|
|
180
|
+
const float*,
|
|
181
|
+
const float*,
|
|
182
|
+
ConnectionModeT,
|
|
183
|
+
const tp_info<float>&,
|
|
184
|
+
const tp_info<float>&,
|
|
185
|
+
const tp_info<float>&,
|
|
186
|
+
const tp_data_sizes&,
|
|
187
|
+
cudaStream_t);
|
|
188
|
+
|
|
189
|
+
extern template void fused_tensor_product_bwd<float, float, float, float, double>(
|
|
190
|
+
float*,
|
|
191
|
+
float*,
|
|
192
|
+
float*,
|
|
193
|
+
const float*,
|
|
194
|
+
const float*,
|
|
195
|
+
const float*,
|
|
196
|
+
const float*,
|
|
197
|
+
ConnectionModeT,
|
|
198
|
+
const tp_info<double>&,
|
|
199
|
+
const tp_info<double>&,
|
|
200
|
+
const tp_info<double>&,
|
|
201
|
+
const tp_data_sizes&,
|
|
202
|
+
cudaStream_t);
|
|
203
|
+
|
|
204
|
+
extern template void fused_tensor_product_bwd<double, double, double, double, double>(
|
|
205
|
+
double*,
|
|
206
|
+
double*,
|
|
207
|
+
double*,
|
|
208
|
+
const double*,
|
|
209
|
+
const double*,
|
|
210
|
+
const double*,
|
|
211
|
+
const double*,
|
|
212
|
+
ConnectionModeT,
|
|
213
|
+
const tp_info<double>&,
|
|
214
|
+
const tp_info<double>&,
|
|
215
|
+
const tp_info<double>&,
|
|
216
|
+
const tp_data_sizes&,
|
|
217
|
+
cudaStream_t);
|
|
218
|
+
|
|
219
|
+
extern template void fused_tensor_product_bwd<__half, __half, __half, __half, float>(
|
|
220
|
+
__half*,
|
|
221
|
+
__half*,
|
|
222
|
+
__half*,
|
|
223
|
+
const __half*,
|
|
224
|
+
const __half*,
|
|
225
|
+
const __half*,
|
|
226
|
+
const __half*,
|
|
227
|
+
ConnectionModeT,
|
|
228
|
+
const tp_info<float>&,
|
|
229
|
+
const tp_info<float>&,
|
|
230
|
+
const tp_info<float>&,
|
|
231
|
+
const tp_data_sizes&,
|
|
232
|
+
cudaStream_t);
|
|
233
|
+
extern template void
|
|
234
|
+
fused_tensor_product_bwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
235
|
+
__nv_bfloat16*,
|
|
236
|
+
__nv_bfloat16*,
|
|
237
|
+
__nv_bfloat16*,
|
|
238
|
+
const __nv_bfloat16*,
|
|
239
|
+
const __nv_bfloat16*,
|
|
240
|
+
const __nv_bfloat16*,
|
|
241
|
+
const __nv_bfloat16*,
|
|
242
|
+
ConnectionModeT,
|
|
243
|
+
const tp_info<float>&,
|
|
244
|
+
const tp_info<float>&,
|
|
245
|
+
const tp_info<float>&,
|
|
246
|
+
const tp_data_sizes&,
|
|
247
|
+
cudaStream_t);
|
|
248
|
+
|
|
249
|
+
extern template void fused_tensor_product_fwd<float, float, float, float, float>(
|
|
250
|
+
float*,
|
|
251
|
+
const float*,
|
|
252
|
+
const float*,
|
|
253
|
+
const float*,
|
|
254
|
+
ConnectionModeT,
|
|
255
|
+
const tp_info<float>&,
|
|
256
|
+
const tp_data_sizes&,
|
|
257
|
+
cudaStream_t);
|
|
258
|
+
extern template void fused_tensor_product_fwd<float, float, float, float, double>(
|
|
259
|
+
float*,
|
|
260
|
+
const float*,
|
|
261
|
+
const float*,
|
|
262
|
+
const float*,
|
|
263
|
+
ConnectionModeT,
|
|
264
|
+
const tp_info<double>&,
|
|
265
|
+
const tp_data_sizes&,
|
|
266
|
+
cudaStream_t);
|
|
267
|
+
extern template void fused_tensor_product_fwd<double, double, double, double, double>(
|
|
268
|
+
double*,
|
|
269
|
+
const double*,
|
|
270
|
+
const double*,
|
|
271
|
+
const double*,
|
|
272
|
+
ConnectionModeT,
|
|
273
|
+
const tp_info<double>&,
|
|
274
|
+
const tp_data_sizes&,
|
|
275
|
+
cudaStream_t);
|
|
276
|
+
|
|
277
|
+
extern template void fused_tensor_product_fwd<__half, __half, __half, __half, float>(
|
|
278
|
+
__half*,
|
|
279
|
+
const __half*,
|
|
280
|
+
const __half*,
|
|
281
|
+
const __half*,
|
|
282
|
+
ConnectionModeT,
|
|
283
|
+
const tp_info<float>&,
|
|
284
|
+
const tp_data_sizes&,
|
|
285
|
+
cudaStream_t);
|
|
286
|
+
extern template void
|
|
287
|
+
fused_tensor_product_fwd<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, float>(
|
|
288
|
+
__nv_bfloat16*,
|
|
289
|
+
const __nv_bfloat16*,
|
|
290
|
+
const __nv_bfloat16*,
|
|
291
|
+
const __nv_bfloat16*,
|
|
292
|
+
ConnectionModeT,
|
|
293
|
+
const tp_info<float>&,
|
|
294
|
+
const tp_data_sizes&,
|
|
295
|
+
cudaStream_t);
|
|
296
|
+
|
|
297
|
+
} // namespace kernelcatcher::tensor_product
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* This source code and/or documentation ("Licensed Deliverables") are
|
|
5
|
+
* subject to NVIDIA intellectual property rights under U.S. and
|
|
6
|
+
* international Copyright laws.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "dtypes.hh" // for Datatype
|
|
12
|
+
|
|
13
|
+
namespace kernelcatcher::equivariance::indexed_linear {
|
|
14
|
+
using namespace kernelcatcher::utils; // for Datatype
|
|
15
|
+
|
|
16
|
+
#define KC_INDEXED_LINEAR_DECL_ARGUMENTS \
|
|
17
|
+
const void *ptr_A, const void *ptr_B, const int *counts, void *ptr_C, Datatype dtype_A, \
|
|
18
|
+
Datatype dtype_B, Datatype dtype_D, int Z, int C, int u, int v, double coefficient, \
|
|
19
|
+
Datatype math_dtype, void *workspace, size_t workspace_size, void *stream
|
|
20
|
+
|
|
21
|
+
#define KC_INDEXED_LINEAR_ARGUMENTS \
|
|
22
|
+
ptr_A, ptr_B, counts, ptr_C, dtype_A, dtype_B, dtype_D, Z, C, u, v, coefficient, math_dtype, \
|
|
23
|
+
workspace, workspace_size, stream
|
|
24
|
+
|
|
25
|
+
int run_indexed_linear_B( // ptr_A Zu
|
|
26
|
+
// ptr_B Cuv or Cvu
|
|
27
|
+
// ptr_C Zv
|
|
28
|
+
bool transpose_B,
|
|
29
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
30
|
+
|
|
31
|
+
int run_indexed_linear_C( // ptr_A Zu
|
|
32
|
+
// ptr_B Zv
|
|
33
|
+
// ptr_C Cuv
|
|
34
|
+
KC_INDEXED_LINEAR_DECL_ARGUMENTS);
|
|
35
|
+
|
|
36
|
+
} // namespace kernelcatcher::equivariance::indexed_linear
|