ninetoothed 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/aot.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import ast
2
2
  import pathlib
3
+ import re
3
4
  import subprocess
4
5
  import tempfile
5
6
  import uuid
6
7
 
8
+ import ninetoothed.naming as naming
7
9
  from ninetoothed.dtype import int64
8
10
  from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
11
  from ninetoothed.tensor import Tensor
@@ -25,8 +27,10 @@ def aot(
25
27
 
26
28
  def _aot(func, caller, kernel_name, num_warps, num_stages):
27
29
  def _find_tensor_by_source_name(tensors, name):
30
+ name = naming.remove_prefixes(name)
31
+
28
32
  for tensor in tensors:
29
- if tensor.source.name == name:
33
+ if naming.remove_prefixes(tensor.source.name) == name:
30
34
  return tensor
31
35
 
32
36
  _HEADER_PATH.parent.mkdir(exist_ok=True)
@@ -49,11 +53,15 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
49
53
  kernel_func = code_generator.kernel_func
50
54
  launch_func = code_generator.launch_func
51
55
 
56
+ param_strings = ["stream"]
52
57
  param_types = []
58
+ constexpr_param_indices = []
53
59
 
54
60
  for arg in kernel_func.args.args:
55
61
  param = arg.arg
56
62
 
63
+ param_strings.append(param)
64
+
57
65
  if match := Tensor.pointer_pattern().fullmatch(param):
58
66
  source_name = match.group(0).removesuffix("_pointer")
59
67
  tensor = _find_tensor_by_source_name(tensors, source_name)
@@ -64,9 +72,23 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
64
72
  param_types.append(int64)
65
73
  elif Tensor.stride_pattern().fullmatch(param):
66
74
  param_types.append(int64)
75
+ else:
76
+ source_name = param
77
+ tensor = _find_tensor_by_source_name(tensors, source_name)
78
+ dtype = tensor.source.dtype
79
+
80
+ if tensor.constexpr:
81
+ param_types.append(f"{tensor.value}")
82
+ constexpr_param_indices.append(len(param_types) - 1)
83
+ else:
84
+ param_types.append(dtype)
67
85
 
68
86
  signature = ", ".join(param_types)
69
87
 
88
+ for index in sorted(set(constexpr_param_indices), reverse=True):
89
+ param_strings.pop(index + 1)
90
+ param_types.pop(index)
91
+
70
92
  grid_extractor = _GridExtractor()
71
93
  launch_func = grid_extractor.visit(launch_func)
72
94
  grid_extractor.visit(code_generator.raw_grid)
@@ -76,21 +98,26 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
76
98
  source_file, kernel_name, signature, grid, num_warps, num_stages
77
99
  )
78
100
 
79
- unparser = _Unparser()
101
+ c_source_file_name = f"{kernel_name}.{signature_hash}.c"
102
+ c_source_file = output_contents[c_source_file_name]
103
+
104
+ c_header_file_name = f"{kernel_name}.{signature_hash}.h"
105
+ c_header_file = output_contents[c_header_file_name]
106
+
107
+ pattern = rf"\({', '.join(rf'(.*) {param}' for param in param_strings)}\)"
108
+ c_param_type_strings = re.search(pattern, c_header_file).groups()
109
+
110
+ unparser = _Unparser(c_param_type_strings)
80
111
 
81
112
  launch_func_unparsed = unparser.unparse(launch_func)
82
113
  launch_func_unparsed = launch_func_unparsed.replace(
83
114
  func.__name__, f"{kernel_name}_{signature_hash}"
84
115
  )
85
116
 
86
- c_source_file_name = f"{kernel_name}.{signature_hash}.c"
87
- c_source_file = output_contents[c_source_file_name]
88
117
  c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
89
118
  c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
119
  output_contents[c_source_file_name] = c_source_file
91
120
 
92
- c_header_file_name = f"{kernel_name}.{signature_hash}.h"
93
- c_header_file = output_contents[c_header_file_name]
94
121
  c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
95
122
  c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
96
123
  output_contents[c_header_file_name] = c_header_file
@@ -120,6 +147,9 @@ _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
120
147
 
121
148
 
122
149
  class _Unparser:
150
+ def __init__(self, param_types):
151
+ self._param_types = param_types
152
+
123
153
  def unparse(self, node):
124
154
  method_name = "_unparse_" + node.__class__.__name__
125
155
 
@@ -137,11 +167,29 @@ class _Unparser:
137
167
  def _unparse_Call(self, node):
138
168
  call = ast.Call(
139
169
  func=node.func,
140
- args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
170
+ args=[ast.Name(id="stream", ctx=ast.Load())]
171
+ + [
172
+ arg
173
+ for arg in node.args
174
+ if not isinstance(arg, ast.Name) or not naming.is_constexpr(arg.id)
175
+ ],
141
176
  keywords=[],
142
177
  )
143
178
 
144
- return f"return {self._generic_unparse(call)};"
179
+ unparsed = f"return {self._generic_unparse(call)};"
180
+
181
+ pattern = rf"\((stream), {', '.join(r'([^,]*)' for _ in range(len(self._param_types) - 1))}\)"
182
+ args = re.search(pattern, unparsed).groups()
183
+
184
+ for i, (arg, type) in enumerate(zip(args, self._param_types)):
185
+ if i != 0 and "." not in arg:
186
+ new_arg = f"*({type} *){arg}.data"
187
+ else:
188
+ new_arg = f"({type}){arg}"
189
+
190
+ unparsed = unparsed.replace(arg, new_arg)
191
+
192
+ return unparsed
145
193
 
