ninetoothed 0.13.0__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
  2. ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
  3. ninetoothed-0.15.0/.github/pull_request_template.md +5 -0
  4. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/PKG-INFO +1 -1
  5. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/pyproject.toml +1 -1
  6. ninetoothed-0.15.0/src/ninetoothed/__init__.py +35 -0
  7. ninetoothed-0.15.0/src/ninetoothed/aot.py +217 -0
  8. ninetoothed-0.15.0/src/ninetoothed/cudaifier.py +36 -0
  9. ninetoothed-0.15.0/src/ninetoothed/dtype.py +13 -0
  10. ninetoothed-0.13.0/src/ninetoothed/jit.py → ninetoothed-0.15.0/src/ninetoothed/generation.py +286 -110
  11. ninetoothed-0.15.0/src/ninetoothed/jit.py +77 -0
  12. ninetoothed-0.15.0/src/ninetoothed/make.py +45 -0
  13. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/tensor.py +15 -3
  14. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_addmm.py +21 -35
  15. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_max_pool2d.py +34 -12
  16. ninetoothed-0.13.0/src/ninetoothed/__init__.py +0 -5
  17. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.gitattributes +0 -0
  18. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/publish-to-pypi.yml +0 -0
  19. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/pytest.yml +0 -0
  20. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/ruff.yml +0 -0
  21. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.github/workflows/sphinx.yml +0 -0
  22. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/.gitignore +0 -0
  23. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/LICENSE +0 -0
  24. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/README.md +0 -0
  25. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/Makefile +0 -0
  26. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/README.zh.md +0 -0
  27. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/make.bat +0 -0
  28. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/requirements.txt +0 -0
  29. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/matmul-tiling.png +0 -0
  30. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/ninetoothed-logo.png +0 -0
  31. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/_static/vecadd-tiling.png +0 -0
  32. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/code_generation.rst +0 -0
  33. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/conf.py +0 -0
  34. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/index.rst +0 -0
  35. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/installation.rst +0 -0
  36. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/python_api.rst +0 -0
  37. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/symbol.rst +0 -0
  38. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/tensor.rst +0 -0
  39. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/docs/source/visualization.rst +0 -0
  40. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/requirements.txt +0 -0
  41. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/language.py +0 -0
  42. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/naming.py +0 -0
  43. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/symbol.py +0 -0
  44. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/torchifier.py +0 -0
  45. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/src/ninetoothed/visualization.py +0 -0
  46. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/__init__.py +0 -0
  47. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/skippers.py +0 -0
  48. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_add.py +0 -0
  49. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_attention.py +0 -0
  50. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_conv2d.py +0 -0
  51. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_matmul.py +0 -0
  52. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_naming.py +0 -0
  53. {ninetoothed-0.13.0 → ninetoothed-0.15.0}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
