ninetoothed 0.14.0__tar.gz → 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/bug-report.yml +55 -0
  2. ninetoothed-0.15.0/.github/ISSUE_TEMPLATE/feature-request.yml +13 -0
  3. ninetoothed-0.15.0/.github/pull_request_template.md +5 -0
  4. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/PKG-INFO +1 -1
  5. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/pyproject.toml +1 -1
  6. ninetoothed-0.15.0/src/ninetoothed/__init__.py +35 -0
  7. ninetoothed-0.15.0/src/ninetoothed/aot.py +217 -0
  8. ninetoothed-0.15.0/src/ninetoothed/cudaifier.py +36 -0
  9. ninetoothed-0.15.0/src/ninetoothed/dtype.py +13 -0
  10. ninetoothed-0.14.0/src/ninetoothed/jit.py → ninetoothed-0.15.0/src/ninetoothed/generation.py +82 -116
  11. ninetoothed-0.15.0/src/ninetoothed/jit.py +77 -0
  12. ninetoothed-0.15.0/src/ninetoothed/make.py +45 -0
  13. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_addmm.py +6 -7
  14. ninetoothed-0.14.0/src/ninetoothed/__init__.py +0 -5
  15. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.gitattributes +0 -0
  16. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/publish-to-pypi.yml +0 -0
  17. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/pytest.yml +0 -0
  18. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/ruff.yml +0 -0
  19. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.github/workflows/sphinx.yml +0 -0
  20. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/.gitignore +0 -0
  21. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/LICENSE +0 -0
  22. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/README.md +0 -0
  23. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/Makefile +0 -0
  24. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/README.zh.md +0 -0
  25. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/make.bat +0 -0
  26. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/requirements.txt +0 -0
  27. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/matmul-tiling.png +0 -0
  28. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/ninetoothed-logo.png +0 -0
  29. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/_static/vecadd-tiling.png +0 -0
  30. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/code_generation.rst +0 -0
  31. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/conf.py +0 -0
  32. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/index.rst +0 -0
  33. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/installation.rst +0 -0
  34. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/python_api.rst +0 -0
  35. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/symbol.rst +0 -0
  36. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/tensor.rst +0 -0
  37. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/docs/source/visualization.rst +0 -0
  38. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/requirements.txt +0 -0
  39. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/language.py +0 -0
  40. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/naming.py +0 -0
  41. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/symbol.py +0 -0
  42. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/tensor.py +0 -0
  43. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/torchifier.py +0 -0
  44. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/src/ninetoothed/visualization.py +0 -0
  45. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/__init__.py +0 -0
  46. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/skippers.py +0 -0
  47. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_add.py +0 -0
  48. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_attention.py +0 -0
  49. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_conv2d.py +0 -0
  50. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_matmul.py +0 -0
  51. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_max_pool2d.py +0 -0
  52. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_naming.py +0 -0
  53. {ninetoothed-0.14.0 → ninetoothed-0.15.0}/tests/test_softmax.py +0 -0
