ninetoothed 0.13.0__tar.gz → 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
- ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
- ninetoothed-0.15.0/.github/pull_request_template.md +5 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/PKG-INFO +1 -1
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/pyproject.toml +1 -1
- ninetoothed-0.15.0/src/ninetoothed/__init__.py +35 -0
- ninetoothed-0.15.0/src/ninetoothed/aot.py +217 -0
- ninetoothed-0.15.0/src/ninetoothed/cudaifier.py +36 -0
- ninetoothed-0.15.0/src/ninetoothed/dtype.py +13 -0
- ninetoothed-0.13.0/src/ninetoothed/jit.py → ninetoothed-0.15.0/src/ninetoothed/generation.py +286 -110
- ninetoothed-0.15.0/src/ninetoothed/jit.py +77 -0
- ninetoothed-0.15.0/src/ninetoothed/make.py +45 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/tensor.py +15 -3
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_addmm.py +21 -35
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_max_pool2d.py +34 -12
- ninetoothed-0.13.0/src/ninetoothed/__init__.py +0 -5
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.gitattributes +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/publish-to-pypi.yml +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/pytest.yml +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/ruff.yml +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/sphinx.yml +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.gitignore +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/LICENSE +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/README.md +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/Makefile +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/README.zh.md +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/make.bat +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/requirements.txt +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/matmul-tiling.png +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/ninetoothed-logo.png +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/vecadd-tiling.png +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/code_generation.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/conf.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/index.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/installation.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/python_api.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/symbol.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/tensor.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/visualization.rst +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/requirements.txt +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/language.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/naming.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/symbol.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/torchifier.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/visualization.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/__init__.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/skippers.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_add.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_attention.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_conv2d.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_matmul.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_naming.py +0 -0
- {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
name: 🐛 Bug report
|
2
|
+
description: Something isn't working as expected 🤔.
|
3
|
+
labels: ["bug"]
|
4
|
+
|
5
|
+
body:
|
6
|
+
- type: markdown
|
7
|
+
attributes:
|
8
|
+
value: Thanks for taking the time to fill out this bug report!
|
9
|
+
|
10
|
+
- type: checkboxes
|
11
|
+
attributes:
|
12
|
+
label: Is there an existing issue for this?
|
13
|
+
description: >
|
14
|
+
Please search to see if an issue already exists
|
15
|
+
for the bug you encountered.
|
16
|
+
options:
|
17
|
+
- label: I have searched the existing issues.
|
18
|
+
required: true
|
19
|
+
|
20
|
+
- type: textarea
|
21
|
+
attributes:
|
22
|
+
label: "Describe the bug:"
|
23
|
+
description: A clear and concise description of what the bug is.
|
24
|
+
validations:
|
25
|
+
required: false
|
26
|
+
|
27
|
+
- type: textarea
|
28
|
+
attributes:
|
29
|
+
label: "To reproduce:"
|
30
|
+
description: >
|
31
|
+
Steps to reproduce the behavior.
|
32
|
+
If applicable, provide a small, self-contained piece of code
|
33
|
+
that can be run directly to reproduce the issue.
|
34
|
+
validations:
|
35
|
+
required: false
|
36
|
+
|
37
|
+
- type: textarea
|
38
|
+
attributes:
|
39
|
+
label: "Expected behavior:"
|
40
|
+
description: >
|
41
|
+
A clear and concise description of what you expected to happen.
|
42
|
+
validations:
|
43
|
+
required: false
|
44
|
+
|
45
|
+
- type: textarea
|
46
|
+
attributes:
|
47
|
+
label: "Environment details:"
|
48
|
+
description: >
|
49
|
+
Please include your NineToothed version, operating system,
|
50
|
+
hardware platform, and any relevant information.
|
51
|
+
If you are using PyTorch, please run
|
52
|
+
`python -m torch.utils.collect_env` to gather
|
53
|
+
environment information.
|
54
|
+
validations:
|
55
|
+
required: false
|
@@ -0,0 +1,13 @@
|
|
1
|
+
name: 🚀 Feature request
|
2
|
+
description: I have a suggestion 🙂!
|
3
|
+
|
4
|
+
body:
|
5
|
+
- type: textarea
|
6
|
+
attributes:
|
7
|
+
label: "Description & motivation:"
|
8
|
+
description: >
|
9
|
+
Please describe the feature that you would like to see and
|
10
|
+
explain the problem it would solve or
|
11
|
+
the benefit it would provide.
|
12
|
+
validations:
|
13
|
+
required: true
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.15.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "ninetoothed"
|
7
|
-
version = "0.
|
7
|
+
version = "0.15.0"
|
8
8
|
authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
|
9
9
|
description = "A domain-specific language based on Triton but providing higher-level abstraction."
|
10
10
|
readme = "README.md"
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from ninetoothed.dtype import (
|
2
|
+
float16,
|
3
|
+
float32,
|
4
|
+
float64,
|
5
|
+
int8,
|
6
|
+
int16,
|
7
|
+
int32,
|
8
|
+
int64,
|
9
|
+
uint8,
|
10
|
+
uint16,
|
11
|
+
uint32,
|
12
|
+
uint64,
|
13
|
+
)
|
14
|
+
from ninetoothed.jit import jit
|
15
|
+
from ninetoothed.make import make
|
16
|
+
from ninetoothed.symbol import Symbol
|
17
|
+
from ninetoothed.tensor import Tensor
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"Symbol",
|
21
|
+
"Tensor",
|
22
|
+
"float16",
|
23
|
+
"float32",
|
24
|
+
"float64",
|
25
|
+
"int8",
|
26
|
+
"int16",
|
27
|
+
"int32",
|
28
|
+
"int64",
|
29
|
+
"jit",
|
30
|
+
"make",
|
31
|
+
"uint8",
|
32
|
+
"uint16",
|
33
|
+
"uint32",
|
34
|
+
"uint64",
|
35
|
+
]
|
@@ -0,0 +1,217 @@
|
|
1
|
+
import ast
|
2
|
+
import pathlib
|
3
|
+
import subprocess
|
4
|
+
import tempfile
|
5
|
+
import uuid
|
6
|
+
|
7
|
+
from ninetoothed.dtype import int64, uint64
|
8
|
+
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
|
+
from ninetoothed.tensor import Tensor
|
10
|
+
|
11
|
+
|
12
|
+
def aot(
|
13
|
+
func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
|
14
|
+
):
|
15
|
+
output_dir = pathlib.Path(output_dir)
|
16
|
+
|
17
|
+
output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
|
18
|
+
|
19
|
+
for output_name, output_content in output_contents.items():
|
20
|
+
output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
|
21
|
+
|
22
|
+
with open(output_path, "w") as f:
|
23
|
+
f.write(output_content)
|
24
|
+
|
25
|
+
|
26
|
+
def _aot(func, caller, kernel_name, num_warps, num_stages):
|
27
|
+
def _find_tensor_by_source_name(tensors, name):
|
28
|
+
for tensor in tensors:
|
29
|
+
if tensor.source.name == name:
|
30
|
+
return tensor
|
31
|
+
|
32
|
+
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
33
|
+
|
34
|
+
if not _HEADER_PATH.exists():
|
35
|
+
_HEADER_PATH.write_text(_HEADER_CONTENT)
|
36
|
+
|
37
|
+
code_generator = CodeGenerator()
|
38
|
+
source_file = code_generator(
|
39
|
+
func, caller=caller, kernel_name=kernel_name, prettify=False
|
40
|
+
)
|
41
|
+
|
42
|
+
tensors = code_generator.tensors
|
43
|
+
kernel_func = code_generator.kernel_func
|
44
|
+
launch_func = code_generator.launch_func
|
45
|
+
|
46
|
+
param_types = []
|
47
|
+
|
48
|
+
for arg in kernel_func.args.args:
|
49
|
+
param = arg.arg
|
50
|
+
|
51
|
+
if match := Tensor.pointer_pattern().fullmatch(param):
|
52
|
+
source_name = match.group(0).removesuffix("_pointer")
|
53
|
+
tensor = _find_tensor_by_source_name(tensors, source_name)
|
54
|
+
dtype = tensor.source.dtype
|
55
|
+
|
56
|
+
param_types.append(f"*{dtype}")
|
57
|
+
elif Tensor.size_pattern().fullmatch(param):
|
58
|
+
param_types.append(uint64)
|
59
|
+
elif Tensor.stride_pattern().fullmatch(param):
|
60
|
+
param_types.append(int64)
|
61
|
+
|
62
|
+
signature = ", ".join(param_types)
|
63
|
+
|
64
|
+
grid_extractor = _GridExtractor()
|
65
|
+
launch_func = grid_extractor.visit(launch_func)
|
66
|
+
grid_extractor.visit(code_generator.raw_grid)
|
67
|
+
grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
|
68
|
+
|
69
|
+
signature_hash, output_contents = _compile(
|
70
|
+
source_file, kernel_name, signature, grid, num_warps, num_stages
|
71
|
+
)
|
72
|
+
|
73
|
+
unparser = _Unparser()
|
74
|
+
|
75
|
+
launch_func_unparsed = unparser.unparse(launch_func)
|
76
|
+
launch_func_unparsed = launch_func_unparsed.replace(
|
77
|
+
func.__name__, f"{kernel_name}_{signature_hash}"
|
78
|
+
)
|
79
|
+
|
80
|
+
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
81
|
+
c_source_file = output_contents[c_source_file_name]
|
82
|
+
c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
|
83
|
+
c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
84
|
+
output_contents[c_source_file_name] = c_source_file
|
85
|
+
|
86
|
+
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
87
|
+
c_header_file = output_contents[c_header_file_name]
|
88
|
+
c_header_file = f"{c_header_file}\n{unparser.header};\n"
|
89
|
+
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
90
|
+
output_contents[c_header_file_name] = c_header_file
|
91
|
+
|
92
|
+
return output_contents
|
93
|
+
|
94
|
+
|
95
|
+
_HEADER_CONTENT = """#include <stdint.h>
|
96
|
+
|
97
|
+
typedef struct {
|
98
|
+
uintptr_t data;
|
99
|
+
uint64_t *shape;
|
100
|
+
int64_t *strides;
|
101
|
+
} NineToothedTensor;
|
102
|
+
"""
|
103
|
+
|
104
|
+
_HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
105
|
+
|
106
|
+
|
107
|
+
class _Unparser:
|
108
|
+
def unparse(self, node):
|
109
|
+
method_name = "_unparse_" + node.__class__.__name__
|
110
|
+
|
111
|
+
if hasattr(self, method_name):
|
112
|
+
return getattr(self, method_name)(node)
|
113
|
+
|
114
|
+
return self._generic_unparse(node)
|
115
|
+
|
116
|
+
def _generic_unparse(self, node):
|
117
|
+
return ast.unparse(node)
|
118
|
+
|
119
|
+
def _unparse_Expr(self, node):
|
120
|
+
return self.unparse(node.value)
|
121
|
+
|
122
|
+
def _unparse_Call(self, node):
|
123
|
+
call = ast.Call(
|
124
|
+
func=node.func,
|
125
|
+
args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
|
126
|
+
keywords=[],
|
127
|
+
)
|
128
|
+
|
129
|
+
return f"return {self._generic_unparse(call)};"
|
130
|
+
|
131
|
+
def _unparse_FunctionDef(self, node):
|
132
|
+
params = ["CUstream stream"]
|
133
|
+
params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
|
134
|
+
header = f"CUresult {node.name}({', '.join(params)})"
|
135
|
+
|
136
|
+
self.header = header
|
137
|
+
|
138
|
+
body_lines = []
|
139
|
+
|
140
|
+
for stmt in node.body:
|
141
|
+
stmt_unparsed = self.unparse(stmt)
|
142
|
+
|
143
|
+
if isinstance(stmt, ast.Expr):
|
144
|
+
stmt_unparsed = stmt_unparsed.strip()
|
145
|
+
|
146
|
+
if not stmt_unparsed.endswith(";"):
|
147
|
+
stmt_unparsed += ";"
|
148
|
+
|
149
|
+
body_lines.append(" " + stmt_unparsed)
|
150
|
+
|
151
|
+
body = "\n".join(body_lines)
|
152
|
+
|
153
|
+
return f"{header} {{\n{body}\n}}"
|
154
|
+
|
155
|
+
|
156
|
+
class _GridExtractor(ast.NodeTransformer):
|
157
|
+
def visit_BinOp(self, node):
|
158
|
+
self.generic_visit(node)
|
159
|
+
|
160
|
+
if isinstance(node.op, ast.FloorDiv):
|
161
|
+
node.op = ast.Div()
|
162
|
+
|
163
|
+
return node
|
164
|
+
|
165
|
+
def visit_Call(self, node):
|
166
|
+
self.generic_visit(node)
|
167
|
+
|
168
|
+
node.func = node.func.value
|
169
|
+
|
170
|
+
return node
|
171
|
+
|
172
|
+
def visit_Lambda(self, node):
|
173
|
+
self.generic_visit(node)
|
174
|
+
|
175
|
+
self.grid = node.body.elts
|
176
|
+
|
177
|
+
return node
|
178
|
+
|
179
|
+
|
180
|
+
def _compile(path, name, signature, grid, num_warps, num_stages):
|
181
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
182
|
+
output_dir = pathlib.Path(temp_dir)
|
183
|
+
output_name = uuid.uuid4().hex
|
184
|
+
output_path = output_dir / output_name
|
185
|
+
|
186
|
+
command = [
|
187
|
+
"python",
|
188
|
+
"-m",
|
189
|
+
"triton.tools.compile",
|
190
|
+
str(path),
|
191
|
+
"--kernel-name",
|
192
|
+
str(name),
|
193
|
+
"--signature",
|
194
|
+
str(signature),
|
195
|
+
"--grid",
|
196
|
+
str(grid),
|
197
|
+
"--num-warps",
|
198
|
+
str(num_warps),
|
199
|
+
"--num-stages",
|
200
|
+
str(num_stages),
|
201
|
+
"--out-path",
|
202
|
+
str(output_path),
|
203
|
+
]
|
204
|
+
|
205
|
+
subprocess.run(command, check=True)
|
206
|
+
|
207
|
+
matching_files = list(output_dir.glob(f"{output_name}.*"))
|
208
|
+
|
209
|
+
signature_hash = matching_files[0].name.split(".")[1]
|
210
|
+
|
211
|
+
output_contents = {}
|
212
|
+
|
213
|
+
for file in matching_files:
|
214
|
+
with file.open() as f:
|
215
|
+
output_contents[file.name.replace(output_name, name)] = f.read()
|
216
|
+
|
217
|
+
return signature_hash, output_contents
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
import ninetoothed.naming as naming
|
4
|
+
from ninetoothed.tensor import Tensor
|
5
|
+
|
6
|
+
|
7
|
+
class Cudaifier(ast.NodeTransformer):
|
8
|
+
def visit_Name(self, node):
|
9
|
+
self.generic_visit(node)
|
10
|
+
|
11
|
+
source = node.id
|
12
|
+
|
13
|
+
if naming.is_constexpr(source):
|
14
|
+
return node
|
15
|
+
|
16
|
+
def repl(match):
|
17
|
+
return f"{match.group(1)}.data"
|
18
|
+
|
19
|
+
source = Tensor.pointer_pattern().sub(repl, source)
|
20
|
+
|
21
|
+
def repl(match):
|
22
|
+
return f"{match.group(1)}.shape[{match.group(3)}]"
|
23
|
+
|
24
|
+
source = Tensor.size_pattern().sub(repl, source)
|
25
|
+
|
26
|
+
def repl(match):
|
27
|
+
return f"{match.group(1)}.strides[{match.group(3)}]"
|
28
|
+
|
29
|
+
source = Tensor.stride_pattern().sub(repl, source)
|
30
|
+
|
31
|
+
source = source.removesuffix("_with_auto_tuning")
|
32
|
+
|
33
|
+
if source != node.id:
|
34
|
+
return ast.parse(source, mode="eval").body
|
35
|
+
|
36
|
+
return node
|