146
194
  def _unparse_FunctionDef(self, node):
147
195
  params = ["NineToothedStream stream"]
@@ -153,13 +201,16 @@ class _Unparser:
153
201
  body_lines = []
154
202
 
155
203
  for stmt in node.body:
204
+ if isinstance(stmt, ast.Assign):
205
+ continue
206
+
156
207
  stmt_unparsed = self.unparse(stmt)
157
208
 
158
209
  if isinstance(stmt, ast.Expr):
159
210
  stmt_unparsed = stmt_unparsed.strip()
160
211
 
161
- if not stmt_unparsed.endswith(";"):
162
- stmt_unparsed += ";"
212
+ if not stmt_unparsed.endswith(";"):
213
+ stmt_unparsed += ";"
163
214
 
164
215
  body_lines.append(" " + stmt_unparsed)
165
216
 
ninetoothed/generation.py CHANGED
@@ -289,12 +289,12 @@ class CodeGenerator(ast.NodeTransformer):
289
289
  if isinstance(value, Tensor):
290
290
  attr = getattr(value, node.attr)
291
291
 
292
- if node.attr == "dtype" and attr is None:
293
- return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
294
-
295
292
  if isinstance(attr, Tensor):
296
293
  return attr
297
294
 
295
+ if node.attr == "dtype":
296
+ return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
297
+
298
298
  return Symbol(attr).node
299
299
 
300
300
  self.generic_visit(node)
@@ -500,16 +500,19 @@ class CodeGenerator(ast.NodeTransformer):
500
500
  naming.remove_prefixes(param) for param in next_power_of_2_params
501
501
  ]
502
502
 
503
+ arg_names = [naming.remove_prefixes(arg.source.name) for arg in self._args]
504
+
505
+ arg_names += [
506
+ param
507
+ for param in non_next_power_of_2_constexpr_params_without_prefixes
508
+ if not Tensor.size_pattern().fullmatch(param) and param not in arg_names
509
+ ]
510
+
503
511
  launch = ast.FunctionDef(
504
512
  name=self.launch_func_name,
505
513
  args=ast.arguments(
506
514
  posonlyargs=[],
507
- args=[ast.arg(arg=arg.source.name) for arg in self._args]
508
- + [
509
- ast.arg(arg=param)
510
- for param in non_next_power_of_2_constexpr_params_without_prefixes
511
- if not Tensor.size_pattern().fullmatch(param)
512
- ],
515
+ args=[ast.arg(arg=name) for name in arg_names],
513
516
  kwonlyargs=[],
514
517
  defaults=[],
515
518
  ),
ninetoothed/tensor.py CHANGED
@@ -32,6 +32,8 @@ class Tensor:
32
32
  strides=None,
33
33
  other=None,
34
34
  shape_options=None,
35
+ constexpr=None,
36
+ value=None,
35
37
  name=None,
36
38
  source=None,
37
39
  source_dims=None,
@@ -74,6 +76,21 @@ class Tensor:
74
76
 
75
77
  self.other = other
76
78
 
79
+ if constexpr and self.ndim != 0:
80
+ raise ValueError(
81
+ "`constexpr` can only be set for zero-dimensional tensors."
82
+ )
83
+
84
+ self.constexpr = constexpr
85
+
86
+ if self.constexpr:
87
+ self.name = naming.make_constexpr(self.name)
88
+
89
+ if not constexpr and value is not None:
90
+ raise ValueError("`value` can only be set for constexpr tensors.")
91
+
92
+ self.value = value
93
+
77
94
  if source is not None:
78
95
  self.source = source
79
96
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.17.0
3
+ Version: 0.18.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -1,18 +1,18 @@
1
1
  ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
2
- ninetoothed/aot.py,sha256=b7ykTC5roe_xg3NkZv6VyInBrEiNRwjpixCULUPRuEg,6506
2
+ ninetoothed/aot.py,sha256=VLPFRNZgq82DumuVMi36_qptM5nkORzmhbP4uPa559Q,8173
3
3
  ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
4
  ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
- ninetoothed/generation.py,sha256=wf8BL-x0PR6rG-9OSpgIZi8LtsIdFbqRUFiQFE5FIno,38107
5
+ ninetoothed/generation.py,sha256=zbqRWvpa-1q44WuZV9S13DDAxvi4dai2AJ47ihjODsM,38150
6
6
  ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
7
7
  ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
8
8
  ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
9
9
  ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
10
  ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
11
- ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
11
+ ninetoothed/tensor.py,sha256=lK8s5-l5cqhM9FCWXMjTle9vA1Nass_92tvuHY8H3OM,15265
12
12
  ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
13
  ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
14
14
  ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
15
- ninetoothed-0.17.0.dist-info/METADATA,sha256=_V2M45nT4Yin-zs7hq5-yHlN6KwV5_zcA8afwXP8S-Q,7340
16
- ninetoothed-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- ninetoothed-0.17.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
- ninetoothed-0.17.0.dist-info/RECORD,,
15
+ ninetoothed-0.18.0.dist-info/METADATA,sha256=X4TvwcjVuB40X4jmCsRMee8Auj2mGyORYsbk81fd-G0,7340
16
+ ninetoothed-0.18.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ ninetoothed-0.18.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ ninetoothed-0.18.0.dist-info/RECORD,,