morphottention 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphottention-0.1.0/CMakeLists.txt +113 -0
- morphottention-0.1.0/PKG-INFO +62 -0
- morphottention-0.1.0/README.md +27 -0
- morphottention-0.1.0/csrc/cuda/attention/attention.cpp +71 -0
- morphottention-0.1.0/csrc/cuda/attention/attention.cu +235 -0
- morphottention-0.1.0/csrc/cuda/attention/attention.cuh +36 -0
- morphottention-0.1.0/csrc/cuda/binder.cpp +12 -0
- morphottention-0.1.0/csrc/cuda/dispatch.cpp +13 -0
- morphottention-0.1.0/csrc/cuda/dispatch.h +22 -0
- morphottention-0.1.0/csrc/cuda/morfology/cube.cuh +81 -0
- morphottention-0.1.0/csrc/cuda/morfology/soft_morph.cuh +81 -0
- morphottention-0.1.0/csrc/cuda/sm120/matmul.cuh +58 -0
- morphottention-0.1.0/csrc/cuda/sm120/project.cuh +51 -0
- morphottention-0.1.0/csrc/cuda/sm120/smem.cuh +35 -0
- morphottention-0.1.0/csrc/cuda/utils/declarations.cuh +91 -0
- morphottention-0.1.0/csrc/cuda/utils/reductions.cuh +80 -0
- morphottention-0.1.0/csrc/cuda/utils/smem.cuh +26 -0
- morphottention-0.1.0/csrc/cuda/utils/utils.cuh +29 -0
- morphottention-0.1.0/pyproject.toml +80 -0
- morphottention-0.1.0/src/morphottention/_C.pyi +14 -0
- morphottention-0.1.0/src/morphottention/__init__.py +3 -0
- morphottention-0.1.0/src/morphottention/autograd.py +123 -0
- morphottention-0.1.0/src/morphottention/py.typed +0 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.24)
|
|
2
|
+
|
|
3
|
+
if(NOT DEFINED CMAKE_CUDA_COMPILER)
|
|
4
|
+
if(DEFINED ENV{CUDA_HOME})
|
|
5
|
+
set(CMAKE_CUDA_COMPILER "$ENV{CUDA_HOME}/bin/nvcc" CACHE FILEPATH "" FORCE)
|
|
6
|
+
else()
|
|
7
|
+
find_program(_NVCC nvcc)
|
|
8
|
+
if(_NVCC)
|
|
9
|
+
set(CMAKE_CUDA_COMPILER "${_NVCC}" CACHE FILEPATH "" FORCE)
|
|
10
|
+
endif()
|
|
11
|
+
endif()
|
|
12
|
+
endif()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
project(morphottention LANGUAGES CXX CUDA)
|
|
16
|
+
|
|
17
|
+
# 23 host / 20 device
|
|
18
|
+
set(CMAKE_CXX_STANDARD 23)
|
|
19
|
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
|
20
|
+
set(CMAKE_CUDA_STANDARD 20)
|
|
21
|
+
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
|
22
|
+
|
|
23
|
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|
24
|
+
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF)
|
|
25
|
+
|
|
26
|
+
set(CMAKE_CUDA_ARCHITECTURES 90 100 120)
|
|
27
|
+
set(TORCH_CUDA_ARCH_LIST "9.0;10.0;12.0")
|
|
28
|
+
|
|
29
|
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
30
|
+
|
|
31
|
+
set(Python_FIND_VIRTUALENV FIRST)
|
|
32
|
+
|
|
33
|
+
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
|
|
34
|
+
|
|
35
|
+
execute_process(
|
|
36
|
+
COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
|
|
37
|
+
OUTPUT_VARIABLE TORCH_CMAKE_PREFIX
|
|
38
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
39
|
+
)
|
|
40
|
+
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX}")
|
|
41
|
+
find_package(Torch REQUIRED CONFIG)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if(WIN32)
|
|
45
|
+
set(_TORCH_PYTHON_LIB_NAME "torch_python.lib")
|
|
46
|
+
else()
|
|
47
|
+
set(_TORCH_PYTHON_LIB_NAME "libtorch_python.so")
|
|
48
|
+
endif()
|
|
49
|
+
|
|
50
|
+
execute_process(
|
|
51
|
+
COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib' / '${_TORCH_PYTHON_LIB_NAME}')"
|
|
52
|
+
OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY
|
|
53
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
54
|
+
)
|
|
55
|
+
if(NOT EXISTS "${TORCH_PYTHON_LIBRARY}")
|
|
56
|
+
message(FATAL_ERROR "Could not locate ${_TORCH_PYTHON_LIB_NAME} at: ${TORCH_PYTHON_LIBRARY}")
|
|
57
|
+
endif()
|
|
58
|
+
get_filename_component(TORCH_LIB_DIR "${TORCH_PYTHON_LIBRARY}" DIRECTORY)
|
|
59
|
+
|
|
60
|
+
find_package(CUDAToolkit REQUIRED)
|
|
61
|
+
|
|
62
|
+
Python_add_library(_C MODULE WITH_SOABI
|
|
63
|
+
csrc/cuda/binder.cpp
|
|
64
|
+
csrc/cuda/dispatch.cpp
|
|
65
|
+
|
|
66
|
+
csrc/cuda/attention/attention.cpp
|
|
67
|
+
csrc/cuda/attention/attention.cu
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# -lineinfo on Debug; full device debug (-G) only when explicitly sanitizing.
|
|
71
|
+
option(CUDA_SANITIZE "Build device code with -G for compute-sanitizer" OFF)
|
|
72
|
+
target_compile_options(_C PRIVATE
|
|
73
|
+
$<$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>:--threads=0>
|
|
74
|
+
$<$<AND:$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>,$<CONFIG:Debug>>:-lineinfo>
|
|
75
|
+
$<$<AND:$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>,$<CONFIG:Debug>,$<BOOL:${CUDA_SANITIZE}>>:-G>
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if(MSVC)
|
|
80
|
+
target_compile_options(_C PRIVATE
|
|
81
|
+
$<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>
|
|
82
|
+
$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/Zc:preprocessor>
|
|
83
|
+
)
|
|
84
|
+
endif()
|
|
85
|
+
|
|
86
|
+
target_include_directories(_C PRIVATE
|
|
87
|
+
${CMAKE_CURRENT_SOURCE_DIR}/csrc
|
|
88
|
+
${TORCH_INCLUDE_DIRS}
|
|
89
|
+
${Python_INCLUDE_DIRS}
|
|
90
|
+
${CUDAToolkit_INCLUDE_DIRS}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
target_compile_definitions(_C PRIVATE
|
|
94
|
+
TORCH_EXTENSION_NAME=_C
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
target_link_libraries(_C PRIVATE
|
|
98
|
+
${TORCH_LIBRARIES}
|
|
99
|
+
${TORCH_PYTHON_LIBRARY}
|
|
100
|
+
Python::Module
|
|
101
|
+
CUDA::cudart
|
|
102
|
+
CUDA::cublas
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if(NOT WIN32)
|
|
106
|
+
target_link_options(_C PRIVATE "-Wl,--no-as-needed")
|
|
107
|
+
set_target_properties(_C PROPERTIES
|
|
108
|
+
BUILD_RPATH "${TORCH_LIB_DIR}"
|
|
109
|
+
INSTALL_RPATH "$ORIGIN/../torch/lib"
|
|
110
|
+
)
|
|
111
|
+
endif()
|
|
112
|
+
|
|
113
|
+
install(TARGETS _C LIBRARY DESTINATION morphottention)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: morphottention
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: CUDA attention & morphology kernels for PyTorch (sm_90/100/120)
|
|
5
|
+
Keywords: attention,cuda,pytorch,transformer,morfologfy,flash-attention,ViT
|
|
6
|
+
Author-Email: Vedran Hrabar <vedran.hrabar@outlook.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
9
|
+
Classifier: Environment :: GPU
|
|
10
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
11
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA :: 13
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
16
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
17
|
+
Classifier: Programming Language :: C++
|
|
18
|
+
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
24
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
26
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
27
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
28
|
+
Classifier: Typing :: Typed
|
|
29
|
+
Project-URL: repository, https://github.com/vhrabar/morphottention
|
|
30
|
+
Project-URL: documentation, https://github.com/vhrabar/morphottention/wiki
|
|
31
|
+
Project-URL: Bug Tracker, https://github.com/vhrabar/morphottention/issues
|
|
32
|
+
Requires-Python: ==3.14.*
|
|
33
|
+
Requires-Dist: torch>=2.12
|
|
34
|
+
Description-Content-Type: text/markdown
|
|
35
|
+
|
|
36
|
+
# Morphottention
|
|
37
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
|
|
41
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
42
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
43
|
+
be installed in the environment.
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
pip install morphottention
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Building from source
|
|
50
|
+
|
|
51
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
uv sync --package morphottention --no-dev --group build
|
|
55
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## License
|
|
59
|
+
|
|
60
|
+
MIT
|
|
61
|
+
|
|
62
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Morphottention
|
|
2
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
3
|
+
|
|
4
|
+
## Install
|
|
5
|
+
|
|
6
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
7
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
8
|
+
be installed in the environment.
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install morphottention
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Building from source
|
|
15
|
+
|
|
16
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
uv sync --package morphottention --no-dev --group build
|
|
20
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## License
|
|
24
|
+
|
|
25
|
+
MIT
|
|
26
|
+
|
|
27
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
#include "attention.cuh"
|
|
2
|
+
|
|
3
|
+
#include <cuda/dispatch.h>
|
|
4
|
+
|
|
5
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
6
|
+
#include <c10/cuda/CUDAGuard.h>
|
|
7
|
+
#include <torch/extension.h>
|
|
8
|
+
|
|
9
|
+
auto check = [](const torch::Tensor& t, const char* name) {
|
|
10
|
+
TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
|
|
11
|
+
TORCH_CHECK(t.scalar_type() == torch::kHalf, name, " must be float16");
|
|
12
|
+
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
|
|
16
|
+
const torch::Tensor& gate_q, const torch::Tensor& gate_k,
|
|
17
|
+
const torch::Tensor& W_V, const int64_t H, const int64_t cube_m,
|
|
18
|
+
const double scale, const bool causal) {
|
|
19
|
+
// checkers
|
|
20
|
+
check(X, "X");
|
|
21
|
+
check(W_phi, "W_phi");
|
|
22
|
+
check(gate_q, "gate_q");
|
|
23
|
+
check(gate_k, "gate_k");
|
|
24
|
+
check(W_V, "W_V");
|
|
25
|
+
|
|
26
|
+
// shape managment
|
|
27
|
+
TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
|
|
28
|
+
const int64_t B = X.size(0);
|
|
29
|
+
const int64_t N = X.size(1);
|
|
30
|
+
const int64_t D = X.size(2);
|
|
31
|
+
|
|
32
|
+
TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
|
|
33
|
+
const int64_t head_dim_v = D / H;
|
|
34
|
+
|
|
35
|
+
TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
|
|
36
|
+
TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
|
|
37
|
+
TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
|
|
38
|
+
TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
|
|
39
|
+
|
|
40
|
+
auto out = torch::empty_like(X);
|
|
41
|
+
auto lse = torch::empty({B * H, N}, X.options().dtype(torch::kFloat32));
|
|
42
|
+
|
|
43
|
+
// launcher
|
|
44
|
+
const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
|
|
45
|
+
|
|
46
|
+
attention_forward_kernel_launcher(reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
|
|
47
|
+
reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
|
|
48
|
+
reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
|
|
49
|
+
reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
|
|
50
|
+
reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
|
|
51
|
+
reinterpret_cast<__half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
|
|
52
|
+
static_cast<int>(B), static_cast<int>(N), static_cast<int>(D),
|
|
53
|
+
static_cast<int>(H), static_cast<int>(cube_m), static_cast<int>(head_dim_v),
|
|
54
|
+
static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
|
|
55
|
+
|
|
56
|
+
return {out, lse};
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
std::vector<torch::Tensor> morpho_backward(const torch::Tensor& X, const torch::Tensor& grad_out) {
|
|
60
|
+
TORCH_CHECK(X.is_cuda(), "Input tensor must be a CUDA tensor");
|
|
61
|
+
TORCH_CHECK(grad_out.is_cuda(), "Gradient output tensor must be a CUDA tensor");
|
|
62
|
+
|
|
63
|
+
auto X_contig = X.contiguous();
|
|
64
|
+
auto grad_out_contig = grad_out.contiguous();
|
|
65
|
+
|
|
66
|
+
const int B = static_cast<int>(X_contig.size(0));
|
|
67
|
+
const int N = static_cast<int>(X_contig.size(1));
|
|
68
|
+
const int D = static_cast<int>(X_contig.size(2));
|
|
69
|
+
|
|
70
|
+
return {X};
|
|
71
|
+
}
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
#include "attention.cuh"
|
|
2
|
+
|
|
3
|
+
#include <cuda_pipeline.h>
|
|
4
|
+
|
|
5
|
+
#include <cfloat>
|
|
6
|
+
|
|
7
|
+
template <int HEAD_DIM_V, int CUBE_M, int BR, int BC>
|
|
8
|
+
__global__ void morpho_attention_forward_kernel(const __half* __restrict__ X, const __half* __restrict__ W_phi,
|
|
9
|
+
const __half* __restrict__ gate_q, const __half* __restrict__ gate_k,
|
|
10
|
+
const __half* __restrict__ W_V, __half* __restrict__ out, float* lse,
|
|
11
|
+
int B, int N, int D, int H, float scale) {
|
|
12
|
+
|
|
13
|
+
static_assert(HEAD_DIM_V <= BR, "HEAD_DIM_V must be <= BR");
|
|
14
|
+
static_assert(HEAD_DIM_V <= BC, "HEAD_DIM_V must be <= BC");
|
|
15
|
+
|
|
16
|
+
// NOLINTNEXTLINE(readability-suspicious-call-argument)
|
|
17
|
+
auto [t, b, tid] = get_coords();
|
|
18
|
+
const unsigned int warp = tid / 32;
|
|
19
|
+
const unsigned int lane = tid & 31;
|
|
20
|
+
|
|
21
|
+
const unsigned int bh = t;
|
|
22
|
+
const unsigned int batch = bh / static_cast<unsigned int>(H);
|
|
23
|
+
const unsigned int head = bh % static_cast<unsigned int>(H);
|
|
24
|
+
|
|
25
|
+
const int q_row_start = static_cast<int>(b) * BR;
|
|
26
|
+
const int q_row_end = min(q_row_start + BR, N);
|
|
27
|
+
const int q_rows = q_row_end - q_row_start;
|
|
28
|
+
if (q_rows <= 0) {
|
|
29
|
+
return;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
const int ldphi = H * CUBE_M;
|
|
33
|
+
const int ldv = H * HEAD_DIM_V;
|
|
34
|
+
|
|
35
|
+
const __half* X_b = X + batch * N * D;
|
|
36
|
+
const __half* W_phi_h = W_phi + head * CUBE_M;
|
|
37
|
+
const __half* W_V_h = W_V + head * HEAD_DIM_V;
|
|
38
|
+
const __half* gate_q_h = gate_q + head * CUBE_M;
|
|
39
|
+
const __half* gate_k_h = gate_k + head * CUBE_M;
|
|
40
|
+
__half* out_s = out + batch * N * D + head * HEAD_DIM_V;
|
|
41
|
+
float* lse_s = lse + bh * N;
|
|
42
|
+
|
|
43
|
+
// smem carve
|
|
44
|
+
constexpr int XT_ROWS = (BR > BC) ? BR : BC;
|
|
45
|
+
extern __shared__ __align__(16) unsigned char smem[];
|
|
46
|
+
auto* q_mem = reinterpret_cast<__half*>(smem); // q codes [BR, CUBE_M]
|
|
47
|
+
auto* k_mem = q_mem + BR * CUBE_M; // k codes [BC, CUBE_M]
|
|
48
|
+
auto* v_mem = k_mem + BC * CUBE_M; // V tile [BC, HEAD_DIM_V]
|
|
49
|
+
auto* xt_mem = v_mem + BC * HEAD_DIM_V; // token staging [max(BR,BC), D]
|
|
50
|
+
|
|
51
|
+
auto* s_mem = reinterpret_cast<float*>(xt_mem + XT_ROWS * D); // S/P fp32 tile [BR, BC]
|
|
52
|
+
auto* p_h_mem = reinterpret_cast<__half*>(s_mem + BR * BC); // P fp16 tile [BR, BC]
|
|
53
|
+
auto* state_mem = reinterpret_cast<float*>(p_h_mem + BR * BC);
|
|
54
|
+
|
|
55
|
+
float* max_mem = state_mem; // [BR]
|
|
56
|
+
float* lse_mem = state_mem + BR; // [BR]
|
|
57
|
+
float* corr_mem = state_mem + 2 * BR; // [BR]
|
|
58
|
+
float* qbias_mem = state_mem + 3 * BR; // [BR]
|
|
59
|
+
float* cbias_mem = state_mem + 4 * BR; // [BC]
|
|
60
|
+
float* o_acc = cbias_mem + BC; // [BR, HEAD_DIM_V]
|
|
61
|
+
float* pv_mem = s_mem; // PV scratch
|
|
62
|
+
|
|
63
|
+
// runing vars in SMEM
|
|
64
|
+
for (unsigned int i = tid; i < BR * HEAD_DIM_V; i += blockDim.x) {
|
|
65
|
+
o_acc[i] = 0.0f;
|
|
66
|
+
}
|
|
67
|
+
for (unsigned int row = tid; row < BR; row += blockDim.x) {
|
|
68
|
+
max_mem[row] = -FLT_MAX;
|
|
69
|
+
lse_mem[row] = 0.0f;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// move X to SRAM
|
|
73
|
+
arch::impl::smem_load(xt_mem, X_b + q_row_start * D, static_cast<unsigned int>(q_rows),
|
|
74
|
+
static_cast<unsigned int>(D), tid, BR);
|
|
75
|
+
__syncthreads();
|
|
76
|
+
// project Q to the gated unit-hypercube
|
|
77
|
+
// q = gate_q @ sigma(W_phi.t, x_q) -> q_mem, q_bias = gate_k x q_i
|
|
78
|
+
arch::impl::project_phi<BR, CUBE_M, WARPS>(xt_mem, W_phi_h, gate_q_h, gate_k_h, q_mem, qbias_mem, s_mem, D, ldphi,
|
|
79
|
+
warp, lane);
|
|
80
|
+
|
|
81
|
+
const float scale_log2 = scale * LOG2E;
|
|
82
|
+
|
|
83
|
+
// loop over KV vlocks
|
|
84
|
+
unsigned int n_tiles = (static_cast<unsigned int>(N) + BC - 1) / BC;
|
|
85
|
+
const int q_global_max = q_row_start + q_rows - 1;
|
|
86
|
+
if (CAUSAL) {
|
|
87
|
+
// full mask after last row-q
|
|
88
|
+
n_tiles = min(n_tiles, static_cast<unsigned int>(q_global_max) / BC + 1u);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
for (unsigned int tile = 0; tile < n_tiles; tile++) {
|
|
92
|
+
const int kv_block_start = static_cast<int>(tile) * BC;
|
|
93
|
+
const int kv_valid = min(BC, N - kv_block_start);
|
|
94
|
+
|
|
95
|
+
// stage K/V tokens
|
|
96
|
+
arch::impl::smem_load(xt_mem, X_b + kv_block_start * D, static_cast<unsigned int>(kv_valid),
|
|
97
|
+
static_cast<unsigned int>(D), tid, BC);
|
|
98
|
+
__syncthreads();
|
|
99
|
+
|
|
100
|
+
// project K to the gated hypercube
|
|
101
|
+
// k = gate_k @ sigma(W_phi.t, x_j) -> k_mem, c_bias = gate_q x k_j
|
|
102
|
+
arch::impl::project_phi<BC, CUBE_M, WARPS>(xt_mem, W_phi_h, gate_k_h, gate_q_h, k_mem, cbias_mem, s_mem, D,
|
|
103
|
+
ldphi, warp, lane);
|
|
104
|
+
|
|
105
|
+
// v = W_v x x_j -> s_mem -> v_curr
|
|
106
|
+
arch::impl::matmul<BC, HEAD_DIM_V, WARPS, false>(s_mem, xt_mem, W_V_h, D, D, ldv, HEAD_DIM_V, 1.0f);
|
|
107
|
+
__syncthreads();
|
|
108
|
+
for (unsigned int i = tid; i < BC * HEAD_DIM_V; i += blockDim.x) {
|
|
109
|
+
v_mem[i] = __float2half(s_mem[i]);
|
|
110
|
+
}
|
|
111
|
+
__syncthreads();
|
|
112
|
+
|
|
113
|
+
// S_raw = QK.T -> s_mem
|
|
114
|
+
arch::impl::matmul<BR, BC, WARPS, true>(s_mem, q_mem, k_mem, CUBE_M, CUBE_M, CUBE_M, BC, 1.0f);
|
|
115
|
+
__syncthreads();
|
|
116
|
+
|
|
117
|
+
// S = (2S_raw − qbias[i] − cbias[j])·scale_log2e
|
|
118
|
+
// symetry -> -inf mask
|
|
119
|
+
for (unsigned int row = 0; row < BR / WARPS; row++) {
|
|
120
|
+
const unsigned int row_cor = row + warp * (BR / WARPS);
|
|
121
|
+
const int q_global = q_row_start + static_cast<int>(row_cor);
|
|
122
|
+
|
|
123
|
+
float tile_max = -FLT_MAX;
|
|
124
|
+
for (unsigned int c = lane; c < BC; c += 32) {
|
|
125
|
+
float s = (2.0f * s_mem[row_cor * BC + c] - qbias_mem[row_cor] - cbias_mem[c]) * scale_log2;
|
|
126
|
+
const int k_global = kv_block_start + static_cast<int>(c);
|
|
127
|
+
|
|
128
|
+
bool drop = (static_cast<int>(c) >= kv_valid) || (static_cast<int>(row_cor) >= q_rows);
|
|
129
|
+
if (CAUSAL && k_global > q_global)
|
|
130
|
+
drop = true;
|
|
131
|
+
if (MASK_DIAG && k_global == q_global)
|
|
132
|
+
drop = true;
|
|
133
|
+
|
|
134
|
+
s = drop ? -FLT_MAX : s;
|
|
135
|
+
s_mem[row_cor * BC + c] = s;
|
|
136
|
+
tile_max = fmaxf(tile_max, s);
|
|
137
|
+
}
|
|
138
|
+
tile_max = warpMax(tile_max);
|
|
139
|
+
|
|
140
|
+
const float max_prev = max_mem[row_cor];
|
|
141
|
+
const float lse_prev = lse_mem[row_cor];
|
|
142
|
+
const float max_new = fmaxf(max_prev, tile_max);
|
|
143
|
+
const float corr = exp2f(max_prev - max_new);
|
|
144
|
+
|
|
145
|
+
float partial = 0.0f;
|
|
146
|
+
for (unsigned int c = lane; c < BC; c += 32) {
|
|
147
|
+
const float s = s_mem[row_cor * BC + c];
|
|
148
|
+
const float p = (s == -FLT_MAX) ? 0.0f : exp2f(s - max_new);
|
|
149
|
+
p_h_mem[row_cor * BC + c] = __float2half(p);
|
|
150
|
+
partial += p;
|
|
151
|
+
}
|
|
152
|
+
const float lse_new = lse_prev * corr + warpAllReduceSum(partial);
|
|
153
|
+
|
|
154
|
+
if (lane == 0) {
|
|
155
|
+
corr_mem[row_cor] = corr;
|
|
156
|
+
lse_mem[row_cor] = lse_new;
|
|
157
|
+
max_mem[row_cor] = max_new;
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
__syncthreads();
|
|
161
|
+
|
|
162
|
+
// PV -> s_mem -> pv_mem
|
|
163
|
+
arch::impl::matmul<BR, HEAD_DIM_V, WARPS, false>(pv_mem, p_h_mem, v_mem, BC, BC, HEAD_DIM_V, HEAD_DIM_V, 1.0f);
|
|
164
|
+
__syncthreads();
|
|
165
|
+
|
|
166
|
+
// O = ALpha O + PV
|
|
167
|
+
for (unsigned int i = tid; i < BR * HEAD_DIM_V; i += blockDim.x) {
|
|
168
|
+
o_acc[i] = o_acc[i] * corr_mem[i / HEAD_DIM_V] + pv_mem[i];
|
|
169
|
+
}
|
|
170
|
+
__syncthreads();
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
// store to GMEM
|
|
174
|
+
for (unsigned int row = 0; row < BR / WARPS; row++) {
|
|
175
|
+
const unsigned int row_cor = row + warp * (BR / WARPS);
|
|
176
|
+
const unsigned int g_row = q_row_start + row_cor;
|
|
177
|
+
if (row_cor >= static_cast<unsigned int>(q_rows) || g_row >= static_cast<unsigned int>(N)) {
|
|
178
|
+
continue;
|
|
179
|
+
}
|
|
180
|
+
const float denom = lse_mem[row_cor];
|
|
181
|
+
const float inv = (denom > 0.0f) ? 1.0f / denom : 0.0f;
|
|
182
|
+
for (unsigned int d = lane; d < HEAD_DIM_V; d += 32) {
|
|
183
|
+
out_s[g_row * D + d] = __float2half(o_acc[row_cor * HEAD_DIM_V + d] * inv);
|
|
184
|
+
}
|
|
185
|
+
if (lane == 0) {
|
|
186
|
+
lse_s[g_row] = (denom > 0.0f) ? max_mem[row_cor] + log2f(denom) : -FLT_MAX;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
void attention_forward_kernel_launcher(const __half* X, const __half* W_phi, const __half* gate_q, const __half* gate_k,
|
|
192
|
+
const __half* W_V, __half* out, float* lse, const int B, const int N,
|
|
193
|
+
const int D, const int H, const int cube_m, const int head_dim_v,
|
|
194
|
+
const float scale, cudaStream_t stream) {
|
|
195
|
+
|
|
196
|
+
// checkers
|
|
197
|
+
TORCH_CHECK(B > 0 && N > 0 && D > 0 && H > 0, "Invalid dimensions");
|
|
198
|
+
TORCH_CHECK(D % H == 0, "D must be divisible by H");
|
|
199
|
+
TORCH_CHECK(H * head_dim_v == D, "H * head_dim_v must equal D");
|
|
200
|
+
TORCH_CHECK(cube_m == CUBE_M_FWD && head_dim_v == HEAD_DIM_V_FWD, "kernel built for fixed (cube_m, head_dim_v)");
|
|
201
|
+
TORCH_CHECK(X && W_phi && gate_q && gate_k && W_V && out && lse, "Null pointer");
|
|
202
|
+
|
|
203
|
+
// smem carve
|
|
204
|
+
const size_t smem = sizeof(__half) * (BR_FWD * CUBE_M_FWD + // q codes
|
|
205
|
+
BC_FWD * CUBE_M_FWD + // k codes
|
|
206
|
+
BC_FWD * HEAD_DIM_V_FWD + // V tile
|
|
207
|
+
BR_FWD * BC_FWD) // P fp16 tile
|
|
208
|
+
+ sizeof(__half) * (XT_ROWS * D) // shared raw-token staging tile
|
|
209
|
+
+ sizeof(float) * (BR_FWD * BC_FWD + // S/raw fp32 tile
|
|
210
|
+
4 * BR_FWD + // max + lse + corr + qbias
|
|
211
|
+
BC_FWD + // column bias
|
|
212
|
+
BR_FWD * HEAD_DIM_V_FWD); // O accumulator
|
|
213
|
+
|
|
214
|
+
// kernel instance
|
|
215
|
+
const auto kernel = morpho_attention_forward_kernel<HEAD_DIM_V_FWD, CUBE_M_FWD, BR_FWD, BC_FWD>;
|
|
216
|
+
|
|
217
|
+
// smem alloc check
|
|
218
|
+
static int fwd_cached = -1;
|
|
219
|
+
configure_kernel_smem(kernel, smem, fwd_cached, "forward_kernel");
|
|
220
|
+
|
|
221
|
+
// launch
|
|
222
|
+
dim3 grid(B * H, (N + BR_FWD - 1) / BR_FWD);
|
|
223
|
+
dim3 block(BLOCK_SIZE);
|
|
224
|
+
|
|
225
|
+
kernel<<<grid, block, smem, stream>>>(X, W_phi, gate_q, gate_k, W_V, out, lse, B, N, D, H, scale);
|
|
226
|
+
|
|
227
|
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
void attention_backward_kernel_launcher(const __half* grad_out, const __half* X, const __half* dX, const __half* d_se,
|
|
231
|
+
int B, int N, int D, cudaStream_t stream) {
|
|
232
|
+
TORCH_CHECK(B > 0 && N > 0 && D > 0, "Invalid dimensions");
|
|
233
|
+
|
|
234
|
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
235
|
+
}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#ifndef MORPHOTTENTION_ATTENTION_CUH
|
|
2
|
+
#define MORPHOTTENTION_ATTENTION_CUH
|
|
3
|
+
|
|
4
|
+
#include <cuda_runtime.h>
|
|
5
|
+
|
|
6
|
+
#include <cuda_fp16.h>
|
|
7
|
+
|
|
8
|
+
#ifdef __CUDACC__
|
|
9
|
+
#include <cuda/sm120/matmul.cuh>
|
|
10
|
+
#include <cuda/sm120/project.cuh>
|
|
11
|
+
#include <cuda/sm120/smem.cuh>
|
|
12
|
+
#include <cuda/utils/declarations.cuh>
|
|
13
|
+
#include <cuda/utils/smem.cuh>
|
|
14
|
+
#include <cuda/utils/utils.cuh>
|
|
15
|
+
|
|
16
|
+
#include <c10/cuda/CUDAException.h>
|
|
17
|
+
#include <c10/util/Exception.h>
|
|
18
|
+
|
|
19
|
+
namespace arch {
|
|
20
|
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 120)
|
|
21
|
+
namespace impl = sm120; // Consumer Blackwell
|
|
22
|
+
#else
|
|
23
|
+
namespace impl = sm120; // fallback
|
|
24
|
+
#endif
|
|
25
|
+
} // namespace arch
|
|
26
|
+
#endif // __CUDACC__
|
|
27
|
+
|
|
28
|
+
void attention_forward_kernel_launcher(const __half* X, const __half* W_phi, const __half* gate_q, const __half* gate_k,
|
|
29
|
+
const __half* W_V, __half* out, float* lse, const int B, const int N,
|
|
30
|
+
const int D, const int H, const int cube_m, const int head_dim_v,
|
|
31
|
+
const float scale, cudaStream_t stream);
|
|
32
|
+
|
|
33
|
+
void attention_backward_kernel_launcher(const __half* grad_out, const __half* X, const __half* dX, const __half* d_se,
|
|
34
|
+
int B, int N, int D, cudaStream_t stream);
|
|
35
|
+
|
|
36
|
+
#endif // MORPHOTTENTION_ATTENTION_CUH
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#include "dispatch.h"
|
|
2
|
+
|
|
3
|
+
#include <torch/extension.h>
|
|
4
|
+
|
|
5
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
6
|
+
m.doc() = "Morphottention CUDA attention kernels";
|
|
7
|
+
|
|
8
|
+
m.def("forward", &forward, "Attention forward dispatcher", py::arg("X"), py::arg("W_phi"), py::arg("gate_q"),
|
|
9
|
+
py::arg("gate_k"), py::arg("W_V"), py::arg("H"), py::arg("cube_m"), py::arg("scale"), py::arg("causal"));
|
|
10
|
+
|
|
11
|
+
m.def("backward", &backward, "Attention backward dispatcher", py::arg("grad_out"), py::arg("X"));
|
|
12
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "dispatch.h"
|
|
2
|
+
|
|
3
|
+
std::vector<torch::Tensor> forward(const torch::Tensor& X, const torch::Tensor& W_phi, const torch::Tensor& gate_q,
|
|
4
|
+
const torch::Tensor& gate_k, const torch::Tensor& W_V, const int64_t H,
|
|
5
|
+
const int64_t cube_m, const double scale, const bool causal) {
|
|
6
|
+
return morpho_forward(X, W_phi, gate_q, gate_k, W_V, H, cube_m, scale, causal);
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
std::vector<torch::Tensor> backward(const torch::Tensor& grad_out, const torch::Tensor& X) {
|
|
10
|
+
TORCH_CHECK(grad_out.is_cuda() && X.is_cuda(), "Gradient output and X must be a CUDA tensors");
|
|
11
|
+
|
|
12
|
+
return morpho_backward(X, grad_out);
|
|
13
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
#ifndef MORPHOTTENTION_DISPATCH_H
|
|
2
|
+
#define MORPHOTTENTION_DISPATCH_H
|
|
3
|
+
|
|
4
|
+
#include <torch/extension.h>
|
|
5
|
+
|
|
6
|
+
// py-facing dispatchers
|
|
7
|
+
std::vector<torch::Tensor> forward(const torch::Tensor& X, const torch::Tensor& W_phi, const torch::Tensor& gate_q,
|
|
8
|
+
const torch::Tensor& gate_k, const torch::Tensor& W_V, int64_t H, int64_t cube_m,
|
|
9
|
+
double scale, bool causal);
|
|
10
|
+
|
|
11
|
+
std::vector<torch::Tensor> backward(const torch::Tensor& grad_out, const torch::Tensor& X);
|
|
12
|
+
|
|
13
|
+
// CUDA-facing dispatchers
|
|
14
|
+
|
|
15
|
+
std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
|
|
16
|
+
const torch::Tensor& gate_q, const torch::Tensor& gate_k,
|
|
17
|
+
const torch::Tensor& W_V, int64_t H, int64_t cube_m, double scale,
|
|
18
|
+
bool causal);
|
|
19
|
+
|
|
20
|
+
std::vector<torch::Tensor> morpho_backward(const torch::Tensor& grad_out, const torch::Tensor& X);
|
|
21
|
+
|
|
22
|
+
#endif // MORPHOTTENTION_DISPATCH_H
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#ifndef MORPHOTTENTION_CUBE_CUH
|
|
2
|
+
#define MORPHOTTENTION_CUBE_CUH
|
|
3
|
+
|
|
4
|
+
#include <cuda_runtime.h>
|
|
5
|
+
|
|
6
|
+
#include <math.h>
|
|
7
|
+
|
|
8
|
+
namespace morph {
|
|
9
|
+
|
|
10
|
+
enum class CubeProjection { Sigmoid, HardTanh01, Identity };
|
|
11
|
+
|
|
12
|
+
__device__ __forceinline__ float cube_sigmoid(const float x) {
|
|
13
|
+
return 1.0f / (1.0f + __expf(-x));
|
|
14
|
+
}
|
|
15
|
+
__device__ __forceinline__ float cube_hardtanh01(const float x) {
|
|
16
|
+
return fminf(1.0f, fmaxf(0.0f, 0.5f * (x + 1.0f)));
|
|
17
|
+
}
|
|
18
|
+
__device__ __forceinline__ float cube_project(const float x, CubeProjection p) {
|
|
19
|
+
switch (p) {
|
|
20
|
+
case CubeProjection::Sigmoid:
|
|
21
|
+
return cube_sigmoid(x);
|
|
22
|
+
case CubeProjection::HardTanh01:
|
|
23
|
+
return cube_hardtanh01(x);
|
|
24
|
+
default:
|
|
25
|
+
return x;
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
__device__ __forceinline__ float cube_sigmoid_grad(float x) {
|
|
30
|
+
const float s = cube_sigmoid(x);
|
|
31
|
+
return s * (1.0f - s);
|
|
32
|
+
}
|
|
33
|
+
__device__ __forceinline__ float cube_hardtanh01_grad(float x) {
|
|
34
|
+
return (x > -1.0f && x < 1.0f) ? 0.5f : 0.0f;
|
|
35
|
+
}
|
|
36
|
+
__device__ __forceinline__ float cube_project_grad(float x, CubeProjection p) {
|
|
37
|
+
switch (p) {
|
|
38
|
+
case CubeProjection::Sigmoid:
|
|
39
|
+
return cube_sigmoid_grad(x);
|
|
40
|
+
case CubeProjection::HardTanh01:
|
|
41
|
+
return cube_hardtanh01_grad(x);
|
|
42
|
+
default:
|
|
43
|
+
return 1.0f;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
template <CubeProjection P>
|
|
48
|
+
struct Cube;
|
|
49
|
+
template <>
|
|
50
|
+
struct Cube<CubeProjection::Sigmoid> {
|
|
51
|
+
static __device__ __forceinline__ float project(const float x) {
|
|
52
|
+
return cube_sigmoid(x);
|
|
53
|
+
}
|
|
54
|
+
static __device__ __forceinline__ float grad(const float x) {
|
|
55
|
+
return cube_sigmoid_grad(x);
|
|
56
|
+
}
|
|
57
|
+
};
|
|
58
|
+
|
|
59
|
+
template <>
|
|
60
|
+
struct Cube<CubeProjection::HardTanh01> {
|
|
61
|
+
static __device__ __forceinline__ float project(const float x) {
|
|
62
|
+
return cube_hardtanh01(x);
|
|
63
|
+
}
|
|
64
|
+
static __device__ __forceinline__ float grad(const float x) {
|
|
65
|
+
return cube_hardtanh01_grad(x);
|
|
66
|
+
}
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
template <>
|
|
70
|
+
struct Cube<CubeProjection::Identity> {
|
|
71
|
+
static __device__ __forceinline__ float project(const float x) {
|
|
72
|
+
return x;
|
|
73
|
+
}
|
|
74
|
+
static __device__ __forceinline__ float grad(float) {
|
|
75
|
+
return 1.0f;
|
|
76
|
+
}
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
} // namespace morph
|
|
80
|
+
|
|
81
|
+
#endif // MORPHOTTENTION_CUBE_CUH
|