@@ -0,0 +1,55 @@
1
+ name: 🐛 Bug report
2
+ description: Something isn't working as expected 🤔.
3
+ labels: ["bug"]
4
+
5
+ body:
6
+ - type: markdown
7
+ attributes:
8
+ value: Thanks for taking the time to fill out this bug report!
9
+
10
+ - type: checkboxes
11
+ attributes:
12
+ label: Is there an existing issue for this?
13
+ description: >
14
+ Please search to see if an issue already exists
15
+ for the bug you encountered.
16
+ options:
17
+ - label: I have searched the existing issues.
18
+ required: true
19
+
20
+ - type: textarea
21
+ attributes:
22
+ label: "Describe the bug:"
23
+ description: A clear and concise description of what the bug is.
24
+ validations:
25
+ required: false
26
+
27
+ - type: textarea
28
+ attributes:
29
+ label: "To reproduce:"
30
+ description: >
31
+ Steps to reproduce the behavior.
32
+ If applicable, provide a small, self-contained piece of code
33
+ that can be run directly to reproduce the issue.
34
+ validations:
35
+ required: false
36
+
37
+ - type: textarea
38
+ attributes:
39
+ label: "Expected behavior:"
40
+ description: >
41
+ A clear and concise description of what you expected to happen.
42
+ validations:
43
+ required: false
44
+
45
+ - type: textarea
46
+ attributes:
47
+ label: "Environment details:"
48
+ description: >
49
+ Please include your NineToothed version, operating system,
50
+ hardware platform, and any relevant information.
51
+ If you are using PyTorch, please run
52
+ `python -m torch.utils.collect_env` to gather
53
+ environment information.
54
+ validations:
55
+ required: false
@@ -0,0 +1,13 @@
1
+ name: 🚀 Feature request
2
+ description: I have a suggestion 🙂!
3
+
4
+ body:
5
+ - type: textarea
6
+ attributes:
7
+ label: "Description & motivation:"
8
+ description: >
9
+ Please describe the feature that you would like to see and
10
+ explain the problem it would solve or
11
+ the benefit it would provide.
12
+ validations:
13
+ required: true
@@ -0,0 +1,5 @@
1
+ `pytest` output:
2
+
3
+ ```
4
+
5
+ ```
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.14.0
3
+ Version: 0.15.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.14.0"
7
+ version = "0.15.0"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -0,0 +1,35 @@
1
+ from ninetoothed.dtype import (
2
+ float16,
3
+ float32,
4
+ float64,
5
+ int8,
6
+ int16,
7
+ int32,
8
+ int64,
9
+ uint8,
10
+ uint16,
11
+ uint32,
12
+ uint64,
13
+ )
14
+ from ninetoothed.jit import jit
15
+ from ninetoothed.make import make
16
+ from ninetoothed.symbol import Symbol
17
+ from ninetoothed.tensor import Tensor
18
+
19
+ __all__ = [
20
+ "Symbol",
21
+ "Tensor",
22
+ "float16",
23
+ "float32",
24
+ "float64",
25
+ "int8",
26
+ "int16",
27
+ "int32",
28
+ "int64",
29
+ "jit",
30
+ "make",
31
+ "uint8",
32
+ "uint16",
33
+ "uint32",
34
+ "uint64",
35
+ ]
@@ -0,0 +1,217 @@
1
+ import ast
2
+ import pathlib
3
+ import subprocess
4
+ import tempfile
5
+ import uuid
6
+
7
+ from ninetoothed.dtype import int64, uint64
8
+ from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
+ from ninetoothed.tensor import Tensor
10
+
11
+
12
+ def aot(
13
+ func, caller="cuda", kernel_name=None, output_dir=None, num_warps=4, num_stages=3
14
+ ):
15
+ output_dir = pathlib.Path(output_dir)
16
+
17
+ output_contents = _aot(func, caller, kernel_name, num_warps, num_stages)
18
+
19
+ for output_name, output_content in output_contents.items():
20
+ output_path = output_dir / f"{kernel_name}{output_name[-2:]}"
21
+
22
+ with open(output_path, "w") as f:
23
+ f.write(output_content)
24
+
25
+
26
+ def _aot(func, caller, kernel_name, num_warps, num_stages):
27
+ def _find_tensor_by_source_name(tensors, name):
28
+ for tensor in tensors:
29
+ if tensor.source.name == name:
30
+ return tensor
31
+
32
+ _HEADER_PATH.parent.mkdir(exist_ok=True)
33
+
34
+ if not _HEADER_PATH.exists():
35
+ _HEADER_PATH.write_text(_HEADER_CONTENT)
36
+
37
+ code_generator = CodeGenerator()
38
+ source_file = code_generator(
39
+ func, caller=caller, kernel_name=kernel_name, prettify=False
40
+ )
41
+
42
+ tensors = code_generator.tensors
43
+ kernel_func = code_generator.kernel_func
44
+ launch_func = code_generator.launch_func
45
+
46
+ param_types = []
47
+
48
+ for arg in kernel_func.args.args:
49
+ param = arg.arg
50
+
51
+ if match := Tensor.pointer_pattern().fullmatch(param):
52
+ source_name = match.group(0).removesuffix("_pointer")
53
+ tensor = _find_tensor_by_source_name(tensors, source_name)
54
+ dtype = tensor.source.dtype
55
+
56
+ param_types.append(f"*{dtype}")
57
+ elif Tensor.size_pattern().fullmatch(param):
58
+ param_types.append(uint64)
59
+ elif Tensor.stride_pattern().fullmatch(param):
60
+ param_types.append(int64)
61
+
62
+ signature = ", ".join(param_types)
63
+
64
+ grid_extractor = _GridExtractor()
65
+ launch_func = grid_extractor.visit(launch_func)
66
+ grid_extractor.visit(code_generator.raw_grid)
67
+ grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"
68
+
69
+ signature_hash, output_contents = _compile(
70
+ source_file, kernel_name, signature, grid, num_warps, num_stages
71
+ )
72
+
73
+ unparser = _Unparser()
74
+
75
+ launch_func_unparsed = unparser.unparse(launch_func)
76
+ launch_func_unparsed = launch_func_unparsed.replace(
77
+ func.__name__, f"{kernel_name}_{signature_hash}"
78
+ )
79
+
80
+ c_source_file_name = f"{kernel_name}.{signature_hash}.c"
81
+ c_source_file = output_contents[c_source_file_name]
82
+ c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
83
+ c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
84
+ output_contents[c_source_file_name] = c_source_file
85
+
86
+ c_header_file_name = f"{kernel_name}.{signature_hash}.h"
87
+ c_header_file = output_contents[c_header_file_name]
88
+ c_header_file = f"{c_header_file}\n{unparser.header};\n"
89
+ c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
+ output_contents[c_header_file_name] = c_header_file
91
+
92
+ return output_contents
93
+
94
+
95
+ _HEADER_CONTENT = """#include <stdint.h>
96
+
97
+ typedef struct {
98
+ uintptr_t data;
99
+ uint64_t *shape;
100
+ int64_t *strides;
101
+ } NineToothedTensor;
102
+ """
103
+
104
+ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
105
+
106
+
107
+ class _Unparser:
108
+ def unparse(self, node):
109
+ method_name = "_unparse_" + node.__class__.__name__
110
+
111
+ if hasattr(self, method_name):
112
+ return getattr(self, method_name)(node)
113
+
114
+ return self._generic_unparse(node)
115
+
116
+ def _generic_unparse(self, node):
117
+ return ast.unparse(node)
118
+
119
+ def _unparse_Expr(self, node):
120
+ return self.unparse(node.value)
121
+
122
+ def _unparse_Call(self, node):
123
+ call = ast.Call(
124
+ func=node.func,
125
+ args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
126
+ keywords=[],
127
+ )
128
+
129
+ return f"return {self._generic_unparse(call)};"
130
+
131
+ def _unparse_FunctionDef(self, node):
132
+ params = ["CUstream stream"]
133
+ params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
134
+ header = f"CUresult {node.name}({', '.join(params)})"
135
+
136
+ self.header = header
137
+
138
+ body_lines = []
139
+
140
+ for stmt in node.body:
141
+ stmt_unparsed = self.unparse(stmt)
142
+
143
+ if isinstance(stmt, ast.Expr):
144
+ stmt_unparsed = stmt_unparsed.strip()
145
+
146
+ if not stmt_unparsed.endswith(";"):
147
+ stmt_unparsed += ";"
148
+
149
+ body_lines.append(" " + stmt_unparsed)
150
+
151
+ body = "\n".join(body_lines)
152
+
153
+ return f"{header} {{\n{body}\n}}"
154
+
155
+
156
+ class _GridExtractor(ast.NodeTransformer):
157
+ def visit_BinOp(self, node):
158
+ self.generic_visit(node)
159
+
160
+ if isinstance(node.op, ast.FloorDiv):
161
+ node.op = ast.Div()
162
+
163
+ return node
164
+
165
+ def visit_Call(self, node):
166
+ self.generic_visit(node)
167
+
168
+ node.func = node.func.value
169
+
170
+ return node
171
+
172
+ def visit_Lambda(self, node):
173
+ self.generic_visit(node)
174
+
175
+ self.grid = node.body.elts
176
+
177
+ return node
178
+
179
+
180
+ def _compile(path, name, signature, grid, num_warps, num_stages):
181
+ with tempfile.TemporaryDirectory() as temp_dir:
182
+ output_dir = pathlib.Path(temp_dir)
183
+ output_name = uuid.uuid4().hex
184
+ output_path = output_dir / output_name
185
+
186
+ command = [
187
+ "python",
188
+ "-m",
189
+ "triton.tools.compile",
190
+ str(path),
191
+ "--kernel-name",
192
+ str(name),
193
+ "--signature",
194
+ str(signature),
195
+ "--grid",
196
+ str(grid),
197
+ "--num-warps",
198
+ str(num_warps),
199
+ "--num-stages",
200
+ str(num_stages),
201
+ "--out-path",
202
+ str(output_path),
203
+ ]
204
+
205
+ subprocess.run(command, check=True)
206
+
207
+ matching_files = list(output_dir.glob(f"{output_name}.*"))
208
+
209
+ signature_hash = matching_files[0].name.split(".")[1]
210
+
211
+ output_contents = {}
212
+
213
+ for file in matching_files:
214
+ with file.open() as f:
215
+ output_contents[file.name.replace(output_name, name)] = f.read()
216
+
217
+ return signature_hash, output_contents
@@ -0,0 +1,36 @@
1
+ import ast
2
+
3
+ import ninetoothed.naming as naming
4
+ from ninetoothed.tensor import Tensor
5
+
6
+
7
+ class Cudaifier(ast.NodeTransformer):
8
+ def visit_Name(self, node):
9
+ self.generic_visit(node)
10
+
11
+ source = node.id
12
+
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
16
+ def repl(match):
17
+ return f"{match.group(1)}.data"
18
+
19
+ source = Tensor.pointer_pattern().sub(repl, source)
20
+
21
+ def repl(match):
22
+ return f"{match.group(1)}.shape[{match.group(3)}]"
23
+
24
+ source = Tensor.size_pattern().sub(repl, source)
25
+
26
+ def repl(match):
27
+ return f"{match.group(1)}.strides[{match.group(3)}]"
28
+
29
+ source = Tensor.stride_pattern().sub(repl, source)
30
+
31
+ source = source.removesuffix("_with_auto_tuning")
32
+
33
+ if source != node.id:
34
+ return ast.parse(source, mode="eval").body
35
+
36
+ return node
@@ -0,0 +1,13 @@
1
+ int8 = "i8"
2
+ int16 = "i16"
3
+ int32 = "i32"
4
+ int64 = "i64"
5
+
6
+ uint8 = "u8"
7
+ uint16 = "u16"
8
+ uint32 = "u32"
9
+ uint64 = "u64"
10
+
11
+ float16 = "fp16"
12
+ float32 = "fp32"
13
+ float64 = "fp64"
@@ -2,84 +2,88 @@ import ast
2
2
  import collections
