ninetoothed 0.14.0__tar.gz → 0.15.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed-0.15.1/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
- ninetoothed-0.15.1/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
- ninetoothed-0.15.1/.github/pull_request_template.md +5 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/PKG-INFO +1 -1
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/pyproject.toml +1 -1
- ninetoothed-0.15.1/src/ninetoothed/__init__.py +35 -0
- ninetoothed-0.15.1/src/ninetoothed/aot.py +217 -0
- ninetoothed-0.15.1/src/ninetoothed/cudaifier.py +36 -0
- ninetoothed-0.15.1/src/ninetoothed/dtype.py +13 -0
- ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.1/src/ninetoothed/generation.py +83 -116
- ninetoothed-0.15.1/src/ninetoothed/jit.py +77 -0
- ninetoothed-0.15.1/src/ninetoothed/make.py +45 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/tensor.py +1 -1
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/visualization.py +10 -4
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_addmm.py +6 -7
- ninetoothed-0.15.1/tests/test_aot.py +153 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_conv2d.py +16 -2
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_matmul.py +13 -6
- ninetoothed-0.14.0/src/ninetoothed/__init__.py +0 -5
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.gitattributes +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/publish-to-pypi.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/sphinx.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.gitignore +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/LICENSE +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/README.md +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/Makefile +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/README.zh.md +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/make.bat +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/requirements.txt +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/matmul-tiling.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/ninetoothed-logo.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/code_generation.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/conf.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/index.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/installation.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/python_api.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/symbol.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/tensor.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/visualization.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/requirements.txt +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/naming.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/symbol.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/__init__.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/skippers.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_add.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_attention.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_max_pool2d.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_naming.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
name: 🐛 Bug report
|
2
|
+
description: Something isn't working as expected 🤔.
|
3
|
+
labels: ["bug"]
|
4
|
+
|
5
|
+
body:
|
6
|
+
- type: markdown
|
7
|
+
attributes:
|
8
|
+
value: Thanks for taking the time to fill out this bug report!
|
9
|
+
|
10
|
+
- type: checkboxes
|
11
|
+
attributes:
|
12
|
+
label: Is there an existing issue for this?
|
13
|
+
description: >
|
14
|
+
Please search to see if an issue already exists
|
15
|
+
for the bug you encountered.
|
16
|
+
options:
|
17
|
+
- label: I have searched the existing issues.
|
18
|
+
required: true
|
19
|
+
|
20
|
+
- type: textarea
|
21
|
+
attributes:
|
22
|
+
label: "Describe the bug:"
|
23
|
+
description: A clear and concise description of what the bug is.
|
24
|
+
validations:
|
25
|
+
required: false
|
26
|
+
|
27
|
+
- type: textarea
|
28
|
+
attributes:
|
29
|
+
label: "To reproduce:"
|
30
|
+
description: >
|
31
|
+
Steps to reproduce the behavior.
|
32
|
+
If applicable, provide a small, self-contained piece of code
|
33
|
+
that can be run directly to reproduce the issue.
|
34
|
+
validations:
|
35
|
+
required: false
|
36
|
+
|
37
|
+
- type: textarea
|
38
|
+
attributes:
|
39
|
+
label: "Expected behavior:"
|
40
|
+
description: >
|
41
|
+
A clear and concise description of what you expected to happen.
|
42
|
+
validations:
|
43
|
+
required: false
|
44
|
+
|
45
|
+
- type: textarea
|
46
|
+
attributes:
|
47
|
+
label: "Environment details:"
|
48
|
+
description: >
|
49
|
+
Please include your NineToothed version, operating system,
|
50
|
+
hardware platform, and any relevant information.
|
51
|
+
If you are using PyTorch, please run
|
52
|
+
`python -m torch.utils.collect_env` to gather
|
53
|
+
environment information.
|
54
|
+
validations:
|
55
|
+
required: false
|
@@ -0,0 +1,13 @@
|
|
1
|
+
name: 🚀 Feature request
|
2
|
+
description: I have a suggestion 🙂!
|
3
|
+
|
4
|
+
body:
|
5
|
+
- type: textarea
|
6
|
+
attributes:
|
7
|
+
label: "Description & motivation:"
|
8
|
+
description: >
|
9
|
+
Please describe the feature that you would like to see and
|
10
|
+
explain the problem it would solve or
|
11
|
+
the benefit it would provide.
|
12
|
+
validations:
|
13
|
+
required: true
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.15.1
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.15.1"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from ninetoothed.dtype import (
|
2
|
+
float16,
|
3
|
+
float32,
|
4
|
+
float64,
|
5
|
+
int8,
|
6
|
+
int16,
|
7
|
+
int32,
|
8
|
+
int64,
|
9
|
+
uint8,
|
10
|
+
uint16,
|
11
|
+
uint32,
|
12
|
+
uint64,
|
13
|
+
)
|
14
|
+
from ninetoothed.jit import jit
|
15
|
+
from ninetoothed.make import make
|
16
|
+
from ninetoothed.symbol import Symbol
|
17
|
+
from ninetoothed.tensor import Tensor
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"Symbol",
|
21
|
+
"Tensor",
|
22
|
+
"float16",
|
23
|
+
"float32",
|
24
|
+
"float64",
|
25
|
+
"int8",
|
26
|
+
"int16",
|
27
|
+
"int32",
|
28
|
+
"int64",
|
29
|
+
"jit",
|
30
|
+
"make",
|
31
|
+
"uint8",
|
32
|
+
"uint16",
|
33
|
+
"uint32",
|
34
|
+
"uint64",
|
35
|
+
]
|
@@ -0,0 +1,217 @@
|
|
1
|
+
import ast
|
2
|
+
import pathlib
|
3
|
+
import subprocess
|
4
|
+
import tempfile
|
5
|
+
import uuid
|
6
|
+
|
7
|
+
from ninetoothed.dtype import int64
|
8
|
+
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
|
+
from ninetoothed.tensor import Tensor
|
10
|
+
|
11
|
+
|
12
|
+
def aot(
|
13
|
+
func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
|
14
|
+
):
|
15
|
+
output_dir = pathlib.Path(output_dir)
|
16
|
+
|
17
|
+
output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
|
18
|
+
|
19
|
+
for output_name, output_content in output_contents.items():
|
20
|
+
output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
|
21
|
+
|
22
|
+
with open(output_path, "w") as f:
|
23
|
+
f.write(output_content)
|
24
|
+
|
25
|
+
|
26
|
+
def _aot(func, caller, kernel_name, num_warps, num_stages):
|
27
|
+
def _find_tensor_by_source_name(tensors, name):
|
28
|
+
for tensor in tensors:
|
29
|
+
if tensor.source.name == name:
|
30
|
+
return tensor
|
31
|
+
|
32
|
+
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
33
|
+
|
34
|
+
if not _HEADER_PATH.exists():
|
35
|
+
_HEADER_PATH.write_text(_HEADER_CONTENT)
|
36
|
+
|
37
|
+
code_generator = CodeGenerator()
|
38
|
+
source_file = code_generator(
|
39
|
+
func, caller=caller, kernel_name=kernel_name, prettify=False
|
40
|
+
)
|
41
|
+
|
42
|
+
tensors = code_generator.tensors
|
43
|
+
kernel_func = code_generator.kernel_func
|
44
|
+
launch_func = code_generator.launch_func
|
45
|
+
|
46
|
+
param_types = []
|
47
|
+
|
48
|
+
for arg in kernel_func.args.args:
|
49
|
+
param = arg.arg
|
50
|
+
|
51
|
+
if match := Tensor.pointer_pattern().fullmatch(param):
|
52
|
+
source_name = match.group(0).removesuffix("_pointer")
|
53
|
+
tensor = _find_tensor_by_source_name(tensors, source_name)
|
54
|
+
dtype = tensor.source.dtype
|
55
|
+
|
56
|
+
param_types.append(f"*{dtype}")
|
57
|
+
elif Tensor.size_pattern().fullmatch(param):
|
58
|
+
param_types.append(int64)
|
59
|
+
elif Tensor.stride_pattern().fullmatch(param):
|
60
|
+
param_types.append(int64)
|
61
|
+
|
62
|
+
signature = ", ".join(param_types)
|
63
|
+
|
64
|
+
grid_extractor = _GridExtractor()
|
65
|
+
launch_func = grid_extractor.visit(launch_func)
|
66
|
+
grid_extractor.visit(code_generator.raw_grid)
|
67
|
+
grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
|
68
|
+
|
69
|
+
signature_hash, output_contents = _compile(
|
70
|
+
source_file, kernel_name, signature, grid, num_warps, num_stages
|
71
|
+
)
|
72
|
+
|
73
|
+
unparser = _Unparser()
|
74
|
+
|
75
|
+
launch_func_unparsed = unparser.unparse(launch_func)
|
76
|
+
launch_func_unparsed = launch_func_unparsed.replace(
|
77
|
+
func.__name__, f"{kernel_name}_{signature_hash}"
|
78
|
+
)
|
79
|
+
|
80
|
+
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
81
|
+
c_source_file = output_contents[c_source_file_name]
|
82
|
+
c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
|
83
|
+
c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
84
|
+
output_contents[c_source_file_name] = c_source_file
|
85
|
+
|
86
|
+
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
87
|
+
c_header_file = output_contents[c_header_file_name]
|
88
|
+
c_header_file = f"{c_header_file}\n{unparser.header};\n"
|
89
|
+
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
90
|
+
output_contents[c_header_file_name] = c_header_file
|
91
|
+
|
92
|
+
return output_contents
|
93
|
+
|
94
|
+
|
95
|
+
_HEADER_CONTENT = """#include <stdint.h>
|
96
|
+
|
97
|
+
typedef struct {
|
98
|
+
uintptr_t data;
|
99
|
+
uint64_t *shape;
|
100
|
+
int64_t *strides;
|
101
|
+
} NineToothedTensor;
|
102
|
+
"""
|
103
|
+
|
104
|
+
_HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
105
|
+
|
106
|
+
|
107
|
+
class _Unparser:
|
108
|
+
def unparse(self, node):
|
109
|
+
method_name = "_unparse_" + node.__class__.__name__
|
110
|
+
|
111
|
+
if hasattr(self, method_name):
|
112
|
+
return getattr(self, method_name)(node)
|
113
|
+
|
114
|
+
return self._generic_unparse(node)
|
115
|
+
|
116
|
+
def _generic_unparse(self, node):
|
117
|
+
return ast.unparse(node)
|
118
|
+
|
119
|
+
def _unparse_Expr(self, node):
|
120
|
+
return self.unparse(node.value)
|
121
|
+
|
122
|
+
def _unparse_Call(self, node):
|
123
|
+
call = ast.Call(
|
124
|
+
func=node.func,
|
125
|
+
args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
|
126
|
+
keywords=[],
|
127
|
+
)
|
128
|
+
|
129
|
+
return f"return {self._generic_unparse(call)};"
|
130
|
+
|
131
|
+
def _unparse_FunctionDef(self, node):
|
132
|
+
params = ["CUstream stream"]
|
133
|
+
params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
|
134
|
+
header = f"CUresult {node.name}({', '.join(params)})"
|
135
|
+
|
136
|
+
self.header = header
|
137
|
+
|
138
|
+
body_lines = []
|
139
|
+
|
140
|
+
for stmt in node.body:
|
141
|
+
stmt_unparsed = self.unparse(stmt)
|
142
|
+
|
143
|
+
if isinstance(stmt, ast.Expr):
|
144
|
+
stmt_unparsed = stmt_unparsed.strip()
|
145
|
+
|
146
|
+
if not stmt_unparsed.endswith(";"):
|
147
|
+
stmt_unparsed += ";"
|
148
|
+
|
149
|
+
body_lines.append(" " + stmt_unparsed)
|
150
|
+
|
151
|
+
body = "\n".join(body_lines)
|
152
|
+
|
153
|
+
return f"{header} {{\n{body}\n}}"
|
154
|
+
|
155
|
+
|
156
|
+
class _GridExtractor(ast.NodeTransformer):
|
157
|
+
def visit_BinOp(self, node):
|
158
|
+
self.generic_visit(node)
|
159
|
+
|
160
|
+
if isinstance(node.op, ast.FloorDiv):
|
161
|
+
node.op = ast.Div()
|
162
|
+
|
163
|
+
return node
|
164
|
+
|
165
|
+
def visit_Call(self, node):
|
166
|
+
self.generic_visit(node)
|
167
|
+
|
168
|
+
node.func = node.func.value
|
169
|
+
|
170
|
+
return node
|
171
|
+
|
172
|
+
def visit_Lambda(self, node):
|
173
|
+
self.generic_visit(node)
|
174
|
+
|
175
|
+
self.grid = node.body.elts
|
176
|
+
|
177
|
+
return node
|
178
|
+
|
179
|
+
|
180
|
+
def _compile(path, name, signature, grid, num_warps, num_stages):
|
181
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
182
|
+
output_dir = pathlib.Path(temp_dir)
|
183
|
+
output_name = uuid.uuid4().hex
|
184
|
+
output_path = output_dir / output_name
|
185
|
+
|
186
|
+
command = [
|
187
|
+
"python",
|
188
|
+
"-m",
|
189
|
+
"triton.tools.compile",
|
190
|
+
str(path),
|
191
|
+
"--kernel-name",
|
192
|
+
str(name),
|
193
|
+
"--signature",
|
194
|
+
str(signature),
|
195
|
+
"--grid",
|
196
|
+
str(grid),
|
197
|
+
"--num-warps",
|
198
|
+
str(num_warps),
|
199
|
+
"--num-stages",
|
200
|
+
str(num_stages),
|
201
|
+
"--out-path",
|
202
|
+
str(output_path),
|
203
|
+
]
|
204
|
+
|
205
|
+
subprocess.run(command, check=True)
|
206
|
+
|
207
|
+
matching_files = list(output_dir.glob(f"{output_name}.*"))
|
208
|
+
|
209
|
+
signature_hash = matching_files[0].name.split(".")[1]
|
210
|
+
|
211
|
+
output_contents = {}
|
212
|
+
|
213
|
+
for file in matching_files:
|
214
|
+
with file.open() as f:
|
215
|
+
output_contents[file.name.replace(output_name, name)] = f.read()
|
216
|
+
|
217
|
+
return signature_hash, output_contents
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
import ninetoothed.naming as naming
|
4
|
+
from ninetoothed.tensor import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class Cudaifier(ast.NodeTransformer):
|
8
|
+
def visit_Name(self, node):
|
9
|
+
self.generic_visit(node)
|
10
|
+
|
11
|
+
source = node.id
|
12
|
+
|
13
|
+
if naming.is_constexpr(source):
|
14
|
+
return node
|
15
|
+
|
16
|
+
def repl(match):
|
17
|
+
return f"{match.group(1)}.data"
|
18
|
+
|
19
|
+
source = Tensor.pointer_pattern().sub(repl, source)
|
20
|
+
|
21
|
+
def repl(match):
|
22
|
+
return f"{match.group(1)}.shape[{match.group(3)}]"
|
23
|
+
|
24
|
+
source = Tensor.size_pattern().sub(repl, source)
|
25
|
+
|
26
|
+
def repl(match):
|
27
|
+
return f"{match.group(1)}.strides[{match.group(3)}]"
|
28
|
+
|
29
|
+
source = Tensor.stride_pattern().sub(repl, source)
|
30
|
+
|
31
|
+
source = source.removesuffix("_with_auto_tuning")
|
32
|
+
|
33
|
+
if source != node.id:
|
34
|
+
return ast.parse(source, mode="eval").body
|
35
|
+
|
36
|
+
return node
|
ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.1/src/ninetoothed/generation.py
RENAMED
@@ -2,84 +2,89 @@ import ast
|
|
2
2
|
import collections
|
3
3
|
import copy
|
4
4
|
import functools
|
5
|
-
import
|
5
|
+
import hashlib
|
6
6
|
import inspect
|
7
7
|
import itertools
|
8
8
|
import math
|
9
|
+
import pathlib
|
9
10
|
import subprocess
|
10
|
-
import sys
|
11
|
-
import tempfile
|
12
11
|
|
13
12
|
import triton
|
14
13
|
|
15
14
|
import ninetoothed.naming as naming
|
15
|
+
from ninetoothed.cudaifier import Cudaifier
|
16
16
|
from ninetoothed.language import attribute, call
|
17
17
|
from ninetoothed.symbol import Symbol
|
18
18
|
from ninetoothed.tensor import Tensor
|
19
19
|
from ninetoothed.torchifier import Torchifier
|
20
20
|
|
21
|
+
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
|
21
22
|
|
22
|
-
def make(arrangement, application, tensors):
|
23
|
-
"""Integrate the arrangement and the application of the tensors.
|
24
23
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
24
|
+
class CodeGenerator(ast.NodeTransformer):
|
25
|
+
def __init__(self):
|
26
|
+
super().__init__()
|
27
|
+
|
28
|
+
self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
|
29
|
+
|
30
|
+
self._MIN_PRODUCT = 2**10
|
31
|
+
|
32
|
+
self._MAX_PRODUCT = 2**20
|
34
33
|
|
35
|
-
|
34
|
+
def __call__(self, func, caller, kernel_name, prettify):
|
35
|
+
def _get_tree(func):
|
36
|
+
module = ast.parse(inspect.getsource(inspect.getmodule(func)))
|
36
37
|
|
38
|
+
collector = _ImportCollector()
|
39
|
+
collector.visit(module)
|
37
40
|
|
38
|
-
|
39
|
-
|
41
|
+
finder = _FunctionDefFinder(func.__name__)
|
42
|
+
finder.visit(module)
|
43
|
+
func_def = finder.result
|
40
44
|
|
41
|
-
|
42
|
-
|
43
|
-
|
45
|
+
inliner = _Inliner(func.__globals__)
|
46
|
+
inliner.visit(func_def)
|
47
|
+
module.body = collector.imports + inliner.imports + [finder.result]
|
44
48
|
|
45
|
-
|
49
|
+
return _AliasRestorer().visit(module)
|
46
50
|
|
47
|
-
|
48
|
-
|
49
|
-
"""
|
51
|
+
def _find_dependencies(func):
|
52
|
+
dependencies = set()
|
50
53
|
|
51
|
-
|
52
|
-
|
54
|
+
for obj in func.__globals__.values():
|
55
|
+
if isinstance(obj, triton.runtime.JITFunction):
|
56
|
+
dependencies.add(obj.src)
|
53
57
|
|
54
|
-
|
55
|
-
|
58
|
+
return "\n".join(
|
59
|
+
f"@triton.jit\n{dependency}" for dependency in dependencies
|
60
|
+
)
|
56
61
|
|
57
|
-
|
62
|
+
self.launch_func_name = f"launch_{kernel_name}"
|
58
63
|
|
64
|
+
self._caller = caller
|
59
65
|
|
60
|
-
|
61
|
-
def __init__(self, func, _prettify=False):
|
62
|
-
self.func = func
|
66
|
+
self._context = inspect.get_annotations(func)
|
63
67
|
|
64
|
-
self.
|
68
|
+
self._args = list(self._context.values())
|
65
69
|
|
66
|
-
|
67
|
-
tree = self._get_tree()
|
70
|
+
tree = _get_tree(func)
|
68
71
|
|
69
|
-
|
72
|
+
self.visit(tree)
|
70
73
|
Tritonizer().visit(tree)
|
71
74
|
_BinOpSimplifier().visit(tree)
|
72
75
|
ast.fix_missing_locations(tree)
|
73
76
|
|
74
|
-
if
|
77
|
+
if prettify:
|
75
78
|
name_collector = _SimplifiedNameCollector()
|
76
79
|
name_collector.visit(tree)
|
77
80
|
|
78
81
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
79
|
-
dependencies =
|
82
|
+
dependencies = _find_dependencies(func)
|
80
83
|
source = "\n\n".join((unparsed, dependencies)).strip()
|
84
|
+
source = source.replace(func.__name__, kernel_name)
|
85
|
+
source += "\n"
|
81
86
|
|
82
|
-
if
|
87
|
+
if prettify:
|
83
88
|
for original, simplified in name_collector.simplified_names.items():
|
84
89
|
if simplified not in name_collector.simplified_names:
|
85
90
|
source = source.replace(original, simplified)
|
@@ -88,73 +93,29 @@ class JIT:
|
|
88
93
|
["ruff", "format", "-"], input=source, encoding="utf-8"
|
89
94
|
)
|
90
95
|
|
91
|
-
|
92
|
-
|
93
|
-
|
96
|
+
digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
|
97
|
+
cache_dir = CACHE_DIR
|
98
|
+
cache_dir.mkdir(exist_ok=True)
|
99
|
+
cache_file = cache_dir / f"{digest}.py"
|
94
100
|
|
95
|
-
|
96
|
-
|
101
|
+
if not cache_file.exists():
|
102
|
+
with open(cache_file, "w", encoding="utf-8") as f:
|
103
|
+
f.write(source)
|
97
104
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
source,
|
102
|
-
)
|
103
|
-
|
104
|
-
return handle
|
105
|
-
|
106
|
-
def _get_tree(self):
|
107
|
-
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
108
|
-
|
109
|
-
collector = _ImportCollector()
|
110
|
-
collector.visit(module)
|
111
|
-
|
112
|
-
finder = _FunctionDefFinder(self.func.__name__)
|
113
|
-
finder.visit(module)
|
114
|
-
func_def = finder.result
|
105
|
+
self.tensors = self._args
|
106
|
+
self.kernel_func = self._func_def
|
107
|
+
self.launch_func = self._launch
|
115
108
|
|
116
|
-
|
117
|
-
inliner.visit(func_def)
|
118
|
-
module.body = collector.imports + inliner.imports + [finder.result]
|
119
|
-
|
120
|
-
return _AliasRestorer().visit(module)
|
121
|
-
|
122
|
-
def _find_dependencies(self):
|
123
|
-
dependencies = set()
|
124
|
-
|
125
|
-
for obj in self.func.__globals__.values():
|
126
|
-
if isinstance(obj, triton.runtime.JITFunction):
|
127
|
-
dependencies.add(obj.src)
|
128
|
-
|
129
|
-
return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
|
130
|
-
|
131
|
-
@staticmethod
|
132
|
-
def _import_from_path(module_name, file_path):
|
133
|
-
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
134
|
-
module = importlib.util.module_from_spec(spec)
|
135
|
-
sys.modules[module_name] = module
|
136
|
-
spec.loader.exec_module(module)
|
137
|
-
|
138
|
-
return module
|
139
|
-
|
140
|
-
|
141
|
-
class CodeGenerator(ast.NodeTransformer):
|
142
|
-
def __init__(self, context):
|
143
|
-
super().__init__()
|
144
|
-
|
145
|
-
self._context = context
|
146
|
-
|
147
|
-
self._args = list(self._context.values())
|
148
|
-
|
149
|
-
self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
|
150
|
-
|
151
|
-
self._MIN_PRODUCT = 2**10
|
152
|
-
|
153
|
-
self._MAX_PRODUCT = 2**20
|
109
|
+
return str(cache_file)
|
154
110
|
|
155
111
|
def visit_Module(self, node):
|
156
112
|
self.generic_visit(node)
|
157
113
|
|
114
|
+
func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
|
115
|
+
|
116
|
+
node.body.append(
|
117
|
+
ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
|
118
|
+
)
|
158
119
|
node.body.append(self._launch)
|
159
120
|
|
160
121
|
return node
|
@@ -162,6 +123,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
162
123
|
def visit_FunctionDef(self, node):
|
163
124
|
self._func_def = node
|
164
125
|
|
126
|
+
self._func_name_with_auto_tuning = f"{self._func_def.name}_with_auto_tuning"
|
127
|
+
|
165
128
|
self._invariants = {}
|
166
129
|
|
167
130
|
self.generic_visit(node)
|
@@ -184,6 +147,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
184
147
|
if naming.is_constexpr(name)
|
185
148
|
}
|
186
149
|
|
150
|
+
non_meta_names = sorted(non_meta_names)
|
151
|
+
meta_names = sorted(meta_names)
|
152
|
+
|
187
153
|
node.args = [
|
188
154
|
ast.arg(arg=name)
|
189
155
|
if not naming.is_constexpr(name)
|
@@ -194,8 +160,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
194
160
|
for name in meta_names
|
195
161
|
]
|
196
162
|
|
197
|
-
|
198
|
-
self._func_def.decorator_list = [
|
163
|
+
self._autotune = self._generate_autotune(non_meta_names, meta_names)
|
164
|
+
self._func_def.decorator_list = [Symbol("triton.jit").node]
|
199
165
|
|
200
166
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
201
167
|
|
@@ -354,7 +320,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
354
320
|
]
|
355
321
|
|
356
322
|
launch = ast.FunctionDef(
|
357
|
-
name=
|
323
|
+
name=self.launch_func_name,
|
358
324
|
args=ast.arguments(
|
359
325
|
posonlyargs=[],
|
360
326
|
args=[ast.arg(arg=arg.source.name) for arg in self._args]
|
@@ -392,7 +358,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
392
358
|
ast.Expr(
|
393
359
|
ast.Call(
|
394
360
|
func=ast.Subscript(
|
395
|
-
value=ast.Name(
|
361
|
+
value=ast.Name(
|
362
|
+
id=self._func_name_with_auto_tuning, ctx=ast.Load()
|
363
|
+
),
|
396
364
|
slice=self._generate_grid(),
|
397
365
|
ctx=ast.Load(),
|
398
366
|
),
|
@@ -422,14 +390,23 @@ class CodeGenerator(ast.NodeTransformer):
|
|
422
390
|
|
423
391
|
MetaEncloser(meta).visit(launch)
|
424
392
|
|
425
|
-
|
393
|
+
if self._caller == "torch":
|
394
|
+
Torchifier().visit(launch)
|
395
|
+
elif self._caller == "cuda":
|
396
|
+
Cudaifier().visit(launch)
|
397
|
+
else:
|
398
|
+
raise ValueError(f"Unsupported caller: `{self._caller}`.")
|
426
399
|
|
427
400
|
return launch
|
428
401
|
|
429
402
|
def _generate_grid(self):
|
430
403
|
num_elements = functools.reduce(lambda x, y: x * y, self._args[0].shape)
|
431
404
|
|
432
|
-
|
405
|
+
grid = ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
406
|
+
|
407
|
+
self.raw_grid = copy.deepcopy(grid)
|
408
|
+
|
409
|
+
return grid
|
433
410
|
|
434
411
|
def _generate_load(self, tensor, indices=()):
|
435
412
|
if tensor.ndim == 0:
|
@@ -851,7 +828,7 @@ class _Inliner(ast.NodeTransformer):
|
|
851
828
|
return names
|
852
829
|
|
853
830
|
def _make_temporary():
|
854
|
-
prefix = naming.auto_generate(f
|
831
|
+
prefix = f"{naming.auto_generate(f'temporary_{self._count}')}_"
|
855
832
|
self._count += 1
|
856
833
|
|
857
834
|
return prefix
|
@@ -941,16 +918,6 @@ class _SimplifiedNameCollector(ast.NodeVisitor):
|
|
941
918
|
self.simplified_names[node.id] = naming.remove_prefixes(node.id)
|
942
919
|
|
943
920
|
|
944
|
-
class _Handle:
|
945
|
-
def __init__(self, kernel, launch, source):
|
946
|
-
self._kernel = kernel
|
947
|
-
self._launch = launch
|
948
|
-
self._source = source
|
949
|
-
|
950
|
-
def __call__(self, *args, **kwargs):
|
951
|
-
return self._launch(*args, **kwargs)
|
952
|
-
|
953
|
-
|
954
921
|
class _AliasRestorer(ast.NodeTransformer):
|
955
922
|
def __init__(self):
|
956
923
|
super().__init__()
|
@@ -0,0 +1,77 @@
|
|
1
|
+
import importlib
|
2
|
+
import sys
|
3
|
+
|
4
|
+
from ninetoothed.generation import CodeGenerator
|
5
|
+
|
6
|
+
|
7
|
+
def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
8
|
+
"""A decorator for generating compute kernels.
|
9
|
+
|
10
|
+
:param func: The function to be compiled.
|
11
|
+
:param caller: Who will call the compute kernel.
|
12
|
+
:param kernel_name: The name for the generated kernel.
|
13
|
+
:param _prettify: Whether to prettify the generated code.
|
14
|
+
:return: A handle to the compute kernel.
|
15
|
+
|
16
|
+
.. note::
|
17
|
+
|
18
|
+
The ``_prettify`` parameter is experimental, which might break
|
19
|
+
the generated code.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def wrapper(func):
|
23
|
+
return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
|
24
|
+
|
25
|
+
if func is None:
|
26
|
+
return wrapper
|
27
|
+
|
28
|
+
return wrapper(func)
|
29
|
+
|
30
|
+
|
31
|
+
class JIT:
|
32
|
+
def __init__(self, func, caller, kernel_name, _prettify=False):
|
33
|
+
self.func = func
|
34
|
+
|
35
|
+
self._caller = caller
|
36
|
+
|
37
|
+
if kernel_name is not None:
|
38
|
+
self._kernel_name = kernel_name
|
39
|
+
else:
|
40
|
+
self._kernel_name = func.__name__
|
41
|
+
|
42
|
+
self._prettify = _prettify
|
43
|
+
|
44
|
+
def __call__(self):
|
45
|
+
code_generator = CodeGenerator()
|
46
|
+
source_file = code_generator(
|
47
|
+
self.func, self._caller, self._kernel_name, self._prettify
|
48
|
+
)
|
49
|
+
module = type(self)._import_from_path(source_file, source_file)
|
50
|
+
module_vars = vars(module)
|
51
|
+
|
52
|
+
handle = _Handle(
|
53
|
+
module_vars[self._kernel_name],
|
54
|
+
module_vars[code_generator.launch_func_name],
|
55
|
+
source_file,
|
56
|
+
)
|
57
|
+
|
58
|
+
return handle
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def _import_from_path(module_name, file_path):
|
62
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
63
|
+
module = importlib.util.module_from_spec(spec)
|
64
|
+
sys.modules[module_name] = module
|
65
|
+
spec.loader.exec_module(module)
|
66
|
+
|
67
|
+
return module
|
68
|
+
|
69
|
+
|
70
|
+
class _Handle:
|
71
|
+
def __init__(self, kernel, launch, source):
|
72
|
+
self._kernel = kernel
|
73
|
+
self._launch = launch
|
74
|
+
self._source = source
|
75
|
+
|
76
|
+
def __call__(self, *args, **kwargs):
|
77
|
+
return self._launch(*args, **kwargs)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
from ninetoothed.aot import aot
|
4
|
+
from ninetoothed.jit import jit
|
5
|
+
|
6
|
+
|
7
|
+
def make(
|
8
|
+
arrangement,
|
9
|
+
application,
|
10
|
+
tensors,
|
11
|
+
caller="torch",
|
12
|
+
kernel_name=None,
|
13
|
+
output_dir=None,
|
14
|
+
num_warps=4,
|
15
|
+
num_stages=3,
|
16
|
+
):
|
17
|
+
"""Integrate the arrangement and the application of the tensors.
|
18
|
+
|
19
|
+
:param arrangement: The arrangement of the tensors.
|
20
|
+
:param application: The application of the tensors.
|
21
|
+
:param tensors: The tensors.
|
22
|
+
:param caller: Who will call the compute kernel.
|
23
|
+
:param kernel_name: The name for the generated kernel.
|
24
|
+
:param output_dir: The directory to store the generated files.
|
25
|
+
:param num_warps: The number of warps to use.
|
26
|
+
:param num_stages: The number of pipeline stages.
|
27
|
+
:return: A handle to the compute kernel.
|
28
|
+
"""
|
29
|
+
|
30
|
+
params = inspect.signature(application).parameters
|
31
|
+
types = arrangement(*tensors)
|
32
|
+
annotations = {param: type for param, type in zip(params, types)}
|
33
|
+
application.__annotations__ = annotations
|
34
|
+
|
35
|
+
if caller == "torch":
|
36
|
+
return jit(application, caller=caller, kernel_name=kernel_name)
|
37
|
+
|
38
|
+
return aot(
|
39
|
+
application,
|
40
|
+
caller=caller,
|
41
|
+
kernel_name=kernel_name,
|
42
|
+
output_dir=output_dir,
|
43
|
+
num_warps=num_warps,
|
44
|
+
num_stages=num_stages,
|
45
|
+
)
|
@@ -118,10 +118,16 @@ def _visualize_unit_square(ax, x, y, color):
|
|
118
118
|
|
119
119
|
|
120
120
|
def _visualize_rect(ax, width, height, x, y, color):
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
121
|
+
ax.add_patch(
|
122
|
+
plt.Rectangle(
|
123
|
+
(x, y),
|
124
|
+
width,
|
125
|
+
height,
|
126
|
+
edgecolor="k",
|
127
|
+
facecolor=color,
|
128
|
+
linewidth=plt.rcParams["lines.linewidth"],
|
129
|
+
)
|
130
|
+
)
|
125
131
|
|
126
132
|
|
127
133
|
def _verts_of_rect(width, height, x, y):
|
@@ -3,6 +3,7 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
import ninetoothed
|
6
|
+
import ninetoothed.language as ntl
|
6
7
|
import tests.test_matmul as matmul
|
7
8
|
from ninetoothed import Tensor
|
8
9
|
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
@@ -19,8 +20,9 @@ def arrangement(input, mat1, mat2, beta, alpha, output):
|
|
19
20
|
|
20
21
|
|
21
22
|
def application(input, mat1, mat2, beta, alpha, output):
|
22
|
-
|
23
|
-
|
23
|
+
matmul_output = ntl.zeros(output.shape, dtype=ntl.float32)
|
24
|
+
matmul.application(mat1, mat2, matmul_output)
|
25
|
+
output = beta * input + alpha * matmul_output
|
24
26
|
|
25
27
|
|
26
28
|
def addmm(input, mat1, mat2, beta=1, alpha=1):
|
@@ -43,6 +45,7 @@ def addmm(input, mat1, mat2, beta=1, alpha=1):
|
|
43
45
|
class TestCUDA:
|
44
46
|
@classmethod
|
45
47
|
def setup_class(cls):
|
48
|
+
random.seed(0)
|
46
49
|
torch.manual_seed(0)
|
47
50
|
|
48
51
|
shape = (512, 512)
|
@@ -74,9 +77,6 @@ class TestCUDA:
|
|
74
77
|
beta = type(self).beta
|
75
78
|
alpha = type(self).alpha
|
76
79
|
|
77
|
-
# TODO: The current application function inlining feature
|
78
|
-
# causes some precision issues. Consider reducing `atol` and
|
79
|
-
# `rtol` of this test in the future.
|
80
80
|
assert torch.allclose(
|
81
81
|
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
|
82
82
|
torch.addmm(
|
@@ -86,6 +86,5 @@ class TestCUDA:
|
|
86
86
|
beta=beta,
|
87
87
|
alpha=alpha,
|
88
88
|
),
|
89
|
-
atol=0.
|
90
|
-
rtol=0.5,
|
89
|
+
atol=0.125,
|
91
90
|
)
|
@@ -0,0 +1,153 @@
|
|
1
|
+
import ctypes
|
2
|
+
import functools
|
3
|
+
import subprocess
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
import ninetoothed
|
9
|
+
import ninetoothed.generation
|
10
|
+
import tests.test_conv2d as conv2d
|
11
|
+
import tests.test_matmul as matmul
|
12
|
+
from ninetoothed import Tensor
|
13
|
+
from tests.skippers import skip_if_cuda_not_available
|
14
|
+
|
15
|
+
|
16
|
+
@skip_if_cuda_not_available
|
17
|
+
class TestCUDA:
|
18
|
+
@classmethod
|
19
|
+
def setup_class(cls):
|
20
|
+
torch.manual_seed(0)
|
21
|
+
|
22
|
+
def test_matmul(self):
|
23
|
+
arrangement = functools.partial(
|
24
|
+
matmul.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
|
25
|
+
)
|
26
|
+
application = matmul.application
|
27
|
+
tensors = tuple(Tensor(2, dtype=ninetoothed.float16) for _ in range(3))
|
28
|
+
caller = "cuda"
|
29
|
+
kernel_name = "matmul"
|
30
|
+
output_dir = ninetoothed.generation.CACHE_DIR
|
31
|
+
|
32
|
+
launch_func = _generate_launch_func(
|
33
|
+
arrangement,
|
34
|
+
application,
|
35
|
+
tensors,
|
36
|
+
caller=caller,
|
37
|
+
kernel_name=kernel_name,
|
38
|
+
output_dir=output_dir,
|
39
|
+
)
|
40
|
+
|
41
|
+
shape = (512, 512)
|
42
|
+
dtype = torch.float16
|
43
|
+
device = caller
|
44
|
+
|
45
|
+
lhs = torch.randn(shape, dtype=dtype, device=device)
|
46
|
+
rhs = torch.randn(shape, dtype=dtype, device=device)
|
47
|
+
output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device)
|
48
|
+
|
49
|
+
_run_launch_func(launch_func, lhs, rhs, output)
|
50
|
+
|
51
|
+
assert torch.allclose(output, torch.matmul(lhs, rhs))
|
52
|
+
|
53
|
+
def test_conv2d(self):
|
54
|
+
arrangement = functools.partial(
|
55
|
+
conv2d.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
|
56
|
+
)
|
57
|
+
application = matmul.application
|
58
|
+
tensors = tuple(Tensor(4, dtype=ninetoothed.float16) for _ in range(3))
|
59
|
+
caller = "cuda"
|
60
|
+
kernel_name = "conv2d"
|
61
|
+
output_dir = ninetoothed.generation.CACHE_DIR
|
62
|
+
|
63
|
+
launch_func = _generate_launch_func(
|
64
|
+
arrangement,
|
65
|
+
application,
|
66
|
+
tensors,
|
67
|
+
caller=caller,
|
68
|
+
kernel_name=kernel_name,
|
69
|
+
output_dir=output_dir,
|
70
|
+
)
|
71
|
+
|
72
|
+
n, c, h, w = 4, 64, 16, 16
|
73
|
+
k, _, r, s = 512, c, 3, 3
|
74
|
+
p = h - r + 1
|
75
|
+
q = w - s + 1
|
76
|
+
dtype = torch.float16
|
77
|
+
device = caller
|
78
|
+
|
79
|
+
input = torch.randn(n, c, h, w, dtype=dtype, device=device)
|
80
|
+
filter = torch.randn(k, c, r, s, dtype=dtype, device=device)
|
81
|
+
output = torch.empty(n, k, p, q, dtype=dtype, device=device)
|
82
|
+
|
83
|
+
_run_launch_func(launch_func, input, filter, output)
|
84
|
+
|
85
|
+
assert torch.allclose(output, F.conv2d(input, filter), atol=0.001, rtol=0.001)
|
86
|
+
|
87
|
+
|
88
|
+
class _ArgumentTensor(ctypes.Structure):
|
89
|
+
_fields_ = [
|
90
|
+
("data", ctypes.c_void_p),
|
91
|
+
("shape", ctypes.POINTER(ctypes.c_uint64)),
|
92
|
+
("strides", ctypes.POINTER(ctypes.c_int64)),
|
93
|
+
]
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def from_torch_tensor(tensor):
|
97
|
+
data = ctypes.c_void_p(tensor.data_ptr())
|
98
|
+
shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
|
99
|
+
strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())
|
100
|
+
|
101
|
+
return _ArgumentTensor(data, shape, strides)
|
102
|
+
|
103
|
+
|
104
|
+
def _run_launch_func(launch_func, *tensors):
|
105
|
+
stream = torch.cuda.Stream()
|
106
|
+
|
107
|
+
arg_tensors = tuple(_ArgumentTensor.from_torch_tensor(tensor) for tensor in tensors)
|
108
|
+
|
109
|
+
with torch.cuda.stream(stream):
|
110
|
+
launch_func(ctypes.c_void_p(stream.cuda_stream), *arg_tensors)
|
111
|
+
|
112
|
+
stream.synchronize()
|
113
|
+
|
114
|
+
|
115
|
+
def _generate_launch_func(
|
116
|
+
arrangement, application, tensors, caller, kernel_name, output_dir
|
117
|
+
):
|
118
|
+
ninetoothed.make(
|
119
|
+
arrangement,
|
120
|
+
application,
|
121
|
+
tensors,
|
122
|
+
caller=caller,
|
123
|
+
kernel_name=kernel_name,
|
124
|
+
output_dir=output_dir,
|
125
|
+
)
|
126
|
+
|
127
|
+
_compile_library(kernel_name, output_dir)
|
128
|
+
library = _load_library(kernel_name, output_dir)
|
129
|
+
launch_func_name = f"launch_{kernel_name}"
|
130
|
+
launch_func = getattr(library, launch_func_name)
|
131
|
+
launch_func.argtypes = (ctypes.c_void_p,) + tuple(_ArgumentTensor for _ in tensors)
|
132
|
+
launch_func.restype = ctypes.c_int
|
133
|
+
|
134
|
+
return launch_func
|
135
|
+
|
136
|
+
|
137
|
+
def _compile_library(kernel_name, output_dir):
|
138
|
+
command = [
|
139
|
+
"nvcc",
|
140
|
+
"-shared",
|
141
|
+
"-Xcompiler",
|
142
|
+
"-fPIC",
|
143
|
+
"-lcuda",
|
144
|
+
"-o",
|
145
|
+
output_dir / f"{kernel_name}.so",
|
146
|
+
output_dir / f"{kernel_name}.c",
|
147
|
+
]
|
148
|
+
|
149
|
+
subprocess.run(command, check=True)
|
150
|
+
|
151
|
+
|
152
|
+
def _load_library(kernel_name, kernel_dir):
|
153
|
+
return ctypes.CDLL(kernel_dir / f"{kernel_name}.so")
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import functools
|
2
|
+
|
1
3
|
import torch
|
2
4
|
import torch.nn.functional as F
|
3
5
|
|
@@ -7,7 +9,14 @@ from ninetoothed import Tensor
|
|
7
9
|
from tests.skippers import skip_if_cuda_not_available
|
8
10
|
|
9
11
|
|
10
|
-
def arrangement(
|
12
|
+
def arrangement(
|
13
|
+
input,
|
14
|
+
filter,
|
15
|
+
output,
|
16
|
+
BLOCK_SIZE_M=matmul.BLOCK_SIZE_M,
|
17
|
+
BLOCK_SIZE_N=matmul.BLOCK_SIZE_N,
|
18
|
+
BLOCK_SIZE_K=matmul.BLOCK_SIZE_K,
|
19
|
+
):
|
11
20
|
input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1))
|
12
21
|
input_squeezed = input_tiled.squeeze(1)
|
13
22
|
input_squeezed.dtype = input_squeezed.dtype.squeeze(0)
|
@@ -19,7 +28,12 @@ def arrangement(input, filter, output):
|
|
19
28
|
|
20
29
|
output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
|
21
30
|
|
22
|
-
return
|
31
|
+
return functools.partial(
|
32
|
+
matmul.arrangement,
|
33
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
34
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
35
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
36
|
+
)(input_flattened, filter_permuted, output_flattened)
|
23
37
|
|
24
38
|
|
25
39
|
def conv2d(input, filter):
|
@@ -5,12 +5,19 @@ import ninetoothed.language as ntl
|
|
5
5
|
from ninetoothed import Symbol, Tensor
|
6
6
|
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
7
7
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
8
|
+
BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
|
9
|
+
BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
|
10
|
+
BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
|
11
|
+
|
12
|
+
|
13
|
+
def arrangement(
|
14
|
+
lhs,
|
15
|
+
rhs,
|
16
|
+
output,
|
17
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
18
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
19
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
20
|
+
):
|
14
21
|
output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
15
22
|
|
16
23
|
lhs_tiled = (
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|