ninetoothed 0.14.0__tar.gz → 0.15.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ninetoothed-0.15.1/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
  2. ninetoothed-0.15.1/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
  3. ninetoothed-0.15.1/.github/pull_request_template.md +5 -0
  4. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/PKG-INFO +1 -1
  5. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/pyproject.toml +1 -1
  6. ninetoothed-0.15.1/src/ninetoothed/__init__.py +35 -0
  7. ninetoothed-0.15.1/src/ninetoothed/aot.py +217 -0
  8. ninetoothed-0.15.1/src/ninetoothed/cudaifier.py +36 -0
  9. ninetoothed-0.15.1/src/ninetoothed/dtype.py +13 -0
  10. ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.1/src/ninetoothed/generation.py +83 -116
  11. ninetoothed-0.15.1/src/ninetoothed/jit.py +77 -0
  12. ninetoothed-0.15.1/src/ninetoothed/make.py +45 -0
  13. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/tensor.py +1 -1
  14. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/visualization.py +10 -4
  15. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_addmm.py +6 -7
  16. ninetoothed-0.15.1/tests/test_aot.py +153 -0
  17. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_conv2d.py +16 -2
  18. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_matmul.py +13 -6
  19. ninetoothed-0.14.0/src/ninetoothed/__init__.py +0 -5
  20. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.gitattributes +0 -0
  21. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/publish-to-pypi.yml +0 -0
  22. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/pytest.yml +0 -0
  23. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/ruff.yml +0 -0
  24. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.github/workflows/sphinx.yml +0 -0
  25. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/.gitignore +0 -0
  26. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/LICENSE +0 -0
  27. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/README.md +0 -0
  28. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/Makefile +0 -0
  29. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/README.zh.md +0 -0
  30. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/make.bat +0 -0
  31. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/requirements.txt +0 -0
  32. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/matmul-tiling.png +0 -0
  33. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/ninetoothed-logo.png +0 -0
  34. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/_static/vecadd-tiling.png +0 -0
  35. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/code_generation.rst +0 -0
  36. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/conf.py +0 -0
  37. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/index.rst +0 -0
  38. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/installation.rst +0 -0
  39. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/python_api.rst +0 -0
  40. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/symbol.rst +0 -0
  41. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/tensor.rst +0 -0
  42. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/docs/source/visualization.rst +0 -0
  43. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/requirements.txt +0 -0
  44. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/language.py +0 -0
  45. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/naming.py +0 -0
  46. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/symbol.py +0 -0
  47. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/src/ninetoothed/torchifier.py +0 -0
  48. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/__init__.py +0 -0
  49. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/skippers.py +0 -0
  50. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_add.py +0 -0
  51. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_attention.py +0 -0
  52. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_max_pool2d.py +0 -0
  53. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_naming.py +0 -0
  54. {ninetoothed-0.14.0 → ninetoothed-0.15.1}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
