ninetoothed 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/aot.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import ast
2
2
  import pathlib
3
+ import re
3
4
  import subprocess
4
5
  import tempfile
5
6
  import uuid
6
7
 
8
+ import ninetoothed.naming as naming
7
9
  from ninetoothed.dtype import int64
8
10
  from ninetoothed.generation import CACHE_DIR, CodeGenerator
9
11
  from ninetoothed.tensor import Tensor
@@ -25,13 +27,15 @@ def aot(
25
27
 
26
28
  def _aot(func, caller, kernel_name, num_warps, num_stages):
27
29
  def _find_tensor_by_source_name(tensors, name):
30
+ name = naming.remove_prefixes(name)
31
+
28
32
  for tensor in tensors:
29
- if tensor.source.name == name:
33
+ if naming.remove_prefixes(tensor.source.name) == name:
30
34
  return tensor
31
35
 
32
36
  _HEADER_PATH.parent.mkdir(exist_ok=True)
33
37
 
34
- if not _HEADER_PATH.exists():
38
+ if not _HEADER_PATH.exists() or _HEADER_PATH.read_text() != _HEADER_CONTENT:
35
39
  _HEADER_PATH.write_text(_HEADER_CONTENT)
36
40
 
37
41
  code_generator = CodeGenerator()
@@ -49,11 +53,15 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
49
53
  kernel_func = code_generator.kernel_func
50
54
  launch_func = code_generator.launch_func
51
55
 
56
+ param_strings = ["stream"]
52
57
  param_types = []
58
+ constexpr_param_indices = []
53
59
 
54
60
  for arg in kernel_func.args.args:
55
61
  param = arg.arg
56
62
 
63
+ param_strings.append(param)
64
+
57
65
  if match := Tensor.pointer_pattern().fullmatch(param):
58
66
  source_name = match.group(0).removesuffix("_pointer")
59
67
  tensor = _find_tensor_by_source_name(tensors, source_name)
@@ -64,9 +72,23 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
64
72
  param_types.append(int64)
65
73
  elif Tensor.stride_pattern().fullmatch(param):
66
74
  param_types.append(int64)
75
+ else:
76
+ source_name = param
77
+ tensor = _find_tensor_by_source_name(tensors, source_name)
78
+ dtype = tensor.source.dtype
79
+
80
+ if tensor.constexpr:
81
+ param_types.append(f"{tensor.value}")
82
+ constexpr_param_indices.append(len(param_types) - 1)
83
+ else:
84
+ param_types.append(dtype)
67
85
 
68
86
  signature = ", ".join(param_types)
69
87
 
88
+ for index in sorted(set(constexpr_param_indices), reverse=True):
89
+ param_strings.pop(index + 1)
90
+ param_types.pop(index)
91
+
70
92
  grid_extractor = _GridExtractor()
71
93
  launch_func = grid_extractor.visit(launch_func)
72
94
  grid_extractor.visit(code_generator.raw_grid)
@@ -76,41 +98,58 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
76
98
  source_file, kernel_name, signature, grid, num_warps, num_stages
77
99
  )
78
100
 
79
- unparser = _Unparser()
101
+ c_source_file_name = f"{kernel_name}.{signature_hash}.c"
102
+ c_source_file = output_contents[c_source_file_name]
103
+
104
+ c_header_file_name = f"{kernel_name}.{signature_hash}.h"
105
+ c_header_file = output_contents[c_header_file_name]
106
+
107
+ pattern = rf"\({', '.join(rf'(.*) {param}' for param in param_strings)}\)"
108
+ c_param_type_strings = re.search(pattern, c_header_file).groups()
109
+
110
+ unparser = _Unparser(c_param_type_strings)
80
111
 
81
112
  launch_func_unparsed = unparser.unparse(launch_func)
82
113
  launch_func_unparsed = launch_func_unparsed.replace(
83
114
  func.__name__, f"{kernel_name}_{signature_hash}"
84
115
  )
85
116
 
86
- c_source_file_name = f"{kernel_name}.{signature_hash}.c"
87
- c_source_file = output_contents[c_source_file_name]
88
117
  c_source_file = f"{c_source_file}\n{launch_func_unparsed}\n"
89
118
  c_source_file = c_source_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
90
119
  output_contents[c_source_file_name] = c_source_file
91
120
 
92
- c_header_file_name = f"{kernel_name}.{signature_hash}.h"
93
- c_header_file = output_contents[c_header_file_name]
94
- c_header_file = f"{c_header_file}\n{unparser.header};\n"
121
+ c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
95
122
  c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
96
123
  output_contents[c_header_file_name] = c_header_file
