morphottention 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {morphottention-0.2.0 → morphottention-0.2.1}/CMakeLists.txt +5 -18
- {morphottention-0.2.0 → morphottention-0.2.1}/PKG-INFO +1 -1
- morphottention-0.2.1/build.toml +37 -0
- morphottention-0.2.1/csrc/compat/registration.h +14 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/attention/attention.cpp +1 -1
- morphottention-0.2.1/csrc/cuda/binder.cpp +17 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/dispatch.h +1 -1
- morphottention-0.2.1/flake.lock +117 -0
- morphottention-0.2.1/flake.nix +17 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/pyproject.toml +2 -2
- morphottention-0.2.1/torch-ext/morphottention/_cmake_ops.py +34 -0
- {morphottention-0.2.0/src → morphottention-0.2.1/torch-ext}/morphottention/autograd.py +25 -4
- morphottention-0.2.0/csrc/cuda/binder.cpp +0 -14
- morphottention-0.2.0/src/morphottention/_C.pyi +0 -27
- {morphottention-0.2.0 → morphottention-0.2.1}/README.md +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/attention/attention.cu +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/attention/attention.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/dispatch.cpp +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/morfology/cube.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/morfology/soft_morph.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/sm120/matmul.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/sm120/project.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/sm120/smem.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/utils/declarations.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/utils/reductions.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/utils/smem.cuh +0 -0
- {morphottention-0.2.0 → morphottention-0.2.1}/csrc/cuda/utils/utils.cuh +0 -0
- {morphottention-0.2.0/src → morphottention-0.2.1/torch-ext}/morphottention/__init__.py +0 -0
- {morphottention-0.2.0/src → morphottention-0.2.1/torch-ext}/morphottention/py.typed +0 -0
|
@@ -41,21 +41,14 @@ list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX}")
|
|
|
41
41
|
find_package(Torch REQUIRED CONFIG)
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
if(WIN32)
|
|
45
|
-
set(_TORCH_PYTHON_LIB_NAME "torch_python.lib")
|
|
46
|
-
else()
|
|
47
|
-
set(_TORCH_PYTHON_LIB_NAME "libtorch_python.so")
|
|
48
|
-
endif()
|
|
49
|
-
|
|
50
44
|
execute_process(
|
|
51
|
-
COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib'
|
|
52
|
-
OUTPUT_VARIABLE
|
|
45
|
+
COMMAND ${Python_EXECUTABLE} -c "import pathlib, torch; print(pathlib.Path(torch.__file__).resolve().parent / 'lib')"
|
|
46
|
+
OUTPUT_VARIABLE TORCH_LIB_DIR
|
|
53
47
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
54
48
|
)
|
|
55
|
-
if(NOT EXISTS "${
|
|
56
|
-
message(FATAL_ERROR "Could not locate
|
|
49
|
+
if(NOT EXISTS "${TORCH_LIB_DIR}")
|
|
50
|
+
message(FATAL_ERROR "Could not locate torch lib dir at: ${TORCH_LIB_DIR}")
|
|
57
51
|
endif()
|
|
58
|
-
get_filename_component(TORCH_LIB_DIR "${TORCH_PYTHON_LIBRARY}" DIRECTORY)
|
|
59
52
|
|
|
60
53
|
find_package(CUDAToolkit REQUIRED)
|
|
61
54
|
|
|
@@ -85,25 +78,19 @@ endif()
|
|
|
85
78
|
|
|
86
79
|
target_include_directories(_C PRIVATE
|
|
87
80
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc
|
|
81
|
+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/compat
|
|
88
82
|
${TORCH_INCLUDE_DIRS}
|
|
89
83
|
${Python_INCLUDE_DIRS}
|
|
90
84
|
${CUDAToolkit_INCLUDE_DIRS}
|
|
91
85
|
)
|
|
92
86
|
|
|
93
|
-
target_compile_definitions(_C PRIVATE
|
|
94
|
-
TORCH_EXTENSION_NAME=_C
|
|
95
|
-
)
|
|
96
|
-
|
|
97
87
|
target_link_libraries(_C PRIVATE
|
|
98
88
|
${TORCH_LIBRARIES}
|
|
99
|
-
${TORCH_PYTHON_LIBRARY}
|
|
100
|
-
Python::Module
|
|
101
89
|
CUDA::cudart
|
|
102
90
|
CUDA::cublas
|
|
103
91
|
)
|
|
104
92
|
|
|
105
93
|
if(NOT WIN32)
|
|
106
|
-
target_link_options(_C PRIVATE "-Wl,--no-as-needed")
|
|
107
94
|
set_target_properties(_C PROPERTIES
|
|
108
95
|
BUILD_RPATH "${TORCH_LIB_DIR}"
|
|
109
96
|
INSTALL_RPATH "$ORIGIN/../torch/lib"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: morphottention
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Mathematical Morphology-based self-attention module for PyTorch (CUDA) using Flash-style kernel fusion.
|
|
5
5
|
Keywords: attention,cuda,pytorch,transformer,morphology,flash-attention,ViT
|
|
6
6
|
Author-Email: Vedran Hrabar <vedran.hrabar@outlook.com>
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[general]
|
|
2
|
+
name = "morphottention"
|
|
3
|
+
backends = ["cuda"]
|
|
4
|
+
license = "MIT"
|
|
5
|
+
version = 0
|
|
6
|
+
|
|
7
|
+
[general.hub]
|
|
8
|
+
repo-id = "vhrabar/morphottention"
|
|
9
|
+
|
|
10
|
+
[torch]
|
|
11
|
+
include = ["csrc", "torch-ext"]
|
|
12
|
+
src = [
|
|
13
|
+
"csrc/cuda/binder.cpp",
|
|
14
|
+
"csrc/cuda/dispatch.cpp",
|
|
15
|
+
"csrc/cuda/dispatch.h",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[kernel.morpho]
|
|
19
|
+
backend = "cuda"
|
|
20
|
+
cuda-capabilities = ["8.9", "9.0", "10.0", "12.0", "12.0a"]
|
|
21
|
+
depends = ["torch"]
|
|
22
|
+
include = ["csrc"]
|
|
23
|
+
src = [
|
|
24
|
+
"csrc/cuda/attention/attention.cpp",
|
|
25
|
+
"csrc/cuda/attention/attention.cu",
|
|
26
|
+
"csrc/cuda/attention/attention.cuh",
|
|
27
|
+
"csrc/cuda/dispatch.h",
|
|
28
|
+
"csrc/cuda/morfology/cube.cuh",
|
|
29
|
+
"csrc/cuda/morfology/soft_morph.cuh",
|
|
30
|
+
"csrc/cuda/sm120/matmul.cuh",
|
|
31
|
+
"csrc/cuda/sm120/project.cuh",
|
|
32
|
+
"csrc/cuda/sm120/smem.cuh",
|
|
33
|
+
"csrc/cuda/utils/declarations.cuh",
|
|
34
|
+
"csrc/cuda/utils/reductions.cuh",
|
|
35
|
+
"csrc/cuda/utils/smem.cuh",
|
|
36
|
+
"csrc/cuda/utils/utils.cuh",
|
|
37
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#ifndef MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
2
|
+
#define MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
3
|
+
|
|
4
|
+
// Compatibility shim for the local CMake / PyPI build.
|
|
5
|
+
|
|
6
|
+
#include <torch/library.h>
|
|
7
|
+
|
|
8
|
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
|
9
|
+
|
|
10
|
+
#define TORCH_EXTENSION_NAME morphottention
|
|
11
|
+
|
|
12
|
+
#define REGISTER_EXTENSION(NAME)
|
|
13
|
+
|
|
14
|
+
#endif // MORPHOTTENTION_COMPAT_REGISTRATION_H
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#include "dispatch.h"
|
|
2
|
+
#include "registration.h"
|
|
3
|
+
|
|
4
|
+
#include <torch/library.h>
|
|
5
|
+
|
|
6
|
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|
7
|
+
m.def("forward(Tensor X, Tensor W_phi, Tensor gate_q, Tensor gate_k, Tensor W_V, "
|
|
8
|
+
"int H, int cube_m, float scale, bool causal) -> Tensor[]");
|
|
9
|
+
|
|
10
|
+
m.def("backward(Tensor grad_out, Tensor X, Tensor W_phi, Tensor gate_q, Tensor gate_k, "
|
|
11
|
+
"Tensor W_V, Tensor out, Tensor lse, int H, int cube_m, float scale, bool causal) -> Tensor[]");
|
|
12
|
+
|
|
13
|
+
m.impl("forward", torch::kCUDA, &forward);
|
|
14
|
+
m.impl("backward", torch::kCUDA, &backward);
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
#ifndef MORPHOTTENTION_DISPATCH_H
|
|
2
2
|
#define MORPHOTTENTION_DISPATCH_H
|
|
3
3
|
|
|
4
|
-
#include <torch/
|
|
4
|
+
#include <torch/torch.h>
|
|
5
5
|
|
|
6
6
|
// py-facing dispatchers
|
|
7
7
|
std::vector<torch::Tensor> forward(const torch::Tensor& X, const torch::Tensor& W_phi, const torch::Tensor& gate_q,
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
{
|
|
2
|
+
"nodes": {
|
|
3
|
+
"flake-compat": {
|
|
4
|
+
"locked": {
|
|
5
|
+
"lastModified": 1767039857,
|
|
6
|
+
"narHash": "sha256-vNpUSpF5Nuw8xvDLj2KCwwksIbjua2LZCqhV1LNRDns=",
|
|
7
|
+
"owner": "edolstra",
|
|
8
|
+
"repo": "flake-compat",
|
|
9
|
+
"rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab",
|
|
10
|
+
"type": "github"
|
|
11
|
+
},
|
|
12
|
+
"original": {
|
|
13
|
+
"owner": "edolstra",
|
|
14
|
+
"repo": "flake-compat",
|
|
15
|
+
"type": "github"
|
|
16
|
+
}
|
|
17
|
+
},
|
|
18
|
+
"flake-utils": {
|
|
19
|
+
"inputs": {
|
|
20
|
+
"systems": "systems"
|
|
21
|
+
},
|
|
22
|
+
"locked": {
|
|
23
|
+
"lastModified": 1731533236,
|
|
24
|
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
|
25
|
+
"owner": "numtide",
|
|
26
|
+
"repo": "flake-utils",
|
|
27
|
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
|
28
|
+
"type": "github"
|
|
29
|
+
},
|
|
30
|
+
"original": {
|
|
31
|
+
"owner": "numtide",
|
|
32
|
+
"repo": "flake-utils",
|
|
33
|
+
"type": "github"
|
|
34
|
+
}
|
|
35
|
+
},
|
|
36
|
+
"kernel-builder": {
|
|
37
|
+
"inputs": {
|
|
38
|
+
"flake-compat": "flake-compat",
|
|
39
|
+
"flake-utils": "flake-utils",
|
|
40
|
+
"nixpkgs": "nixpkgs",
|
|
41
|
+
"rust-overlay": "rust-overlay"
|
|
42
|
+
},
|
|
43
|
+
"locked": {
|
|
44
|
+
"lastModified": 1783078249,
|
|
45
|
+
"narHash": "sha256-WjIYeFTWncfrj6w2t9PxYrCpqCQa2MIyQIaQncH7XvA=",
|
|
46
|
+
"owner": "huggingface",
|
|
47
|
+
"repo": "kernels",
|
|
48
|
+
"rev": "b9710edf6436d3949e50085938a3c49e626ee885",
|
|
49
|
+
"type": "github"
|
|
50
|
+
},
|
|
51
|
+
"original": {
|
|
52
|
+
"owner": "huggingface",
|
|
53
|
+
"repo": "kernels",
|
|
54
|
+
"type": "github"
|
|
55
|
+
}
|
|
56
|
+
},
|
|
57
|
+
"nixpkgs": {
|
|
58
|
+
"locked": {
|
|
59
|
+
"lastModified": 1776927958,
|
|
60
|
+
"narHash": "sha256-XOzEtft7E0P6TgQViLUOQeGHlEYiQ0+FY24BPEksj6s=",
|
|
61
|
+
"owner": "NixOS",
|
|
62
|
+
"repo": "nixpkgs",
|
|
63
|
+
"rev": "fec2c46cca5bf9767486a290abae51200b656d69",
|
|
64
|
+
"type": "github"
|
|
65
|
+
},
|
|
66
|
+
"original": {
|
|
67
|
+
"owner": "NixOS",
|
|
68
|
+
"ref": "nixos-unstable-small",
|
|
69
|
+
"repo": "nixpkgs",
|
|
70
|
+
"type": "github"
|
|
71
|
+
}
|
|
72
|
+
},
|
|
73
|
+
"root": {
|
|
74
|
+
"inputs": {
|
|
75
|
+
"kernel-builder": "kernel-builder"
|
|
76
|
+
}
|
|
77
|
+
},
|
|
78
|
+
"rust-overlay": {
|
|
79
|
+
"inputs": {
|
|
80
|
+
"nixpkgs": [
|
|
81
|
+
"kernel-builder",
|
|
82
|
+
"nixpkgs"
|
|
83
|
+
]
|
|
84
|
+
},
|
|
85
|
+
"locked": {
|
|
86
|
+
"lastModified": 1776914043,
|
|
87
|
+
"narHash": "sha256-qug5r56yW1qOsjSI99l3Jm15JNT9CvS2otkXNRNtrPI=",
|
|
88
|
+
"owner": "oxalica",
|
|
89
|
+
"repo": "rust-overlay",
|
|
90
|
+
"rev": "2d35c4358d7de3a0e606a6e8b27925d981c01cc3",
|
|
91
|
+
"type": "github"
|
|
92
|
+
},
|
|
93
|
+
"original": {
|
|
94
|
+
"owner": "oxalica",
|
|
95
|
+
"repo": "rust-overlay",
|
|
96
|
+
"type": "github"
|
|
97
|
+
}
|
|
98
|
+
},
|
|
99
|
+
"systems": {
|
|
100
|
+
"locked": {
|
|
101
|
+
"lastModified": 1681028828,
|
|
102
|
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
103
|
+
"owner": "nix-systems",
|
|
104
|
+
"repo": "default",
|
|
105
|
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
106
|
+
"type": "github"
|
|
107
|
+
},
|
|
108
|
+
"original": {
|
|
109
|
+
"owner": "nix-systems",
|
|
110
|
+
"repo": "default",
|
|
111
|
+
"type": "github"
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
},
|
|
115
|
+
"root": "root",
|
|
116
|
+
"version": 7
|
|
117
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
{
|
|
2
|
+
description = "Flake for the Morphottention CUDA kernel";
|
|
3
|
+
|
|
4
|
+
inputs = {
|
|
5
|
+
kernel-builder.url = "github:huggingface/kernels";
|
|
6
|
+
};
|
|
7
|
+
|
|
8
|
+
outputs =
|
|
9
|
+
{
|
|
10
|
+
self,
|
|
11
|
+
kernel-builder,
|
|
12
|
+
}:
|
|
13
|
+
kernel-builder.lib.genKernelFlakeOutputs {
|
|
14
|
+
inherit self;
|
|
15
|
+
path = ./.;
|
|
16
|
+
};
|
|
17
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "morphottention"
|
|
3
|
-
version = "0.2.
|
|
3
|
+
version = "0.2.1"
|
|
4
4
|
description = "Mathematical Morphology-based self-attention module for PyTorch (CUDA) using Flash-style kernel fusion."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -78,5 +78,5 @@ minimum-version = "build-system.requires"
|
|
|
78
78
|
cmake.version = ">=3.24"
|
|
79
79
|
build-dir = "build/{wheel_tag}"
|
|
80
80
|
|
|
81
|
-
wheel.packages = ["
|
|
81
|
+
wheel.packages = ["torch-ext/morphottention"]
|
|
82
82
|
editable.rebuild = true
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Op-namespace loader for the local CMake / PyPI build.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import glob
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
_NAMESPACE = "morphottention"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_c_extension() -> None:
|
|
16
|
+
pkg_dir = os.path.dirname(__file__)
|
|
17
|
+
for pattern in ("_C*.so", "_C*.pyd", "_C*.dll"):
|
|
18
|
+
matches = glob.glob(os.path.join(pkg_dir, pattern))
|
|
19
|
+
if matches:
|
|
20
|
+
torch.ops.load_library(matches[0]) # type: ignore[no-untyped-call]
|
|
21
|
+
return
|
|
22
|
+
raise ImportError(
|
|
23
|
+
f"Could not find the compiled morphottention '_C' extension in {pkg_dir}. "
|
|
24
|
+
"Reinstall the package so the CUDA kernels are built."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_load_c_extension()
|
|
29
|
+
|
|
30
|
+
ops = getattr(torch.ops, _NAMESPACE)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def add_op_namespace_prefix(op_name: str) -> str:
|
|
34
|
+
return f"{_NAMESPACE}::{op_name}"
|
|
@@ -7,7 +7,28 @@ from __future__ import annotations
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch import nn
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
try:
|
|
11
|
+
from ._ops import add_op_namespace_prefix, ops # type: ignore[import-not-found]
|
|
12
|
+
except ImportError:
|
|
13
|
+
from ._cmake_ops import add_op_namespace_prefix, ops
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@torch.library.register_fake(add_op_namespace_prefix("forward")) # type: ignore[untyped-decorator]
|
|
17
|
+
def _morpho_forward_fake(
|
|
18
|
+
X: torch.Tensor,
|
|
19
|
+
W_phi: torch.Tensor,
|
|
20
|
+
gate_q: torch.Tensor,
|
|
21
|
+
gate_k: torch.Tensor,
|
|
22
|
+
W_V: torch.Tensor,
|
|
23
|
+
H: int,
|
|
24
|
+
cube_m: int,
|
|
25
|
+
scale: float,
|
|
26
|
+
causal: bool,
|
|
27
|
+
) -> list[torch.Tensor]:
|
|
28
|
+
B, N, _D = X.shape
|
|
29
|
+
out = torch.empty_like(X)
|
|
30
|
+
lse = X.new_empty((B * H, N), dtype=torch.float32)
|
|
31
|
+
return [out, lse]
|
|
11
32
|
|
|
12
33
|
|
|
13
34
|
class MorphoAttentionFunction(torch.autograd.Function):
|
|
@@ -32,14 +53,14 @@ class MorphoAttentionFunction(torch.autograd.Function):
|
|
|
32
53
|
raise ValueError("MorphoAttention expects a CUDA tensor")
|
|
33
54
|
|
|
34
55
|
x = x.contiguous()
|
|
35
|
-
out, lse =
|
|
56
|
+
out, lse = ops.forward(x, W_phi, gate_q, gate_k, W_V, H, cube_m, scale, causal)
|
|
36
57
|
|
|
37
58
|
ctx.save_for_backward(x, W_phi, gate_q, gate_k, W_V, out, lse)
|
|
38
59
|
ctx.H = H # type: ignore[attr-defined]
|
|
39
60
|
ctx.cube_m = cube_m # type: ignore[attr-defined]
|
|
40
61
|
ctx.scale = scale # type: ignore[attr-defined]
|
|
41
62
|
ctx.causal = causal # type: ignore[attr-defined]
|
|
42
|
-
return out
|
|
63
|
+
return out # type: ignore[no-any-return]
|
|
43
64
|
|
|
44
65
|
@staticmethod
|
|
45
66
|
def backward(
|
|
@@ -49,7 +70,7 @@ class MorphoAttentionFunction(torch.autograd.Function):
|
|
|
49
70
|
x, W_phi, gate_q, gate_k, W_V, out, lse = ctx.saved_tensors # type: ignore[attr-defined]
|
|
50
71
|
|
|
51
72
|
grad_out = grad_out.contiguous()
|
|
52
|
-
dX, dW_phi, d_gate_q, d_gate_k, dW_V =
|
|
73
|
+
dX, dW_phi, d_gate_q, d_gate_k, dW_V = ops.backward(
|
|
53
74
|
grad_out,
|
|
54
75
|
x,
|
|
55
76
|
W_phi,
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
#include "dispatch.h"
|
|
2
|
-
|
|
3
|
-
#include <torch/extension.h>
|
|
4
|
-
|
|
5
|
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
6
|
-
m.doc() = "Morphottention CUDA attention kernels";
|
|
7
|
-
|
|
8
|
-
m.def("forward", &forward, "Attention forward dispatcher", py::arg("X"), py::arg("W_phi"), py::arg("gate_q"),
|
|
9
|
-
py::arg("gate_k"), py::arg("W_V"), py::arg("H"), py::arg("cube_m"), py::arg("scale"), py::arg("causal"));
|
|
10
|
-
|
|
11
|
-
m.def("backward", &backward, "Attention backward dispatcher", py::arg("grad_out"), py::arg("X"), py::arg("W_phi"),
|
|
12
|
-
py::arg("gate_q"), py::arg("gate_k"), py::arg("W_V"), py::arg("out"), py::arg("lse"), py::arg("H"),
|
|
13
|
-
py::arg("cube_m"), py::arg("scale"), py::arg("causal"));
|
|
14
|
-
}
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
def forward(
|
|
4
|
-
X: torch.Tensor,
|
|
5
|
-
W_phi: torch.Tensor,
|
|
6
|
-
gate_q: torch.Tensor,
|
|
7
|
-
gate_k: torch.Tensor,
|
|
8
|
-
W_V: torch.Tensor,
|
|
9
|
-
H: int,
|
|
10
|
-
cube_m: int,
|
|
11
|
-
scale: float,
|
|
12
|
-
causal: bool,
|
|
13
|
-
) -> list[torch.Tensor]: ...
|
|
14
|
-
def backward(
|
|
15
|
-
grad_out: torch.Tensor,
|
|
16
|
-
X: torch.Tensor,
|
|
17
|
-
W_phi: torch.Tensor,
|
|
18
|
-
gate_q: torch.Tensor,
|
|
19
|
-
gate_k: torch.Tensor,
|
|
20
|
-
W_V: torch.Tensor,
|
|
21
|
-
out: torch.Tensor,
|
|
22
|
-
lse: torch.Tensor,
|
|
23
|
-
H: int,
|
|
24
|
-
cube_m: int,
|
|
25
|
-
scale: float,
|
|
26
|
-
causal: bool,
|
|
27
|
-
) -> list[torch.Tensor]: ...
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|