1
+ name: 🐛 Bug report
2
+ description: Something isn't working as expected 🤔.
3
+ labels: ["bug"]
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: Thanks for taking the time to fill out this bug report!
9
+
10
+ - type: checkboxes
11
+ attributes:
12
+ label: Is there an existing issue for this?
13
+ description: >
14
+ Please search to see if an issue already exists
15
+ for the bug you encountered.
16
+ options:
17
+ - label: I have searched the existing issues.
18
+ required: true
19
+
20
+ - type: textarea
21
+ attributes:
22
+ label: "Describe the bug:"
23
+ description: A clear and concise description of what the bug is.
24
+ validations:
25
+ required: false
26
+
27
+ - type: textarea
28
+ attributes:
29
+ label: "To reproduce:"
30
+ description: >
31
+ Steps to reproduce the behavior.
32
+ If applicable, provide a small, self-contained piece of code
33
+ that can be run directly to reproduce the issue.
34
+ validations:
35
+ required: false
36
+
37
+ - type: textarea
38
+ attributes:
39
+ label: "Expected behavior:"
40
+ description: >
41
+ A clear and concise description of what you expected to happen.
42
+ validations:
43
+ required: false
44
+
45
+ - type: textarea
46
+ attributes:
47
+ label: "Environment details:"
48
+ description: >
49
+ Please include your NineToothed version, operating system,
50
+ hardware platform, and any relevant information.
51
+ If you are using PyTorch, please run
52
+ `python -m torch.utils.collect_env` to gather
53
+ environment information.
54
+ validations:
55
+ required: false
@@ -0,0 +1,13 @@
1
+ name: 🚀 Feature request
2
+ description: I have a suggestion 🙂!
3
+
4
+ body:
5
+ - type: textarea
6
+ attributes:
7
+ label: "Description & motivation:"
8
+ description: >
9
+ Please describe the feature that you would like to see and
10
+ explain the problem it would solve or
11
+ the benefit it would provide.
12
+ validations:
13
+ required: true
@@ -0,0 +1,5 @@
1
+ `pytest` output:
2
+
3
+ ```
4
+
5
+ ```
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.14.0
3
+ Version: 0.15.1
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.14.0"
7
+ version = "0.15.1"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -0,0 +1,35 @@
1
+ from ninetoothed.dtype import (
2
+ float16,
3
+ float32,
4
+ float64,
5
+ int8,
6
+ int16,
7
+ int32,
8
+ int64,
9
+ uint8,
10
+ uint16,
11
+ uint32,
12
+ uint64,
13
+ )
14
+ from ninetoothed.jit import jit
15
+ from ninetoothed.make import make
16
+ from ninetoothed.symbol import Symbol
17
+ from ninetoothed.tensor import Tensor
18
+
19
+ __all__ = [
20
+ "Symbol",
21
+ "Tensor",
22
+ "float16",
23
+ "float32",
24
+ "float64",
25
+ "int8",
26
+ "int16",
27
+ "int32",
28
+ "int64",
29
+ "jit",
30
+ "make",
31
+ "uint8",
32
+ "uint16",
33
+ "uint32",
34
+ "uint64",
35
+ ]
@@ -0,0 +1,217 @@
1
+ import ast
2
+ import pathlib
3
+ import subprocess
4
+ import tempfile
5
+ import uuid
6
+
7
+ from ninetoothed.dtype import int64
8
+ from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
+ from ninetoothed.tensor import Tensor
10
+
11
+
12
+ def aot(
13
+ func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
14
+ ):
15
+ output_dir = pathlib.Path(output_dir)
16
+
17
+ output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
18
+
19
+ for output_name, output_content in output_contents.items():
20
+ output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
21
+
22
+ with open(output_path, "w") as f:
23
+ f.write(output_content)
24
+
25
+
26
+ def _aot(func, caller, kernel_name, num_warps, num_stages):
27
+ def _find_tensor_by_source_name(tensors, name):
28
+ for tensor in tensors:
29
+ if tensor.source.name == name:
30
+ return tensor
31
+
32
+ _HEADER_PATH.parent.mkdir(exist_ok=True)
33
+
34
+ if not _HEADER_PATH.exists():
35
+ _HEADER_PATH.write_text(_HEADER_CONTENT)
36
+
37
+ code_generator = CodeGenerator()
38
+ source_file = code_generator(
39
+ func, caller=caller, kernel_name=kernel_name, prettify=False
40
+ )
41
+
42
+ tensors = code_generator.tensors
43
+ kernel_func = code_generator.kernel_func
44
+ launch_func = code_generator.launch_func
45
+
46
+ param_types = []
47
+
48
+ for arg in kernel_func.args.args:
49
+ param = arg.arg
50
+
51
+ if match := Tensor.pointer_pattern().fullmatch(param):
52
+ source_name = match.group(0).removesuffix("_pointer")
53
+ tensor = _find_tensor_by_source_name(tensors, source_name)
54
+ dtype = tensor.source.dtype
55
+
56
+ param_types.append(f"*{dtype}")
57
+ elif Tensor.size_pattern().fullmatch(param):
58
+ param_types.append(int64)
59
+ elif Tensor.stride_pattern().fullmatch(param):
60
+ param_types.append(int64)
61
+
62
+ signature = ", ".join(param_types)
63
+
64
+ grid_extractor = _GridExtractor()
65
+ launch_func = grid_extractor.visit(launch_func)
66
+ grid_extractor.visit(code_generator.raw_grid)
67
+ grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
68
+
69
+ signature_hash, output_contents = _compile(
70
+ source_file, kernel_name, signature, grid, num_warps, num_stages
71
+ )
72
+
73
+ unparser = _Unparser()
74
+
75
+ launch_func_unparsed = unparser.unparse(launch_func)
76
+ launch_func_unparsed = launch_func_unparsed.replace(
77
+ func.__name__, f"{kernel_name}_{signature_hash}"
78
+ )
79
+
80
+ c_source_file_name = f"{kernel_name}.{signature_hash}.c"
81
+ c_source_file = output_contents[c_source_file_name]
82
+ c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
83
+ c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
84
+ output_contents[c_source_file_name] = c_source_file
85
+
86
+ c_header_file_name = f"{kernel_name}.{signature_hash}.h"
87
+ c_header_file = output_contents[c_header_file_name]
88
+ c_header_file = f"{c_header_file}\n{unparser.header};\n"
89
+ c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
+ output_contents[c_header_file_name] = c_header_file
91
+
92
+ return output_contents
93
+
94
+
95
+ _HEADER_CONTENT = """#include <stdint.h>
96
+
97
+ typedef struct {
98
+ uintptr_t data;
99
+ uint64_t *shape;
100
+ int64_t *strides;
101
+ } NineToothedTensor;
102
+ """
103
+
104
+ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
105
+
106
+
107
+ class _Unparser:
108
+ def unparse(self, node):
109
+ method_name = "_unparse_" + node.__class__.__name__
110
+
111
+ if hasattr(self, method_name):
112
+ return getattr(self, method_name)(node)
113
+
114
+ return self._generic_unparse(node)
115
+
116
+ def _generic_unparse(self, node):
117
+ return ast.unparse(node)
118
+
119
+ def _unparse_Expr(self, node):
120
+ return self.unparse(node.value)
121
+
122
+ def _unparse_Call(self, node):
123
+ call = ast.Call(
124
+ func=node.func,
125
+ args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
126
+ keywords=[],
127
+ )
128
+
129
+ return f"return {self._generic_unparse(call)};"
130
+
131
+ def _unparse_FunctionDef(self, node):
132
+ params = ["CUstream stream"]
133
+ params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
134
+ header = f"CUresult {node.name}({', '.join(params)})"
135
+
136
+ self.header = header
137
+
138
+ body_lines = []
139
+
140
+ for stmt in node.body:
141
+ stmt_unparsed = self.unparse(stmt)
142
+
143
+ if isinstance(stmt, ast.Expr):
144
+ stmt_unparsed = stmt_unparsed.strip()
145
+
146
+ if not stmt_unparsed.endswith(";"):
147
+ stmt_unparsed += ";"
148
+
149
+ body_lines.append(" " + stmt_unparsed)
150
+
151
+ body = "\n".join(body_lines)
152
+
153
+ return f"{header} {{\n{body}\n}}"
154
+
155
+
156
+ class _GridExtractor(ast.NodeTransformer):
157
+ def visit_BinOp(self, node):
158
+ self.generic_visit(node)
159
+
160
+ if isinstance(node.op, ast.FloorDiv):
161
+ node.op = ast.Div()
162
+
163
+ return node
164
+
165
+ def visit_Call(self, node):
166
+ self.generic_visit(node)
167
+
168
+ node.func = node.func.value
169
+
170
+ return node
171
+
172
+ def visit_Lambda(self, node):
173
+ self.generic_visit(node)
174
+
175
+ self.grid = node.body.elts
176
+
177
+ return node
178
+
179
+
180
+ def _compile(path, name, signature, grid, num_warps, num_stages):
181
+ with tempfile.TemporaryDirectory() as temp_dir:
182
+ output_dir = pathlib.Path(temp_dir)
183
+ output_name = uuid.uuid4().hex
184
+ output_path = output_dir / output_name
185
+
186
+ command = [
187
+ "python",
188
+ "-m",
189
+ "triton.tools.compile",
190
+ str(path),
191
+ "--kernel-name",
192
+ str(name),
193
+ "--signature",
194
+ str(signature),
195
+ "--grid",
196
+ str(grid),
197
+ "--num-warps",
198
+ str(num_warps),
199
+ "--num-stages",
200
+ str(num_stages),
201
+ "--out-path",
202
+ str(output_path),
203
+ ]
204
+
205
+ subprocess.run(command, check=True)
206
+
207
+ matching_files = list(output_dir.glob(f"{output_name}.*"))
208
+
209
+ signature_hash = matching_files[0].name.split(".")[1]
210
+
211
+ output_contents = {}
212
+
213
+ for file in matching_files:
214
+ with file.open() as f:
215
+ output_contents[file.name.replace(output_name, name)] = f.read()
216
+
217
+ return signature_hash, output_contents
@@ -0,0 +1,36 @@
1
+ import ast
2
+
3
+ import ninetoothed.naming as naming
4
+ from ninetoothed.tensor import Tensor
5
+
6
+
7
+ class Cudaifier(ast.NodeTransformer):
8
+ def visit_Name(self, node):
9
+ self.generic_visit(node)
10
+
11
+ source = node.id
12
+
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
16
+ def repl(match):
17
+ return f"{match.group(1)}.data"
18
+
19
+ source = Tensor.pointer_pattern().sub(repl, source)
20
+
21
+ def repl(match):
22
+ return f"{match.group(1)}.shape[{match.group(3)}]"
23
+
24
+ source = Tensor.size_pattern().sub(repl, source)
25
+
26
+ def repl(match):
27
+ return f"{match.group(1)}.strides[{match.group(3)}]"
28
+
29
+ source = Tensor.stride_pattern().sub(repl, source)
30
+
31
+ source = source.removesuffix("_with_auto_tuning")
32
+
33
+ if source != node.id:
34
+ return ast.parse(source, mode="eval").body
35
+
36
+ return node
@@ -0,0 +1,13 @@
1
+ int8 = "i8"
2
+ int16 = "i16"
3
+ int32 = "i32"
4
+ int64 = "i64"
5
+
6
+ uint8 = "u8"
7
+ uint16 = "u16"
8
+ uint32 = "u32"
9
+ uint64 = "u64"
10
+
11
+ float16 = "fp16"
12
+ float32 = "fp32"
13
+ float64 = "fp64"
@@ -2,84 +2,89 @@ import ast
2
2
  import collections
