ninetoothed 0.17.0__tar.gz → 0.18.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/PKG-INFO +1 -1
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/pyproject.toml +1 -1
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/aot.py +61 -10
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/generation.py +12 -9
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/tensor.py +17 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_addmm.py +15 -3
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_aot.py +87 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_attention.py +53 -37
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.gitattributes +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/ISSUE_TEMPLATE/bug-report.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/ISSUE_TEMPLATE/feature-request.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/pull_request_template.md +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/workflows/publish-to-pypi.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.github/workflows/sphinx.yml +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/.gitignore +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/LICENSE +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/README.md +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/Makefile +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/README.zh.md +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/make.bat +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/requirements.txt +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/_static/matmul-tiling.png +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/_static/ninetoothed-logo.png +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/basics.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/code_generation.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/conf.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/index.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/installation.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/python_api.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/symbol.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/tensor.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/visualization.rst +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/docs/source/visualize.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/requirements.txt +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/__init__.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/cudaifier.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/dtype.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/jit.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/make.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/naming.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/symbol.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/utils.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/src/ninetoothed/visualization.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/__init__.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/skippers.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_add.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_conv2d.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_dropout.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_matmul.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_max_pool2d.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_naming.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_pow.py +0 -0
- {ninetoothed-0.17.0 → ninetoothed-0.18.0}/tests/test_softmax.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.18.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.18.0"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import ast
|
2
2
|
import pathlib
|
3
|
+
import re
|
3
4
|
import subprocess
|
4
5
|
import tempfile
|
5
6
|
import uuid
|
6
7
|
|
8
|
+
import ninetoothed.naming as naming
|
7
9
|
from ninetoothed.dtype import int64
|
8
10
|
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
11
|
from ninetoothed.tensor import Tensor
|
@@ -25,8 +27,10 @@ def aot(
|
|
25
27
|
|
26
28
|
def _aot(func, caller, kernel_name, num_warps, num_stages):
|
27
29
|
def _find_tensor_by_source_name(tensors, name):
|
30
|
+
name = naming.remove_prefixes(name)
|
31
|
+
|
28
32
|
for tensor in tensors:
|
29
|
-
if tensor.source.name == name:
|
33
|
+
if naming.remove_prefixes(tensor.source.name) == name:
|
30
34
|
return tensor
|
31
35
|
|
32
36
|
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
@@ -49,11 +53,15 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
49
53
|
kernel_func = code_generator.kernel_func
|
50
54
|
launch_func = code_generator.launch_func
|
51
55
|
|
56
|
+
param_strings = ["stream"]
|
52
57
|
param_types = []
|
58
|
+
constexpr_param_indices = []
|
53
59
|
|
54
60
|
for arg in kernel_func.args.args:
|
55
61
|
param = arg.arg
|
56
62
|
|
63
|
+
param_strings.append(param)
|
64
|
+
|
57
65
|
if match := Tensor.pointer_pattern().fullmatch(param):
|
58
66
|
source_name = match.group(0).removesuffix("_pointer")
|
59
67
|
tensor = _find_tensor_by_source_name(tensors, source_name)
|
@@ -64,9 +72,23 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
64
72
|
param_types.append(int64)
|
65
73
|
elif Tensor.stride_pattern().fullmatch(param):
|
66
74
|
param_types.append(int64)
|
75
|
+
else:
|
76
|
+
source_name = param
|
77
|
+
tensor = _find_tensor_by_source_name(tensors, source_name)
|
78
|
+
dtype = tensor.source.dtype
|
79
|
+
|
80
|
+
if tensor.constexpr:
|
81
|
+
param_types.append(f"{tensor.value}")
|
82
|
+
constexpr_param_indices.append(len(param_types) - 1)
|
83
|
+
else:
|
84
|
+
param_types.append(dtype)
|
67
85
|
|
68
86
|
signature = ", ".join(param_types)
|
69
87
|
|
88
|
+
for index in sorted(set(constexpr_param_indices), reverse=True):
|
89
|
+
param_strings.pop(index + 1)
|
90
|
+
param_types.pop(index)
|
91
|
+
|
70
92
|
grid_extractor = _GridExtractor()
|
71
93
|
launch_func = grid_extractor.visit(launch_func)
|
72
94
|
grid_extractor.visit(code_generator.raw_grid)
|
@@ -76,21 +98,26 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
76
98
|
source_file, kernel_name, signature, grid, num_warps, num_stages
|
77
99
|
)
|
78
100
|
|
79
|
-
|
101
|
+
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
102
|
+
c_source_file = output_contents[c_source_file_name]
|
103
|
+
|
104
|
+
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
105
|
+
c_header_file = output_contents[c_header_file_name]
|
106
|
+
|
107
|
+
pattern = rf"\({', '.join(rf'(.*) {param}' for param in param_strings)}\)"
|
108
|
+
c_param_type_strings = re.search(pattern, c_header_file).groups()
|
109
|
+
|
110
|
+
unparser = _Unparser(c_param_type_strings)
|
80
111
|
|
81
112
|
launch_func_unparsed = unparser.unparse(launch_func)
|
82
113
|
launch_func_unparsed = launch_func_unparsed.replace(
|
83
114
|
func.__name__, f"{kernel_name}_{signature_hash}"
|
84
115
|
)
|
85
116
|
|
86
|
-
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
87
|
-
c_source_file = output_contents[c_source_file_name]
|
88
117
|
c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
|
89
118
|
c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
90
119
|
output_contents[c_source_file_name] = c_source_file
|
91
120
|
|
92
|
-
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
93
|
-
c_header_file = output_contents[c_header_file_name]
|
94
121
|
c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
|
95
122
|
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
96
123
|
output_contents[c_header_file_name] = c_header_file
|
@@ -120,6 +147,9 @@ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
|
120
147
|
|
121
148
|
|
122
149
|
class _Unparser:
|
150
|
+
def __init__(self, param_types):
|
151
|
+
self._param_types = param_types
|
152
|
+
|
123
153
|
def unparse(self, node):
|
124
154
|
method_name = "_unparse_" + node.__class__.__name__
|
125
155
|
|
@@ -137,11 +167,29 @@ class _Unparser:
|
|
137
167
|
def _unparse_Call(self, node):
|
138
168
|
call = ast.Call(
|
139
169
|
func=node.func,
|
140
|
-
args=[ast.Name(id="stream", ctx=ast.Load())]
|
170
|
+
args=[ast.Name(id="stream", ctx=ast.Load())]
|
171
|
+
+ [
|
172
|
+
arg
|
173
|
+
for arg in node.args
|
174
|
+
if not isinstance(arg, ast.Name) or not naming.is_constexpr(arg.id)
|
175
|
+
],
|
141
176
|
keywords=[],
|
142
177
|
)
|
143
178
|
|
144
|
-
|
179
|
+
unparsed = f"return {self._generic_unparse(call)};"
|
180
|
+
|
181
|
+
pattern = rf"\((stream), {', '.join(r'([^,]*)' for _ in range(len(self._param_types) - 1))}\)"
|
182
|
+
args = re.search(pattern, unparsed).groups()
|
183
|
+
|
184
|
+
for i, (arg, type) in enumerate(zip(args, self._param_types)):
|
185
|
+
if i != 0 and "." not in arg:
|
186
|
+
new_arg = f"*({type} *){arg}.data"
|
187
|
+
else:
|
188
|
+
new_arg = f"({type}){arg}"
|
189
|
+
|
190
|
+
unparsed = unparsed.replace(arg, new_arg)
|
191
|
+
|
192
|
+
return unparsed
|
145
193
|
|
146
194
|
def _unparse_FunctionDef(self, node):
|
147
195
|
params = ["NineToothedStream stream"]
|
@@ -153,13 +201,16 @@ class _Unparser:
|
|
153
201
|
body_lines = []
|
154
202
|
|
155
203
|
for stmt in node.body:
|
204
|
+
if isinstance(stmt, ast.Assign):
|
205
|
+
continue
|
206
|
+
|
156
207
|
stmt_unparsed = self.unparse(stmt)
|
157
208
|
|
158
209
|
if isinstance(stmt, ast.Expr):
|
159
210
|
stmt_unparsed = stmt_unparsed.strip()
|
160
211
|
|
161
|
-
|
162
|
-
|
212
|
+
if not stmt_unparsed.endswith(";"):
|
213
|
+
stmt_unparsed += ";"
|
163
214
|
|
164
215
|
body_lines.append(" " + stmt_unparsed)
|
165
216
|
|
@@ -289,12 +289,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
289
289
|
if isinstance(value, Tensor):
|
290
290
|
attr = getattr(value, node.attr)
|
291
291
|
|
292
|
-
if node.attr == "dtype" and attr is None:
|
293
|
-
return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
|
294
|
-
|
295
292
|
if isinstance(attr, Tensor):
|
296
293
|
return attr
|
297
294
|
|
295
|
+
if node.attr == "dtype":
|
296
|
+
return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
|
297
|
+
|
298
298
|
return Symbol(attr).node
|
299
299
|
|
300
300
|
self.generic_visit(node)
|
@@ -500,16 +500,19 @@ class CodeGenerator(ast.NodeTransformer):
|
|
500
500
|
naming.remove_prefixes(param) for param in next_power_of_2_params
|
501
501
|
]
|
502
502
|
|
503
|
+
arg_names = [naming.remove_prefixes(arg.source.name) for arg in self._args]
|
504
|
+
|
505
|
+
arg_names += [
|
506
|
+
param
|
507
|
+
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
508
|
+
if not Tensor.size_pattern().fullmatch(param) and param not in arg_names
|
509
|
+
]
|
510
|
+
|
503
511
|
launch = ast.FunctionDef(
|
504
512
|
name=self.launch_func_name,
|
505
513
|
args=ast.arguments(
|
506
514
|
posonlyargs=[],
|
507
|
-
args=[ast.arg(arg=
|
508
|
-
+ [
|
509
|
-
ast.arg(arg=param)
|
510
|
-
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
511
|
-
if not Tensor.size_pattern().fullmatch(param)
|
512
|
-
],
|
515
|
+
args=[ast.arg(arg=name) for name in arg_names],
|
513
516
|
kwonlyargs=[],
|
514
517
|
defaults=[],
|
515
518
|
),
|
@@ -32,6 +32,8 @@ class Tensor:
|
|
32
32
|
strides=None,
|
33
33
|
other=None,
|
34
34
|
shape_options=None,
|
35
|
+
constexpr=None,
|
36
|
+
value=None,
|
35
37
|
name=None,
|
36
38
|
source=None,
|
37
39
|
source_dims=None,
|
@@ -74,6 +76,21 @@ class Tensor:
|
|
74
76
|
|
75
77
|
self.other = other
|
76
78
|
|
79
|
+
if constexpr and self.ndim != 0:
|
80
|
+
raise ValueError(
|
81
|
+
"`constexpr` can only be set for zero-dimensional tensors."
|
82
|
+
)
|
83
|
+
|
84
|
+
self.constexpr = constexpr
|
85
|
+
|
86
|
+
if self.constexpr:
|
87
|
+
self.name = naming.make_constexpr(self.name)
|
88
|
+
|
89
|
+
if not constexpr and value is not None:
|
90
|
+
raise ValueError("`value` can only be set for constexpr tensors.")
|
91
|
+
|
92
|
+
self.value = value
|
93
|
+
|
77
94
|
if source is not None:
|
78
95
|
self.source = source
|
79
96
|
else:
|
@@ -9,11 +9,23 @@ from ninetoothed import Tensor
|
|
9
9
|
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
10
10
|
|
11
11
|
|
12
|
-
def arrangement(
|
13
|
-
|
12
|
+
def arrangement(
|
13
|
+
input,
|
14
|
+
mat1,
|
15
|
+
mat2,
|
16
|
+
beta,
|
17
|
+
alpha,
|
18
|
+
output,
|
19
|
+
BLOCK_SIZE_M=matmul.BLOCK_SIZE_M,
|
20
|
+
BLOCK_SIZE_N=matmul.BLOCK_SIZE_N,
|
21
|
+
BLOCK_SIZE_K=matmul.BLOCK_SIZE_K,
|
22
|
+
):
|
23
|
+
_, _, input_arranged = matmul.arrangement(
|
24
|
+
mat1, mat2, input, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K
|
25
|
+
)
|
14
26
|
|
15
27
|
mat1_arranged, mat2_arranged, output_arranged = matmul.arrangement(
|
16
|
-
mat1, mat2, output
|
28
|
+
mat1, mat2, output, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K
|
17
29
|
)
|
18
30
|
|
19
31
|
return input_arranged, mat1_arranged, mat2_arranged, beta, alpha, output_arranged
|
@@ -7,6 +7,8 @@ import torch.nn.functional as F
|
|
7
7
|
|
8
8
|
import ninetoothed
|
9
9
|
import ninetoothed.generation
|
10
|
+
import tests.test_addmm as addmm
|
11
|
+
import tests.test_attention as attention
|
10
12
|
import tests.test_conv2d as conv2d
|
11
13
|
import tests.test_matmul as matmul
|
12
14
|
from ninetoothed import Tensor
|
@@ -19,6 +21,91 @@ class TestCUDA:
|
|
19
21
|
def setup_class(cls):
|
20
22
|
torch.manual_seed(0)
|
21
23
|
|
24
|
+
def test_addmm(self):
|
25
|
+
arrangement = functools.partial(
|
26
|
+
addmm.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
|
27
|
+
)
|
28
|
+
application = addmm.application
|
29
|
+
tensors = tuple(
|
30
|
+
Tensor(ndim, dtype=ninetoothed.float16) for ndim in (2, 2, 2, 0, 0, 2)
|
31
|
+
)
|
32
|
+
caller = "cuda"
|
33
|
+
kernel_name = "addmm"
|
34
|
+
output_dir = ninetoothed.generation.CACHE_DIR
|
35
|
+
|
36
|
+
launch_func = _generate_launch_func(
|
37
|
+
arrangement,
|
38
|
+
application,
|
39
|
+
tensors,
|
40
|
+
caller=caller,
|
41
|
+
kernel_name=kernel_name,
|
42
|
+
output_dir=output_dir,
|
43
|
+
)
|
44
|
+
|
45
|
+
shape = (512, 512)
|
46
|
+
dtype = torch.float16
|
47
|
+
device = caller
|
48
|
+
|
49
|
+
input = torch.randn(shape, dtype=dtype, device=device)
|
50
|
+
mat1 = torch.randn(shape, dtype=dtype, device=device)
|
51
|
+
mat2 = torch.randn(shape, dtype=dtype, device=device)
|
52
|
+
beta = torch.randn((), dtype=dtype)
|
53
|
+
alpha = torch.randn((), dtype=dtype)
|
54
|
+
output = torch.empty(
|
55
|
+
(mat1.shape[0], mat2.shape[1]), dtype=mat1.dtype, device=mat1.device
|
56
|
+
)
|
57
|
+
|
58
|
+
_run_launch_func(launch_func, input, mat1, mat2, beta, alpha, output)
|
59
|
+
|
60
|
+
assert torch.allclose(
|
61
|
+
output, torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha), atol=0.075
|
62
|
+
)
|
63
|
+
|
64
|
+
def test_attention(self):
|
65
|
+
emb_dim = 64
|
66
|
+
|
67
|
+
arrangement = functools.partial(
|
68
|
+
attention.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64
|
69
|
+
)
|
70
|
+
application = attention.application
|
71
|
+
query_, key_, value_, output_ = tuple(
|
72
|
+
Tensor(4, dtype=ninetoothed.float16) for _ in range(4)
|
73
|
+
)
|
74
|
+
for tensor in (query_, key_, value_, output_):
|
75
|
+
tensor.shape = tensor.shape[:-1] + (emb_dim,)
|
76
|
+
is_causal_ = Tensor(0, constexpr=True, value=1)
|
77
|
+
tensors = (query_, key_, value_, is_causal_, output_)
|
78
|
+
caller = "cuda"
|
79
|
+
kernel_name = "attention"
|
80
|
+
output_dir = ninetoothed.generation.CACHE_DIR
|
81
|
+
|
82
|
+
launch_func = _generate_launch_func(
|
83
|
+
arrangement,
|
84
|
+
application,
|
85
|
+
tensors,
|
86
|
+
caller=caller,
|
87
|
+
kernel_name=kernel_name,
|
88
|
+
output_dir=output_dir,
|
89
|
+
)
|
90
|
+
|
91
|
+
shape = (2, 4, 1024, emb_dim)
|
92
|
+
dtype = torch.float16
|
93
|
+
device = caller
|
94
|
+
|
95
|
+
query = torch.randn(shape, dtype=dtype, device=device)
|
96
|
+
key = torch.randn(shape, dtype=dtype, device=device)
|
97
|
+
value = torch.randn(shape, dtype=dtype, device=device)
|
98
|
+
is_causal = torch.tensor(True)
|
99
|
+
output = torch.empty(shape, dtype=dtype, device=device)
|
100
|
+
|
101
|
+
_run_launch_func(launch_func, query, key, value, is_causal, output)
|
102
|
+
|
103
|
+
assert torch.allclose(
|
104
|
+
output,
|
105
|
+
F.scaled_dot_product_attention(query, key, value, is_causal=True, scale=1),
|
106
|
+
atol=0.01,
|
107
|
+
)
|
108
|
+
|
22
109
|
def test_matmul(self):
|
23
110
|
arrangement = functools.partial(
|
24
111
|
matmul.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
|
@@ -7,11 +7,13 @@ import ninetoothed.language as ntl
|
|
7
7
|
from ninetoothed import Tensor
|
8
8
|
from tests.skippers import skip_if_cuda_not_available
|
9
9
|
|
10
|
+
BLOCK_SIZE_M = ninetoothed.block_size(lower_bound=64, upper_bound=128)
|
11
|
+
BLOCK_SIZE_N = ninetoothed.block_size(lower_bound=32, upper_bound=64)
|
10
12
|
|
11
|
-
def arrangement(q, k, v, o):
|
12
|
-
BLOCK_SIZE_M = ninetoothed.block_size(lower_bound=64, upper_bound=128)
|
13
|
-
BLOCK_SIZE_N = ninetoothed.block_size(lower_bound=32, upper_bound=64)
|
14
13
|
|
14
|
+
def arrangement(
|
15
|
+
q, k, v, is_causal, o, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N
|
16
|
+
):
|
15
17
|
def arrange_q_or_o(input):
|
16
18
|
arranged = input.tile((1, 1, BLOCK_SIZE_M, -1))
|
17
19
|
arranged.dtype = arranged.dtype.squeeze((0, 1))
|
@@ -31,10 +33,16 @@ def arrangement(q, k, v, o):
|
|
31
33
|
|
32
34
|
q_arranged = arrange_q_or_o(q)
|
33
35
|
|
34
|
-
return
|
36
|
+
return (
|
37
|
+
q_arranged,
|
38
|
+
arrange_k_or_v(k),
|
39
|
+
arrange_k_or_v(v),
|
40
|
+
is_causal,
|
41
|
+
arrange_q_or_o(o),
|
42
|
+
)
|
35
43
|
|
36
44
|
|
37
|
-
def application(q, k, v, o):
|
45
|
+
def application(q, k, v, is_causal, o):
|
38
46
|
q_loaded = (q * 1.44269504089).to(q.dtype)
|
39
47
|
|
40
48
|
acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
|
@@ -45,6 +53,10 @@ def application(q, k, v, o):
|
|
45
53
|
qk = ntl.dot(q_loaded, ntl.trans(k[i]))
|
46
54
|
qk = ntl.where(k[i].offsets(-2) < k.source.shape[-2], qk, float("-inf"))
|
47
55
|
|
56
|
+
if is_causal:
|
57
|
+
mask = q.offsets(-2)[:, None] >= k[i].offsets(-2)[None, :]
|
58
|
+
qk = ntl.where(mask, qk, float("-inf"))
|
59
|
+
|
48
60
|
m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
|
49
61
|
p = ntl.exp2(qk - m_ij[:, None])
|
50
62
|
l_ij = ntl.sum(p, 1)
|
@@ -58,27 +70,28 @@ def application(q, k, v, o):
|
|
58
70
|
o = acc # noqa: F841
|
59
71
|
|
60
72
|
|
61
|
-
def attention(q, k, v):
|
73
|
+
def attention(q, k, v, is_causal=False):
|
62
74
|
o = torch.empty_like(q, dtype=v.dtype)
|
63
75
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
),
|
76
|
-
)
|
77
|
-
for _ in range(4)
|
78
|
-
),
|
76
|
+
q_, k_, v_, o_ = (
|
77
|
+
Tensor(
|
78
|
+
4,
|
79
|
+
shape_options=(
|
80
|
+
None,
|
81
|
+
None,
|
82
|
+
{"constexpr": True},
|
83
|
+
{"constexpr": True, "upper_bound": 128},
|
84
|
+
),
|
85
|
+
)
|
86
|
+
for _ in range(4)
|
79
87
|
)
|
88
|
+
is_causal_ = Tensor(0, constexpr=True)
|
80
89
|
|
81
|
-
|
90
|
+
tensors = (q_, k_, v_, is_causal_, o_)
|
91
|
+
|
92
|
+
attention_kernel = ninetoothed.make(arrangement, application, tensors)
|
93
|
+
|
94
|
+
attention_kernel(q, k, v, is_causal, o)
|
82
95
|
|
83
96
|
return o
|
84
97
|
|
@@ -87,6 +100,10 @@ def attention(q, k, v):
|
|
87
100
|
class TestCUDA:
|
88
101
|
shapes = ((2, 4, 1024, 64), (2, 4, 1, 64))
|
89
102
|
|
103
|
+
dtypes = (torch.float32, torch.float16)
|
104
|
+
|
105
|
+
is_causal_values = (False, True)
|
106
|
+
|
90
107
|
@classmethod
|
91
108
|
def setup_class(cls):
|
92
109
|
torch.manual_seed(0)
|
@@ -96,23 +113,22 @@ class TestCUDA:
|
|
96
113
|
for shape in cls.shapes
|
97
114
|
}
|
98
115
|
|
116
|
+
@pytest.mark.parametrize("is_causal", is_causal_values)
|
117
|
+
@pytest.mark.parametrize("dtype", dtypes)
|
99
118
|
@pytest.mark.parametrize("shape", shapes)
|
100
|
-
def
|
101
|
-
q, k, v = (arg.to(
|
119
|
+
def test(self, shape, dtype, is_causal):
|
120
|
+
q, k, v = (arg.to(dtype) for arg in type(self).args[shape])
|
102
121
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
@pytest.mark.parametrize("shape", shapes)
|
111
|
-
def test_fp16(self, shape):
|
112
|
-
q, k, v = (arg.to(torch.float16) for arg in type(self).args[shape])
|
122
|
+
if dtype == torch.float32:
|
123
|
+
atol = 0.025
|
124
|
+
rtol = 0.025
|
125
|
+
elif dtype == torch.float16:
|
126
|
+
atol = 0.01
|
127
|
+
rtol = 0.01
|
113
128
|
|
114
129
|
assert torch.allclose(
|
115
|
-
attention(q, k, v),
|
116
|
-
F.scaled_dot_product_attention(q, k, v, scale=1),
|
117
|
-
atol=
|
130
|
+
attention(q, k, v, is_causal=is_causal),
|
131
|
+
F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, scale=1),
|
132
|
+
atol=atol,
|
133
|
+
rtol=rtol,
|
118
134
|
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|