3
3
  import copy
4
4
  import functools
5
- import importlib.util
5
+ import hashlib
6
6
  import inspect
7
7
  import itertools
8
8
  import math
9
+ import pathlib
9
10
  import subprocess
10
- import sys
11
- import tempfile
12
11
 
13
12
  import triton
14
13
 
15
14
  import ninetoothed.naming as naming
15
+ from ninetoothed.cudaifier import Cudaifier
16
16
  from ninetoothed.language import attribute, call
17
17
  from ninetoothed.symbol import Symbol
18
18
  from ninetoothed.tensor import Tensor
19
19
  from ninetoothed.torchifier import Torchifier
20
20
 
21
+ CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
21
22
 
22
- def make(arrangement, application, tensors):
23
- """Integrate the arrangement and the application of the tensors.
24
23
 
25
- :param arrangement: The arrangement of the tensors.
26
- :param application: The application of the tensors.
27
- :param tensors: The tensors.
28
- :return: A handle to the compute kernel.
29
- """
30
- params = inspect.signature(application).parameters
31
- types = arrangement(*tensors)
32
- annotations = {param: type for param, type in zip(params, types)}
33
- application.__annotations__ = annotations
24
+ class CodeGenerator(ast.NodeTransformer):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
29
+
30
+ self._MIN_PRODUCT = 2**10
31
+
32
+ self._MAX_PRODUCT = 2**20
34
33
 