97
124
 
98
125
  return output_contents
99
126
 
100
127
 
101
- _HEADER_CONTENT = """#include <stdint.h>
128
+ _HEADER_CONTENT = """#ifndef NINETOOTHED_H
129
+ #define NINETOOTHED_H
130
+
131
+ #include <stdint.h>
102
132
 
103
133
  typedef struct {
104
- uintptr_t data;
134
+ void *data;
105
135
  uint64_t *shape;
106
136
  int64_t *strides;
107
137
  } NineToothedTensor;
138
+
139
+ typedef void *NineToothedStream;
140
+
141
+ typedef int NineToothedResult;
142
+
143
+ #endif // NINETOOTHED_H
108
144
  """
109
145
 
110
146
  _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
111
147
 
112
148
 
113
149
  class _Unparser:
150
+ def __init__(self, param_types):
151
+ self._param_types = param_types
152
+
114
153
  def unparse(self, node):
115
154
  method_name = "_unparse_" + node.__class__.__name__
116
155
 
@@ -128,29 +167,50 @@ class _Unparser:
128
167
  def _unparse_Call(self, node):
129
168
  call = ast.Call(
130
169
  func=node.func,
131
- args=[ast.Name(id="stream", ctx=ast.Load())] + node.args,
170
+ args=[ast.Name(id="stream", ctx=ast.Load())]
171
+ + [
172
+ arg
173
+ for arg in node.args
174
+ if not isinstance(arg, ast.Name) or not naming.is_constexpr(arg.id)
175
+ ],
132
176
  keywords=[],
133
177
  )
134
178
 
135
- return f"return {self._generic_unparse(call)};"
179
+ unparsed = f"return {self._generic_unparse(call)};"
180
+
181
+ pattern = rf"\((stream), {', '.join(r'([^,]*)' for _ in range(len(self._param_types) - 1))}\)"
182
+ args = re.search(pattern, unparsed).groups()
183
+
184
+ for i, (arg, type) in enumerate(zip(args, self._param_types)):
185
+ if i != 0 and "." not in arg:
186
+ new_arg = f"*({type} *){arg}.data"
187
+ else:
188
+ new_arg = f"({type}){arg}"
189
+
190
+ unparsed = unparsed.replace(arg, new_arg)
191
+
192
+ return unparsed
136
193
 
137
194
  def _unparse_FunctionDef(self, node):
138
- params = ["CUstream stream"]
195
+ params = ["NineToothedStream stream"]
139
196
  params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
140
- header = f"CUresult {node.name}({', '.join(params)})"
197
+ header = f"NineToothedResult {node.name}({', '.join(params)})"
141
198
 
142
199
  self.header = header
143
200
 
144
201
  body_lines = []
145
202
 
146
203
  for stmt in node.body:
204
+ if isinstance(stmt, ast.Assign):
205
+ continue
206
+
147
207
  stmt_unparsed = self.unparse(stmt)
148
208
 
149
209
  if isinstance(stmt, ast.Expr):
150
210
  stmt_unparsed = stmt_unparsed.strip()
151
211
 
152
- if not stmt_unparsed.endswith(";"):
153
- stmt_unparsed += ";"
212
+ if not stmt_unparsed.endswith(";"):
213
+ stmt_unparsed += ";"
154
214
 
155
215
  body_lines.append(" " + stmt_unparsed)
156
216
 
ninetoothed/generation.py CHANGED
@@ -19,6 +19,7 @@ import uuid
19
19
  import sympy
20
20
  import triton
21
21
  import triton.language as tl
22
+ from triton.language.extra import libdevice
22
23
 
23
24
  import ninetoothed.naming as naming
24
25
  from ninetoothed.cudaifier import Cudaifier
@@ -225,6 +226,41 @@ class CodeGenerator(ast.NodeTransformer):
225
226
 
226
227
  return node
227
228
 
229
+ def visit_Call(self, node):
230
+ def _offsets(tensor, dim=None):
231
+ if dim is None:
232
+ return tensor._last_generated_overall_offsets.node
233
+
234
+ offsets = tensor._last_generated_offsets
235
+
236
+ if dim < 0:
237
+ dim += tensor.source.ndim
238
+
239
+ return sum(
240
+ offsets[dim][target_dim] for target_dim in range(tensor.target.ndim)
241
+ ).node
242
+
243
+ func = node.func
244
+ args = node.args
245
+
246
+ if isinstance(func, ast.Attribute):
247
+ if func.attr == "offsets":
248
+ value = func.value
249
+
250
+ if self._in_context(value):
251
+ tensor = self._context[value.id]
252
+ elif isinstance(value, ast.Subscript) and self._in_context(value.value):
253
+ tensor = self._context[value.value.id]
254
+
255
+ self.visit(value)
256
+
257
+ # TODO: Add error handling.
258
+ return _offsets(tensor, ast.literal_eval(args[0]) if args else None)
259
+
260
+ self.generic_visit(node)
261
+
262
+ return node
263
+
228
264
  def visit_Subscript(self, node):
229
265
  if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
230
266
  value = self._context[node.value.id]
@@ -242,13 +278,24 @@ class CodeGenerator(ast.NodeTransformer):
242
278
  return node
243
279
 
244
280
  def visit_Attribute(self, node):
245
- if self._in_context(node.value):
246
- value = self._context[node.value.id]
281
+ value = node.value
247
282
 
248
- if isinstance(value, Tensor):
249
- inner = value.dtype
283
+ if isinstance(value, ast.Attribute):
284
+ value = self.visit_Attribute(value)
285
+
286
+ if self._in_context(value):
287
+ value = self._context[value.id].dtype
288
+
289
+ if isinstance(value, Tensor):
290
+ attr = getattr(value, node.attr)
291
+
292
+ if isinstance(attr, Tensor):
293
+ return attr
250
294
 
251
- return Symbol(getattr(inner, node.attr)).node
295
+ if node.attr == "dtype":
296
+ return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
297
+
298
+ return Symbol(attr).node
252
299
 
253
300
  self.generic_visit(node)
254
301
 
@@ -453,16 +500,19 @@ class CodeGenerator(ast.NodeTransformer):
453
500
  naming.remove_prefixes(param) for param in next_power_of_2_params
454
501
  ]
455
502
 
503
+ arg_names = [naming.remove_prefixes(arg.source.name) for arg in self._args]
504
+
505
+ arg_names += [
506
+ param
507
+ for param in non_next_power_of_2_constexpr_params_without_prefixes
508
+ if not Tensor.size_pattern().fullmatch(param) and param not in arg_names
509
+ ]
510
+
456
511
  launch = ast.FunctionDef(
457
512
  name=self.launch_func_name,
458
513
  args=ast.arguments(
459
514
  posonlyargs=[],
460
- args=[ast.arg(arg=arg.source.name) for arg in self._args]
461
- + [
462
- ast.arg(arg=param)
463
- for param in non_next_power_of_2_constexpr_params_without_prefixes
464
- if not Tensor.size_pattern().fullmatch(param)
465
- ],
515
+ args=[ast.arg(arg=name) for name in arg_names],
466
516
  kwonlyargs=[],
467
517
  defaults=[],
468
518
  ),
@@ -560,6 +610,8 @@ class CodeGenerator(ast.NodeTransformer):
560
610
  indices = self._complete_indices(tensor, indices)
561
611
  offsets = type(self)._generate_offsets(tensor, indices)
562
612
 
613
+ tensor._last_generated_offsets = offsets
614
+
563
615
  for source_dim in range(tensor.source.ndim):
564
616
  for target_dim in range(tensor.target.ndim):
565
617
  if target_dim not in invariant_target_dims:
@@ -584,7 +636,7 @@ class CodeGenerator(ast.NodeTransformer):
584
636
  * tensor.source.strides[source_dim]
585
637
  )
586
638
 
587
- pointers = name_for_pointers + sum(
639
+ overall_offsets = sum(
588
640
  offsets[source_dim][target_dim][
589
641
  type(self)._generate_slices(tensor, target_dim)
590
642
  ]
@@ -594,6 +646,10 @@ class CodeGenerator(ast.NodeTransformer):
594
646
  if target_dim not in invariant_target_dims
595
647
  and offsets[source_dim][target_dim] != 0
596
648
  )
649
+
650
+ tensor._last_generated_overall_offsets = overall_offsets
651
+
652
+ pointers = name_for_pointers + overall_offsets
597
653
  mask = functools.reduce(
598
654
  lambda x, y: x & y,
599
655
  (
@@ -980,6 +1036,9 @@ class _Inliner(ast.NodeTransformer):
980
1036
  if func_def is None:
981
1037
  return None, []
982
1038
 
1039
+ if inspect.getmodule(func) is libdevice:
1040
+ return None, []
1041
+
983
1042
  collector = _ImportCollector()
984
1043
  collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
985
1044
  self.imports.extend(collector.imports)
ninetoothed/language.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import ast
2
2
 
3
+ from triton.language.extra import libdevice
4
+
3
5
  from ninetoothed.symbol import Symbol
4
6
 
7
+ __all__ = ["libdevice"]
8
+
5
9
  LANGUAGE = "ninetoothed.language"
6
10
 
7
11
 
ninetoothed/tensor.py CHANGED
@@ -32,6 +32,8 @@ class Tensor:
32
32
  strides=None,
33
33
  other=None,
34
34
  shape_options=None,
35
+ constexpr=None,
36
+ value=None,
35
37
  name=None,
36
38
  source=None,
37
39
  source_dims=None,
@@ -74,6 +76,21 @@ class Tensor:
74
76
 
75
77
  self.other = other
76
78
 
79
+ if constexpr and self.ndim != 0:
80
+ raise ValueError(
81
+ "`constexpr` can only be set for zero-dimensional tensors."
82
+ )
83
+
84
+ self.constexpr = constexpr
85
+
86
+ if self.constexpr:
87
+ self.name = naming.make_constexpr(self.name)
88
+
89
+ if not constexpr and value is not None:
90
+ raise ValueError("`value` can only be set for constexpr tensors.")
91
+
92
+ self.value = value
93
+
77
94
  if source is not None:
78
95
  self.source = source
79
96
  else:
@@ -10,8 +10,6 @@ def visualize(tensor, color=None, save_path=None):
10
10
  :param color: The color to be used for visualization.
11
11
  :param save_path: The path where the visualization should be saved.
12
12
  """
