ninetoothed 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/aot.py CHANGED
@@ -31,7 +31,7 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
31
31
 
32
32
  _HEADER_PATH.parent.mkdir(exist_ok=True)
33
33
 
34
- if not _HEADER_PATH.exists():
34
+ if not _HEADER_PATH.exists() or _HEADER_PATH.read_text() != _HEADER_CONTENT:
35
35
  _HEADER_PATH.write_text(_HEADER_CONTENT)
36
36
 
37
37
  code_generator = CodeGenerator()
@@ -91,20 +91,29 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
91
91
 
92
92
  c_header_file_name = f"{kernel_name}.{signature_hash}.h"
93
93
  c_header_file = output_contents[c_header_file_name]
94
- c_header_file = f"{c_header_file}\n{unparser.header};\n"
94
+ c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
95
95
  c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
96
96
  output_contents[c_header_file_name] = c_header_file
97
97
 
98
98
  return output_contents
99
99
 
100
100
 
101
- _HEADER_CONTENT = """#include <stdint.h>
101
+ _HEADER_CONTENT = """#ifndef NINETOOTHED_H
102
+ #define NINETOOTHED_H
103
+
104
+ #include <stdint.h>
102
105
 
103
106
  typedef struct {
104
- uintptr_t data;
107
+ void *data;
105
108
  uint64_t *shape;
106
109
  int64_t *strides;
107
110
  } NineToothedTensor;
111
+
112
+ typedef void *NineToothedStream;
113
+
114
+ typedef int NineToothedResult;
115
+
116
+ #endif // NINETOOTHED_H
108
117
  """
109
118
 
110
119
  _HEADER_PATH = CACHE_DIR / "ninetoothed.h"
@@ -135,9 +144,9 @@ class _Unparser:
135
144
  return f"return {self._generic_unparse(call)};"
136
145
 
137
146
  def _unparse_FunctionDef(self, node):
138
- params = ["CUstream stream"]
147
+ params = ["NineToothedStream stream"]
139
148
  params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
140
- header = f"CUresult {node.name}({', '.join(params)})"
149
+ header = f"NineToothedResult {node.name}({', '.join(params)})"
141
150
 
142
151
  self.header = header
143
152
 
ninetoothed/generation.py CHANGED
@@ -19,6 +19,7 @@ import uuid
19
19
  import sympy
20
20
  import triton
21
21
  import triton.language as tl
22
+ from triton.language.extra import libdevice
22
23
 
23
24
  import ninetoothed.naming as naming
24
25
  from ninetoothed.cudaifier import Cudaifier
@@ -225,6 +226,41 @@ class CodeGenerator(ast.NodeTransformer):
225
226
 
226
227
  return node
227
228
 
229
+ def visit_Call(self, node):
230
+ def _offsets(tensor, dim=None):
231
+ if dim is None:
232
+ return tensor._last_generated_overall_offsets.node
233
+
234
+ offsets = tensor._last_generated_offsets
235
+
236
+ if dim < 0:
237
+ dim += tensor.source.ndim
238
+
239
+ return sum(
240
+ offsets[dim][target_dim] for target_dim in range(tensor.target.ndim)
241
+ ).node
242
+
243
+ func = node.func
244
+ args = node.args
245
+
246
+ if isinstance(func, ast.Attribute):
247
+ if func.attr == "offsets":
248
+ value = func.value
249
+
250
+ if self._in_context(value):
251
+ tensor = self._context[value.id]
252
+ elif isinstance(value, ast.Subscript) and self._in_context(value.value):
253
+ tensor = self._context[value.value.id]
254
+
255
+ self.visit(value)
256
+
257
+ # TODO: Add error handling.
258
+ return _offsets(tensor, ast.literal_eval(args[0]) if args else None)
259
+
260
+ self.generic_visit(node)
261
+
262
+ return node
263
+
228
264
  def visit_Subscript(self, node):
229
265
  if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
230
266
  value = self._context[node.value.id]
@@ -242,13 +278,24 @@ class CodeGenerator(ast.NodeTransformer):
242
278
  return node
243
279
 
244
280
  def visit_Attribute(self, node):
245
- if self._in_context(node.value):
246
- value = self._context[node.value.id]
281
+ value = node.value
247
282
 
248
- if isinstance(value, Tensor):
249
- inner = value.dtype
283
+ if isinstance(value, ast.Attribute):
284
+ value = self.visit_Attribute(value)
285
+
286
+ if self._in_context(value):
287
+ value = self._context[value.id].dtype
288
+
289
+ if isinstance(value, Tensor):
290
+ attr = getattr(value, node.attr)
291
+
292
+ if node.attr == "dtype" and attr is None:
293
+ return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
250
294
 
251
- return Symbol(getattr(inner, node.attr)).node
295
+ if isinstance(attr, Tensor):
296
+ return attr
297
+
298
+ return Symbol(attr).node
252
299
 
253
300
  self.generic_visit(node)
254
301
 
@@ -560,6 +607,8 @@ class CodeGenerator(ast.NodeTransformer):
560
607
  indices = self._complete_indices(tensor, indices)
561
608
  offsets = type(self)._generate_offsets(tensor, indices)
562
609
 
610
+ tensor._last_generated_offsets = offsets
611
+
563
612
  for source_dim in range(tensor.source.ndim):
564
613
  for target_dim in range(tensor.target.ndim):
565
614
  if target_dim not in invariant_target_dims:
@@ -584,7 +633,7 @@ class CodeGenerator(ast.NodeTransformer):
584
633
  * tensor.source.strides[source_dim]
585
634
  )
586
635
 
587
- pointers = name_for_pointers + sum(
636
+ overall_offsets = sum(
588
637
  offsets[source_dim][target_dim][
589
638
  type(self)._generate_slices(tensor, target_dim)
590
639
  ]
@@ -594,6 +643,10 @@ class CodeGenerator(ast.NodeTransformer):
594
643
  if target_dim not in invariant_target_dims
595
644
  and offsets[source_dim][target_dim] != 0
596
645
  )
646
+
647
+ tensor._last_generated_overall_offsets = overall_offsets
648
+
649
+ pointers = name_for_pointers + overall_offsets
597
650
  mask = functools.reduce(
598
651
  lambda x, y: x & y,
599
652
  (
@@ -980,6 +1033,9 @@ class _Inliner(ast.NodeTransformer):
980
1033
  if func_def is None:
981
1034
  return None, []
982
1035
 
1036
+ if inspect.getmodule(func) is libdevice:
1037
+ return None, []
1038
+
983
1039
  collector = _ImportCollector()
984
1040
  collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
985
1041
  self.imports.extend(collector.imports)
ninetoothed/language.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import ast
2
2
 
3
+ from triton.language.extra import libdevice
4
+
3
5
  from ninetoothed.symbol import Symbol
4
6
 
7
+ __all__ = ["libdevice"]
8
+
5
9
  LANGUAGE = "ninetoothed.language"
6
10
 
7
11
 
@@ -10,8 +10,6 @@ def visualize(tensor, color=None, save_path=None):
10
10
  :param color: The color to be used for visualization.
11
11
  :param save_path: The path where the visualization should be saved.
12
12
  """
