morphottention 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,113 @@
1
+ cmake_minimum_required(VERSION 3.24)
2
+
3
+ if(NOT DEFINED CMAKE_CUDA_COMPILER)
4
+ if(DEFINED ENV{CUDA_HOME})
5
+ set(CMAKE_CUDA_COMPILER "$ENV{CUDA_HOME}/bin/nvcc" CACHE FILEPATH "" FORCE)
6
+ else()
7
+ find_program(_NVCC nvcc)
8
+ if(_NVCC)
9
+ set(CMAKE_CUDA_COMPILER "${_NVCC}" CACHE FILEPATH "" FORCE)
10
+ endif()
11
+ endif()
12
+ endif()
13
+
14
+
15
+ project(morphottention LANGUAGES CXX CUDA)
16
+
17
+ # 23 host / 20 device
18
+ set(CMAKE_CXX_STANDARD 23)
19
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
20
+ set(CMAKE_CUDA_STANDARD 20)
21
+ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
22
+
23
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
24
+ set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF)
25
+
26
+ set(CMAKE_CUDA_ARCHITECTURES 90 100 120)
27
+ set(TORCH_CUDA_ARCH_LIST "9.0;10.0;12.0")
28
+
29
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
30
+
31
+ set(Python_FIND_VIRTUALENV FIRST)
32
+
33
+ find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
34
+
35
+ execute_process(
36
+ COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
37
+ OUTPUT_VARIABLE TORCH_CMAKE_PREFIX
38
+ OUTPUT_STRIP_TRAILING_WHITESPACE
39
+ )
40
+ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX}")
41
+ find_package(Torch REQUIRED CONFIG)
42
+
43
+
44
+ if(WIN32)
45
+ set(_TORCH_PYTHON_LIB_NAME "torch_python.lib")
46
+ else()
47
+ set(_TORCH_PYTHON_LIB_NAME "libtorch_python.so")
48
+ endif()
49
+
50
+ execute_process(
51
+ COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib' / '${_TORCH_PYTHON_LIB_NAME}')"
52
+ OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY
53
+ OUTPUT_STRIP_TRAILING_WHITESPACE
54
+ )
55
+ if(NOT EXISTS "${TORCH_PYTHON_LIBRARY}")
56
+ message(FATAL_ERROR "Could not locate ${_TORCH_PYTHON_LIB_NAME} at: ${TORCH_PYTHON_LIBRARY}")
57
+ endif()
58
+ get_filename_component(TORCH_LIB_DIR "${TORCH_PYTHON_LIBRARY}" DIRECTORY)
59
+
60
+ find_package(CUDAToolkit REQUIRED)
61
+
62
+ Python_add_library(_C MODULE WITH_SOABI
63
+ csrc/cuda/binder.cpp
64
+ csrc/cuda/dispatch.cpp
65
+
66
+ csrc/cuda/attention/attention.cpp
67
+ csrc/cuda/attention/attention.cu
68
+ )
69
+
70
+ # -lineinfo on Debug; full device debug (-G) only when explicitly sanitizing.
71
+ option(CUDA_SANITIZE "Build device code with -G for compute-sanitizer" OFF)
72
+ target_compile_options(_C PRIVATE
73
+ $<$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>:--threads=0>
74
+ $<$<AND:$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>,$<CONFIG:Debug>>:-lineinfo>
75
+ $<$<AND:$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>,$<CONFIG:Debug>,$<BOOL:${CUDA_SANITIZE}>>:-G>
76
+ )
77
+
78
+
79
+ if(MSVC)
80
+ target_compile_options(_C PRIVATE
81
+ $<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>
82
+ $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=/Zc:preprocessor>
83
+ )
84
+ endif()
85
+
86
+ target_include_directories(_C PRIVATE
87
+ ${CMAKE_CURRENT_SOURCE_DIR}/csrc
88
+ ${TORCH_INCLUDE_DIRS}
89
+ ${Python_INCLUDE_DIRS}
90
+ ${CUDAToolkit_INCLUDE_DIRS}
91
+ )
92
+
93
+ target_compile_definitions(_C PRIVATE
94
+ TORCH_EXTENSION_NAME=_C
95
+ )
96
+
97
+ target_link_libraries(_C PRIVATE
98
+ ${TORCH_LIBRARIES}
99
+ ${TORCH_PYTHON_LIBRARY}
100
+ Python::Module
101
+ CUDA::cudart
102
+ CUDA::cublas
103
+ )
104
+
105
+ if(NOT WIN32)
106
+ target_link_options(_C PRIVATE "-Wl,--no-as-needed")
107
+ set_target_properties(_C PROPERTIES
108
+ BUILD_RPATH "${TORCH_LIB_DIR}"
109
+ INSTALL_RPATH "$ORIGIN/../torch/lib"
110
+ )
111
+ endif()
112
+
113
+ install(TARGETS _C LIBRARY DESTINATION morphottention)
@@ -0,0 +1,62 @@
1
+ Metadata-Version: 2.2
2
+ Name: morphottention
3
+ Version: 0.1.0
4
+ Summary: CUDA attention & morphology kernels for PyTorch (sm_90/100/120)
5
+ Keywords: attention,cuda,pytorch,transformer,morfologfy,flash-attention,ViT
6
+ Author-Email: Vedran Hrabar <vedran.hrabar@outlook.com>
7
+ License: MIT
8
+ Classifier: Development Status :: 2 - Pre-Alpha
9
+ Classifier: Environment :: GPU
10
+ Classifier: Environment :: GPU :: NVIDIA CUDA
11
+ Classifier: Environment :: GPU :: NVIDIA CUDA :: 13
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: POSIX :: Linux
16
+ Classifier: Operating System :: Microsoft :: Windows
17
+ Classifier: Programming Language :: C++
18
+ Classifier: Programming Language :: Python :: 3
19
+ Classifier: Programming Language :: Python :: 3 :: Only
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Classifier: Programming Language :: Python :: Implementation :: CPython
22
+ Classifier: Topic :: Scientific/Engineering
23
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
25
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
26
+ Classifier: Topic :: Software Development :: Libraries
27
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
28
+ Classifier: Typing :: Typed
29
+ Project-URL: repository, https://github.com/vhrabar/morphottention
30
+ Project-URL: documentation, https://github.com/vhrabar/morphottention/wiki
31
+ Project-URL: Bug Tracker, https://github.com/vhrabar/morphottention/issues
32
+ Requires-Python: ==3.14.*
33
+ Requires-Dist: torch>=2.12
34
+ Description-Content-Type: text/markdown
35
+
36
+ # Morphottention
37
+ Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
38
+
39
+ ## Install
40
+
41
+ Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
42
+ Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
43
+ be installed in the environment.
44
+
45
+ ```bash
46
+ pip install morphottention
47
+ ```
48
+
49
+ ## Building from source
50
+
51
+ Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
52
+
53
+ ```bash
54
+ uv sync --package morphottention --no-dev --group build
55
+ uv build --package morphottention --wheel --no-build-isolation
56
+ ```
57
+
58
+ ## License
59
+
60
+ MIT
61
+
62
+ Copyright © 2026 Vedran Hrabar.
@@ -0,0 +1,27 @@
1
+ # Morphottention
2
+ Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
3
+
4
+ ## Install
5
+
6
+ Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
7
+ Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
8
+ be installed in the environment.
9
+
10
+ ```bash
11
+ pip install morphottention
12
+ ```
13
+
14
+ ## Building from source
15
+
16
+ Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
17
+
18
+ ```bash
19
+ uv sync --package morphottention --no-dev --group build
20
+ uv build --package morphottention --wheel --no-build-isolation
21
+ ```
22
+
23
+ ## License
24
+
25
+ MIT
26
+
27
+ Copyright © 2026 Vedran Hrabar.
@@ -0,0 +1,71 @@
1
+ #include "attention.cuh"
2
+
3
+ #include <cuda/dispatch.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <torch/extension.h>
8
+
9
+ auto check = [](const torch::Tensor& t, const char* name) {
10
+ TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
11
+ TORCH_CHECK(t.scalar_type() == torch::kHalf, name, " must be float16");
12
+ TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
13
+ };
14
+
15
+ std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
16
+ const torch::Tensor& gate_q, const torch::Tensor& gate_k,
17
+ const torch::Tensor& W_V, const int64_t H, const int64_t cube_m,
18
+ const double scale, const bool causal) {
19
+ // checkers
20
+ check(X, "X");
21
+ check(W_phi, "W_phi");
22
+ check(gate_q, "gate_q");
23
+ check(gate_k, "gate_k");
24
+ check(W_V, "W_V");
25
+
26
+ // shape managment
27
+ TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
28
+ const int64_t B = X.size(0);
29
+ const int64_t N = X.size(1);
30
+ const int64_t D = X.size(2);
31
+
32
+ TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
33
+ const int64_t head_dim_v = D / H;
34
+
35
+ TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
36
+ TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
37
+ TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
38
+ TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
39
+
40
+ auto out = torch::empty_like(X);
41
+ auto lse = torch::empty({B * H, N}, X.options().dtype(torch::kFloat32));
42
+
43
+ // launcher
44
+ const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
45
+
46
+ attention_forward_kernel_launcher(reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
47
+ reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
48
+ reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
49
+ reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
50
+ reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
51
+ reinterpret_cast<__half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
52
+ static_cast<int>(B), static_cast<int>(N), static_cast<int>(D),
53
+ static_cast<int>(H), static_cast<int>(cube_m), static_cast<int>(head_dim_v),
54
+ static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
55
+
56
+ return {out, lse};
57
+ }
58
+
59
+ std::vector<torch::Tensor> morpho_backward(const torch::Tensor& X, const torch::Tensor& grad_out) {
60
+ TORCH_CHECK(X.is_cuda(), "Input tensor must be a CUDA tensor");
61
+ TORCH_CHECK(grad_out.is_cuda(), "Gradient output tensor must be a CUDA tensor");
62
+
63
+ auto X_contig = X.contiguous();
64
+ auto grad_out_contig = grad_out.contiguous();
65
+
66
+ const int B = static_cast<int>(X_contig.size(0));
67
+ const int N = static_cast<int>(X_contig.size(1));
68
+ const int D = static_cast<int>(X_contig.size(2));
69
+
70
+ return {X};
71
+ }
@@ -0,0 +1,235 @@
1
+ #include "attention.cuh"
2
+
3
+ #include <cuda_pipeline.h>
4
+
5
+ #include <cfloat>
6
+
7
+ template <int HEAD_DIM_V, int CUBE_M, int BR, int BC>
8
+ __global__ void morpho_attention_forward_kernel(const __half* __restrict__ X, const __half* __restrict__ W_phi,
9
+ const __half* __restrict__ gate_q, const __half* __restrict__ gate_k,
10
+ const __half* __restrict__ W_V, __half* __restrict__ out, float* lse,
11
+ int B, int N, int D, int H, float scale) {
12
+
13
+ static_assert(HEAD_DIM_V <= BR, "HEAD_DIM_V must be <= BR");
14
+ static_assert(HEAD_DIM_V <= BC, "HEAD_DIM_V must be <= BC");
15
+
16
+ // NOLINTNEXTLINE(readability-suspicious-call-argument)
17
+ auto [t, b, tid] = get_coords();
18
+ const unsigned int warp = tid / 32;
19
+ const unsigned int lane = tid & 31;
20
+
21
+ const unsigned int bh = t;
22
+ const unsigned int batch = bh / static_cast<unsigned int>(H);
23
+ const unsigned int head = bh % static_cast<unsigned int>(H);
24
+
25
+ const int q_row_start = static_cast<int>(b) * BR;
26
+ const int q_row_end = min(q_row_start + BR, N);
27
+ const int q_rows = q_row_end - q_row_start;
28
+ if (q_rows <= 0) {
29
+ return;
30
+ }
31
+
32
+ const int ldphi = H * CUBE_M;
33
+ const int ldv = H * HEAD_DIM_V;
34
+
35
+ const __half* X_b = X + batch * N * D;
36
+ const __half* W_phi_h = W_phi + head * CUBE_M;
37
+ const __half* W_V_h = W_V + head * HEAD_DIM_V;
38
+ const __half* gate_q_h = gate_q + head * CUBE_M;
39
+ const __half* gate_k_h = gate_k + head * CUBE_M;
40
+ __half* out_s = out + batch * N * D + head * HEAD_DIM_V;
41
+ float* lse_s = lse + bh * N;
42
+
43
+ // smem carve
44
+ constexpr int XT_ROWS = (BR > BC) ? BR : BC;
45
+ extern __shared__ __align__(16) unsigned char smem[];
46
+ auto* q_mem = reinterpret_cast<__half*>(smem); // q codes [BR, CUBE_M]
47
+ auto* k_mem = q_mem + BR * CUBE_M; // k codes [BC, CUBE_M]
48
+ auto* v_mem = k_mem + BC * CUBE_M; // V tile [BC, HEAD_DIM_V]
49
+ auto* xt_mem = v_mem + BC * HEAD_DIM_V; // token staging [max(BR,BC), D]
50
+
51
+ auto* s_mem = reinterpret_cast<float*>(xt_mem + XT_ROWS * D); // S/P fp32 tile [BR, BC]
52
+ auto* p_h_mem = reinterpret_cast<__half*>(s_mem + BR * BC); // P fp16 tile [BR, BC]
53
+ auto* state_mem = reinterpret_cast<float*>(p_h_mem + BR * BC);
54
+
55
+ float* max_mem = state_mem; // [BR]
56
+ float* lse_mem = state_mem + BR; // [BR]
57
+ float* corr_mem = state_mem + 2 * BR; // [BR]
58
+ float* qbias_mem = state_mem + 3 * BR; // [BR]
59
+ float* cbias_mem = state_mem + 4 * BR; // [BC]
60
+ float* o_acc = cbias_mem + BC; // [BR, HEAD_DIM_V]
61
+ float* pv_mem = s_mem; // PV scratch
62
+
63
+ // runing vars in SMEM
64
+ for (unsigned int i = tid; i < BR * HEAD_DIM_V; i += blockDim.x) {
65
+ o_acc[i] = 0.0f;
66
+ }
67
+ for (unsigned int row = tid; row < BR; row += blockDim.x) {
68
+ max_mem[row] = -FLT_MAX;
69
+ lse_mem[row] = 0.0f;
70
+ }
71
+
72
+ // move X to SRAM
73
+ arch::impl::smem_load(xt_mem, X_b + q_row_start * D, static_cast<unsigned int>(q_rows),
74
+ static_cast<unsigned int>(D), tid, BR);
75
+ __syncthreads();
76
+ // project Q to the gated unit-hypercube
77
+ // q = gate_q @ sigma(W_phi.t, x_q) -> q_mem, q_bias = gate_k x q_i
78
+ arch::impl::project_phi<BR, CUBE_M, WARPS>(xt_mem, W_phi_h, gate_q_h, gate_k_h, q_mem, qbias_mem, s_mem, D, ldphi,
79
+ warp, lane);
80
+
81
+ const float scale_log2 = scale * LOG2E;
82
+
83
+ // loop over KV vlocks
84
+ unsigned int n_tiles = (static_cast<unsigned int>(N) + BC - 1) / BC;
85
+ const int q_global_max = q_row_start + q_rows - 1;
86
+ if (CAUSAL) {
87
+ // full mask after last row-q
88
+ n_tiles = min(n_tiles, static_cast<unsigned int>(q_global_max) / BC + 1u);
89
+ }
90
+
91
+ for (unsigned int tile = 0; tile < n_tiles; tile++) {
92
+ const int kv_block_start = static_cast<int>(tile) * BC;
93
+ const int kv_valid = min(BC, N - kv_block_start);
94
+
95
+ // stage K/V tokens
96
+ arch::impl::smem_load(xt_mem, X_b + kv_block_start * D, static_cast<unsigned int>(kv_valid),
97
+ static_cast<unsigned int>(D), tid, BC);
98
+ __syncthreads();
99
+
100
+ // project K to the gated hypercube
101
+ // k = gate_k @ sigma(W_phi.t, x_j) -> k_mem, c_bias = gate_q x k_j
102
+ arch::impl::project_phi<BC, CUBE_M, WARPS>(xt_mem, W_phi_h, gate_k_h, gate_q_h, k_mem, cbias_mem, s_mem, D,
103
+ ldphi, warp, lane);
104
+
105
+ // v = W_v x x_j -> s_mem -> v_curr
106
+ arch::impl::matmul<BC, HEAD_DIM_V, WARPS, false>(s_mem, xt_mem, W_V_h, D, D, ldv, HEAD_DIM_V, 1.0f);
107
+ __syncthreads();
108
+ for (unsigned int i = tid; i < BC * HEAD_DIM_V; i += blockDim.x) {
109
+ v_mem[i] = __float2half(s_mem[i]);
110
+ }
111
+ __syncthreads();
112
+
113
+ // S_raw = QK.T -> s_mem
114
+ arch::impl::matmul<BR, BC, WARPS, true>(s_mem, q_mem, k_mem, CUBE_M, CUBE_M, CUBE_M, BC, 1.0f);
115
+ __syncthreads();
116
+
117
+ // S = (2S_raw − qbias[i] − cbias[j])·scale_log2e
118
+ // symetry -> -inf mask
119
+ for (unsigned int row = 0; row < BR / WARPS; row++) {
120
+ const unsigned int row_cor = row + warp * (BR / WARPS);
121
+ const int q_global = q_row_start + static_cast<int>(row_cor);
122
+
123
+ float tile_max = -FLT_MAX;
124
+ for (unsigned int c = lane; c < BC; c += 32) {
125
+ float s = (2.0f * s_mem[row_cor * BC + c] - qbias_mem[row_cor] - cbias_mem[c]) * scale_log2;
126
+ const int k_global = kv_block_start + static_cast<int>(c);
127
+
128
+ bool drop = (static_cast<int>(c) >= kv_valid) || (static_cast<int>(row_cor) >= q_rows);
129
+ if (CAUSAL && k_global > q_global)
130
+ drop = true;
131
+ if (MASK_DIAG && k_global == q_global)
132
+ drop = true;
133
+
134
+ s = drop ? -FLT_MAX : s;
135
+ s_mem[row_cor * BC + c] = s;
136
+ tile_max = fmaxf(tile_max, s);
137
+ }
138
+ tile_max = warpMax(tile_max);
139
+
140
+ const float max_prev = max_mem[row_cor];
141
+ const float lse_prev = lse_mem[row_cor];
142
+ const float max_new = fmaxf(max_prev, tile_max);
143
+ const float corr = exp2f(max_prev - max_new);
144
+
145
+ float partial = 0.0f;
146
+ for (unsigned int c = lane; c < BC; c += 32) {
147
+ const float s = s_mem[row_cor * BC + c];
148
+ const float p = (s == -FLT_MAX) ? 0.0f : exp2f(s - max_new);
149
+ p_h_mem[row_cor * BC + c] = __float2half(p);
150
+ partial += p;
151
+ }
152
+ const float lse_new = lse_prev * corr + warpAllReduceSum(partial);
153
+
154
+ if (lane == 0) {
155
+ corr_mem[row_cor] = corr;
156
+ lse_mem[row_cor] = lse_new;
157
+ max_mem[row_cor] = max_new;
158
+ }
159
+ }
160
+ __syncthreads();
161
+
162
+ // PV -> s_mem -> pv_mem
163
+ arch::impl::matmul<BR, HEAD_DIM_V, WARPS, false>(pv_mem, p_h_mem, v_mem, BC, BC, HEAD_DIM_V, HEAD_DIM_V, 1.0f);
164
+ __syncthreads();
165
+
166
+ // O = ALpha O + PV
167
+ for (unsigned int i = tid; i < BR * HEAD_DIM_V; i += blockDim.x) {
168
+ o_acc[i] = o_acc[i] * corr_mem[i / HEAD_DIM_V] + pv_mem[i];
169
+ }
170
+ __syncthreads();
171
+ }
172
+
173
+ // store to GMEM
174
+ for (unsigned int row = 0; row < BR / WARPS; row++) {
175
+ const unsigned int row_cor = row + warp * (BR / WARPS);
176
+ const unsigned int g_row = q_row_start + row_cor;
177
+ if (row_cor >= static_cast<unsigned int>(q_rows) || g_row >= static_cast<unsigned int>(N)) {
178
+ continue;
179
+ }
180
+ const float denom = lse_mem[row_cor];
181
+ const float inv = (denom > 0.0f) ? 1.0f / denom : 0.0f;
182
+ for (unsigned int d = lane; d < HEAD_DIM_V; d += 32) {
183
+ out_s[g_row * D + d] = __float2half(o_acc[row_cor * HEAD_DIM_V + d] * inv);
184
+ }
185
+ if (lane == 0) {
186
+ lse_s[g_row] = (denom > 0.0f) ? max_mem[row_cor] + log2f(denom) : -FLT_MAX;
187
+ }
188
+ }
189
+ }
190
+
191
+ void attention_forward_kernel_launcher(const __half* X, const __half* W_phi, const __half* gate_q, const __half* gate_k,
192
+ const __half* W_V, __half* out, float* lse, const int B, const int N,
193
+ const int D, const int H, const int cube_m, const int head_dim_v,
194
+ const float scale, cudaStream_t stream) {
195
+
196
+ // checkers
197
+ TORCH_CHECK(B > 0 && N > 0 && D > 0 && H > 0, "Invalid dimensions");
198
+ TORCH_CHECK(D % H == 0, "D must be divisible by H");
199
+ TORCH_CHECK(H * head_dim_v == D, "H * head_dim_v must equal D");
200
+ TORCH_CHECK(cube_m == CUBE_M_FWD && head_dim_v == HEAD_DIM_V_FWD, "kernel built for fixed (cube_m, head_dim_v)");
201
+ TORCH_CHECK(X && W_phi && gate_q && gate_k && W_V && out && lse, "Null pointer");
202
+
203
+ // smem carve
204
+ const size_t smem = sizeof(__half) * (BR_FWD * CUBE_M_FWD + // q codes
205
+ BC_FWD * CUBE_M_FWD + // k codes
206
+ BC_FWD * HEAD_DIM_V_FWD + // V tile
207
+ BR_FWD * BC_FWD) // P fp16 tile
208
+ + sizeof(__half) * (XT_ROWS * D) // shared raw-token staging tile
209
+ + sizeof(float) * (BR_FWD * BC_FWD + // S/raw fp32 tile
210
+ 4 * BR_FWD + // max + lse + corr + qbias
211
+ BC_FWD + // column bias
212
+ BR_FWD * HEAD_DIM_V_FWD); // O accumulator
213
+
214
+ // kernel instance
215
+ const auto kernel = morpho_attention_forward_kernel<HEAD_DIM_V_FWD, CUBE_M_FWD, BR_FWD, BC_FWD>;
216
+
217
+ // smem alloc check
218
+ static int fwd_cached = -1;
219
+ configure_kernel_smem(kernel, smem, fwd_cached, "forward_kernel");
220
+
221
+ // launch
222
+ dim3 grid(B * H, (N + BR_FWD - 1) / BR_FWD);
223
+ dim3 block(BLOCK_SIZE);
224
+
225
+ kernel<<<grid, block, smem, stream>>>(X, W_phi, gate_q, gate_k, W_V, out, lse, B, N, D, H, scale);
226
+
227
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
228
+ }
229
+
230
+ void attention_backward_kernel_launcher(const __half* grad_out, const __half* X, const __half* dX, const __half* d_se,
231
+ int B, int N, int D, cudaStream_t stream) {
232
+ TORCH_CHECK(B > 0 && N > 0 && D > 0, "Invalid dimensions");
233
+
234
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
235
+ }
@@ -0,0 +1,36 @@
1
+ #ifndef MORPHOTTENTION_ATTENTION_CUH
2
+ #define MORPHOTTENTION_ATTENTION_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <cuda_fp16.h>
7
+
8
+ #ifdef __CUDACC__
9
+ #include <cuda/sm120/matmul.cuh>
10
+ #include <cuda/sm120/project.cuh>
11
+ #include <cuda/sm120/smem.cuh>
12
+ #include <cuda/utils/declarations.cuh>
13
+ #include <cuda/utils/smem.cuh>
14
+ #include <cuda/utils/utils.cuh>
15
+
16
+ #include <c10/cuda/CUDAException.h>
17
+ #include <c10/util/Exception.h>
18
+
19
+ namespace arch {
20
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 120)
21
+ namespace impl = sm120; // Consumer Blackwell
22
+ #else
23
+ namespace impl = sm120; // fallback
24
+ #endif
25
+ } // namespace arch
26
+ #endif // __CUDACC__
27
+
28
+ void attention_forward_kernel_launcher(const __half* X, const __half* W_phi, const __half* gate_q, const __half* gate_k,
29
+ const __half* W_V, __half* out, float* lse, const int B, const int N,
30
+ const int D, const int H, const int cube_m, const int head_dim_v,
31
+ const float scale, cudaStream_t stream);
32
+
33
+ void attention_backward_kernel_launcher(const __half* grad_out, const __half* X, const __half* dX, const __half* d_se,
34
+ int B, int N, int D, cudaStream_t stream);
35
+
36
+ #endif // MORPHOTTENTION_ATTENTION_CUH
@@ -0,0 +1,12 @@
1
+ #include "dispatch.h"
2
+
3
+ #include <torch/extension.h>
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.doc() = "Morphottention CUDA attention kernels";
7
+
8
+ m.def("forward", &forward, "Attention forward dispatcher", py::arg("X"), py::arg("W_phi"), py::arg("gate_q"),
9
+ py::arg("gate_k"), py::arg("W_V"), py::arg("H"), py::arg("cube_m"), py::arg("scale"), py::arg("causal"));
10
+
11
+ m.def("backward", &backward, "Attention backward dispatcher", py::arg("grad_out"), py::arg("X"));
12
+ }
@@ -0,0 +1,13 @@
1
+ #include "dispatch.h"
2
+
3
+ std::vector<torch::Tensor> forward(const torch::Tensor& X, const torch::Tensor& W_phi, const torch::Tensor& gate_q,
4
+ const torch::Tensor& gate_k, const torch::Tensor& W_V, const int64_t H,
5
+ const int64_t cube_m, const double scale, const bool causal) {
6
+ return morpho_forward(X, W_phi, gate_q, gate_k, W_V, H, cube_m, scale, causal);
7
+ }
8
+
9
+ std::vector<torch::Tensor> backward(const torch::Tensor& grad_out, const torch::Tensor& X) {
10
+ TORCH_CHECK(grad_out.is_cuda() && X.is_cuda(), "Gradient output and X must be a CUDA tensors");
11
+
12
+ return morpho_backward(X, grad_out);
13
+ }
@@ -0,0 +1,22 @@
1
+ #ifndef MORPHOTTENTION_DISPATCH_H
2
+ #define MORPHOTTENTION_DISPATCH_H
3
+
4
+ #include <torch/extension.h>
5
+
6
+ // py-facing dispatchers
7
+ std::vector<torch::Tensor> forward(const torch::Tensor& X, const torch::Tensor& W_phi, const torch::Tensor& gate_q,
8
+ const torch::Tensor& gate_k, const torch::Tensor& W_V, int64_t H, int64_t cube_m,
9
+ double scale, bool causal);
10
+
11
+ std::vector<torch::Tensor> backward(const torch::Tensor& grad_out, const torch::Tensor& X);
12
+
13
+ // CUDA-facing dispatchers
14
+
15
+ std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
16
+ const torch::Tensor& gate_q, const torch::Tensor& gate_k,
17
+ const torch::Tensor& W_V, int64_t H, int64_t cube_m, double scale,
18
+ bool causal);
19
+
20
+ std::vector<torch::Tensor> morpho_backward(const torch::Tensor& grad_out, const torch::Tensor& X);
21
+
22
+ #endif // MORPHOTTENTION_DISPATCH_H
@@ -0,0 +1,81 @@
1
+ #ifndef MORPHOTTENTION_CUBE_CUH
2
+ #define MORPHOTTENTION_CUBE_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <math.h>
7
+
8
+ namespace morph {
9
+
10
+ enum class CubeProjection { Sigmoid, HardTanh01, Identity };
11
+
12
+ __device__ __forceinline__ float cube_sigmoid(const float x) {
13
+ return 1.0f / (1.0f + __expf(-x));
14
+ }
15
+ __device__ __forceinline__ float cube_hardtanh01(const float x) {
16
+ return fminf(1.0f, fmaxf(0.0f, 0.5f * (x + 1.0f)));
17
+ }
18
+ __device__ __forceinline__ float cube_project(const float x, CubeProjection p) {
19
+ switch (p) {
20
+ case CubeProjection::Sigmoid:
21
+ return cube_sigmoid(x);
22
+ case CubeProjection::HardTanh01:
23
+ return cube_hardtanh01(x);
24
+ default:
25
+ return x;
26
+ }
27
+ }
28
+
29
+ __device__ __forceinline__ float cube_sigmoid_grad(float x) {
30
+ const float s = cube_sigmoid(x);
31
+ return s * (1.0f - s);
32
+ }
33
+ __device__ __forceinline__ float cube_hardtanh01_grad(float x) {
34
+ return (x > -1.0f && x < 1.0f) ? 0.5f : 0.0f;
35
+ }
36
+ __device__ __forceinline__ float cube_project_grad(float x, CubeProjection p) {
37
+ switch (p) {
38
+ case CubeProjection::Sigmoid:
39
+ return cube_sigmoid_grad(x);
40
+ case CubeProjection::HardTanh01:
41
+ return cube_hardtanh01_grad(x);
42
+ default:
43
+ return 1.0f;
44
+ }
45
+ }
46
+
47
+ template <CubeProjection P>
48
+ struct Cube;
49
+ template <>
50
+ struct Cube<CubeProjection::Sigmoid> {
51
+ static __device__ __forceinline__ float project(const float x) {
52
+ return cube_sigmoid(x);
53
+ }
54
+ static __device__ __forceinline__ float grad(const float x) {
55
+ return cube_sigmoid_grad(x);
56
+ }
57
+ };
58
+
59
+ template <>
60
+ struct Cube<CubeProjection::HardTanh01> {
61
+ static __device__ __forceinline__ float project(const float x) {
62
+ return cube_hardtanh01(x);
63
+ }
64
+ static __device__ __forceinline__ float grad(const float x) {
65
+ return cube_hardtanh01_grad(x);
66
+ }
67
+ };
68
+
69
+ template <>
70
+ struct Cube<CubeProjection::Identity> {
71
+ static __device__ __forceinline__ float project(const float x) {
72
+ return x;
73
+ }
74
+ static __device__ __forceinline__ float grad(float) {
75
+ return 1.0f;
76
+ }
77
+ };
78
+
79
+ } // namespace morph
80
+
81
+ #endif // MORPHOTTENTION_CUBE_CUH