13
- outline_width = 0.1
14
- plt.rcParams["lines.linewidth"] = 72 * outline_width
15
13
 
16
14
  if color is None:
17
15
  color = f"C{visualize.count}"
@@ -21,6 +19,24 @@ def visualize(tensor, color=None, save_path=None):
21
19
  width = max_pos_y + 1
22
20
  height = max_pos_x + 1
23
21
 
22
+ _, ax = _prepare_figure_and_axes(width, height)
23
+
24
+ _visualize_tensor(ax, tensor, 0, 0, color)
25
+
26
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
27
+
28
+ plt.close()
29
+
30
+ visualize.count += 1
31
+
32
+
33
+ visualize.count = 0
34
+
35
+
36
+ def _prepare_figure_and_axes(width, height):
37
+ outline_width = 0.1
38
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
39
+
24
40
  fig = plt.figure(figsize=(width + outline_width, height + outline_width))
25
41
 
26
42
  h = (Size.Fixed(0), Size.Fixed(width + outline_width))
@@ -41,16 +57,7 @@ def visualize(tensor, color=None, save_path=None):
41
57
  plt.xlim((-half_outline_width, width + half_outline_width))
42
58
  plt.ylim((-half_outline_width, height + half_outline_width))
43
59
 
44
- _visualize_tensor(ax, tensor, 0, 0, color)
45
-
46
- plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
47
-
48
- plt.close()
49
-
50
- visualize.count += 1
51
-
52
-
53
- visualize.count = 0
60
+ return fig, ax
54
61
 