3
3
  import copy
4
4
  import functools
5
- import importlib.util
5
+ import hashlib
6
6
  import inspect
7
7
  import itertools
8
8
  import math
9
+ import pathlib
9
10
  import subprocess
10
- import sys
11
- import tempfile
12
11
 
13
12
  import triton
14
13
 
15
14
  import ninetoothed.naming as naming
15
+ from ninetoothed.cudaifier import Cudaifier
16
16
  from ninetoothed.language import attribute, call
17
17
  from ninetoothed.symbol import Symbol
18
18
  from ninetoothed.tensor import Tensor
19
19
  from ninetoothed.torchifier import Torchifier
20
20
 
21
+ CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
21
22
 
22
- def make(arrangement, application, tensors):
23
- """Integrate the arrangement and the application of the tensors.
24
23
 
25
- :param arrangement: The arrangement of the tensors.
26
- :param application: The application of the tensors.
27
- :param tensors: The tensors.
28
- :return: A handle to the compute kernel.
29
- """
30
- params = inspect.signature(application).parameters
31
- types = arrangement(*tensors)
32
- annotations = {param: type for param, type in zip(params, types)}
33
- application.__annotations__ = annotations
24
+ class CodeGenerator(ast.NodeTransformer):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
29
+
30
+ self._MIN_PRODUCT = 2**10
31
+
32
+ self._MAX_PRODUCT = 2**20
34
33
 
