ninetoothed 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/aot.py +61 -10
- ninetoothed/generation.py +12 -9
- ninetoothed/tensor.py +17 -0
- {ninetoothed-0.17.0.dist-info → ninetoothed-0.18.0.dist-info}/METADATA +1 -1
- {ninetoothed-0.17.0.dist-info → ninetoothed-0.18.0.dist-info}/RECORD +7 -7
- {ninetoothed-0.17.0.dist-info → ninetoothed-0.18.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.17.0.dist-info → ninetoothed-0.18.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/aot.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
import ast
|
2
2
|
import pathlib
|
3
|
+
import re
|
3
4
|
import subprocess
|
4
5
|
import tempfile
|
5
6
|
import uuid
|
6
7
|
|
8
|
+
import ninetoothed.naming as naming
|
7
9
|
from ninetoothed.dtype import int64
|
8
10
|
from ninetoothed.generation import CACHE_DIR, CodeGenerator
|
9
11
|
from ninetoothed.tensor import Tensor
|
@@ -25,8 +27,10 @@ def aot(
|
|
25
27
|
|
26
28
|
def _aot(func, caller, kernel_name, num_warps, num_stages):
|
27
29
|
def _find_tensor_by_source_name(tensors, name):
|
30
|
+
name = naming.remove_prefixes(name)
|
31
|
+
|
28
32
|
for tensor in tensors:
|
29
|
-
if tensor.source.name == name:
|
33
|
+
if naming.remove_prefixes(tensor.source.name) == name:
|
30
34
|
return tensor
|
31
35
|
|
32
36
|
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
@@ -49,11 +53,15 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
49
53
|
kernel_func = code_generator.kernel_func
|
50
54
|
launch_func = code_generator.launch_func
|
51
55
|
|
56
|
+
param_strings = ["stream"]
|
52
57
|
param_types = []
|
58
|
+
constexpr_param_indices = []
|
53
59
|
|
54
60
|
for arg in kernel_func.args.args:
|
55
61
|
param = arg.arg
|
56
62
|
|
63
|
+
param_strings.append(param)
|
64
|
+
|
57
65
|
if match := Tensor.pointer_pattern().fullmatch(param):
|
58
66
|
source_name = match.group(0).removesuffix("_pointer")
|
59
67
|
tensor = _find_tensor_by_source_name(tensors, source_name)
|
@@ -64,9 +72,23 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
64
72
|
param_types.append(int64)
|
65
73
|
elif Tensor.stride_pattern().fullmatch(param):
|
66
74
|
param_types.append(int64)
|
75
|
+
else:
|
76
|
+
source_name = param
|
77
|
+
tensor = _find_tensor_by_source_name(tensors, source_name)
|
78
|
+
dtype = tensor.source.dtype
|
79
|
+
|
80
|
+
if tensor.constexpr:
|
81
|
+
param_types.append(f"{tensor.value}")
|
82
|
+
constexpr_param_indices.append(len(param_types) - 1)
|
83
|
+
else:
|
84
|
+
param_types.append(dtype)
|
67
85
|
|
68
86
|
signature = ", ".join(param_types)
|
69
87
|
|
88
|
+
for index in sorted(set(constexpr_param_indices), reverse=True):
|
89
|
+
param_strings.pop(index + 1)
|
90
|
+
param_types.pop(index)
|
91
|
+
|
70
92
|
grid_extractor = _GridExtractor()
|
71
93
|
launch_func = grid_extractor.visit(launch_func)
|
72
94
|
grid_extractor.visit(code_generator.raw_grid)
|
@@ -76,21 +98,26 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
76
98
|
source_file, kernel_name, signature, grid, num_warps, num_stages
|
77
99
|
)
|
78
100
|
|
79
|
-
|
101
|
+
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
102
|
+
c_source_file = output_contents[c_source_file_name]
|
103
|
+
|
104
|
+
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
105
|
+
c_header_file = output_contents[c_header_file_name]
|
106
|
+
|
107
|
+
pattern = rf"\({', '.join(rf'(.*) {param}' for param in param_strings)}\)"
|
108
|
+
c_param_type_strings = re.search(pattern, c_header_file).groups()
|
109
|
+
|
110
|
+
unparser = _Unparser(c_param_type_strings)
|
80
111
|
|
81
112
|
launch_func_unparsed = unparser.unparse(launch_func)
|
82
113
|
launch_func_unparsed = launch_func_unparsed.replace(
|
83
114
|
func.__name__, f"{kernel_name}_{signature_hash}"
|
84
115
|
)
|
85
116
|
|
86
|
-
c_source_file_name = f"{kernel_name}.{signature_hash}.c"
|
87
|
-
c_source_file = output_contents[c_source_file_name]
|
88
117
|
c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
|
89
118
|
c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
90
119
|
output_contents[c_source_file_name] = c_source_file
|
91
120
|
|
92
|
-
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
93
|
-
c_header_file = output_contents[c_header_file_name]
|
94
121
|
c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
|
95
122
|
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
96
123
|
output_contents[c_header_file_name] = c_header_file
|
@@ -120,6 +147,9 @@ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
|
120
147
|
|
121
148
|
|
122
149
|
class _Unparser:
|
150
|
+
def __init__(self, param_types):
|
151
|
+
self._param_types = param_types
|
152
|
+
|
123
153
|
def unparse(self, node):
|
124
154
|
method_name = "_unparse_" + node.__class__.__name__
|
125
155
|
|
@@ -137,11 +167,29 @@ class _Unparser:
|
|
137
167
|
def _unparse_Call(self, node):
|
138
168
|
call = ast.Call(
|
139
169
|
func=node.func,
|
140
|
-
args=[ast.Name(id="stream", ctx=ast.Load())]
|
170
|
+
args=[ast.Name(id="stream", ctx=ast.Load())]
|
171
|
+
+ [
|
172
|
+
arg
|
173
|
+
for arg in node.args
|
174
|
+
if not isinstance(arg, ast.Name) or not naming.is_constexpr(arg.id)
|
175
|
+
],
|
141
176
|
keywords=[],
|
142
177
|
)
|
143
178
|
|
144
|
-
|
179
|
+
unparsed = f"return {self._generic_unparse(call)};"
|
180
|
+
|
181
|
+
pattern = rf"\((stream), {', '.join(r'([^,]*)' for _ in range(len(self._param_types) - 1))}\)"
|
182
|
+
args = re.search(pattern, unparsed).groups()
|
183
|
+
|
184
|
+
for i, (arg, type) in enumerate(zip(args, self._param_types)):
|
185
|
+
if i != 0 and "." not in arg:
|
186
|
+
new_arg = f"*({type} *){arg}.data"
|
187
|
+
else:
|
188
|
+
new_arg = f"({type}){arg}"
|
189
|
+
|
190
|
+
unparsed = unparsed.replace(arg, new_arg)
|
191
|
+
|
192
|
+
return unparsed
|
145
193
|
|
146
194
|
def _unparse_FunctionDef(self, node):
|
147
195
|
params = ["NineToothedStream stream"]
|
@@ -153,13 +201,16 @@ class _Unparser:
|
|
153
201
|
body_lines = []
|
154
202
|
|
155
203
|
for stmt in node.body:
|
204
|
+
if isinstance(stmt, ast.Assign):
|
205
|
+
continue
|
206
|
+
|
156
207
|
stmt_unparsed = self.unparse(stmt)
|
157
208
|
|
158
209
|
if isinstance(stmt, ast.Expr):
|
159
210
|
stmt_unparsed = stmt_unparsed.strip()
|
160
211
|
|
161
|
-
|
162
|
-
|
212
|
+
if not stmt_unparsed.endswith(";"):
|
213
|
+
stmt_unparsed += ";"
|
163
214
|
|
164
215
|
body_lines.append(" " + stmt_unparsed)
|
165
216
|
|
ninetoothed/generation.py
CHANGED
@@ -289,12 +289,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
289
289
|
if isinstance(value, Tensor):
|
290
290
|
attr = getattr(value, node.attr)
|
291
291
|
|
292
|
-
if node.attr == "dtype" and attr is None:
|
293
|
-
return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
|
294
|
-
|
295
292
|
if isinstance(attr, Tensor):
|
296
293
|
return attr
|
297
294
|
|
295
|
+
if node.attr == "dtype":
|
296
|
+
return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
|
297
|
+
|
298
298
|
return Symbol(attr).node
|
299
299
|
|
300
300
|
self.generic_visit(node)
|
@@ -500,16 +500,19 @@ class CodeGenerator(ast.NodeTransformer):
|
|
500
500
|
naming.remove_prefixes(param) for param in next_power_of_2_params
|
501
501
|
]
|
502
502
|
|
503
|
+
arg_names = [naming.remove_prefixes(arg.source.name) for arg in self._args]
|
504
|
+
|
505
|
+
arg_names += [
|
506
|
+
param
|
507
|
+
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
508
|
+
if not Tensor.size_pattern().fullmatch(param) and param not in arg_names
|
509
|
+
]
|
510
|
+
|
503
511
|
launch = ast.FunctionDef(
|
504
512
|
name=self.launch_func_name,
|
505
513
|
args=ast.arguments(
|
506
514
|
posonlyargs=[],
|
507
|
-
args=[ast.arg(arg=
|
508
|
-
+ [
|
509
|
-
ast.arg(arg=param)
|
510
|
-
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
511
|
-
if not Tensor.size_pattern().fullmatch(param)
|
512
|
-
],
|
515
|
+
args=[ast.arg(arg=name) for name in arg_names],
|
513
516
|
kwonlyargs=[],
|
514
517
|
defaults=[],
|
515
518
|
),
|
ninetoothed/tensor.py
CHANGED
@@ -32,6 +32,8 @@ class Tensor:
|
|
32
32
|
strides=None,
|
33
33
|
other=None,
|
34
34
|
shape_options=None,
|
35
|
+
constexpr=None,
|
36
|
+
value=None,
|
35
37
|
name=None,
|
36
38
|
source=None,
|
37
39
|
source_dims=None,
|
@@ -74,6 +76,21 @@ class Tensor:
|
|
74
76
|
|
75
77
|
self.other = other
|
76
78
|
|
79
|
+
if constexpr and self.ndim != 0:
|
80
|
+
raise ValueError(
|
81
|
+
"`constexpr` can only be set for zero-dimensional tensors."
|
82
|
+
)
|
83
|
+
|
84
|
+
self.constexpr = constexpr
|
85
|
+
|
86
|
+
if self.constexpr:
|
87
|
+
self.name = naming.make_constexpr(self.name)
|
88
|
+
|
89
|
+
if not constexpr and value is not None:
|
90
|
+
raise ValueError("`value` can only be set for constexpr tensors.")
|
91
|
+
|
92
|
+
self.value = value
|
93
|
+
|
77
94
|
if source is not None:
|
78
95
|
self.source = source
|
79
96
|
else:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.18.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -1,18 +1,18 @@
|
|
1
1
|
ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
|
2
|
-
ninetoothed/aot.py,sha256=
|
2
|
+
ninetoothed/aot.py,sha256=VLPFRNZgq82DumuVMi36_qptM5nkORzmhbP4uPa559Q,8173
|
3
3
|
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
4
|
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
-
ninetoothed/generation.py,sha256=
|
5
|
+
ninetoothed/generation.py,sha256=zbqRWvpa-1q44WuZV9S13DDAxvi4dai2AJ47ihjODsM,38150
|
6
6
|
ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
|
7
7
|
ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
|
8
8
|
ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
|
9
9
|
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
10
|
ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
|
11
|
-
ninetoothed/tensor.py,sha256=
|
11
|
+
ninetoothed/tensor.py,sha256=lK8s5-l5cqhM9FCWXMjTle9vA1Nass_92tvuHY8H3OM,15265
|
12
12
|
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
13
|
ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
|
14
14
|
ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
|
15
|
-
ninetoothed-0.
|
16
|
-
ninetoothed-0.
|
17
|
-
ninetoothed-0.
|
18
|
-
ninetoothed-0.
|
15
|
+
ninetoothed-0.18.0.dist-info/METADATA,sha256=X4TvwcjVuB40X4jmCsRMee8Auj2mGyORYsbk81fd-G0,7340
|
16
|
+
ninetoothed-0.18.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
ninetoothed-0.18.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
18
|
+
ninetoothed-0.18.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|