35
- return jit(application)
34
+ def __call__(self, func, caller, kernel_name, prettify):
35
+ def _get_tree(func):
36
+ module = ast.parse(inspect.getsource(inspect.getmodule(func)))
36
37
 
38
+ collector = _ImportCollector()
39
+ collector.visit(module)
37
40
 
38
- def jit(func=None, *, _prettify=False):
39
- """A decorator for generating compute kernels.
41
+ finder = _FunctionDefFinder(func.__name__)
42
+ finder.visit(module)
43
+ func_def = finder.result
40
44
 
41
- :param func: The function to be compiled.
42
- :param _prettify: Whether to prettify the generated code.
43
- :return: A handle to the compute kernel.
45
+ inliner = _Inliner(func.__globals__)
46
+ inliner.visit(func_def)
47
+ module.body = collector.imports + inliner.imports + [finder.result]
44
48
 
45
- .. note::
49
+ return _AliasRestorer().visit(module)
46
50
 
47
- The ``_prettify`` parameter is experimental, which might break
48
- the generated code.
49
- """
51
+ def _find_dependencies(func):
52
+ dependencies = set()
50
53
 
51
- def wrapper(func):
52
- return JIT(func, _prettify=_prettify)()
54
+ for obj in func.__globals__.values():
55
+ if isinstance(obj, triton.runtime.JITFunction):
56
+ dependencies.add(obj.src)
53
57
 