35
- return jit(application)
34
+ def __call__(self, func, caller, kernel_name, prettify):
35
+ def _get_tree(func):
36
+ module = ast.parse(inspect.getsource(inspect.getmodule(func)))
36
37
 
38
+ collector = _ImportCollector()
39
+ collector.visit(module)
37
40
 
38
- def jit(func=None, *, _prettify=False):
39
- """A decorator for generating compute kernels.
41
+ finder = _FunctionDefFinder(func.__name__)
42
+ finder.visit(module)
43
+ func_def = finder.result
40
44
 
41
- :param func: The function to be compiled.
42
- :param _prettify: Whether to prettify the generated code.
43
- :return: A handle to the compute kernel.
45
+ inliner = _Inliner(func.__globals__)
46
+ inliner.visit(func_def)
47
+ module.body = collector.imports + inliner.imports + [finder.result]
44
48
 
45
- .. note::
49
+ return _AliasRestorer().visit(module)
46
50
 
47
- The ``_prettify`` parameter is experimental, which might break
48
- the generated code.
49
- """
51
+ def _find_dependencies(func):
52
+ dependencies = set()
50
53
 
51
- def wrapper(func):
52
- return JIT(func, _prettify=_prettify)()
54
+ for obj in func.__globals__.values():
55
+ if isinstance(obj, triton.runtime.JITFunction):
56
+ dependencies.add(obj.src)
53
57
 
54
- if func is None:
55
- return wrapper
58
+ return "\n".join(
59
+ f"@triton.jit\n{dependency}" for dependency in dependencies
60
+ )
56
61
 
57
- return wrapper(func)
62
+ self.launch_func_name = f"launch_{kernel_name}"
58
63
 
64
+ self._caller = caller
59
65
 
60
- class JIT:
61
- def __init__(self, func, _prettify=False):
62
- self.func = func
66
+ self._context = inspect.get_annotations(func)
63
67
 
64
- self._prettify = _prettify
68
+ self._args = list(self._context.values())
65
69
 
66
- def __call__(self):
67
- tree = self._get_tree()
70
+ tree = _get_tree(func)
68
71
 
69
- CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
72
+ self.visit(tree)
70
73
  Tritonizer().visit(tree)
71
74
  _BinOpSimplifier().visit(tree)
72
75
  ast.fix_missing_locations(tree)
73
76
 
74
- if self._prettify:
77
+ if prettify:
75
78
  name_collector = _SimplifiedNameCollector()
76
79
  name_collector.visit(tree)
77
80
 
78
81
  unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
79
- dependencies = self._find_dependencies()
82
+ dependencies = _find_dependencies(func)
80
83
  source = "\n\n".join((unparsed, dependencies)).strip()
84
+ source = source.replace(func.__name__, kernel_name)
85
+ source += "\n"
81
86
 
82
- if self._prettify:
87
+ if prettify:
83
88
  for original, simplified in name_collector.simplified_names.items():
84
89
  if simplified not in name_collector.simplified_names:
85
90
  source = source.replace(original, simplified)
@@ -88,73 +93,29 @@ class JIT:
88
93
  ["ruff", "format", "-"], input=source, encoding="utf-8"
89
94
  )
90
95
 
91
- with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
92
- temp_file.write(source.encode("utf-8"))
93
- temp_file_name = temp_file.name
96
+ digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
97
+ cache_dir = CACHE_DIR
98
+ cache_dir.mkdir(exist_ok=True)
99
+ cache_file = cache_dir / f"{digest}.py"
94
100
 
95
- module = type(self)._import_from_path(temp_file_name, temp_file_name)
96
- module_vars = vars(module)
101
+ if not cache_file.exists():
102
+ with open(cache_file, "w", encoding="utf-8") as f:
103
+ f.write(source)
97
104
 
98
- handle = _Handle(
99
- module_vars[self.func.__name__],
100
- module_vars[f"launch_{self.func.__name__}"],
101
- source,
102
- )
103
-
104
- return handle
105
-
106
- def _get_tree(self):
107
- module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
108
-
109
- collector = _ImportCollector()
110
- collector.visit(module)
111
-
112
- finder = _FunctionDefFinder(self.func.__name__)
113
- finder.visit(module)
114
- func_def = finder.result
105
+ self.tensors = self._args
106
+ self.kernel_func = self._func_def
107
+ self.launch_func = self._launch
115
108
 