55
62
 
56
63
  def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.16.0
3
+ Version: 0.18.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -1,18 +1,18 @@
1
1
  ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
2
- ninetoothed/aot.py,sha256=8ZCLtnsign14YvY7SXX5ASidhuUAhPwppTXUJNkQup4,6243
2
+ ninetoothed/aot.py,sha256=VLPFRNZgq82DumuVMi36_qptM5nkORzmhbP4uPa559Q,8173
3
3
  ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
4
  ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
- ninetoothed/generation.py,sha256=VIqSyZT4yHxY_a2QPmWW6jjALv3e1mohDqdRQBRYsAo,36462
5
+ ninetoothed/generation.py,sha256=zbqRWvpa-1q44WuZV9S13DDAxvi4dai2AJ47ihjODsM,38150
6
6
  ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
7
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
7
+ ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
8
8
  ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
9
9
  ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
10
  ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
11
- ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
11
+ ninetoothed/tensor.py,sha256=lK8s5-l5cqhM9FCWXMjTle9vA1Nass_92tvuHY8H3OM,15265
12
12
  ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
13
  ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
14
- ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
15
- ninetoothed-0.16.0.dist-info/METADATA,sha256=nkq3iImebtmcEs-bZq2zfF2_QxrZD9IWky1S86OnUMA,7340
16
- ninetoothed-0.16.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- ninetoothed-0.16.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
- ninetoothed-0.16.0.dist-info/RECORD,,
14
+ ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
15
+ ninetoothed-0.18.0.dist-info/METADATA,sha256=X4TvwcjVuB40X4jmCsRMee8Auj2mGyORYsbk81fd-G0,7340
16
+ ninetoothed-0.18.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ ninetoothed-0.18.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ ninetoothed-0.18.0.dist-info/RECORD,,