ninetoothed 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/aot.py +15 -6
- ninetoothed/generation.py +62 -6
- ninetoothed/language.py +4 -0
- ninetoothed/visualization.py +19 -12
- {ninetoothed-0.16.0.dist-info → ninetoothed-0.17.0.dist-info}/METADATA +1 -1
- {ninetoothed-0.16.0.dist-info → ninetoothed-0.17.0.dist-info}/RECORD +8 -8
- {ninetoothed-0.16.0.dist-info → ninetoothed-0.17.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.16.0.dist-info → ninetoothed-0.17.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/aot.py
CHANGED
@@ -31,7 +31,7 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
31
31
|
|
32
32
|
_HEADER_PATH.parent.mkdir(exist_ok=True)
|
33
33
|
|
34
|
-
if not _HEADER_PATH.exists():
|
34
|
+
if not _HEADER_PATH.exists() or _HEADER_PATH.read_text() != _HEADER_CONTENT:
|
35
35
|
_HEADER_PATH.write_text(_HEADER_CONTENT)
|
36
36
|
|
37
37
|
code_generator = CodeGenerator()
|
@@ -91,20 +91,29 @@ def _aot(func, caller, kernel_name, num_warps, num_stages):
|
|
91
91
|
|
92
92
|
c_header_file_name = f"{kernel_name}.{signature_hash}.h"
|
93
93
|
c_header_file = output_contents[c_header_file_name]
|
94
|
-
c_header_file = f
|
94
|
+
c_header_file = f'{c_header_file}\n#ifdef __cplusplus\nextern "C" {unparser.header};\n#else\n{unparser.header};\n#endif\n'
|
95
95
|
c_header_file = c_header_file.replace("<stdint.h>", f'"{_HEADER_PATH}"')
|
96
96
|
output_contents[c_header_file_name] = c_header_file
|
97
97
|
|
98
98
|
return output_contents
|
99
99
|
|
100
100
|
|
101
|
-
_HEADER_CONTENT = """#
|
101
|
+
_HEADER_CONTENT = """#ifndef NINETOOTHED_H
|
102
|
+
#define NINETOOTHED_H
|
103
|
+
|
104
|
+
#include <stdint.h>
|
102
105
|
|
103
106
|
typedef struct {
|
104
|
-
|
107
|
+
void *data;
|
105
108
|
uint64_t *shape;
|
106
109
|
int64_t *strides;
|
107
110
|
} NineToothedTensor;
|
111
|
+
|
112
|
+
typedef void *NineToothedStream;
|
113
|
+
|
114
|
+
typedef int NineToothedResult;
|
115
|
+
|
116
|
+
#endif // NINETOOTHED_H
|
108
117
|
"""
|
109
118
|
|
110
119
|
_HEADER_PATH = CACHE_DIR / "ninetoothed.h"
|
@@ -135,9 +144,9 @@ class _Unparser:
|
|
135
144
|
return f"return {self._generic_unparse(call)};"
|
136
145
|
|
137
146
|
def _unparse_FunctionDef(self, node):
|
138
|
-
params = ["
|
147
|
+
params = ["NineToothedStream stream"]
|
139
148
|
params += [f"NineToothedTensor {arg.arg}" for arg in node.args.args]
|
140
|
-
header = f"
|
149
|
+
header = f"NineToothedResult {node.name}({', '.join(params)})"
|
141
150
|
|
142
151
|
self.header = header
|
143
152
|
|
ninetoothed/generation.py
CHANGED
@@ -19,6 +19,7 @@ import uuid
|
|
19
19
|
import sympy
|
20
20
|
import triton
|
21
21
|
import triton.language as tl
|
22
|
+
from triton.language.extra import libdevice
|
22
23
|
|
23
24
|
import ninetoothed.naming as naming
|
24
25
|
from ninetoothed.cudaifier import Cudaifier
|
@@ -225,6 +226,41 @@ class CodeGenerator(ast.NodeTransformer):
|
|
225
226
|
|
226
227
|
return node
|
227
228
|
|
229
|
+
def visit_Call(self, node):
|
230
|
+
def _offsets(tensor, dim=None):
|
231
|
+
if dim is None:
|
232
|
+
return tensor._last_generated_overall_offsets.node
|
233
|
+
|
234
|
+
offsets = tensor._last_generated_offsets
|
235
|
+
|
236
|
+
if dim < 0:
|
237
|
+
dim += tensor.source.ndim
|
238
|
+
|
239
|
+
return sum(
|
240
|
+
offsets[dim][target_dim] for target_dim in range(tensor.target.ndim)
|
241
|
+
).node
|
242
|
+
|
243
|
+
func = node.func
|
244
|
+
args = node.args
|
245
|
+
|
246
|
+
if isinstance(func, ast.Attribute):
|
247
|
+
if func.attr == "offsets":
|
248
|
+
value = func.value
|
249
|
+
|
250
|
+
if self._in_context(value):
|
251
|
+
tensor = self._context[value.id]
|
252
|
+
elif isinstance(value, ast.Subscript) and self._in_context(value.value):
|
253
|
+
tensor = self._context[value.value.id]
|
254
|
+
|
255
|
+
self.visit(value)
|
256
|
+
|
257
|
+
# TODO: Add error handling.
|
258
|
+
return _offsets(tensor, ast.literal_eval(args[0]) if args else None)
|
259
|
+
|
260
|
+
self.generic_visit(node)
|
261
|
+
|
262
|
+
return node
|
263
|
+
|
228
264
|
def visit_Subscript(self, node):
|
229
265
|
if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
|
230
266
|
value = self._context[node.value.id]
|
@@ -242,13 +278,24 @@ class CodeGenerator(ast.NodeTransformer):
|
|
242
278
|
return node
|
243
279
|
|
244
280
|
def visit_Attribute(self, node):
|
245
|
-
|
246
|
-
value = self._context[node.value.id]
|
281
|
+
value = node.value
|
247
282
|
|
248
|
-
|
249
|
-
|
283
|
+
if isinstance(value, ast.Attribute):
|
284
|
+
value = self.visit_Attribute(value)
|
285
|
+
|
286
|
+
if self._in_context(value):
|
287
|
+
value = self._context[value.id].dtype
|
288
|
+
|
289
|
+
if isinstance(value, Tensor):
|
290
|
+
attr = getattr(value, node.attr)
|
291
|
+
|
292
|
+
if node.attr == "dtype" and attr is None:
|
293
|
+
return Symbol(f"{value.source.pointer_string()}.type.element_ty").node
|
250
294
|
|
251
|
-
|
295
|
+
if isinstance(attr, Tensor):
|
296
|
+
return attr
|
297
|
+
|
298
|
+
return Symbol(attr).node
|
252
299
|
|
253
300
|
self.generic_visit(node)
|
254
301
|
|
@@ -560,6 +607,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
560
607
|
indices = self._complete_indices(tensor, indices)
|
561
608
|
offsets = type(self)._generate_offsets(tensor, indices)
|
562
609
|
|
610
|
+
tensor._last_generated_offsets = offsets
|
611
|
+
|
563
612
|
for source_dim in range(tensor.source.ndim):
|
564
613
|
for target_dim in range(tensor.target.ndim):
|
565
614
|
if target_dim not in invariant_target_dims:
|
@@ -584,7 +633,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
584
633
|
* tensor.source.strides[source_dim]
|
585
634
|
)
|
586
635
|
|
587
|
-
|
636
|
+
overall_offsets = sum(
|
588
637
|
offsets[source_dim][target_dim][
|
589
638
|
type(self)._generate_slices(tensor, target_dim)
|
590
639
|
]
|
@@ -594,6 +643,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
594
643
|
if target_dim not in invariant_target_dims
|
595
644
|
and offsets[source_dim][target_dim] != 0
|
596
645
|
)
|
646
|
+
|
647
|
+
tensor._last_generated_overall_offsets = overall_offsets
|
648
|
+
|
649
|
+
pointers = name_for_pointers + overall_offsets
|
597
650
|
mask = functools.reduce(
|
598
651
|
lambda x, y: x & y,
|
599
652
|
(
|
@@ -980,6 +1033,9 @@ class _Inliner(ast.NodeTransformer):
|
|
980
1033
|
if func_def is None:
|
981
1034
|
return None, []
|
982
1035
|
|
1036
|
+
if inspect.getmodule(func) is libdevice:
|
1037
|
+
return None, []
|
1038
|
+
|
983
1039
|
collector = _ImportCollector()
|
984
1040
|
collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
|
985
1041
|
self.imports.extend(collector.imports)
|
ninetoothed/language.py
CHANGED
ninetoothed/visualization.py
CHANGED
@@ -10,8 +10,6 @@ def visualize(tensor, color=None, save_path=None):
|
|
10
10
|
:param color: The color to be used for visualization.
|
11
11
|
:param save_path: The path where the visualization should be saved.
|
12
12
|
"""
|
13
|
-
outline_width = 0.1
|
14
|
-
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
15
13
|
|
16
14
|
if color is None:
|
17
15
|
color = f"C{visualize.count}"
|
@@ -21,6 +19,24 @@ def visualize(tensor, color=None, save_path=None):
|
|
21
19
|
width = max_pos_y + 1
|
22
20
|
height = max_pos_x + 1
|
23
21
|
|
22
|
+
_, ax = _prepare_figure_and_axes(width, height)
|
23
|
+
|
24
|
+
_visualize_tensor(ax, tensor, 0, 0, color)
|
25
|
+
|
26
|
+
plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
|
27
|
+
|
28
|
+
plt.close()
|
29
|
+
|
30
|
+
visualize.count += 1
|
31
|
+
|
32
|
+
|
33
|
+
visualize.count = 0
|
34
|
+
|
35
|
+
|
36
|
+
def _prepare_figure_and_axes(width, height):
|
37
|
+
outline_width = 0.1
|
38
|
+
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
39
|
+
|
24
40
|
fig = plt.figure(figsize=(width + outline_width, height + outline_width))
|
25
41
|
|
26
42
|
h = (Size.Fixed(0), Size.Fixed(width + outline_width))
|
@@ -41,16 +57,7 @@ def visualize(tensor, color=None, save_path=None):
|
|
41
57
|
plt.xlim((-half_outline_width, width + half_outline_width))
|
42
58
|
plt.ylim((-half_outline_width, height + half_outline_width))
|
43
59
|
|
44
|
-
|
45
|
-
|
46
|
-
plt.savefig(save_path, transparent=True, bbox_inches="tight", pad_inches=0)
|
47
|
-
|
48
|
-
plt.close()
|
49
|
-
|
50
|
-
visualize.count += 1
|
51
|
-
|
52
|
-
|
53
|
-
visualize.count = 0
|
60
|
+
return fig, ax
|
54
61
|
|
55
62
|
|
56
63
|
def _visualize_tensor(ax, tensor, x, y, color, level_spacing=4):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.17.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -1,18 +1,18 @@
|
|
1
1
|
ninetoothed/__init__.py,sha256=F2bxRNhzcGdtADA8RehTuf-QK0xnxno8kxvr6H2L5Tg,552
|
2
|
-
ninetoothed/aot.py,sha256=
|
2
|
+
ninetoothed/aot.py,sha256=b7ykTC5roe_xg3NkZv6VyInBrEiNRwjpixCULUPRuEg,6506
|
3
3
|
ninetoothed/cudaifier.py,sha256=5ylMr1q0B9NwbeXkpCu3o2nMGpDfh65nAQ0Az_qMQuI,877
|
4
4
|
ninetoothed/dtype.py,sha256=-0iBleay5gYA4wtT3l17QjCesr7g26M6CSfhNJdI3k4,165
|
5
|
-
ninetoothed/generation.py,sha256=
|
5
|
+
ninetoothed/generation.py,sha256=wf8BL-x0PR6rG-9OSpgIZi8LtsIdFbqRUFiQFE5FIno,38107
|
6
6
|
ninetoothed/jit.py,sha256=CpeSkO_zUe9DwtTJ2K2H7Bwpx-FvIHfrgzOcEosfpek,2946
|
7
|
-
ninetoothed/language.py,sha256=
|
7
|
+
ninetoothed/language.py,sha256=ERiA4dpwiow2AT2xFeFWYg1KqlnBo6xxPGp8VZrP0Lk,574
|
8
8
|
ninetoothed/make.py,sha256=fQKuRJL7HC2iGTAN323mlIWXz9Z3jotIoN68ur29Qlw,1834
|
9
9
|
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
10
10
|
ninetoothed/symbol.py,sha256=lJo3NL2-T7tKbKjb6MCRLMemN94mqS3bIiG943P0Mbo,7454
|
11
11
|
ninetoothed/tensor.py,sha256=gQEzHTcXqZVBFLc2YRfXTKxjxPWMxWN7fNl2BCfJwMs,14782
|
12
12
|
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
13
13
|
ninetoothed/utils.py,sha256=mtRXABBVPnlgd2n1REh9oB3s_5bUsKhd3iwu3oJ5DSQ,338
|
14
|
-
ninetoothed/visualization.py,sha256=
|
15
|
-
ninetoothed-0.
|
16
|
-
ninetoothed-0.
|
17
|
-
ninetoothed-0.
|
18
|
-
ninetoothed-0.
|
14
|
+
ninetoothed/visualization.py,sha256=oc3cA5qqT66_RoAs5D681SCxR5E5wgFwk95ZefdSfZU,3794
|
15
|
+
ninetoothed-0.17.0.dist-info/METADATA,sha256=_V2M45nT4Yin-zs7hq5-yHlN6KwV5_zcA8afwXP8S-Q,7340
|
16
|
+
ninetoothed-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
ninetoothed-0.17.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
18
|
+
ninetoothed-0.17.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|