116
- inliner = _Inliner(self.func.__globals__)
117
- inliner.visit(func_def)
118
- module.body = collector.imports + inliner.imports + [finder.result]
119
-
120
- return _AliasRestorer().visit(module)
121
-
122
- def _find_dependencies(self):
123
- dependencies = set()
124
-
125
- for obj in self.func.__globals__.values():
126
- if isinstance(obj, triton.runtime.JITFunction):
127
- dependencies.add(obj.src)
128
-
129
- return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
130
-
131
- @staticmethod
132
- def _import_from_path(module_name, file_path):
133
- spec = importlib.util.spec_from_file_location(module_name, file_path)
134
- module = importlib.util.module_from_spec(spec)
135
- sys.modules[module_name] = module
136
- spec.loader.exec_module(module)
137
-
138
- return module
139
-
140
-
141
- class CodeGenerator(ast.NodeTransformer):
142
- def __init__(self, context):
143
- super().__init__()
144
-
145
- self._context = context
146
-
147
- self._args = list(self._context.values())
148
-
149
- self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
150
-
151
- self._MIN_PRODUCT = 2**10
152
-
153
- self._MAX_PRODUCT = 2**20
109
+ return str(cache_file)
154
110
 
155
111
  def visit_Module(self, node):
156
112
  self.generic_visit(node)
157
113
 
114
+ func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
115
+
116
+ node.body.append(
117
+ ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
118
+ )
158
119
  node.body.append(self._launch)
159
120
 
160
121
  return node
@@ -162,6 +123,8 @@ class CodeGenerator(ast.NodeTransformer):
162
123
  def visit_FunctionDef(self, node):
163
124
  self._func_def = node
164
125
 
126
+ self._func_name_with_auto_tuning = f"{self._func_def.name}_with_auto_tuning"
127
+
165
128
  self._invariants = {}
166
129
 
167
130
  self.generic_visit(node)
@@ -184,6 +147,9 @@ class CodeGenerator(ast.NodeTransformer):
184
147
  if naming.is_constexpr(name)
185
148
  }
186
149
 
150
+ non_meta_names = sorted(non_meta_names)
151
+ meta_names = sorted(meta_names)
152
+
187
153
  node.args = [
188
154
  ast.arg(arg=name)
189
155
  if not naming.is_constexpr(name)
@@ -194,8 +160,8 @@ class CodeGenerator(ast.NodeTransformer):
194
160
  for name in meta_names
195
161
  ]
196
162
 
197
- autotune = self._generate_autotune(non_meta_names, meta_names)
198
- self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
163
+ self._autotune = self._generate_autotune(non_meta_names, meta_names)
164
+ self._func_def.decorator_list = [Symbol("triton.jit").node]
199
165
 
200
166
  self._launch = self._generate_launch(non_meta_names, meta_names)
201
167
 
@@ -354,7 +320,7 @@ class CodeGenerator(ast.NodeTransformer):
354
320
  ]
355
321
 
