morphottention 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphottention-0.2.0/PKG-INFO +130 -0
- morphottention-0.2.0/README.md +93 -0
- morphottention-0.2.0/csrc/cuda/attention/attention.cpp +118 -0
- morphottention-0.2.0/csrc/cuda/attention/attention.cu +576 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/attention/attention.cuh +7 -3
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/binder.cpp +3 -1
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/dispatch.cpp +5 -2
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/dispatch.h +9 -2
- morphottention-0.2.0/csrc/cuda/sm120/matmul.cuh +116 -0
- morphottention-0.2.0/csrc/cuda/sm120/project.cuh +105 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/pyproject.toml +7 -5
- morphottention-0.2.0/src/morphottention/_C.pyi +27 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/src/morphottention/autograd.py +20 -2
- morphottention-0.1.0/PKG-INFO +0 -62
- morphottention-0.1.0/README.md +0 -27
- morphottention-0.1.0/csrc/cuda/attention/attention.cpp +0 -71
- morphottention-0.1.0/csrc/cuda/attention/attention.cu +0 -235
- morphottention-0.1.0/csrc/cuda/sm120/matmul.cuh +0 -58
- morphottention-0.1.0/csrc/cuda/sm120/project.cuh +0 -51
- morphottention-0.1.0/src/morphottention/_C.pyi +0 -14
- {morphottention-0.1.0 → morphottention-0.2.0}/CMakeLists.txt +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/morfology/cube.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/morfology/soft_morph.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/sm120/smem.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/utils/declarations.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/utils/reductions.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/utils/smem.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/csrc/cuda/utils/utils.cuh +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/src/morphottention/__init__.py +0 -0
- {morphottention-0.1.0 → morphottention-0.2.0}/src/morphottention/py.typed +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: morphottention
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Mathematical Morphology-based self-attention module for PyTorch (CUDA) using Flash-style kernel fusion.
|
|
5
|
+
Keywords: attention,cuda,pytorch,transformer,morphology,flash-attention,ViT
|
|
6
|
+
Author-Email: Vedran Hrabar <vedran.hrabar@outlook.com>
|
|
7
|
+
License: MIT
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Environment :: GPU
|
|
10
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
11
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA :: 13
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
16
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
17
|
+
Classifier: Programming Language :: C++
|
|
18
|
+
Classifier: Programming Language :: Python :: 3
|
|
19
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
23
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
24
|
+
Classifier: Topic :: Scientific/Engineering
|
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
26
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
27
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
28
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
29
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
30
|
+
Classifier: Typing :: Typed
|
|
31
|
+
Project-URL: repository, https://github.com/vhrabar/morphottention
|
|
32
|
+
Project-URL: documentation, https://github.com/vhrabar/morphottention/wiki
|
|
33
|
+
Project-URL: Bug Tracker, https://github.com/vhrabar/morphottention/issues
|
|
34
|
+
Requires-Python: >=3.12
|
|
35
|
+
Requires-Dist: torch>=2.12
|
|
36
|
+
Description-Content-Type: text/markdown
|
|
37
|
+
|
|
38
|
+
# Morphottention
|
|
39
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
40
|
+
|
|
41
|
+
## Install
|
|
42
|
+
|
|
43
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
44
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
45
|
+
be installed in the environment.
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install morphottention
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
The package exposes an `nn.Module` (`MorphoAttention`), a functional entry point
|
|
54
|
+
(`morpho_attention`), and the raw autograd bridge (`MorphoAttentionFunction`).
|
|
55
|
+
All inputs must be CUDA tensors; the module defaults to `float16`.
|
|
56
|
+
|
|
57
|
+
### As an `nn.Module`
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
import torch
|
|
61
|
+
from morphottention import MorphoAttention
|
|
62
|
+
|
|
63
|
+
attn = MorphoAttention(
|
|
64
|
+
dim=256, # model dimension D
|
|
65
|
+
num_heads=8, # number of attention heads H
|
|
66
|
+
cube_m=16, # hypercube width per head
|
|
67
|
+
scale=1.0, # softmax temperature
|
|
68
|
+
causal=False, # casual masking flag
|
|
69
|
+
device="cuda",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
x = torch.randn(2, 128, 256, dtype=torch.float16, device="cuda") # (B, N, D)
|
|
73
|
+
out = attn(x) # (B, N, D)
|
|
74
|
+
out.sum().backward()
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
### Functional form
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
from morphottention import morpho_attention
|
|
81
|
+
|
|
82
|
+
out = morpho_attention(
|
|
83
|
+
x,
|
|
84
|
+
W_phi,
|
|
85
|
+
gate_q,
|
|
86
|
+
gate_k,
|
|
87
|
+
W_V,
|
|
88
|
+
num_heads=8, cube_m=16, scale=1.0,
|
|
89
|
+
causal=False,
|
|
90
|
+
)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### Raw autograd bridge
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
import torch
|
|
97
|
+
from morphottention import MorphoAttentionFunction
|
|
98
|
+
|
|
99
|
+
B, N, D, H, cube_m = 2, 128, 256, 8, 16
|
|
100
|
+
|
|
101
|
+
x = torch.randn(B, N, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
102
|
+
W_phi = torch.randn(D, H * cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
103
|
+
gate_q = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
104
|
+
gate_k = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
105
|
+
W_V = torch.randn(D, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
106
|
+
|
|
107
|
+
out = MorphoAttentionFunction.apply(
|
|
108
|
+
x, W_phi, gate_q, gate_k, W_V,
|
|
109
|
+
H, cube_m, 1.0, False, # num_heads, cube_m, scale, causal
|
|
110
|
+
) # (B, N, D)
|
|
111
|
+
out.sum().backward()
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
`W_phi` has shape `(D, H * cube_m)`, `W_V` has shape `(D, D)`, and `gate_q` /
|
|
115
|
+
`gate_k` each have shape `(H, cube_m)`.
|
|
116
|
+
|
|
117
|
+
## Building from source
|
|
118
|
+
|
|
119
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
uv sync --package morphottention --no-dev --group build
|
|
123
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## License
|
|
127
|
+
|
|
128
|
+
MIT
|
|
129
|
+
|
|
130
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Morphottention
|
|
2
|
+
Mathematical Morphology-based self-attention module for PyTorch using Flash-style kernel fusion.
|
|
3
|
+
|
|
4
|
+
## Install
|
|
5
|
+
|
|
6
|
+
Prebuilt wheels are published for CPython 3.14 on Linux (x86_64, aarch64) and
|
|
7
|
+
Windows (x86_64). A working CUDA-enabled PyTorch (`torch >= 2.12`) must already
|
|
8
|
+
be installed in the environment.
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install morphottention
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Usage
|
|
15
|
+
|
|
16
|
+
The package exposes an `nn.Module` (`MorphoAttention`), a functional entry point
|
|
17
|
+
(`morpho_attention`), and the raw autograd bridge (`MorphoAttentionFunction`).
|
|
18
|
+
All inputs must be CUDA tensors; the module defaults to `float16`.
|
|
19
|
+
|
|
20
|
+
### As an `nn.Module`
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import torch
|
|
24
|
+
from morphottention import MorphoAttention
|
|
25
|
+
|
|
26
|
+
attn = MorphoAttention(
|
|
27
|
+
dim=256, # model dimension D
|
|
28
|
+
num_heads=8, # number of attention heads H
|
|
29
|
+
cube_m=16, # hypercube width per head
|
|
30
|
+
scale=1.0, # softmax temperature
|
|
31
|
+
causal=False, # casual masking flag
|
|
32
|
+
device="cuda",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
x = torch.randn(2, 128, 256, dtype=torch.float16, device="cuda") # (B, N, D)
|
|
36
|
+
out = attn(x) # (B, N, D)
|
|
37
|
+
out.sum().backward()
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
### Functional form
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from morphottention import morpho_attention
|
|
44
|
+
|
|
45
|
+
out = morpho_attention(
|
|
46
|
+
x,
|
|
47
|
+
W_phi,
|
|
48
|
+
gate_q,
|
|
49
|
+
gate_k,
|
|
50
|
+
W_V,
|
|
51
|
+
num_heads=8, cube_m=16, scale=1.0,
|
|
52
|
+
causal=False,
|
|
53
|
+
)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
### Raw autograd bridge
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
import torch
|
|
60
|
+
from morphottention import MorphoAttentionFunction
|
|
61
|
+
|
|
62
|
+
B, N, D, H, cube_m = 2, 128, 256, 8, 16
|
|
63
|
+
|
|
64
|
+
x = torch.randn(B, N, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
65
|
+
W_phi = torch.randn(D, H * cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
66
|
+
gate_q = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
67
|
+
gate_k = torch.ones(H, cube_m, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
68
|
+
W_V = torch.randn(D, D, dtype=torch.float16, device="cuda", requires_grad=True)
|
|
69
|
+
|
|
70
|
+
out = MorphoAttentionFunction.apply(
|
|
71
|
+
x, W_phi, gate_q, gate_k, W_V,
|
|
72
|
+
H, cube_m, 1.0, False, # num_heads, cube_m, scale, causal
|
|
73
|
+
) # (B, N, D)
|
|
74
|
+
out.sum().backward()
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
`W_phi` has shape `(D, H * cube_m)`, `W_V` has shape `(D, D)`, and `gate_q` /
|
|
78
|
+
`gate_k` each have shape `(H, cube_m)`.
|
|
79
|
+
|
|
80
|
+
## Building from source
|
|
81
|
+
|
|
82
|
+
Requires the CUDA 13.X toolkit (`nvcc`) and a matching `torch` build:
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
uv sync --package morphottention --no-dev --group build
|
|
86
|
+
uv build --package morphottention --wheel --no-build-isolation
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
## License
|
|
90
|
+
|
|
91
|
+
MIT
|
|
92
|
+
|
|
93
|
+
Copyright © 2026 Vedran Hrabar.
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
#include "attention.cuh"
|
|
2
|
+
|
|
3
|
+
#include <cuda/dispatch.h>
|
|
4
|
+
|
|
5
|
+
#include <ATen/cuda/CUDAContext.h>
|
|
6
|
+
#include <c10/cuda/CUDAGuard.h>
|
|
7
|
+
#include <torch/extension.h>
|
|
8
|
+
|
|
9
|
+
auto check = [](const torch::Tensor& t, const char* name) {
|
|
10
|
+
TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor");
|
|
11
|
+
TORCH_CHECK(t.scalar_type() == torch::kHalf, name, " must be float16");
|
|
12
|
+
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
std::vector<torch::Tensor> morpho_forward(const torch::Tensor& X, const torch::Tensor& W_phi,
|
|
16
|
+
const torch::Tensor& gate_q, const torch::Tensor& gate_k,
|
|
17
|
+
const torch::Tensor& W_V, const int64_t H, const int64_t cube_m,
|
|
18
|
+
const double scale, const bool causal) {
|
|
19
|
+
// checkers
|
|
20
|
+
check(X, "X");
|
|
21
|
+
check(W_phi, "W_phi");
|
|
22
|
+
check(gate_q, "gate_q");
|
|
23
|
+
check(gate_k, "gate_k");
|
|
24
|
+
check(W_V, "W_V");
|
|
25
|
+
|
|
26
|
+
// shape managment
|
|
27
|
+
TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
|
|
28
|
+
const int64_t B = X.size(0);
|
|
29
|
+
const int64_t N = X.size(1);
|
|
30
|
+
const int64_t D = X.size(2);
|
|
31
|
+
|
|
32
|
+
TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
|
|
33
|
+
const int64_t head_dim_v = D / H;
|
|
34
|
+
|
|
35
|
+
TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
|
|
36
|
+
TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
|
|
37
|
+
TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
|
|
38
|
+
TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
|
|
39
|
+
|
|
40
|
+
auto out = torch::empty_like(X);
|
|
41
|
+
auto lse = torch::empty({B * H, N}, X.options().dtype(torch::kFloat32));
|
|
42
|
+
|
|
43
|
+
// launcher
|
|
44
|
+
const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
|
|
45
|
+
|
|
46
|
+
attention_forward_kernel_launcher(reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
|
|
47
|
+
reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
|
|
48
|
+
reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
|
|
49
|
+
reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
|
|
50
|
+
reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
|
|
51
|
+
reinterpret_cast<__half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
|
|
52
|
+
static_cast<int>(B), static_cast<int>(N), static_cast<int>(D),
|
|
53
|
+
static_cast<int>(H), static_cast<int>(cube_m), static_cast<int>(head_dim_v),
|
|
54
|
+
static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
|
|
55
|
+
|
|
56
|
+
return {out, lse};
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
std::vector<torch::Tensor> morpho_backward(const torch::Tensor& grad_out, const torch::Tensor& X,
|
|
60
|
+
const torch::Tensor& W_phi, const torch::Tensor& gate_q,
|
|
61
|
+
const torch::Tensor& gate_k, const torch::Tensor& W_V,
|
|
62
|
+
const torch::Tensor& out, const torch::Tensor& lse, const int64_t H,
|
|
63
|
+
const int64_t cube_m, const double scale, const bool causal) {
|
|
64
|
+
// checkers
|
|
65
|
+
check(grad_out, "grad_out");
|
|
66
|
+
check(X, "X");
|
|
67
|
+
check(W_phi, "W_phi");
|
|
68
|
+
check(gate_q, "gate_q");
|
|
69
|
+
check(gate_k, "gate_k");
|
|
70
|
+
check(W_V, "W_V");
|
|
71
|
+
check(out, "out");
|
|
72
|
+
TORCH_CHECK(lse.is_cuda(), "lse must be a CUDA tensor");
|
|
73
|
+
TORCH_CHECK(lse.scalar_type() == torch::kFloat32, "lse must be float32");
|
|
74
|
+
TORCH_CHECK(lse.is_contiguous(), "lse must be contiguous");
|
|
75
|
+
|
|
76
|
+
// shape managment
|
|
77
|
+
TORCH_CHECK(X.dim() == 3, "X must be [B, N, D]");
|
|
78
|
+
const int64_t B = X.size(0);
|
|
79
|
+
const int64_t N = X.size(1);
|
|
80
|
+
const int64_t D = X.size(2);
|
|
81
|
+
|
|
82
|
+
TORCH_CHECK(H > 0 && D % H == 0, "D must be divisible by H");
|
|
83
|
+
const int64_t head_dim_v = D / H;
|
|
84
|
+
|
|
85
|
+
TORCH_CHECK(grad_out.sizes() == X.sizes(), "grad_out must match X shape [B, N, D]");
|
|
86
|
+
TORCH_CHECK(out.sizes() == X.sizes(), "out must match X shape [B, N, D]");
|
|
87
|
+
TORCH_CHECK(W_phi.dim() == 2 && W_phi.size(0) == D && W_phi.size(1) == H * cube_m, "W_phi must be [D, H*cube_m]");
|
|
88
|
+
TORCH_CHECK(W_V.dim() == 2 && W_V.size(0) == D && W_V.size(1) == H * head_dim_v, "W_V must be [D, H*head_dim_v]");
|
|
89
|
+
TORCH_CHECK(gate_q.dim() == 2 && gate_q.size(0) == H && gate_q.size(1) == cube_m, "gate_q must be [H, cube_m]");
|
|
90
|
+
TORCH_CHECK(gate_k.dim() == 2 && gate_k.size(0) == H && gate_k.size(1) == cube_m, "gate_k must be [H, cube_m]");
|
|
91
|
+
TORCH_CHECK(lse.dim() == 2 && lse.size(0) == B * H && lse.size(1) == N, "lse must be [B*H, N]");
|
|
92
|
+
|
|
93
|
+
// grad outputs (zero-initialised: the kernel accumulates into these via atomics)
|
|
94
|
+
auto dX = torch::zeros_like(X);
|
|
95
|
+
auto dW_phi = torch::zeros_like(W_phi);
|
|
96
|
+
auto d_gate_q = torch::zeros_like(gate_q);
|
|
97
|
+
auto d_gate_k = torch::zeros_like(gate_k);
|
|
98
|
+
auto dW_V = torch::zeros_like(W_V);
|
|
99
|
+
|
|
100
|
+
// launcher
|
|
101
|
+
const c10::cuda::CUDAStreamGuard guard(c10::cuda::getCurrentCUDAStream());
|
|
102
|
+
|
|
103
|
+
attention_backward_kernel_launcher(
|
|
104
|
+
reinterpret_cast<const __half*>(grad_out.data_ptr<at::Half>()),
|
|
105
|
+
reinterpret_cast<const __half*>(X.data_ptr<at::Half>()),
|
|
106
|
+
reinterpret_cast<const __half*>(W_phi.data_ptr<at::Half>()),
|
|
107
|
+
reinterpret_cast<const __half*>(gate_q.data_ptr<at::Half>()),
|
|
108
|
+
reinterpret_cast<const __half*>(gate_k.data_ptr<at::Half>()),
|
|
109
|
+
reinterpret_cast<const __half*>(W_V.data_ptr<at::Half>()),
|
|
110
|
+
reinterpret_cast<const __half*>(out.data_ptr<at::Half>()), lse.data_ptr<float>(),
|
|
111
|
+
reinterpret_cast<__half*>(dX.data_ptr<at::Half>()), reinterpret_cast<__half*>(dW_phi.data_ptr<at::Half>()),
|
|
112
|
+
reinterpret_cast<__half*>(d_gate_q.data_ptr<at::Half>()),
|
|
113
|
+
reinterpret_cast<__half*>(d_gate_k.data_ptr<at::Half>()), reinterpret_cast<__half*>(dW_V.data_ptr<at::Half>()),
|
|
114
|
+
static_cast<int>(B), static_cast<int>(N), static_cast<int>(D), static_cast<int>(H), static_cast<int>(cube_m),
|
|
115
|
+
static_cast<int>(head_dim_v), static_cast<float>(scale), c10::cuda::getCurrentCUDAStream());
|
|
116
|
+
|
|
117
|
+
return {dX, dW_phi, d_gate_q, d_gate_k, dW_V};
|
|
118
|
+
}
|