54
- if func is None:
55
- return wrapper
58
+ return "\n".join(
59
+ f"@triton.jit\n{dependency}" for dependency in dependencies
60
+ )
56
61
 
57
- return wrapper(func)
62
+ self.launch_func_name = f"launch_{kernel_name}"
58
63
 
64
+ self._caller = caller
59
65
 
60
- class JIT:
61
- def __init__(self, func, _prettify=False):
62
- self.func = func
66
+ self._context = inspect.get_annotations(func)
63
67
 
64
- self._prettify = _prettify
68
+ self._args = list(self._context.values())
65
69
 
66
- def __call__(self):
67
- tree = self._get_tree()
70
+ tree = _get_tree(func)
68
71
 
69
- CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
72
+ self.visit(tree)
70
73
  Tritonizer().visit(tree)
71
74
  _BinOpSimplifier().visit(tree)
72
75
  ast.fix_missing_locations(tree)
73
76
 
74
- if self._prettify:
77
+ if prettify:
75
78
  name_collector = _SimplifiedNameCollector()
76
79
  name_collector.visit(tree)
77
80
 
78
81
  unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
79
- dependencies = self._find_dependencies()
82
+ dependencies = _find_dependencies(func)
80
83
  source = "\n\n".join((unparsed, dependencies)).strip()
84
+ source = source.replace(func.__name__, kernel_name)
81
85
 
82
- if self._prettify:
86
+ if prettify:
83
87
  for original, simplified in name_collector.simplified_names.items():
84
88
  if simplified not in name_collector.simplified_names:
85
89
  source = source.replace(original, simplified)
@@ -88,73 +92,29 @@ class JIT:
88
92
  ["ruff", "format", "-"], input=source, encoding="utf-8"
89
93
  )
90
94
 
91
- with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
92
- temp_file.write(source.encode("utf-8"))
93
- temp_file_name = temp_file.name
95
+ digest = hashlib.sha256(source.encode("utf-8")).hexdigest()
96
+ cache_dir = CACHE_DIR
97
+ cache_dir.mkdir(exist_ok=True)
98
+ cache_file = cache_dir / f"{digest}.py"
94
99
 
95
- module = type(self)._import_from_path(temp_file_name, temp_file_name)
96
- module_vars = vars(module)
100
+ if not cache_file.exists():
101
+ with open(cache_file, "w", encoding="utf-8") as f:
102
+ f.write(source)
97
103
 
98
- handle = _Handle(
99
- module_vars[self.func.__name__],
100
- module_vars[f"launch_{self.func.__name__}"],
101
- source,
102
- )
103
-
104
- return handle
105
-
106
- def _get_tree(self):
107
- module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
108
-
109
- collector = _ImportCollector()
110
- collector.visit(module)
111
-
112
- finder = _FunctionDefFinder(self.func.__name__)
113
- finder.visit(module)
114
- func_def = finder.result
104
+ self.tensors = self._args
105
+ self.kernel_func = self._func_def
106
+ self.launch_func = self._launch
115
107
 