1
+ name: 🐛 Bug report
2
+ description: Something isn't working as expected 🤔.
3
+ labels: ["bug"]
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: Thanks for taking the time to fill out this bug report!
9
+
10
+ - type: checkboxes
11
+ attributes:
12
+ label: Is there an existing issue for this?
13
+ description: >
14
+ Please search to see if an issue already exists
15
+ for the bug you encountered.
16
+ options:
17
+ - label: I have searched the existing issues.
18
+ required: true
19
+
20
+ - type: textarea
21
+ attributes:
22
+ label: "Describe the bug:"
23
+ description: A clear and concise description of what the bug is.
24
+ validations:
25
+ required: false
26
+
27
+ - type: textarea
28
+ attributes:
29
+ label: "To reproduce:"
30
+ description: >
31
+ Steps to reproduce the behavior.
32
+ If applicable, provide a small, self-contained piece of code
33
+ that can be run directly to reproduce the issue.
34
+ validations:
35
+ required: false
36
+
37
+ - type: textarea
38
+ attributes:
39
+ label: "Expected behavior:"
40
+ description: >
41
+ A clear and concise description of what you expected to happen.
42
+ validations:
43
+ required: false
44
+
45
+ - type: textarea
46
+ attributes:
47
+ label: "Environment details:"
48
+ description: >
49
+ Please include your NineToothed version, operating system,
50
+ hardware platform, and any relevant information.
51
+ If you are using PyTorch, please run
52
+ `python -m torch.utils.collect_env` to gather
53
+ environment information.
54
+ validations:
55
+ required: false
@@ -0,0 +1,13 @@
1
+ name: 🚀 Feature request
2
+ description: I have a suggestion 🙂!
3
+
4
+ body:
5
+ - type: textarea
6
+ attributes:
7
+ label: "Description & motivation:"
8
+ description: >
9
+ Please describe the feature that you would like to see and
10
+ explain the problem it would solve or
11
+ the benefit it would provide.
12
+ validations:
13
+ required: true
@@ -0,0 +1,5 @@
1
+ `pytest` output:
2
+
3
+ ```
4
+
5
+ ```
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.13.0
3
+ Version: 0.15.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.13.0"
7
+ version = "0.15.0"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -0,0 +1,35 @@
1
+ from ninetoothed.dtype import (
2
+ float16,
3
+ float32,
4
+ float64,
5
+ int8,
6
+ int16,
7
+ int32,
8
+ int64,
9
+ uint8,
10
+ uint16,
11
+ uint32,
12
+ uint64,
13
+ )
14
+ from ninetoothed.jit import jit
15
+ from ninetoothed.make import make
16
+ from ninetoothed.symbol import Symbol
17
+ from ninetoothed.tensor import Tensor
18
+
19
+ __all__ = [
20
+ "Symbol",
21
+ "Tensor",
22
+ "float16",
23
+ "float32",
24
+ "float64",
25
+ "int8",
26
+ "int16",
27
+ "int32",
28
+ "int64",
29
+ "jit",
30
+ "make",
31
+ "uint8",
32
+ "uint16",
33
+ "uint32",
34
+ "uint64",
35
+ ]
@@ -0,0 +1,217 @@
1
+ import ast
2
+ import pathlib
3
+ import subprocess
4
+ import tempfile
5
+ import uuid
6
+
7
+ from ninetoothed.dtype import int64, uint64
8
+ from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
+ from ninetoothed.tensor import Tensor
10
+
11
+
12
+ def aot(
13
+ func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
14
+ ):
15
+ output_dir = pathlib.Path(output_dir)
16
+
17
+ output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
18
+
19
+ for output_name, output_content in output_contents.items():
20
+ output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
21
+
22
+ with open(output_path, "w") as f:
23
+ f.write(output_content)
24
+
25
+
26
+ def _aot(func, caller, kernel_name, num_warps, num_stages):
27
+ def _find_tensor_by_source_name(tensors, name):
28
+ for tensor in tensors:
29
+ if tensor.source.name == name:
30
+ return tensor
31
+
32
+ _HEADER_PATH.parent.mkdir(exist_ok=True)
33
+
34
+ if not _HEADER_PATH.exists():
35
+ _HEADER_PATH.write_text(_HEADER_CONTENT)
36
+
37
+ code_generator = CodeGenerator()
38
+ source_file = code_generator(
39
+ func, caller=caller, kernel_name=kernel_name, prettify=False
40
+ )
41
+
42
+ tensors = code_generator.tensors
43
+ kernel_func = code_generator.kernel_func
44
+ launch_func = code_generator.launch_func
45
+
46
+ param_types = []
47
+
48
+ for arg in kernel_func.args.args:
49
+ param = arg.arg
50
+
51
+ if match := Tensor.pointer_pattern().fullmatch(param):
52
+ source_name = match.group(0).removesuffix("_pointer")
53
+ tensor = _find_tensor_by_source_name(tensors, source_name)
54
+ dtype = tensor.source.dtype
55
+
56
+ param_types.append(f"*{dtype}")
57
+ elif Tensor.size_pattern().fullmatch(param):
58
+ param_types.append(uint64)
59
+ elif Tensor.stride_pattern().fullmatch(param):
60
+ param_types.append(int64)
61
+
62
+ signature = ", ".join(param_types)
63
+
64
+ grid_extractor = _GridExtractor()
65
+ launch_func = grid_extractor.visit(launch_func)
66
+ grid_extractor.visit(code_generator.raw_grid)
67
+ grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
68
+
69
+ signature_hash, output_contents = _compile(
70
+ source_file, kernel_name, signature, grid, num_warps, num_stages
71
+ )
72
+
73
+ unparser = _Unparser()
74
+
75
+ launch_func_unparsed = unparser.unparse(launch_func)
76
+ launch_func_unparsed = launch_func_unparsed.replace(
77
+ func.__name__, f"{kernel_name}_{signature_hash}"
78
+ )
79
+
80
+ c_source_file_name = f"{kernel_name}.{signature_hash}.c"
81
+ c_source_file = output_contents[c_source_file_name]
82
+ c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
83
+ c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
84
+ output_contents[c_source_file_name] = c_source_file
85
+
86
+ c_header_file_name = f"{kernel_name}.{signature_hash}.h"
87
+ c_header_file = output_contents[c_header_file_name]
88
+ c_header_file = f"{c_header_file}\n{unparser.header};\n"
89
+ c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
+ output_contents[c_header_file_name] = c_header_file
91
+
92
+ return output_contents
93
+
94
+
95
+ _HEADER_CONTENT = """#include <stdint.h>
96
+
97
+ typedef struct {
98
+ uintptr_t data;
99
+ uint64_t *shape;
100
+ int64_t *strides;
101
+ } NineToothedTensor;
102
+ """
103
+
104
+ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
105
+
106
+
107
+ class _Unparser:
108
+ def unparse(self, node):
109
+ method_name = "_unparse_" + node.__class__.__name__
110
+
111
+ if hasattr(self, method_name):
112
+ return getattr(self, method_name)(node)
113
+
114
+ return self._generic_unparse(node)
115
+
116
+ def _generic_unparse(self, node):
117
+ return ast.unparse(node)
118
+
119
+ def _unparse_Expr(self, node):
120
+ return self.unparse(node.value)
121
+
122
+ def _unparse_Call(self, node):
123
+ call = ast.Call(
124
+ func=node.func,
125
+ args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
126
+ keywords=[],
127
+ )
128
+
129
+ return f"return {self._generic_unparse(call)};"
130
+
131
+ def _unparse_FunctionDef(self, node):
132
+ params = ["CUstream stream"]
133
+ params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
134
+ header = f"CUresult {node.name}({', '.join(params)})"
135
+
136
+ self.header = header
137
+
138
+ body_lines = []
139
+
140
+ for stmt in node.body:
141
+ stmt_unparsed = self.unparse(stmt)
142
+
143
+ if isinstance(stmt, ast.Expr):
144
+ stmt_unparsed = stmt_unparsed.strip()
145
+
146
+ if not stmt_unparsed.endswith(";"):
147
+ stmt_unparsed += ";"
148
+
149
+ body_lines.append(" " + stmt_unparsed)
150
+
151
+ body = "\n".join(body_lines)
152
+
153
+ return f"{header} {{\n{body}\n}}"
154
+
155
+
156
+ class _GridExtractor(ast.NodeTransformer):
157
+ def visit_BinOp(self, node):
158
+ self.generic_visit(node)
159
+
160
+ if isinstance(node.op, ast.FloorDiv):
161
+ node.op = ast.Div()
162
+
163
+ return node
164
+
165
+ def visit_Call(self, node):
166
+ self.generic_visit(node)
167
+
168
+ node.func = node.func.value
169
+
170
+ return node
171
+
172
+ def visit_Lambda(self, node):
173
+ self.generic_visit(node)
174
+
175
+ self.grid = node.body.elts
176
+
177
+ return node
178
+
179
+
180
+ def _compile(path, name, signature, grid, num_warps, num_stages):
181
+ with tempfile.TemporaryDirectory() as temp_dir:
182
+ output_dir = pathlib.Path(temp_dir)
183
+ output_name = uuid.uuid4().hex
184
+ output_path = output_dir / output_name
185
+
186
+ command = [
187
+ "python",
188
+ "-m",
189
+ "triton.tools.compile",
190
+ str(path),
191
+ "--kernel-name",
192
+ str(name),
193
+ "--signature",
194
+ str(signature),
195
+ "--grid",
196
+ str(grid),
197
+ "--num-warps",
198
+ str(num_warps),
199
+ "--num-stages",
200
+ str(num_stages),
201
+ "--out-path",
202
+ str(output_path),
203
+ ]
204
+
205
+ subprocess.run(command, check=True)
206
+
207
+ matching_files = list(output_dir.glob(f"{output_name}.*"))
208
+
209
+ signature_hash = matching_files[0].name.split(".")[1]
210
+
211
+ output_contents = {}
212
+
213
+ for file in matching_files:
214
+ with file.open() as f:
215
+ output_contents[file.name.replace(output_name, name)] = f.read()
216
+
217
+ return signature_hash, output_contents
@@ -0,0 +1,36 @@
1
+ import ast
2
+
3
+ import ninetoothed.naming as naming
4
+ from ninetoothed.tensor import Tensor
5
+
6
+
7
+ class Cudaifier(ast.NodeTransformer):
8
+ def visit_Name(self, node):
9
+ self.generic_visit(node)
10
+
11
+ source = node.id
12
+
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
16
+ def repl(match):
17
+ return f"{match.group(1)}.data"
18
+
19
+ source = Tensor.pointer_pattern().sub(repl, source)
20
+
21
+ def repl(match):
22
+ return f"{match.group(1)}.shape[{match.group(3)}]"
23
+
24
+ source = Tensor.size_pattern().sub(repl, source)
25
+
26
+ def repl(match):
27
+ return f"{match.group(1)}.strides[{match.group(3)}]"
28
+
29
+ source = Tensor.stride_pattern().sub(repl, source)
30
+
31
+ source = source.removesuffix("_with_auto_tuning")
32
+
33
+ if source != node.id:
34
+ return ast.parse(source, mode="eval").body
35
+
36
+ return node
@@ -0,0 +1,13 @@
1
+ int8 = "i8"
2
+ int16 = "i16"
3
+ int32 = "i32"
4
+ int64 = "i64"
5
+
6
+ uint8 = "u8"
7
+ uint16 = "u16"
8
+ uint32 = "u32"
9
+ uint64 = "u64"
10
+
11
+ float16 = "fp16"
12
+ float32 = "fp32"
13
+ float64 = "fp64"