ninetoothed 0.14.0__tar.gz → 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
- ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
- ninetoothed-0.15.0/.github/pull_request_template.md +5 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/PKG-INFO +1 -1
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/pyproject.toml +1 -1
- ninetoothed-0.15.0/src/ninetoothed/__init__.py +35 -0
- ninetoothed-0.15.0/src/ninetoothed/aot.py +217 -0
- ninetoothed-0.15.0/src/ninetoothed/cudaifier.py +36 -0
- ninetoothed-0.15.0/src/ninetoothed/dtype.py +13 -0
- ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.0/src/ninetoothed/generation.py +82 -116
- ninetoothed-0.15.0/src/ninetoothed/jit.py +77 -0
- ninetoothed-0.15.0/src/ninetoothed/make.py +45 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_addmm.py +6 -7
- ninetoothed-0.14.0/src/ninetoothed/__init__.py +0 -5
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.gitattributes +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/publish-to-pypi.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/sphinx.yml +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.gitignore +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/LICENSE +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/README.md +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/Makefile +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/README.zh.md +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/make.bat +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/requirements.txt +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/matmul-tiling.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/ninetoothed-logo.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/code_generation.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/conf.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/index.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/installation.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/python_api.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/symbol.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/tensor.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/visualization.rst +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/requirements.txt +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/naming.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/symbol.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/tensor.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/visualization.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/__init__.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/skippers.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_add.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_attention.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_conv2d.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_matmul.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_max_pool2d.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_naming.py +0 -0
- {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
name: 🐛 Bug report
|
2
|
+
description: Something isn't working as expected 🤔.
|
3
|
+
labels: ["bug"]
|
4
|
+
|
5
|
+
body:
|
6
|
+
- type: markdown
|
7
|
+
attributes:
|
8
|
+
value: Thanks for taking the time to fill out this bug report!
|
9
|
+
|
10
|
+
- type: checkboxes
|
11
|
+
attributes:
|
12
|
+
label: Is there an existing issue for this?
|
13
|
+
description: >
|
14
|
+
Please search to see if an issue already exists
|
15
|
+
for the bug you encountered.
|
16
|
+
options:
|
17
|
+
- label: I have searched the existing issues.
|
18
|
+
required: true
|
19
|
+
|
20
|
+
- type: textarea
|
21
|
+
attributes:
|
22
|
+
label: "Describe the bug:"
|
23
|
+
description: A clear and concise description of what the bug is.
|
24
|
+
validations:
|
25
|
+
required: false
|
26
|
+
|
27
|
+
- type: textarea
|
28
|
+
attributes:
|
29
|
+
label: "To reproduce:"
|
30
|
+
description: >
|
31
|
+
Steps to reproduce the behavior.
|
32
|
+
If applicable, provide a small, self-contained piece of code
|
33
|
+
that can be run directly to reproduce the issue.
|
34
|
+
validations:
|
35
|
+
required: false
|
36
|
+
|
37
|
+
- type: textarea
|
38
|
+
attributes:
|
39
|
+
label: "Expected behavior:"
|
40
|
+
description: >
|
41
|
+
A clear and concise description of what you expected to happen.
|
42
|
+
validations:
|
43
|
+
required: false
|
44
|
+
|
45
|
+
- type: textarea
|
46
|
+
attributes:
|
47
|
+
label: "Environment details:"
|
48
|
+
description: >
|
49
|
+
Please include your NineToothed version, operating system,
|
50
|
+
hardware platform, and any relevant information.
|
51
|
+
If you are using PyTorch, please run
|
52
|
+
`python -m torch.utils.collect_env` to gather
|
53
|
+
environment information.
|
54
|
+
validations:
|
55
|
+
required: false
|
@@ -0,0 +1,13 @@
|
|
1
|
+
name: 🚀 Feature request
|
2
|
+
description: I have a suggestion 🙂!
|
3
|
+
|
4
|
+
body:
|
5
|
+
- type: textarea
|
6
|
+
attributes:
|
7
|
+
label: "Description & motivation:"
|
8
|
+
description: >
|
9
|
+
Please describe the feature that you would like to see and
|
10
|
+
explain the problem it would solve or
|
11
|
+
the benefit it would provide.
|
12
|
+
validations:
|
13
|
+
required: true
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.15.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.15.0"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from ninetoothed.dtype import (
|
2
|
+
float16,
|
3
|
+
float32,
|
4
|
+
float64,
|
5
|
+
int8,
|
6
|
+
int16,
|
7
|
+
int32,
|
8
|
+
int64,
|
9
|
+
uint8,
|
10
|
+
uint16,
|
11
|
+
uint32,
|
12
|
+
uint64,
|
13
|
+
)
|
14
|
+
from ninetoothed.jit import jit
|
15
|
+
from ninetoothed.make import make
|
16
|
+
from ninetoothed.symbol import Symbol
|
17
|
+
from ninetoothed.tensor import Tensor
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"Symbol",
|
21
|
+
"Tensor",
|
22
|
+
"float16",
|
23
|
+
"float32",
|
24
|
+
"float64",
|
25
|
+
"int8",
|
26
|
+
"int16",
|
27
|
+
"int32",
|
28
|
+
"int64",
|
29
|
+
"jit",
|
30
|
+
"make",
|
31
|
+
"uint8",
|
32
|
+
"uint16",
|
33
|
+
"uint32",
|
34
|
+
"uint64",
|
35
|
+
]
|
@@ -0,0 +1,217 @@
|
|
1
|
+
import ast
|
2
|
+
import pathlib
|
3
|
+
import subprocess
|
4
|
+
import tempfile
|
5
|
+
import uuid
|
6
|
+
|
7
|
+
from ninetoothed.dtype import int64, uint64
|
8
|
+
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
|
+
from ninetoothed.tensor import Tensor
|
10
|
+
|
11
|
+
|
12
|
+
def aot(
|
13
|
+
func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
|
14
|
+
):
|
15
|
+
output_dir = pathlib.Path(output_dir)
|
16
|
+
|
17
|
+
output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
|
18
|
+
|
19
|
+
for output_name, output_content in output_contents.items():
|
20
|
+
output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
|
21
|
+
|
22
|
+
with open(output_path, "w") as f:
|
23
|
+
f.write(output_content)
|
24
|
+
|
25
|
+
|
26
|
+
def _aot(func, caller, kernel_name, num_warps, num_stages):
|
27
|
+
def _find_tensor_by_source_name(tensors, name):
|
28
|
+
for tensor in tensors:
|
29
|
+
if tensor.source.name == name:
|
30
|
+
return tensor
|
31
|
+
|
32
|
+
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
33
|
+
|
34
|
+
if not _HEADER_PATH.exists():
|
35
|
+
_HEADER_PATH.write_text(_HEADER_CONTENT)
|
36
|
+
|
37
|
+
code_generator = CodeGenerator()
|
38
|
+
source_file = code_generator(
|
39
|
+
func, caller=caller, kernel_name=kernel_name, prettify=False
|
40
|
+
)
|
41
|
+
|
42
|
+
tensors = code_generator.tensors
|
43
|
+
kernel_func = code_generator.kernel_func
|
44
|
+
launch_func = code_generator.launch_func
|
45
|
+
|
46
|
+
param_types = []
|
47
|
+
|
48
|
+
for arg in kernel_func.args.args:
|
49
|
+
param = arg.arg
|
50
|
+
|
51
|
+
if match := Tensor.pointer_pattern().fullmatch(param):
|
52
|
+
source_name = match.group(0).removesuffix("_pointer")
|
53
|
+
tensor = _find_tensor_by_source_name(tensors, source_name)
|
54
|
+
dtype = tensor.source.dtype
|
55
|
+
|
56
|
+
param_types.append(f"*{dtype}")
|
57
|
+
elif Tensor.size_pattern().fullmatch(param):
|
58
|
+
param_types.append(uint64)
|
59
|
+
elif Tensor.stride_pattern().fullmatch(param):
|
60
|
+
param_types.append(int64)
|
61
|
+
|
62
|
+
signature = ", ".join(param_types)
|
63
|
+
|
64
|
+
grid_extractor = _GridExtractor()
|
65
|
+
launch_func = grid_extractor.visit(launch_func)
|
66
|
+
grid_extractor.visit(code_generator.raw_grid)
|
67
|
+
grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
|
68
|
+
|
69
|
+
signature_hash, output_contents = _compile(
|
70
|
+
source_file, kernel_name, signature, grid, num_warps, num_stages
|
71
|
+
)
|
72
|
+
|
73
|
+
unparser = _Unparser()
|
74
|
+
|
75
|
+
launch_func_unparsed = unparser.unparse(launch_func)
|
76
|
+
launch_func_unparsed = launch_func_unparsed.replace(
|
77
|
+
func.__name__, f"{kernel_name}_{signature_hash}"
|
78
|
+
)
|
79
|
+
|
80
|
+
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
81
|
+
c_source_file = output_contents[c_source_file_name]
|
82
|
+
c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
|
83
|
+
c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
84
|
+
output_contents[c_source_file_name] = c_source_file
|
85
|
+
|
86
|
+
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
87
|
+
c_header_file = output_contents[c_header_file_name]
|
88
|
+
c_header_file = f"{c_header_file}\n{unparser.header};\n"
|
89
|
+
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
90
|
+
output_contents[c_header_file_name] = c_header_file
|
91
|
+
|
92
|
+
return output_contents
|
93
|
+
|
94
|
+
|
95
|
+
_HEADER_CONTENT = """#include <stdint.h>
|
96
|
+
|
97
|
+
typedef struct {
|
98
|
+
uintptr_t data;
|
99
|
+
uint64_t *shape;
|
100
|
+
int64_t *strides;
|
101
|
+
} NineToothedTensor;
|
102
|
+
"""
|
103
|
+
|
104
|
+
_HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
105
|
+
|
106
|
+
|
107
|
+
class _Unparser:
|
108
|
+
def unparse(self, node):
|
109
|
+
method_name = "_unparse_" + node.__class__.__name__
|
110
|
+
|
111
|
+
if hasattr(self, method_name):
|
112
|
+
return getattr(self, method_name)(node)
|
113
|
+
|
114
|
+
return self._generic_unparse(node)
|
115
|
+
|
116
|
+
def _generic_unparse(self, node):
|
117
|
+
return ast.unparse(node)
|
118
|
+
|
119
|
+
def _unparse_Expr(self, node):
|
120
|
+
return self.unparse(node.value)
|
121
|
+
|
122
|
+
def _unparse_Call(self, node):
|
123
|
+
call = ast.Call(
|
124
|
+
func=node.func,
|
125
|
+
args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
|
126
|
+
keywords=[],
|
127
|
+
)
|
128
|
+
|
129
|
+
return f"return {self._generic_unparse(call)};"
|
130
|
+
|
131
|
+
def _unparse_FunctionDef(self, node):
|
132
|
+
params = ["CUstream stream"]
|
133
|
+
params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
|
134
|
+
header = f"CUresult {node.name}({', '.join(params)})"
|
135
|
+
|
136
|
+
self.header = header
|
137
|
+
|
138
|
+
body_lines = []
|
139
|
+
|
140
|
+
for stmt in node.body:
|
141
|
+
stmt_unparsed = self.unparse(stmt)
|
142
|
+
|
143
|
+
if isinstance(stmt, ast.Expr):
|
144
|
+
stmt_unparsed = stmt_unparsed.strip()
|
145
|
+
|
146
|
+
if not stmt_unparsed.endswith(";"):
|
147
|
+
stmt_unparsed += ";"
|
148
|
+
|
149
|
+
body_lines.append(" " + stmt_unparsed)
|
150
|
+
|
151
|
+
body = "\n".join(body_lines)
|
152
|
+
|
153
|
+
return f"{header} {{\n{body}\n}}"
|
154
|
+
|
155
|
+
|
156
|
+
class _GridExtractor(ast.NodeTransformer):
|
157
|
+
def visit_BinOp(self, node):
|
158
|
+
self.generic_visit(node)
|
159
|
+
|
160
|
+
if isinstance(node.op, ast.FloorDiv):
|
161
|
+
node.op = ast.Div()
|
162
|
+
|
163
|
+
return node
|
164
|
+
|
165
|
+
def visit_Call(self, node):
|
166
|
+
self.generic_visit(node)
|
167
|
+
|
168
|
+
node.func = node.func.value
|
169
|
+
|
170
|
+
return node
|
171
|
+
|
172
|
+
def visit_Lambda(self, node):
|
173
|
+
self.generic_visit(node)
|
174
|
+
|
175
|
+
self.grid = node.body.elts
|
176
|
+
|
177
|
+
return node
|
178
|
+
|
179
|
+
|
180
|
+
def _compile(path, name, signature, grid, num_warps, num_stages):
|
181
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
182
|
+
output_dir = pathlib.Path(temp_dir)
|
183
|
+
output_name = uuid.uuid4().hex
|
184
|
+
output_path = output_dir / output_name
|
185
|
+
|
186
|
+
command = [
|
187
|
+
"python",
|
188
|
+
"-m",
|
189
|
+
"triton.tools.compile",
|
190
|
+
str(path),
|
191
|
+
"--kernel-name",
|
192
|
+
str(name),
|
193
|
+
"--signature",
|
194
|
+
str(signature),
|
195
|
+
"--grid",
|
196
|
+
str(grid),
|
197
|
+
"--num-warps",
|
198
|
+
str(num_warps),
|
199
|
+
"--num-stages",
|
200
|
+
str(num_stages),
|
201
|
+
"--out-path",
|
202
|
+
str(output_path),
|
203
|
+
]
|
204
|
+
|
205
|
+
subprocess.run(command, check=True)
|
206
|
+
|
207
|
+
matching_files = list(output_dir.glob(f"{output_name}.*"))
|
208
|
+
|
209
|
+
signature_hash = matching_files[0].name.split(".")[1]
|
210
|
+
|
211
|
+
output_contents = {}
|
212
|
+
|
213
|
+
for file in matching_files:
|
214
|
+
with file.open() as f:
|
215
|
+
output_contents[file.name.replace(output_name, name)] = f.read()
|
216
|
+
|
217
|
+
return signature_hash, output_contents
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
import ninetoothed.naming as naming
|
4
|
+
from ninetoothed.tensor import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class Cudaifier(ast.NodeTransformer):
|
8
|
+
def visit_Name(self, node):
|
9
|
+
self.generic_visit(node)
|
10
|
+
|
11
|
+
source = node.id
|
12
|
+
|
13
|
+
if naming.is_constexpr(source):
|
14
|
+
return node
|
15
|
+
|
16
|
+
def repl(match):
|
17
|
+
return f"{match.group(1)}.data"
|
18
|
+
|
19
|
+
source = Tensor.pointer_pattern().sub(repl, source)
|
20
|
+
|
21
|
+
def repl(match):
|
22
|
+
return f"{match.group(1)}.shape[{match.group(3)}]"
|
23
|
+
|
24
|
+
source = Tensor.size_pattern().sub(repl, source)
|
25
|
+
|
26
|
+
def repl(match):
|
27
|
+
return f"{match.group(1)}.strides[{match.group(3)}]"
|
28
|
+
|
29
|
+
source = Tensor.stride_pattern().sub(repl, source)
|
30
|
+
|
31
|
+
source = source.removesuffix("_with_auto_tuning")
|
32
|
+
|
33
|
+
if source != node.id:
|
34
|
+
return ast.parse(source, mode="eval").body
|
35
|
+
|
36
|
+
return node
|
ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.0/src/ninetoothed/generation.py
RENAMED
@@ -2,84 +2,88 @@ import ast
|
|
2
2
|
import collections
|
3
3
|
import copy
|
4
4
|
import functools
|
5
|
-
import
|
5
|
+
import hashlib
|
6
6
|
import inspect
|
7
7
|
import itertools
|
8
8
|
import math
|
9
|
+
import pathlib
|
9
10
|
import subprocess
|
10
|
-
import sys
|
11
|
-
import tempfile
|
12
11
|
|
13
12
|
import triton
|
14
13
|
|
15
14
|
import ninetoothed.naming as naming
|
15
|
+
from ninetoothed.cudaifier import Cudaifier
|
16
16
|
from ninetoothed.language import attribute, call
|
17
17
|
from ninetoothed.symbol import Symbol
|
18
18
|
from ninetoothed.tensor import Tensor
|
19
19
|
from ninetoothed.torchifier import Torchifier
|
20
20
|
|
21
|
+
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
|
21
22
|
|
22
|
-
def make(arrangement, application, tensors):
|
23
|
-
"""Integrate the arrangement and the application of the tensors.
|
24
23
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
24
|
+
class CodeGenerator(ast.NodeTransformer):
|
25
|
+
def __init__(self):
|
26
|
+
super().__init__()
|
27
|
+
|
28
|
+
self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
|
29
|
+
|
30
|
+
self._MIN_PRODUCT = 2**10
|
31
|
+
|
32
|
+
self._MAX_PRODUCT = 2**20
|
34
33
|
|
35
|
-
|
34
|
+
def __call__(self, func, caller, kernel_name, prettify):
|
35
|
+
def _get_tree(func):
|
36
|
+
module = ast.parse(inspect.getsource(inspect.getmodule(func)))
|
36
37
|
|
38
|
+
collector = _ImportCollector()
|
39
|
+
collector.visit(module)
|
37
40
|
|
38
|
-
|
39
|
-
|
41
|
+
finder = _FunctionDefFinder(func.__name__)
|
42
|
+
finder.visit(module)
|
43
|
+
func_def = finder.result
|
40
44
|
|
41
|
-
|
42
|
-
|
43
|
-
|
45
|
+
inliner = _Inliner(func.__globals__)
|
46
|
+
inliner.visit(func_def)
|
47
|
+
module.body = collector.imports + inliner.imports + [finder.result]
|
44
48
|
|
45
|
-
|
49
|
+
return _AliasRestorer().visit(module)
|
46
50
|
|
47
|
-
|
48
|
-
|
49
|
-
"""
|
51
|
+
def _find_dependencies(func):
|
52
|
+
dependencies = set()
|
50
53
|
|
51
|
-
|
52
|
-
|
54
|
+
for obj in func.__globals__.values():
|
55
|
+
if isinstance(obj, triton.runtime.JITFunction):
|
56
|
+
dependencies.add(obj.src)
|
53
57
|
|
54
|
-
|
55
|
-
|
58
|
+
return "\n".join(
|
59
|
+
f"@triton.jit\n{dependency}" for dependency in dependencies
|
60
|
+
)
|
56
61
|
|
57
|
-
|
62
|
+
self.launch_func_name = f"launch_{kernel_name}"
|
58
63
|
|
64
|
+
self._caller = caller
|
59
65
|
|
60
|
-
|
61
|
-
def __init__(self, func, _prettify=False):
|
62
|
-
self.func = func
|
66
|
+
self._context = inspect.get_annotations(func)
|
63
67
|
|
64
|
-
self.
|
68
|
+
self._args = list(self._context.values())
|
65
69
|
|
66
|
-
|
67
|
-
tree = self._get_tree()
|
70
|
+
tree = _get_tree(func)
|
68
71
|
|
69
|
-
|
72
|
+
self.visit(tree)
|
70
73
|
Tritonizer().visit(tree)
|
71
74
|
_BinOpSimplifier().visit(tree)
|
72
75
|
ast.fix_missing_locations(tree)
|
73
76
|
|
74
|
-
if
|
77
|
+
if prettify:
|
75
78
|
name_collector = _SimplifiedNameCollector()
|
76
79
|
name_collector.visit(tree)
|
77
80
|
|
78
81
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
79
|
-
dependencies =
|
82
|
+
dependencies = _find_dependencies(func)
|
80
83
|
source = "\n\n".join((unparsed, dependencies)).strip()
|
84
|
+
source = source.replace(func.__name__, kernel_name)
|
81
85
|
|
82
|
-
if
|
86
|
+
if prettify:
|
83
87
|
for original, simplified in name_collector.simplified_names.items():
|
84
88
|
if simplified not in name_collector.simplified_names:
|
85
89
|
source = source.replace(original, simplified)
|
@@ -88,73 +92,29 @@ class JIT:
|
|
88
92
|
["ruff", "format", "-"], input=source, encoding="utf-8"
|
89
93
|
)
|
90
94
|
|
91
|
-
|
92
|
-
|
93
|
-
|
95
|
+
digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
|
96
|
+
cache_dir = CACHE_DIR
|
97
|
+
cache_dir.mkdir(exist_ok=True)
|
98
|
+
cache_file = cache_dir / f"{digest}.py"
|
94
99
|
|
95
|
-
|
96
|
-
|
100
|
+
if not cache_file.exists():
|
101
|
+
with open(cache_file, "w", encoding="utf-8") as f:
|
102
|
+
f.write(source)
|
97
103
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
source,
|
102
|
-
)
|
103
|
-
|
104
|
-
return handle
|
105
|
-
|
106
|
-
def _get_tree(self):
|
107
|
-
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
108
|
-
|
109
|
-
collector = _ImportCollector()
|
110
|
-
collector.visit(module)
|
111
|
-
|
112
|
-
finder = _FunctionDefFinder(self.func.__name__)
|
113
|
-
finder.visit(module)
|
114
|
-
func_def = finder.result
|
104
|
+
self.tensors = self._args
|
105
|
+
self.kernel_func = self._func_def
|
106
|
+
self.launch_func = self._launch
|
115
107
|
|
116
|
-
|
117
|
-
inliner.visit(func_def)
|
118
|
-
module.body = collector.imports + inliner.imports + [finder.result]
|
119
|
-
|
120
|
-
return _AliasRestorer().visit(module)
|
121
|
-
|
122
|
-
def _find_dependencies(self):
|
123
|
-
dependencies = set()
|
124
|
-
|
125
|
-
for obj in self.func.__globals__.values():
|
126
|
-
if isinstance(obj, triton.runtime.JITFunction):
|
127
|
-
dependencies.add(obj.src)
|
128
|
-
|
129
|
-
return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
|
130
|
-
|
131
|
-
@staticmethod
|
132
|
-
def _import_from_path(module_name, file_path):
|
133
|
-
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
134
|
-
module = importlib.util.module_from_spec(spec)
|
135
|
-
sys.modules[module_name] = module
|
136
|
-
spec.loader.exec_module(module)
|
137
|
-
|
138
|
-
return module
|
139
|
-
|
140
|
-
|
141
|
-
class CodeGenerator(ast.NodeTransformer):
|
142
|
-
def __init__(self, context):
|
143
|
-
super().__init__()
|
144
|
-
|
145
|
-
self._context = context
|
146
|
-
|
147
|
-
self._args = list(self._context.values())
|
148
|
-
|
149
|
-
self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
|
150
|
-
|
151
|
-
self._MIN_PRODUCT = 2**10
|
152
|
-
|
153
|
-
self._MAX_PRODUCT = 2**20
|
108
|
+
return str(cache_file)
|
154
109
|
|
155
110
|
def visit_Module(self, node):
|
156
111
|
self.generic_visit(node)
|
157
112
|
|
113
|
+
func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
|
114
|
+
|
115
|
+
node.body.append(
|
116
|
+
ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
|
117
|
+
)
|
158
118
|
node.body.append(self._launch)
|
159
119
|
|
160
120
|
return node
|
@@ -162,6 +122,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
162
122
|
def visit_FunctionDef(self, node):
|
163
123
|
self._func_def = node
|
164
124
|
|
125
|
+
self._func_name_with_auto_tuning = f"{self._func_def.name}_with_auto_tuning"
|
126
|
+
|
165
127
|
self._invariants = {}
|
166
128
|
|
167
129
|
self.generic_visit(node)
|
@@ -184,6 +146,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
184
146
|
if naming.is_constexpr(name)
|
185
147
|
}
|
186
148
|
|
149
|
+
non_meta_names = sorted(non_meta_names)
|
150
|
+
meta_names = sorted(meta_names)
|
151
|
+
|
187
152
|
node.args = [
|
188
153
|
ast.arg(arg=name)
|
189
154
|
if not naming.is_constexpr(name)
|
@@ -194,8 +159,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
194
159
|
for name in meta_names
|
195
160
|
]
|
196
161
|
|
197
|
-
|
198
|
-
self._func_def.decorator_list = [
|
162
|
+
self._autotune = self._generate_autotune(non_meta_names, meta_names)
|
163
|
+
self._func_def.decorator_list = [Symbol("triton.jit").node]
|
199
164
|
|
200
165
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
201
166
|
|
@@ -354,7 +319,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
354
319
|
]
|
355
320
|
|
356
321
|
launch = ast.FunctionDef(
|
357
|
-
name=
|
322
|
+
name=self.launch_func_name,
|
358
323
|
args=ast.arguments(
|
359
324
|
posonlyargs=[],
|
360
325
|
args=[ast.arg(arg=arg.source.name) for arg in self._args]
|
@@ -392,7 +357,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
392
357
|
ast.Expr(
|
393
358
|
ast.Call(
|
394
359
|
func=ast.Subscript(
|
395
|
-
value=ast.Name(
|
360
|
+
value=ast.Name(
|
361
|
+
id=self._func_name_with_auto_tuning, ctx=ast.Load()
|
362
|
+
),
|
396
363
|
slice=self._generate_grid(),
|
397
364
|
ctx=ast.Load(),
|
398
365
|
),
|
@@ -422,14 +389,23 @@ class CodeGenerator(ast.NodeTransformer):
|
|
422
389
|
|
423
390
|
MetaEncloser(meta).visit(launch)
|
424
391
|
|
425
|
-
|
392
|
+
if self._caller == "torch":
|
393
|
+
Torchifier().visit(launch)
|
394
|
+
elif self._caller == "cuda":
|
395
|
+
Cudaifier().visit(launch)
|
396
|
+
else:
|
397
|
+
raise ValueError(f"Unsupported caller: `{self._caller}`.")
|
426
398
|
|
427
399
|
return launch
|
428
400
|
|
429
401
|
def _generate_grid(self):
|
430
402
|
num_elements = functools.reduce(lambda x, y: x * y, self._args[0].shape)
|
431
403
|
|
432
|
-
|
404
|
+
grid = ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
405
|
+
|
406
|
+
self.raw_grid = copy.deepcopy(grid)
|
407
|
+
|
408
|
+
return grid
|
433
409
|
|
434
410
|
def _generate_load(self, tensor, indices=()):
|
435
411
|
if tensor.ndim == 0:
|
@@ -851,7 +827,7 @@ class _Inliner(ast.NodeTransformer):
|
|
851
827
|
return names
|
852
828
|
|
853
829
|
def _make_temporary():
|
854
|
-
prefix = naming.auto_generate(f
|
830
|
+
prefix = f"{naming.auto_generate(f'temporary_{self._count}')}_"
|
855
831
|
self._count += 1
|
856
832
|
|
857
833
|
return prefix
|
@@ -941,16 +917,6 @@ class _SimplifiedNameCollector(ast.NodeVisitor):
|
|
941
917
|
self.simplified_names[node.id] = naming.remove_prefixes(node.id)
|
942
918
|
|
943
919
|
|
944
|
-
class _Handle:
|
945
|
-
def __init__(self, kernel, launch, source):
|
946
|
-
self._kernel = kernel
|
947
|
-
self._launch = launch
|
948
|
-
self._source = source
|
949
|
-
|
950
|
-
def __call__(self, *args, **kwargs):
|
951
|
-
return self._launch(*args, **kwargs)
|
952
|
-
|
953
|
-
|
954
920
|
class _AliasRestorer(ast.NodeTransformer):
|
955
921
|
def __init__(self):
|
956
922
|
super().__init__()
|
@@ -0,0 +1,77 @@
|
|
1
|
+
import importlib
|
2
|
+
import sys
|
3
|
+
|
4
|
+
from ninetoothed.generation import CodeGenerator
|
5
|
+
|
6
|
+
|
7
|
+
def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
|
8
|
+
"""A decorator for generating compute kernels.
|
9
|
+
|
10
|
+
:param func: The function to be compiled.
|
11
|
+
:param caller: Who will call the compute kernel.
|
12
|
+
:param kernel_name: The name for the generated kernel.
|
13
|
+
:param _prettify: Whether to prettify the generated code.
|
14
|
+
:return: A handle to the compute kernel.
|
15
|
+
|
16
|
+
.. note::
|
17
|
+
|
18
|
+
The ``_prettify`` parameter is experimental, which might break
|
19
|
+
the generated code.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def wrapper(func):
|
23
|
+
return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
|
24
|
+
|
25
|
+
if func is None:
|
26
|
+
return wrapper
|
27
|
+
|
28
|
+
return wrapper(func)
|
29
|
+
|
30
|
+
|
31
|
+
class JIT:
|
32
|
+
def __init__(self, func, caller, kernel_name, _prettify=False):
|
33
|
+
self.func = func
|
34
|
+
|
35
|
+
self._caller = caller
|
36
|
+
|
37
|
+
if kernel_name is not None:
|
38
|
+
self._kernel_name = kernel_name
|
39
|
+
else:
|
40
|
+
self._kernel_name = func.__name__
|
41
|
+
|
42
|
+
self._prettify = _prettify
|
43
|
+
|
44
|
+
def __call__(self):
|
45
|
+
code_generator = CodeGenerator()
|
46
|
+
source_file = code_generator(
|
47
|
+
self.func, self._caller, self._kernel_name, self._prettify
|
48
|
+
)
|
49
|
+
module = type(self)._import_from_path(source_file, source_file)
|
50
|
+
module_vars = vars(module)
|
51
|
+
|
52
|
+
handle = _Handle(
|
53
|
+
module_vars[self._kernel_name],
|
54
|
+
module_vars[code_generator.launch_func_name],
|
55
|
+
source_file,
|
56
|
+
)
|
57
|
+
|
58
|
+
return handle
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def _import_from_path(module_name, file_path):
|
62
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
63
|
+
module = importlib.util.module_from_spec(spec)
|
64
|
+
sys.modules[module_name] = module
|
65
|
+
spec.loader.exec_module(module)
|
66
|
+
|
67
|
+
return module
|
68
|
+
|
69
|
+
|
70
|
+
class _Handle:
|
71
|
+
def __init__(self, kernel, launch, source):
|
72
|
+
self._kernel = kernel
|
73
|
+
self._launch = launch
|
74
|
+
self._source = source
|
75
|
+
|
76
|
+
def __call__(self, *args, **kwargs):
|
77
|
+
return self._launch(*args, **kwargs)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
from ninetoothed.aot import aot
|
4
|
+
from ninetoothed.jit import jit
|
5
|
+
|
6
|
+
|
7
|
+
def make(
|
8
|
+
arrangement,
|
9
|
+
application,
|
10
|
+
tensors,
|
11
|
+
caller="torch",
|
12
|
+
kernel_name=None,
|
13
|
+
output_dir=None,
|
14
|
+
num_warps=4,
|
15
|
+
num_stages=3,
|
16
|
+
):
|
17
|
+
"""Integrate the arrangement and the application of the tensors.
|
18
|
+
|
19
|
+
:param arrangement: The arrangement of the tensors.
|
20
|
+
:param application: The application of the tensors.
|
21
|
+
:param tensors: The tensors.
|
22
|
+
:param caller: Who will call the compute kernel.
|
23
|
+
:param kernel_name: The name for the generated kernel.
|
24
|
+
:param output_dir: The directory to store the generated files.
|
25
|
+
:param num_warps: The number of warps to use.
|
26
|
+
:param num_stages: The number of pipeline stages.
|
27
|
+
:return: A handle to the compute kernel.
|
28
|
+
"""
|
29
|
+
|
30
|
+
params = inspect.signature(application).parameters
|
31
|
+
types = arrangement(*tensors)
|
32
|
+
annotations = {param: type for param, type in zip(params, types)}
|
33
|
+
application.__annotations__ = annotations
|
34
|
+
|
35
|
+
if caller == "torch":
|
36
|
+
return jit(application, caller=caller, kernel_name=kernel_name)
|
37
|
+
|
38
|
+
return aot(
|
39
|
+
application,
|
40
|
+
caller=caller,
|
41
|
+
kernel_name=kernel_name,
|
42
|
+
output_dir=output_dir,
|
43
|
+
num_warps=num_warps,
|
44
|
+
num_stages=num_stages,
|
45
|
+
)
|
@@ -3,6 +3,7 @@ import random
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
import ninetoothed
|
6
|
+
import ninetoothed.language as ntl
|
6
7
|
import tests.test_matmul as matmul
|
7
8
|
from ninetoothed import Tensor
|
8
9
|
from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
|
@@ -19,8 +20,9 @@ def arrangement(input, mat1, mat2, beta, alpha, output):
|
|
19
20
|
|
20
21
|
|
21
22
|
def application(input, mat1, mat2, beta, alpha, output):
|
22
|
-
|
23
|
-
|
23
|
+
matmul_output = ntl.zeros(output.shape, dtype=ntl.float32)
|
24
|
+
matmul.application(mat1, mat2, matmul_output)
|
25
|
+
output = beta * input + alpha * matmul_output
|
24
26
|
|
25
27
|
|
26
28
|
def addmm(input, mat1, mat2, beta=1, alpha=1):
|
@@ -43,6 +45,7 @@ def addmm(input, mat1, mat2, beta=1, alpha=1):
|
|
43
45
|
class TestCUDA:
|
44
46
|
@classmethod
|
45
47
|
def setup_class(cls):
|
48
|
+
random.seed(0)
|
46
49
|
torch.manual_seed(0)
|
47
50
|
|
48
51
|
shape = (512, 512)
|
@@ -74,9 +77,6 @@ class TestCUDA:
|
|
74
77
|
beta = type(self).beta
|
75
78
|
alpha = type(self).alpha
|
76
79
|
|
77
|
-
# TODO: The current application function inlining feature
|
78
|
-
# causes some precision issues. Consider reducing `atol` and
|
79
|
-
# `rtol` of this test in the future.
|
80
80
|
assert torch.allclose(
|
81
81
|
addmm(input, mat1, mat2, beta=beta, alpha=alpha),
|
82
82
|
torch.addmm(
|
@@ -86,6 +86,5 @@ class TestCUDA:
|
|
86
86
|
beta=beta,
|
87
87
|
alpha=alpha,
|
88
88
|
),
|
89
|
-
atol=0.
|
90
|
-
rtol=0.5,
|
89
|
+
atol=0.125,
|
91
90
|
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|