morphottention 0.1.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {morphottention-0.1.0 → morphottention-0.2.1}/CMakeLists.txt +5 -18
- morphottention-0.2.1/PKG-INFO +130 -0
- morphottention-0.2.1/README.md +93 -0
- morphottention-0.2.1/build.toml +37 -0
- morphottention-0.2.1/csrc/compat/registration.h +14 -0
- morphottention-0.2.1/csrc/cuda/attention/attention.cpp +118 -0
- morphottention-0.2.1/csrc/cuda/attention/attention.cu +576 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/attention/attention.cuh +7 -3
- morphottention-0.2.1/csrc/cuda/binder.cpp +17 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/dispatch.cpp +5 -2
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/dispatch.h +10 -3
- morphottention-0.2.1/csrc/cuda/sm120/matmul.cuh +116 -0
- morphottention-0.2.1/csrc/cuda/sm120/project.cuh +105 -0
- morphottention-0.2.1/flake.lock +117 -0
- morphottention-0.2.1/flake.nix +17 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/pyproject.toml +8 -6
- morphottention-0.2.1/torch-ext/morphottention/_cmake_ops.py +34 -0
- {morphottention-0.1.0/src → morphottention-0.2.1/torch-ext}/morphottention/autograd.py +44 -5
- morphottention-0.1.0/PKG-INFO +0 -62
- morphottention-0.1.0/README.md +0 -27
- morphottention-0.1.0/csrc/cuda/attention/attention.cpp +0 -71
- morphottention-0.1.0/csrc/cuda/attention/attention.cu +0 -235
- morphottention-0.1.0/csrc/cuda/binder.cpp +0 -12
- morphottention-0.1.0/csrc/cuda/sm120/matmul.cuh +0 -58
- morphottention-0.1.0/csrc/cuda/sm120/project.cuh +0 -51
- morphottention-0.1.0/src/morphottention/_C.pyi +0 -14
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/morfology/cube.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/morfology/soft_morph.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/sm120/smem.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/utils/declarations.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/utils/reductions.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/utils/smem.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.1}/csrc/cuda/utils/utils.cuh +0 -0
- {morphottention-0.1.0/src → morphottention-0.2.1/torch-ext}/morphottention/__init__.py +0 -0
- {morphottention-0.1.0/src → morphottention-0.2.1/torch-ext}/morphottention/py.typed +0 -0
|
@@ -41,21 +41,14 @@ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX}")
|
|
|
41
41
|
find_package(Torch REQUIRED CONFIG)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
if(WIN32)
|
|
45
|
-
set(_TORCH_PYTHON_LIB_NAME "torch_python.lib")
|
|
46
|
-
else()
|
|
47
|
-
set(_TORCH_PYTHON_LIB_NAME "libtorch_python.so")
|
|
48
|
-
endif()
|
|
49
|
-
|
|
50
44
|
execute_process(
|
|
51
|
-
COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib'
|
|
52
|
-
OUTPUT_VARIABLE
|
|
45
|
+
COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib')"
|
|
46
|
+
OUTPUT_VARIABLE TORCH_LIB_DIR
|
|
53
47
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
54
48
|
)
|
|
55
|
-
if(NOT EXISTS "${
|
|
56
|
-
message(FATAL_ERROR "Could not locate
|
|
49
|
+
if(NOT EXISTS "${TORCH_LIB_DIR}")
|
|
50
|
+
message(FATAL_ERROR "Could not locate torch lib dir at: ${TORCH_LIB_DIR}")
|
|
57
51
|
endif()
|
|
58
|
-
get_filename_component(TORCH_LIB_DIR "${TORCH_PYTHON_LIBRARY}" DIRECTORY)
|
|
59
52
|
|
|
60
53
|
find_package(CUDAToolkit REQUIRED)
|
|
61
54
|
|
|
@@ -85,25 +78,19 @@ endif()
|
|
|
85
78
|
|
|
86
79
|
target_include_directories(_C PRIVATE
|
|
87
80
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc
|
|
81
|
+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/compat
|
|
88
82
|
${TORCH_INCLUDE_DIRS}
|
|
89
83
|
${Python_INCLUDE_DIRS}
|
|
90
84
|
${CUDAToolkit_INCLUDE_DIRS}
|
|
91
85
|
)
|
|
92
86
|
|
|
93
|
-
target_compile_definitions(_C PRIVATE
|
|
94
|
-
TORCH_EXTENSION_NAME=_C
|
|
95
|
-
)
|
|
96
|
-
|
|
97
87
|
target_link_libraries(_C PRIVATE
|
|
98
88
|
${TORCH_LIBRARIES}
|
|
99
|
-
${TORCH_PYTHON_LIBRARY}
|
|
100
|
-
Python::Module
|
|
101
89
|
CUDA::cudart
|
|
102
90
|
CUDA::cublas
|
|
103
91
|
)
|
|
104
92
|
|
|
105
93
|
if(NOT WIN32)
|
|
106
|
-
target_link_options(_C PRIVATE "-Wl,--no-as-needed")
|
|
107
94
|
set_target_properties(_C PROPERTIES
|
|
108
95
|
BUILD_RPATH "${TORCH_LIB_DIR}"
|
|
109
96
|
INSTALL_RPATH "$ORIGIN/../torch/lib"
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: morphottention
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Mathematical Morphology-based self-attention module for PyTorch (CUDA) using Flash-style kernel fusion.
|
|
5
|
+
Keywords: attention,cuda,pytorch,transformer,morphology,flash-attention,ViT
|
|
6
|
+
Author-Email: Vedran Hrabar <vedran.hrabar@outlook.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Environment :: GPU
|
|
10
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
11
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA :: 13
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
16
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
17
|
+
Classifier: Programming Language :: C++
|
|
18
|
+
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
23
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
|
+
Classifier: Topic :: Scientific/Engineering
|
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
26
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
27
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
28
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
29
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
30
|
+
Classifier: Typing :: Typed
|
|
31
|
+
Project-URL: repository, https://github.com/vhrabar/morphottention
|
|
32
|
+
Project-URL: documentation, https://github.com/vhrabar/morphottention/wiki
|
|
33
|
+
Project-URL: Bug Tracker, https://github.com/vhrabar/morphottention/issues
|
|
34
|
+
Requires-Python: >=3.12
|
|
35
|
+
Requires-Dist: torch>=2.12
|
|
36
|
+
Description-Content-Type: text/markdown
|
|
37
|
+
|
|
38
|
+
# Morphottention
|
|
39
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
40
|
+
|
|
41
|
+
## Install
|
|
42
|
+
|
|
43
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
44
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
45
|
+
be installed in the environment.
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install morphottention
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
The package exposes an `nn.Module` (`MorphoAttention`), a functional entry point
|
|
54
|
+
(`morpho_attention`), and the raw autograd bridge (`MorphoAttentionFunction`).
|
|
55
|
+
All inputs must be CUDA tensors; the module defaults to `float16`.
|
|
56
|
+
|
|
57
|
+
### As an `nn.Module`
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import torch
|
|
61
|
+
from morphottention import MorphoAttention
|
|
62
|
+
|
|
63
|
+
attn = MorphoAttention(
|
|
64
|
+
dim=256, # model dimension D
|
|
65
|
+
num_heads=8, # number of attention heads H
|
|
66
|
+
cube_m=16, # hypercube width per head
|
|
67
|
+
scale=1.0, # softmax temperature
|
|
68
|
+
causal=False, # casual masking flag
|
|
69
|
+
device="cuda",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
x = torch.randn(2, 128, 256, dtype=torch.float16, device="cuda") # (B, N, D)
|
|
73
|
+
out = attn(x) # (B, N, D)
|
|
74
|
+
out.sum().backward()
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
### Functional form
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from morphottention import morpho_attention
|
|
81
|
+
|
|
82
|
+
out = morpho_attention(
|
|
83
|
+
x,
|
|
84
|
+
W_phi,
|
|
85
|
+
gate_q,
|
|
86
|
+
gate_k,
|
|
87
|
+
W_V,
|
|
88
|
+
num_heads=8, cube_m=16, scale=1.0,
|
|
89
|
+
causal=False,
|
|
90
|
+
)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### Raw autograd bridge
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
import torch
|
|
97
|
+
from morphottention import MorphoAttentionFunction
|
|
98
|
+
|
|
99
|
+
B, N, D, H, cube_m = 2, 128, 256, 8, 16
|
|
100
|
+
|
|
101
|
+
x = torch.randn(B, N, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
102
|
+
W_phi = torch.randn(D, H * cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
103
|
+
gate_q = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
104
|
+
gate_k = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
105
|
+
W_V = torch.randn(D, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
106
|
+
|
|
107
|
+
out = MorphoAttentionFunction.apply(
|
|
108
|
+
x, W_phi, gate_q, gate_k, W_V,
|
|
109
|
+
H, cube_m, 1.0, False, # num_heads, cube_m, scale, causal
|
|
110
|
+
) # (B, N, D)
|
|
111
|
+
out.sum().backward()
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
`W_phi` has shape `(D, H * cube_m)`, `W_V` has shape `(D, D)`, and `gate_q` /
|
|
115
|
+
`gate_k` each have shape `(H, cube_m)`.
|
|
116
|
+
|
|
117
|
+
## Building from source
|
|
118
|
+
|
|
119
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
uv sync --package morphottention --no-dev --group build
|
|
123
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## License
|
|
127
|
+
|
|
128
|
+
MIT
|
|
129
|
+
|
|
130
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Morphottention
|
|
2
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
3
|
+
|
|
4
|
+
## Install
|
|
5
|
+
|
|
6
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
7
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
8
|
+
be installed in the environment.
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install morphottention
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Usage
|
|
15
|
+
|
|
16
|
+
The package exposes an `nn.Module` (`MorphoAttention`), a functional entry point
|
|
17
|
+
(`morpho_attention`), and the raw autograd bridge (`MorphoAttentionFunction`).
|
|
18
|
+
All inputs must be CUDA tensors; the module defaults to `float16`.
|
|
19
|
+
|
|
20
|
+
### As an `nn.Module`
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import torch
|
|
24
|
+
from morphottention import MorphoAttention
|
|
25
|
+
|
|
26
|
+
attn = MorphoAttention(
|
|
27
|
+
dim=256, # model dimension D
|
|
28
|
+
num_heads=8, # number of attention heads H
|
|
29
|
+
cube_m=16, # hypercube width per head
|
|
30
|
+
scale=1.0, # softmax temperature
|
|
31
|
+
causal=False, # casual masking flag
|
|
32
|
+
device="cuda",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
x = torch.randn(2, 128, 256, dtype=torch.float16, device="cuda") # (B, N, D)
|
|
36
|
+
out = attn(x) # (B, N, D)
|
|
37
|
+
out.sum().backward()
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
### Functional form
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from morphottention import morpho_attention
|
|
44
|
+
|
|
45
|
+
out = morpho_attention(
|
|
46
|
+
x,
|
|
47
|
+
W_phi,
|
|
48
|
+
gate_q,
|
|
49
|
+
gate_k,
|
|
50
|
+
W_V,
|
|
51
|
+
num_heads=8, cube_m=16, scale=1.0,
|
|
52
|
+
causal=False,
|
|
53
|
+
)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
### Raw autograd bridge
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
import torch
|
|
60
|
+
from morphottention import MorphoAttentionFunction
|
|
61
|
+
|
|
62
|
+
B, N, D, H, cube_m = 2, 128, 256, 8, 16
|
|
63
|
+
|
|
64
|
+
x = torch.randn(B, N, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
65
|
+
W_phi = torch.randn(D, H * cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
66
|
+
gate_q = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
67
|
+
gate_k = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
68
|
+
W_V = torch.randn(D, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
69
|
+
|
|
70
|
+
out = MorphoAttentionFunction.apply(
|
|
71
|
+
x, W_phi, gate_q, gate_k, W_V,
|
|
72
|
+
H, cube_m, 1.0, False, # num_heads, cube_m, scale, causal
|
|
73
|
+
) # (B, N, D)
|
|
74
|
+
out.sum().backward()
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
`W_phi` has shape `(D, H * cube_m)`, `W_V` has shape `(D, D)`, and `gate_q` /
|
|
78
|
+
`gate_k` each have shape `(H, cube_m)`.
|
|
79
|
+
|
|
80
|
+
## Building from source
|
|
81
|
+
|
|
82
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
uv sync --package morphottention --no-dev --group build
|
|
86
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
## License
|
|
90
|
+
|
|
91
|
+
MIT
|
|
92
|
+
|
|
93
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[general]
|
|
2
|
+
name = "morphottention"
|
|
3
|
+
backends = ["cuda"]
|
|
4
|
+
license = "MIT"
|
|
5
|
+
version = 0
|
|
6
|
+
|
|
7
|
+
[general.hub]
|
|
8
|
+
repo-id = "vhrabar/morphottention"
|
|
9
|
+
|
|
10
|
+
[torch]
|
|
11
|
+
include = ["csrc", "torch-ext"]
|
|
12
|
+
src = [
|
|
13
|
+
"csrc/cuda/binder.cpp",
|
|
14
|
+
"csrc/cuda/dispatch.cpp",
|
|
15
|
+
"csrc/cuda/dispatch.h",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[kernel.morpho]
|
|
19
|
+
backend = "cuda"
|
|
20
|
+
cuda-capabilities = ["8.9", "9.0", "10.0", "12.0", "12.0a"]
|
|
21
|
+
depends = ["torch"]
|
|
22
|
+
include = ["csrc"]
|
|
23
|
+
src = [
|
|
24
|
+
"csrc/cuda/attention/attention.cpp",
|
|
25
|
+
"csrc/cuda/attention/attention.cu",
|
|
26
|
+
"csrc/cuda/attention/attention.cuh",
|
|
27
|
+
"csrc/cuda/dispatch.h",
|
|
28
|
+
"csrc/cuda/morfology/cube.cuh",
|
|
29
|
+
"csrc/cuda/morfology/soft_morph.cuh",
|
|
30
|
+
"csrc/cuda/sm120/matmul.cuh",
|
|
31
|
+
"csrc/cuda/sm120/project.cuh",
|
|
32
|
+
"csrc/cuda/sm120/smem.cuh",
|
|
33
|
+
"csrc/cuda/utils/declarations.cuh",
|
|
34
|
+
"csrc/cuda/utils/reductions.cuh",
|
|
35
|
+
"csrc/cuda/utils/smem.cuh",
|
|
36
|
+
"csrc/cuda/utils/utils.cuh",
|
|
37
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#ifndef MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
2
|
+
#define MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
3
|
+
|
|
4
|
+
// Compatibility shim for the local CMake / PyPI build.
|
|
5
|
+
|
|
6
|
+
#include <torch/library.h>
|
|
7
|
+
|
|
8
|
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
|
9
|
+
|
|
10
|
+
#define TORCH_EXTENSION_NAME morphottention
|
|
11
|
+
|
|
12
|
+
#define REGISTER_EXTENSION(NAME)
|
|
13
|
+
|
|
14
|
+
#endif // MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
#include "attention.cuh"
|
|
2
|
+
|
|
3
|
+
#include <cuda/dispatch.h>
|
|
4
|
+
|
|
5
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
6
|
+
#include <c10/cuda/CUDAGuard.h>
|
|
7
|
+
#include <torch/torch.h>
|
|
8
|
+
|
|
9
|
+
auto check = [](const torch::Tensor& t, const char* name) {
|
|
10
|
+
TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
|
|
11
|
+
TORCH_CHECK(t.scalar_type() == torch::kHalf, name, " must be float16");
|
|
12
|
+
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
|
|
16
|
+
const torch::Tensor& gate_q, const torch::Tensor& gate_k,
|
|
17
|
+
const torch::Tensor& W_V, const int64_t H, const int64_t cube_m,
|
|
18
|
+
const double scale, const bool causal) {
|
|
19
|
+
// checkers
|
|
20
|
+
check(X, "X");
|
|
21
|
+
check(W_phi, "W_phi");
|
|
22
|
+
check(gate_q, "gate_q");
|
|
23
|
+
check(gate_k, "gate_k");
|
|
24
|
+
check(W_V, "W_V");
|
|
25
|
+
|
|
26
|
+
// shape managment
|
|
27
|
+
TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
|
|
28
|
+
const int64_t B = X.size(0);
|
|
29
|
+
const int64_t N = X.size(1);
|
|
30
|
+
const int64_t D = X.size(2);
|
|
31
|
+
|
|
32
|
+
TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
|
|
33
|
+
const int64_t head_dim_v = D / H;
|
|
34
|
+
|
|
35
|
+
TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
|
|
36
|
+
TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
|
|
37
|
+
TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
|
|
38
|
+
TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
|
|
39
|
+
|
|
40
|
+
auto out = torch::empty_like(X);
|
|
41
|
+
auto lse = torch::empty({B * H, N}, X.options().dtype(torch::kFloat32));
|
|
42
|
+
|
|
43
|
+
// launcher
|
|
44
|
+
const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
|
|
45
|
+
|
|
46
|
+
attention_forward_kernel_launcher(reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
|
|
47
|
+
reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
|
|
48
|
+
reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
|
|
49
|
+
reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
|
|
50
|
+
reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
|
|
51
|
+
reinterpret_cast<__half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
|
|
52
|
+
static_cast<int>(B), static_cast<int>(N), static_cast<int>(D),
|
|
53
|
+
static_cast<int>(H), static_cast<int>(cube_m), static_cast<int>(head_dim_v),
|
|
54
|
+
static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
|
|
55
|
+
|
|
56
|
+
return {out, lse};
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
std::vector<torch::Tensor> morpho_backward(const torch::Tensor& grad_out, const torch::Tensor& X,
|
|
60
|
+
const torch::Tensor& W_phi, const torch::Tensor& gate_q,
|
|
61
|
+
const torch::Tensor& gate_k, const torch::Tensor& W_V,
|
|
62
|
+
const torch::Tensor& out, const torch::Tensor& lse, const int64_t H,
|
|
63
|
+
const int64_t cube_m, const double scale, const bool causal) {
|
|
64
|
+
// checkers
|
|
65
|
+
check(grad_out, "grad_out");
|
|
66
|
+
check(X, "X");
|
|
67
|
+
check(W_phi, "W_phi");
|
|
68
|
+
check(gate_q, "gate_q");
|
|
69
|
+
check(gate_k, "gate_k");
|
|
70
|
+
check(W_V, "W_V");
|
|
71
|
+
check(out, "out");
|
|
72
|
+
TORCH_CHECK(lse.is_cuda(), "lse must be a CUDA tensor");
|
|
73
|
+
TORCH_CHECK(lse.scalar_type() == torch::kFloat32, "lse must be float32");
|
|
74
|
+
TORCH_CHECK(lse.is_contiguous(), "lse must be contiguous");
|
|
75
|
+
|
|
76
|
+
// shape managment
|
|
77
|
+
TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
|
|
78
|
+
const int64_t B = X.size(0);
|
|
79
|
+
const int64_t N = X.size(1);
|
|
80
|
+
const int64_t D = X.size(2);
|
|
81
|
+
|
|
82
|
+
TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
|
|
83
|
+
const int64_t head_dim_v = D / H;
|
|
84
|
+
|
|
85
|
+
TORCH_CHECK(grad_out.sizes() == X.sizes(), "grad_out must match X shape [B, N, D]");
|
|
86
|
+
TORCH_CHECK(out.sizes() == X.sizes(), "out must match X shape [B, N, D]");
|
|
87
|
+
TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
|
|
88
|
+
TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
|
|
89
|
+
TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
|
|
90
|
+
TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
|
|
91
|
+
TORCH_CHECK(lse.dim() == 2 && lse.size(0) == B * H && lse.size(1) == N, "lse must be [B*H, N]");
|
|
92
|
+
|
|
93
|
+
// grad outputs (zero-initialised: the kernel accumulates into these via atomics)
|
|
94
|
+
auto dX = torch::zeros_like(X);
|
|
95
|
+
auto dW_phi = torch::zeros_like(W_phi);
|
|
96
|
+
auto d_gate_q = torch::zeros_like(gate_q);
|
|
97
|
+
auto d_gate_k = torch::zeros_like(gate_k);
|
|
98
|
+
auto dW_V = torch::zeros_like(W_V);
|
|
99
|
+
|
|
100
|
+
// launcher
|
|
101
|
+
const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
|
|
102
|
+
|
|
103
|
+
attention_backward_kernel_launcher(
|
|
104
|
+
reinterpret_cast<const __half*>(grad_out.data_ptr<at::Half>()),
|
|
105
|
+
reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
|
|
106
|
+
reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
|
|
107
|
+
reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
|
|
108
|
+
reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
|
|
109
|
+
reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
|
|
110
|
+
reinterpret_cast<const __half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
|
|
111
|
+
reinterpret_cast<__half*>(dX.data_ptr<at::Half>()), reinterpret_cast<__half*>(dW_phi.data_ptr<at::Half>()),
|
|
112
|
+
reinterpret_cast<__half*>(d_gate_q.data_ptr<at::Half>()),
|
|
113
|
+
reinterpret_cast<__half*>(d_gate_k.data_ptr<at::Half>()), reinterpret_cast<__half*>(dW_V.data_ptr<at::Half>()),
|
|
114
|
+
static_cast<int>(B), static_cast<int>(N), static_cast<int>(D), static_cast<int>(H), static_cast<int>(cube_m),
|
|
115
|
+
static_cast<int>(head_dim_v), static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
|
|
116
|
+
|
|
117
|
+
return {dX, dW_phi, d_gate_q, d_gate_k, dW_V};
|
|
118
|
+
}
|