116
- inliner = _Inliner(self.func.__globals__)
117
- inliner.visit(func_def)
118
- module.body = collector.imports + inliner.imports + [finder.result]
119
-
120
- return _AliasRestorer().visit(module)
121
-
122
- def _find_dependencies(self):
123
- dependencies = set()
124
-
125
- for obj in self.func.__globals__.values():
126
- if isinstance(obj, triton.runtime.JITFunction):
127
- dependencies.add(obj.src)
128
-
129
- return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
130
-
131
- @staticmethod
132
- def _import_from_path(module_name, file_path):
133
- spec = importlib.util.spec_from_file_location(module_name, file_path)
134
- module = importlib.util.module_from_spec(spec)
135
- sys.modules[module_name] = module
136
- spec.loader.exec_module(module)
137
-
138
- return module
139
-
140
-
141
- class CodeGenerator(ast.NodeTransformer):
142
- def __init__(self, context):
143
- super().__init__()
144
-
145
- self._context = context
146
-
147
- self._args = list(self._context.values())
148
-
149
- self._POWER_OF_TWOS = tuple(2**n for n in range(5, 11))
150
-
151
- self._MIN_PRODUCT = 2**10
152
-
153
- self._MAX_PRODUCT = 2**20
108
+ return str(cache_file)
154
109
 
155
110
  def visit_Module(self, node):
156
111
  self.generic_visit(node)
157
112
 
113
+ func_with_auto_tuning = f"{Symbol(self._autotune)}({self._func_def.name})"
114
+
115
+ node.body.append(
116
+ ast.parse(f"{self._func_name_with_auto_tuning} = {func_with_auto_tuning}")
117
+ )
158
118
  node.body.append(self._launch)
159
119
 
160
120
  return node
@@ -162,6 +122,8 @@ class CodeGenerator(ast.NodeTransformer):
162
122
  def visit_FunctionDef(self, node):
163
123
  self._func_def = node
164
124
 
125
+ self._func_name_with_auto_tuning = f"{self._func_def.name}_with_auto_tuning"
126
+
165
127
  self._invariants = {}
166
128
 
167
129
  self.generic_visit(node)
@@ -184,6 +146,9 @@ class CodeGenerator(ast.NodeTransformer):
184
146
  if naming.is_constexpr(name)
185
147
  }
186
148
 
149
+ non_meta_names = sorted(non_meta_names)
150
+ meta_names = sorted(meta_names)
151
+
187
152
  node.args = [
188
153
  ast.arg(arg=name)
189
154
  if not naming.is_constexpr(name)
@@ -194,8 +159,8 @@ class CodeGenerator(ast.NodeTransformer):
194
159
  for name in meta_names
195
160
  ]
196
161
 
197
- autotune = self._generate_autotune(non_meta_names, meta_names)
198
- self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
162
+ self._autotune = self._generate_autotune(non_meta_names, meta_names)
163
+ self._func_def.decorator_list = [Symbol("triton.jit").node]
199
164
 
200
165
  self._launch = self._generate_launch(non_meta_names, meta_names)
201
166
 
@@ -354,7 +319,7 @@ class CodeGenerator(ast.NodeTransformer):
354
319
  ]
355
320
 