356
322
  launch = ast.FunctionDef(
357
- name=f"launch_{self._func_def.name}",
323
+ name=self.launch_func_name,
358
324
  args=ast.arguments(
359
325
  posonlyargs=[],
360
326
  args=[ast.arg(arg=arg.source.name) for arg in self._args]
@@ -392,7 +358,9 @@ class CodeGenerator(ast.NodeTransformer):
392
358
  ast.Expr(
393
359
  ast.Call(
394
360
  func=ast.Subscript(
395
- value=ast.Name(id=self._func_def.name, ctx=ast.Load()),
361
+ value=ast.Name(
362
+ id=self._func_name_with_auto_tuning, ctx=ast.Load()
363
+ ),
396
364
  slice=self._generate_grid(),
397
365
  ctx=ast.Load(),
398
366
  ),
@@ -422,14 +390,23 @@ class CodeGenerator(ast.NodeTransformer):
422
390
 
423
391
  MetaEncloser(meta).visit(launch)
424
392
 
425
- Torchifier().visit(launch)
393
+ if self._caller == "torch":
394
+ Torchifier().visit(launch)
395
+ elif self._caller == "cuda":
396
+ Cudaifier().visit(launch)
397
+ else:
398
+ raise ValueError(f"Unsupported caller: `{self._caller}`.")
426
399
 
427
400
  return launch
428
401
 
429
402
  def _generate_grid(self):
430
403
  num_elements = functools.reduce(lambda x, y: x * y, self._args[0].shape)
431
404
 
432
- return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
405
+ grid = ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
406
+
407
+ self.raw_grid = copy.deepcopy(grid)
408
+
409
+ return grid
433
410
 
434
411
  def _generate_load(self, tensor, indices=()):
435
412
  if tensor.ndim == 0:
@@ -851,7 +828,7 @@ class _Inliner(ast.NodeTransformer):
851
828
  return names
852
829
 
853
830
  def _make_temporary():
854
- prefix = naming.auto_generate(f"temporary_{self._count}")
831
+ prefix = f"{naming.auto_generate(f'temporary_{self._count}')}_"
855
832
  self._count += 1
856
833
 
857
834
  return prefix
@@ -941,16 +918,6 @@ class _SimplifiedNameCollector(ast.NodeVisitor):
941
918
  self.simplified_names[node.id] = naming.remove_prefixes(node.id)
942
919
 
943
920
 
944
- class _Handle:
945
- def __init__(self, kernel, launch, source):
946
- self._kernel = kernel
947
- self._launch = launch
948
- self._source = source
949
-
950
- def __call__(self, *args, **kwargs):
951
- return self._launch(*args, **kwargs)
952
-
953
-
954
921
  class _AliasRestorer(ast.NodeTransformer):
955
922
  def __init__(self):
956
923
  super().__init__()
@@ -0,0 +1,77 @@
1
+ import importlib
2
+ import sys
3
+
4
+ from ninetoothed.generation import CodeGenerator
5
+
6
+
7
+ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
8
+ """A decorator for generating compute kernels.
9
+
10
+ :param func: The function to be compiled.
11
+ :param caller: Who will call the compute kernel.
12
+ :param kernel_name: The name for the generated kernel.
13
+ :param _prettify: Whether to prettify the generated code.
14
+ :return: A handle to the compute kernel.
15
+
16
+ .. note::
17
+
18
+ The ``_prettify`` parameter is experimental, which might break
19
+ the generated code.
20
+ """
21
+
22
+ def wrapper(func):
23
+ return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
24
+
25
+ if func is None:
26
+ return wrapper
27
+
28
+ return wrapper(func)
29
+
30
+
31
+ class JIT:
32
+ def __init__(self, func, caller, kernel_name, _prettify=False):
33
+ self.func = func
34
+
35
+ self._caller = caller
36
+
37
+ if kernel_name is not None:
38
+ self._kernel_name = kernel_name
39
+ else:
40
+ self._kernel_name = func.__name__
41
+
42
+ self._prettify = _prettify
43
+
44
+ def __call__(self):
45
+ code_generator = CodeGenerator()
46
+ source_file = code_generator(
47
+ self.func, self._caller, self._kernel_name, self._prettify
48
+ )
49
+ module = type(self)._import_from_path(source_file, source_file)
50
+ module_vars = vars(module)
51
+
52
+ handle = _Handle(
53
+ module_vars[self._kernel_name],
54
+ module_vars[code_generator.launch_func_name],
55
+ source_file,
56
+ )
57
+
58
+ return handle
59
+
60
+ @staticmethod
61
+ def _import_from_path(module_name, file_path):
62
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
63
+ module = importlib.util.module_from_spec(spec)
64
+ sys.modules[module_name] = module
65
+ spec.loader.exec_module(module)
66
+
67
+ return module
68
+
69
+
70
+ class _Handle:
71
+ def __init__(self, kernel, launch, source):
72
+ self._kernel = kernel
73
+ self._launch = launch
74
+ self._source = source
75
+
76
+ def __call__(self, *args, **kwargs):
77
+ return self._launch(*args, **kwargs)
@@ -0,0 +1,45 @@
1
+ import inspect
2
+
3
+ from ninetoothed.aot import aot
4
+ from ninetoothed.jit import jit
5
+
6
+
7
+ def make(
8
+ arrangement,
9
+ application,
10
+ tensors,
11
+ caller="torch",
12
+ kernel_name=None,
13
+ output_dir=None,
14
+ num_warps=4,
15
+ num_stages=3,
16
+ ):
17
+ """Integrate the arrangement and the application of the tensors.
18
+
19
+ :param arrangement: The arrangement of the tensors.
20
+ :param application: The application of the tensors.
21
+ :param tensors: The tensors.
22
+ :param caller: Who will call the compute kernel.
23
+ :param kernel_name: The name for the generated kernel.
24
+ :param output_dir: The directory to store the generated files.
25
+ :param num_warps: The number of warps to use.
26
+ :param num_stages: The number of pipeline stages.
27
+ :return: A handle to the compute kernel.
28
+ """
29
+
30
+ params = inspect.signature(application).parameters
31
+ types = arrangement(*tensors)
32
+ annotations = {param: type for param, type in zip(params, types)}
33
+ application.__annotations__ = annotations
34
+
35
+ if caller == "torch":
36
+ return jit(application, caller=caller, kernel_name=kernel_name)
37
+
38
+ return aot(
39
+ application,
40
+ caller=caller,
41
+ kernel_name=kernel_name,
42
+ output_dir=output_dir,
43
+ num_warps=num_warps,
44
+ num_stages=num_stages,
45
+ )
@@ -146,7 +146,7 @@ class Tensor:
146
146
  )
147
147
  outer_shape.append(new_size)
148
148
 
149
- new_stride = self_stride * stride // spacing
149
+ new_stride = self_stride * stride
150
150
  outer_strides.append(new_stride)
151
151
 
152
152
  inner_shape.append(tile_size)
@@ -118,10 +118,16 @@ def _visualize_unit_square(ax, x, y, color):
118
118
 
119
119
 
120
120
  def _visualize_rect(ax, width, height, x, y, color):
121
- pos_x, pos_y = zip(*_verts_of_rect(width, height, x, y))
122
-
123
- ax.fill(pos_x, pos_y, color)
124
- ax.plot(pos_x + (pos_x[0],), pos_y + (pos_y[0],), "k")
121
+ ax.add_patch(
122
+ plt.Rectangle(
123
+ (x, y),
124
+ width,
125
+ height,
126
+ edgecolor="k",
127
+ facecolor=color,
128
+ linewidth=plt.rcParams["lines.linewidth"],
129
+ )
130
+ )
125
131
 
126
132
 
127
133
  def _verts_of_rect(width, height, x, y):
@@ -3,6 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  import ninetoothed
6
+ import ninetoothed.language as ntl
6
7
  import tests.test_matmul as matmul
7
8
  from ninetoothed import Tensor
8
9
  from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
@@ -19,8 +20,9 @@ def arrangement(input, mat1, mat2, beta, alpha, output):
19
20
 
20
21
 
21
22
  def application(input, mat1, mat2, beta, alpha, output):
22
- matmul.application(mat1, mat2, output)
23
- output = beta * input + alpha * output
23
+ matmul_output = ntl.zeros(output.shape, dtype=ntl.float32)
24
+ matmul.application(mat1, mat2, matmul_output)
25
+ output = beta * input + alpha * matmul_output
24
26
 
25
27
 
26
28
  def addmm(input, mat1, mat2, beta=1, alpha=1):
@@ -43,6 +45,7 @@ def addmm(input, mat1, mat2, beta=1, alpha=1):
43
45
  class TestCUDA:
44
46
  @classmethod
45
47
  def setup_class(cls):
48
+ random.seed(0)
46
49
  torch.manual_seed(0)
47
50
 
48
51
  shape = (512, 512)
@@ -74,9 +77,6 @@ class TestCUDA:
74
77
  beta = type(self).beta
75
78
  alpha = type(self).alpha
76
79
 
77
- # TODO: The current application function inlining feature
78
- # causes some precision issues. Consider reducing `atol` and
79
- # `rtol` of this test in the future.
80
80
  assert torch.allclose(
81
81
  addmm(input, mat1, mat2, beta=beta, alpha=alpha),
82
82
  torch.addmm(
@@ -86,6 +86,5 @@ class TestCUDA:
86
86
  beta=beta,
87
87
  alpha=alpha,
88
88
  ),
89
- atol=0.5,
90
- rtol=0.5,
89
+ atol=0.125,
91
90
  )
@@ -0,0 +1,153 @@
1
+ import ctypes
2
+ import functools
3
+ import subprocess
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import ninetoothed
9
+ import ninetoothed.generation
10
+ import tests.test_conv2d as conv2d
11
+ import tests.test_matmul as matmul
12
+ from ninetoothed import Tensor
13
+ from tests.skippers import skip_if_cuda_not_available
14
+
15
+
16
+ @skip_if_cuda_not_available
17
+ class TestCUDA:
18
+ @classmethod
19
+ def setup_class(cls):
20
+ torch.manual_seed(0)
21
+
22
+ def test_matmul(self):
23
+ arrangement = functools.partial(
24
+ matmul.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
25
+ )
26
+ application = matmul.application
27
+ tensors = tuple(Tensor(2, dtype=ninetoothed.float16) for _ in range(3))
28
+ caller = "cuda"
29
+ kernel_name = "matmul"
30
+ output_dir = ninetoothed.generation.CACHE_DIR
31
+
32
+ launch_func = _generate_launch_func(
33
+ arrangement,
34
+ application,
35
+ tensors,
36
+ caller=caller,
37
+ kernel_name=kernel_name,
38
+ output_dir=output_dir,
39
+ )
40
+
41
+ shape = (512, 512)
42
+ dtype = torch.float16
43
+ device = caller
44
+
45
+ lhs = torch.randn(shape, dtype=dtype, device=device)
46
+ rhs = torch.randn(shape, dtype=dtype, device=device)
47
+ output = torch.empty((lhs.shape[0], rhs.shape[1]), dtype=dtype, device=device)
48
+
49
+ _run_launch_func(launch_func, lhs, rhs, output)
50
+
51
+ assert torch.allclose(output, torch.matmul(lhs, rhs))
52
+
53
+ def test_conv2d(self):
54
+ arrangement = functools.partial(
55
+ conv2d.arrangement, BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64
56
+ )
57
+ application = matmul.application
58
+ tensors = tuple(Tensor(4, dtype=ninetoothed.float16) for _ in range(3))
59
+ caller = "cuda"
60
+ kernel_name = "conv2d"
61
+ output_dir = ninetoothed.generation.CACHE_DIR
62
+
63
+ launch_func = _generate_launch_func(
64
+ arrangement,
65
+ application,
66
+ tensors,
67
+ caller=caller,
68
+ kernel_name=kernel_name,
69
+ output_dir=output_dir,
70
+ )
71
+
72
+ n, c, h, w = 4, 64, 16, 16
73
+ k, _, r, s = 512, c, 3, 3
74
+ p = h - r + 1
75
+ q = w - s + 1
76
+ dtype = torch.float16
77
+ device = caller
78
+
79
+ input = torch.randn(n, c, h, w, dtype=dtype, device=device)
80
+ filter = torch.randn(k, c, r, s, dtype=dtype, device=device)
81
+ output = torch.empty(n, k, p, q, dtype=dtype, device=device)
82
+
83
+ _run_launch_func(launch_func, input, filter, output)
84
+
85
+ assert torch.allclose(output, F.conv2d(input, filter), atol=0.001, rtol=0.001)
86
+
87
+
88
+ class _ArgumentTensor(ctypes.Structure):
89
+ _fields_ = [
90
+ ("data", ctypes.c_void_p),
91
+ ("shape", ctypes.POINTER(ctypes.c_uint64)),
92
+ ("strides", ctypes.POINTER(ctypes.c_int64)),
93
+ ]
94
+
95
+ @staticmethod
96
+ def from_torch_tensor(tensor):
97
+ data = ctypes.c_void_p(tensor.data_ptr())
98
+ shape = (ctypes.c_uint64 * len(tensor.shape))(*tensor.shape)
99
+ strides = (ctypes.c_int64 * len(tensor.stride()))(*tensor.stride())
100
+
101
+ return _ArgumentTensor(data, shape, strides)
102
+
103
+
104
+ def _run_launch_func(launch_func, *tensors):
105
+ stream = torch.cuda.Stream()
106
+
107
+ arg_tensors = tuple(_ArgumentTensor.from_torch_tensor(tensor) for tensor in tensors)
108
+
109
+ with torch.cuda.stream(stream):
110
+ launch_func(ctypes.c_void_p(stream.cuda_stream), *arg_tensors)
111
+
112
+ stream.synchronize()
113
+
114
+
115
+ def _generate_launch_func(
116
+ arrangement, application, tensors, caller, kernel_name, output_dir
117
+ ):
118
+ ninetoothed.make(
119
+ arrangement,
120
+ application,
121
+ tensors,
122
+ caller=caller,
123
+ kernel_name=kernel_name,
124
+ output_dir=output_dir,
125
+ )
126
+
127
+ _compile_library(kernel_name, output_dir)
128
+ library = _load_library(kernel_name, output_dir)
129
+ launch_func_name = f"launch_{kernel_name}"
130
+ launch_func = getattr(library, launch_func_name)
131
+ launch_func.argtypes = (ctypes.c_void_p,) + tuple(_ArgumentTensor for _ in tensors)
132
+ launch_func.restype = ctypes.c_int
133
+
134
+ return launch_func
135
+
136
+
137
+ def _compile_library(kernel_name, output_dir):
138
+ command = [
139
+ "nvcc",
140
+ "-shared",
141
+ "-Xcompiler",
142
+ "-fPIC",
143
+ "-lcuda",
144
+ "-o",
145
+ output_dir / f"{kernel_name}.so",
146
+ output_dir / f"{kernel_name}.c",
147
+ ]
148
+
149
+ subprocess.run(command, check=True)
150
+
151
+
152
+ def _load_library(kernel_name, kernel_dir):
153
+ return ctypes.CDLL(kernel_dir / f"{kernel_name}.so")
@@ -1,3 +1,5 @@
1
+ import functools
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
5
 
@@ -7,7 +9,14 @@ from ninetoothed import Tensor
7
9
  from tests.skippers import skip_if_cuda_not_available
8
10
 
9
11
 
10
- def arrangement(input, filter, output):
12
+ def arrangement(
13
+ input,
14
+ filter,
15
+ output,
16
+ BLOCK_SIZE_M=matmul.BLOCK_SIZE_M,
17
+ BLOCK_SIZE_N=matmul.BLOCK_SIZE_N,
18
+ BLOCK_SIZE_K=matmul.BLOCK_SIZE_K,
19
+ ):
11
20
  input_tiled = input.tile((1, *filter.shape[1:]), strides=(-1, -1, 1, 1))
12
21
  input_squeezed = input_tiled.squeeze(1)
13
22
  input_squeezed.dtype = input_squeezed.dtype.squeeze(0)
@@ -19,7 +28,12 @@ def arrangement(input, filter, output):
19
28
 
20
29
  output_flattened = output.permute((0, 2, 3, 1)).flatten(end_dim=3)
21
30
 
22
- return matmul.arrangement(input_flattened, filter_permuted, output_flattened)
31
+ return functools.partial(
32
+ matmul.arrangement,
33
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
34
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
35
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
36
+ )(input_flattened, filter_permuted, output_flattened)
23
37
 
24
38
 
25
39
  def conv2d(input, filter):
@@ -5,12 +5,19 @@ import ninetoothed.language as ntl
5
5
  from ninetoothed import Symbol, Tensor
6
6
  from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
7
7
 
8
-
9
- def arrangement(lhs, rhs, output):
10
- BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
11
- BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
12
- BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
13
-
8
+ BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
9
+ BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
10
+ BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
11
+
12
+
13
+ def arrangement(
14
+ lhs,
15
+ rhs,
16
+ output,
17
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
18
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
19
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
20
+ ):
14
21
  output_tiled = output.tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
15
22
 
16
23
  lhs_tiled = (
@@ -1,5 +0,0 @@
1
- from ninetoothed.jit import jit, make
2
- from ninetoothed.symbol import Symbol
3
- from ninetoothed.tensor import Tensor
4
-
5
- __all__ = ["Symbol", "Tensor", "jit", "make"]
File without changes
File without changes
File without changes
File without changes
File without changes