13
- outline_width = 0.1
14
- plt.rcParams["lines.linewidth"] = 72 * outline_width
15
13
 
16
14
  if color is None:
17
15
  color = f"C{visualize.count}"
@@ -21,6 +19,24 @@ def visualize(tensor, color=None, save_path=None):
21
19
  width = max_pos_y + 1
22
20
  height = max_pos_x + 1
23
21
 
22
+ _, ax = _prepare_figure_and_axes(width, height)
23
+
24
+ _visualize_tensor(ax, tensor, 0, 0, color)
25
+
26
+ plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
27
+
28
+ plt.close()
29
+
30
+ visualize.count += 1
31
+
32
+
33
+ visualize.count = 0
34
+
35
+
36
+ def _prepare_figure_and_axes(width, height):
37
+ outline_width = 0.1
38
+ plt.rcParams["lines.linewidth"] = 72 * outline_width
39
+
24
40
  fig = plt.figure(figsize=(width + outline_width, height + outline_width))
25
41
 
26
42
  h = (Size.Fixed(0), Size.Fixed(width + outline_width))
@@ -41,16 +57,7 @@ def visualize(tensor, color=None, save_path=None):
41
57
  plt.xlim((-half_outline_width, width + half_outline_width))
42
58
  plt.ylim((-half_outline_width, height + half_outline_width))
43
59
 
44
- _visualize_tensor(ax, tensor, 0, 0, color)
45
-
46
- plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
47
-
48
- plt.close()
49
-
50
- visualize.count += 1
51
-
52
-
53
- visualize.count = 0
60
+ return fig, ax
54
61
 
55
62
 
56
63
  def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.16.0
3
+ Version: 0.17.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -1,18 +1,18 @@
1
1
  ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
2
- ninetoothed/aot.py,sha256=8ZCLtnsign14YvY7SXX5ASidhuUAhPwppTXUJNkQup4,6243
2
+ ninetoothed/aot.py,sha256=b7ykTC5roe_xg3NkZv6VyInBrEiNRwjpixCULUPRuEg,6506
3
3
  ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
4
4
  ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
5
- ninetoothed/generation.py,sha256=VIqSyZT4yHxY_a2QPmWW6jjALv3e1mohDqdRQBRYsAo,36462
5
+ ninetoothed/generation.py,sha256=wf8BL-x0PR6rG-9OSpgIZi8LtsIdFbqRUFiQFE5FIno,38107
6
6
  ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
7
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
7
+ ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
8
8
  ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
9
9
  ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
10
10
  ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
11
11
  ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
12
12
  ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
13
13
  ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
14
- ninetoothed/visualization.py,sha256=zlMH-0WplaboePGzcbpcj4UovpX0k2r4SysSPsNS4r4,3674
15
- ninetoothed-0.16.0.dist-info/METADATA,sha256=nkq3iImebtmcEs-bZq2zfF2_QxrZD9IWky1S86OnUMA,7340
16
- ninetoothed-0.16.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- ninetoothed-0.16.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
- ninetoothed-0.16.0.dist-info/RECORD,,
14
+ ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
15
+ ninetoothed-0.17.0.dist-info/METADATA,sha256=_V2M45nT4Yin-zs7hq5-yHlN6KwV5_zcA8afwXP8S-Q,7340
16
+ ninetoothed-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ ninetoothed-0.17.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ ninetoothed-0.17.0.dist-info/RECORD,,