356
321
  launch = ast.FunctionDef(
357
- name=f"launch_{self._func_def.name}",
322
+ name=self.launch_func_name,
358
323
  args=ast.arguments(
359
324
  posonlyargs=[],
360
325
  args=[ast.arg(arg=arg.source.name) for arg in self._args]
@@ -392,7 +357,9 @@ class CodeGenerator(ast.NodeTransformer):
392
357
  ast.Expr(
393
358
  ast.Call(
394
359
  func=ast.Subscript(
395
- value=ast.Name(id=self._func_def.name, ctx=ast.Load()),
360
+ value=ast.Name(
361
+ id=self._func_name_with_auto_tuning, ctx=ast.Load()
362
+ ),
396
363
  slice=self._generate_grid(),
397
364
  ctx=ast.Load(),
398
365
  ),
@@ -422,14 +389,23 @@ class CodeGenerator(ast.NodeTransformer):
422
389
 
423
390
  MetaEncloser(meta).visit(launch)
424
391
 
425
- Torchifier().visit(launch)
392
+ if self._caller == "torch":
393
+ Torchifier().visit(launch)
394
+ elif self._caller == "cuda":
395
+ Cudaifier().visit(launch)
396
+ else:
397
+ raise ValueError(f"Unsupported caller: `{self._caller}`.")
426
398
 
427
399
  return launch
428
400
 
429
401
  def _generate_grid(self):
430
402
  num_elements = functools.reduce(lambda x, y: x * y, self._args[0].shape)
431
403
 
432
- return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
404
+ grid = ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
405
+
406
+ self.raw_grid = copy.deepcopy(grid)
407
+
408
+ return grid
433
409
 
434
410
  def _generate_load(self, tensor, indices=()):
435
411
  if tensor.ndim == 0:
@@ -851,7 +827,7 @@ class _Inliner(ast.NodeTransformer):
851
827
  return names
852
828
 
853
829
  def _make_temporary():
854
- prefix = naming.auto_generate(f"temporary_{self._count}")
830
+ prefix = f"{naming.auto_generate(f'temporary_{self._count}')}_"
855
831
  self._count += 1
856
832
 
857
833
  return prefix
@@ -941,16 +917,6 @@ class _SimplifiedNameCollector(ast.NodeVisitor):
941
917
  self.simplified_names[node.id] = naming.remove_prefixes(node.id)
942
918
 
943
919
 
944
- class _Handle:
945
- def __init__(self, kernel, launch, source):
946
- self._kernel = kernel
947
- self._launch = launch
948
- self._source = source
949
-
950
- def __call__(self, *args, **kwargs):
951
- return self._launch(*args, **kwargs)
952
-
953
-
954
920
  class _AliasRestorer(ast.NodeTransformer):
955
921
  def __init__(self):
956
922
  super().__init__()
@@ -0,0 +1,77 @@
1
+ import importlib
2
+ import sys
3
+
4
+ from ninetoothed.generation import CodeGenerator
5
+
6
+
7
+ def jit(func=None, *, caller="torch", kernel_name=None, _prettify=False):
8
+ """A decorator for generating compute kernels.
9
+
10
+ :param func: The function to be compiled.
11
+ :param caller: Who will call the compute kernel.
12
+ :param kernel_name: The name for the generated kernel.
13
+ :param _prettify: Whether to prettify the generated code.
14
+ :return: A handle to the compute kernel.
15
+
16
+ .. note::
17
+
18
+ The ``_prettify`` parameter is experimental, which might break
19
+ the generated code.
20
+ """
21
+
22
+ def wrapper(func):
23
+ return JIT(func, caller=caller, kernel_name=kernel_name, _prettify=_prettify)()
24
+
25
+ if func is None:
26
+ return wrapper
27
+
28
+ return wrapper(func)
29
+
30
+
31
+ class JIT:
32
+ def __init__(self, func, caller, kernel_name, _prettify=False):
33
+ self.func = func
34
+
35
+ self._caller = caller
36
+
37
+ if kernel_name is not None:
38
+ self._kernel_name = kernel_name
39
+ else:
40
+ self._kernel_name = func.__name__
41
+
42
+ self._prettify = _prettify
43
+
44
+ def __call__(self):
45
+ code_generator = CodeGenerator()
46
+ source_file = code_generator(
47
+ self.func, self._caller, self._kernel_name, self._prettify
48
+ )
49
+ module = type(self)._import_from_path(source_file, source_file)
50
+ module_vars = vars(module)
51
+
52
+ handle = _Handle(
53
+ module_vars[self._kernel_name],
54
+ module_vars[code_generator.launch_func_name],
55
+ source_file,
56
+ )
57
+
58
+ return handle
59
+
60
+ @staticmethod
61
+ def _import_from_path(module_name, file_path):
62
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
63
+ module = importlib.util.module_from_spec(spec)
64
+ sys.modules[module_name] = module
65
+ spec.loader.exec_module(module)
66
+
67
+ return module
68
+
69
+
70
+ class _Handle:
71
+ def __init__(self, kernel, launch, source):
72
+ self._kernel = kernel
73
+ self._launch = launch
74
+ self._source = source
75
+
76
+ def __call__(self, *args, **kwargs):
77
+ return self._launch(*args, **kwargs)
@@ -0,0 +1,45 @@
1
+ import inspect
2
+
3
+ from ninetoothed.aot import aot
4
+ from ninetoothed.jit import jit
5
+
6
+
7
+ def make(
8
+ arrangement,
9
+ application,
10
+ tensors,
11
+ caller="torch",
12
+ kernel_name=None,
13
+ output_dir=None,
14
+ num_warps=4,
15
+ num_stages=3,
16
+ ):
17
+ """Integrate the arrangement and the application of the tensors.
18
+
19
+ :param arrangement: The arrangement of the tensors.
20
+ :param application: The application of the tensors.
21
+ :param tensors: The tensors.
22
+ :param caller: Who will call the compute kernel.
23
+ :param kernel_name: The name for the generated kernel.
24
+ :param output_dir: The directory to store the generated files.
25
+ :param num_warps: The number of warps to use.
26
+ :param num_stages: The number of pipeline stages.
27
+ :return: A handle to the compute kernel.
28
+ """
29
+
30
+ params = inspect.signature(application).parameters
31
+ types = arrangement(*tensors)
32
+ annotations = {param: type for param, type in zip(params, types)}
33
+ application.__annotations__ = annotations
34
+
35
+ if caller == "torch":
36
+ return jit(application, caller=caller, kernel_name=kernel_name)
37
+
38
+ return aot(
39
+ application,
40
+ caller=caller,
41
+ kernel_name=kernel_name,
42
+ output_dir=output_dir,
43
+ num_warps=num_warps,
44
+ num_stages=num_stages,
45
+ )
@@ -3,6 +3,7 @@ import random
3
3
  import torch
4
4
 
5
5
  import ninetoothed
6
+ import ninetoothed.language as ntl
6
7
  import tests.test_matmul as matmul
7
8
  from ninetoothed import Tensor
8
9
  from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
@@ -19,8 +20,9 @@ def arrangement(input, mat1, mat2, beta, alpha, output):
19
20
 
20
21
 
21
22
  def application(input, mat1, mat2, beta, alpha, output):
22
- matmul.application(mat1, mat2, output)
23
- output = beta * input + alpha * output
23
+ matmul_output = ntl.zeros(output.shape, dtype=ntl.float32)
24
+ matmul.application(mat1, mat2, matmul_output)
25
+ output = beta * input + alpha * matmul_output
24
26
 
25
27
 
26
28
  def addmm(input, mat1, mat2, beta=1, alpha=1):
@@ -43,6 +45,7 @@ def addmm(input, mat1, mat2, beta=1, alpha=1):
43
45
  class TestCUDA:
44
46
  @classmethod
45
47
  def setup_class(cls):
48
+ random.seed(0)
46
49
  torch.manual_seed(0)
47
50
 
48
51
  shape = (512, 512)
@@ -74,9 +77,6 @@ class TestCUDA:
74
77
  beta = type(self).beta
75
78
  alpha = type(self).alpha
76
79
 
77
- # TODO: The current application function inlining feature
78
- # causes some precision issues. Consider reducing `atol` and
79
- # `rtol` of this test in the future.
80
80
  assert torch.allclose(
81
81
  addmm(input, mat1, mat2, beta=beta, alpha=alpha),
82
82
  torch.addmm(
@@ -86,6 +86,5 @@ class TestCUDA:
86
86
  beta=beta,
87
87
  alpha=alpha,
88
88
  ),
89
- atol=0.5,
90
- rtol=0.5,
89
+ atol=0.125,
91
90
  )
@@ -1,5 +0,0 @@
1
- from ninetoothed.jit import jit, make
2
- from ninetoothed.symbol import Symbol
3
- from ninetoothed.tensor import Tensor
4
-
5
- __all__ = ["Symbol", "Tensor", "jit", "make"]
File without changes
File without changes
File without changes
File without changes
File without changes