ninetoothed 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +116 -26
- ninetoothed/symbol.py +30 -7
- ninetoothed/tensor.py +16 -2
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.2.0.dist-info}/METADATA +1 -1
- ninetoothed-0.2.0.dist-info/RECORD +10 -0
- ninetoothed-0.1.1.dist-info/RECORD +0 -10
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.2.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.2.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -5,7 +5,6 @@ import inspect
|
|
5
5
|
import itertools
|
6
6
|
import math
|
7
7
|
import tempfile
|
8
|
-
import textwrap
|
9
8
|
|
10
9
|
from ninetoothed.language import attribute, call
|
11
10
|
from ninetoothed.symbol import Symbol
|
@@ -33,8 +32,7 @@ class JIT:
|
|
33
32
|
):
|
34
33
|
return type(self).handles[source_file][source_line]
|
35
34
|
|
36
|
-
|
37
|
-
tree = ast.parse(source)
|
35
|
+
tree = self._get_tree()
|
38
36
|
|
39
37
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
40
38
|
Tritonizer().visit(tree)
|
@@ -56,15 +54,7 @@ class JIT:
|
|
56
54
|
namespace = {}
|
57
55
|
exec(code, namespace)
|
58
56
|
|
59
|
-
|
60
|
-
def __init__(self, kernel, launch):
|
61
|
-
self._kernel = kernel
|
62
|
-
self._launch = launch
|
63
|
-
|
64
|
-
def __call__(self, *args, **kwargs):
|
65
|
-
return self._launch(*args, **kwargs)
|
66
|
-
|
67
|
-
handle = Handle(
|
57
|
+
handle = _Handle(
|
68
58
|
namespace[self.func.__name__],
|
69
59
|
namespace[f"launch_{self.func.__name__}"],
|
70
60
|
)
|
@@ -73,6 +63,15 @@ class JIT:
|
|
73
63
|
|
74
64
|
return handle
|
75
65
|
|
66
|
+
def _get_tree(self):
|
67
|
+
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
68
|
+
|
69
|
+
_AliasRestorer().visit(module)
|
70
|
+
finder = _FunctionDefFinder(self.func.__name__)
|
71
|
+
finder.visit(module)
|
72
|
+
|
73
|
+
return ast.Module(body=[finder.result], type_ignores=[])
|
74
|
+
|
76
75
|
|
77
76
|
class CodeGenerator(ast.NodeTransformer):
|
78
77
|
def __init__(self, context):
|
@@ -100,6 +99,18 @@ class CodeGenerator(ast.NodeTransformer):
|
|
100
99
|
|
101
100
|
self.generic_visit(node)
|
102
101
|
|
102
|
+
for arg in self._args:
|
103
|
+
if not isinstance(arg, Tensor):
|
104
|
+
continue
|
105
|
+
|
106
|
+
node.body.insert(
|
107
|
+
0,
|
108
|
+
ast.Assign(
|
109
|
+
targets=[Symbol(f"{arg.name}_ptrs").node],
|
110
|
+
value=arg.pointers().node,
|
111
|
+
),
|
112
|
+
)
|
113
|
+
|
103
114
|
return node
|
104
115
|
|
105
116
|
def visit_arguments(self, node):
|
@@ -136,12 +147,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
136
147
|
value = self._context[node.value.id]
|
137
148
|
|
138
149
|
if isinstance(value, Tensor):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
150
|
+
pointers = type(self)._create_pointers(
|
151
|
+
value,
|
152
|
+
node.slice.elts
|
153
|
+
if isinstance(node.slice, ast.Tuple)
|
154
|
+
else (node.slice,),
|
155
|
+
)
|
145
156
|
|
146
157
|
return call("load", pointers).node
|
147
158
|
|
@@ -166,7 +177,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
166
177
|
self.generic_visit(node)
|
167
178
|
|
168
179
|
if node.id in self._context and isinstance(node.ctx, ast.Load):
|
169
|
-
return call(
|
180
|
+
return call(
|
181
|
+
"load", type(self)._create_pointers(self._context[node.id], ()).node
|
182
|
+
).node
|
170
183
|
|
171
184
|
return node
|
172
185
|
|
@@ -180,7 +193,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
180
193
|
return ast.Expr(
|
181
194
|
call(
|
182
195
|
"store",
|
183
|
-
self._context[target.id]
|
196
|
+
type(self)._create_pointers(self._context[target.id], ()).node,
|
184
197
|
node.value,
|
185
198
|
).node
|
186
199
|
)
|
@@ -195,13 +208,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
195
208
|
if isinstance(value, Tensor):
|
196
209
|
self.generic_visit(node)
|
197
210
|
|
198
|
-
|
211
|
+
pointers = type(self)._create_pointers(
|
212
|
+
value,
|
199
213
|
target.slice.elts
|
200
214
|
if isinstance(target.slice, ast.Tuple)
|
201
|
-
else target.slice
|
215
|
+
else (target.slice,),
|
202
216
|
)
|
203
|
-
offsets = value.offsets(indices)
|
204
|
-
pointers = value.pointers(offsets)
|
205
217
|
|
206
218
|
return ast.Expr(
|
207
219
|
call(
|
@@ -316,6 +328,14 @@ class CodeGenerator(ast.NodeTransformer):
|
|
316
328
|
|
317
329
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
318
330
|
|
331
|
+
@staticmethod
|
332
|
+
def _create_pointers(tensor, indices):
|
333
|
+
return Symbol(f"{tensor.name}_ptrs") + tensor.offsets(
|
334
|
+
[0 for _ in range(tensor.ndim())]
|
335
|
+
+ list(indices)
|
336
|
+
+ [0 for _ in range(tensor.inmost().ndim())]
|
337
|
+
)
|
338
|
+
|
319
339
|
|
320
340
|
class Tritonizer(ast.NodeTransformer):
|
321
341
|
def visit_Module(self, node):
|
@@ -329,8 +349,8 @@ class Tritonizer(ast.NodeTransformer):
|
|
329
349
|
def visit_Name(self, node):
|
330
350
|
self.generic_visit(node)
|
331
351
|
|
332
|
-
if node.id == "ninetoothed":
|
333
|
-
node.id = "triton"
|
352
|
+
if node.id == "ninetoothed" or "ninetoothed." in node.id:
|
353
|
+
node.id = node.id.replace("ninetoothed", "triton")
|
334
354
|
|
335
355
|
return node
|
336
356
|
|
@@ -348,3 +368,73 @@ class Tritonizer(ast.NodeTransformer):
|
|
348
368
|
)
|
349
369
|
|
350
370
|
return node
|
371
|
+
|
372
|
+
|
373
|
+
class _Handle:
|
374
|
+
def __init__(self, kernel, launch):
|
375
|
+
self._kernel = kernel
|
376
|
+
self._launch = launch
|
377
|
+
|
378
|
+
def __call__(self, *args, **kwargs):
|
379
|
+
return self._launch(*args, **kwargs)
|
380
|
+
|
381
|
+
|
382
|
+
class _AliasRestorer(ast.NodeTransformer):
|
383
|
+
def __init__(self):
|
384
|
+
super().__init__()
|
385
|
+
|
386
|
+
self._aliases = {}
|
387
|
+
self._redefined = set()
|
388
|
+
|
389
|
+
def visit_Import(self, node):
|
390
|
+
for alias in node.names:
|
391
|
+
if alias.asname:
|
392
|
+
self._aliases[alias.asname] = alias.name
|
393
|
+
|
394
|
+
return node
|
395
|
+
|
396
|
+
def visit_ImportFrom(self, node):
|
397
|
+
for alias in node.names:
|
398
|
+
full_name = f"{node.module}.{alias.name}"
|
399
|
+
if alias.asname:
|
400
|
+
self._aliases[alias.asname] = full_name
|
401
|
+
|
402
|
+
return node
|
403
|
+
|
404
|
+
def visit_Assign(self, node):
|
405
|
+
for target in node.targets:
|
406
|
+
if isinstance(target, ast.Name):
|
407
|
+
self._redefined.add(target.id)
|
408
|
+
|
409
|
+
return self.generic_visit(node)
|
410
|
+
|
411
|
+
def visit_FunctionDef(self, node):
|
412
|
+
original_redefined = self._redefined.copy()
|
413
|
+
|
414
|
+
self.generic_visit(node)
|
415
|
+
|
416
|
+
self._redefined = original_redefined
|
417
|
+
|
418
|
+
return node
|
419
|
+
|
420
|
+
def visit_Name(self, node):
|
421
|
+
if node.id in self._redefined:
|
422
|
+
return node
|
423
|
+
|
424
|
+
if node.id in self._aliases:
|
425
|
+
return ast.Name(id=self._aliases[node.id], ctx=node.ctx)
|
426
|
+
|
427
|
+
return node
|
428
|
+
|
429
|
+
|
430
|
+
class _FunctionDefFinder(ast.NodeVisitor):
|
431
|
+
def __init__(self, name):
|
432
|
+
self._name = name
|
433
|
+
|
434
|
+
self.result = None
|
435
|
+
|
436
|
+
def visit_FunctionDef(self, node):
|
437
|
+
if node.name == self._name:
|
438
|
+
self.result = node
|
439
|
+
|
440
|
+
self.generic_visit(node)
|
ninetoothed/symbol.py
CHANGED
@@ -34,24 +34,47 @@ class Symbol:
|
|
34
34
|
self._node.id = type(self)._create_constexpr(self._node.id)
|
35
35
|
|
36
36
|
def __add__(self, other):
|
37
|
-
|
38
|
-
|
39
|
-
)
|
37
|
+
other = type(self)(other)
|
38
|
+
|
39
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 0:
|
40
|
+
return other
|
41
|
+
|
42
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 0:
|
43
|
+
return self
|
44
|
+
|
45
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
|
40
46
|
|
41
47
|
def __radd__(self, other):
|
42
48
|
return self.__add__(other)
|
43
49
|
|
44
50
|
def __mul__(self, other):
|
45
|
-
|
46
|
-
|
47
|
-
)
|
51
|
+
other = type(self)(other)
|
52
|
+
|
53
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 0:
|
54
|
+
return type(self)(0)
|
55
|
+
|
56
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 0:
|
57
|
+
return type(self)(0)
|
58
|
+
|
59
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 1:
|
60
|
+
return other
|
61
|
+
|
62
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 1:
|
63
|
+
return self
|
64
|
+
|
65
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
|
48
66
|
|
49
67
|
def __rmul__(self, other):
|
50
68
|
return self.__mul__(other)
|
51
69
|
|
52
70
|
def __floordiv__(self, other):
|
71
|
+
other = type(self)(other)
|
72
|
+
|
73
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 1:
|
74
|
+
return self
|
75
|
+
|
53
76
|
return type(self)(
|
54
|
-
ast.BinOp(left=self._node, op=ast.FloorDiv(), right=
|
77
|
+
ast.BinOp(left=self._node, op=ast.FloorDiv(), right=other._node)
|
55
78
|
)
|
56
79
|
|
57
80
|
def __mod__(self, other):
|
ninetoothed/tensor.py
CHANGED
@@ -103,11 +103,12 @@ class Tensor:
|
|
103
103
|
indices = self.indices()
|
104
104
|
|
105
105
|
if not isinstance(self.dtype, type(self)):
|
106
|
-
if indices:
|
106
|
+
if len(indices) != self.ndim():
|
107
107
|
raise IndexError("Incorrect number of indices.")
|
108
108
|
|
109
109
|
return sum(
|
110
|
-
|
110
|
+
indices[idx]
|
111
|
+
* self.stride(idx)
|
111
112
|
* call("arange", 0, self.size(idx))[
|
112
113
|
tuple(slice(None) if i == idx else None for i in range(self.ndim()))
|
113
114
|
]
|
@@ -131,8 +132,21 @@ class Tensor:
|
|
131
132
|
indices.append(index // stride)
|
132
133
|
index %= stride
|
133
134
|
|
135
|
+
curr = self.dtype
|
136
|
+
while isinstance(curr, type(self)):
|
137
|
+
indices.extend(
|
138
|
+
0 if curr is not self.inmost() else 1 for _ in range(curr.ndim())
|
139
|
+
)
|
140
|
+
curr = curr.dtype
|
141
|
+
|
134
142
|
return tuple(indices)
|
135
143
|
|
144
|
+
def inmost(self):
|
145
|
+
if not isinstance(self.dtype, type(self)):
|
146
|
+
return self
|
147
|
+
|
148
|
+
return self.dtype.inmost()
|
149
|
+
|
136
150
|
def ndim(self):
|
137
151
|
return len(self.shape)
|
138
152
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -0,0 +1,10 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
+
ninetoothed/jit.py,sha256=hmUzkFZzsiKLgOHbsN0MAr1G1JCiyQ22cFPtmyZ1OyE,12725
|
3
|
+
ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
|
4
|
+
ninetoothed/symbol.py,sha256=I2Mc9D1w7AYAIQtyAXyDQ-FBqowVZrd-PK-JOt_SpgA,3787
|
5
|
+
ninetoothed/tensor.py,sha256=RfwYzdYASkr6usJklESm1n8RoxvYjWnPtCjIfipa2fg,5000
|
6
|
+
ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
|
7
|
+
ninetoothed-0.2.0.dist-info/METADATA,sha256=w6qkc2riniG0N4nDUCUkZWF8Eve3j5brBQHIWIEqLXQ,5422
|
8
|
+
ninetoothed-0.2.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
+
ninetoothed-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
+
ninetoothed-0.2.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=DdRdZ7DhfZwJeS7AcO_RhD9TZcCebKI55V4_6UHs3bo,10523
|
3
|
-
ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
|
4
|
-
ninetoothed/symbol.py,sha256=8BI4ekeLuUdHTEREvMMlAzwrJ93pqiCdSHGc38clBFA,3034
|
5
|
-
ninetoothed/tensor.py,sha256=o_HLEuaBzojmbMLnbPGLcw4iqBI34TNdES3YLTagztE,4590
|
6
|
-
ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
|
7
|
-
ninetoothed-0.1.1.dist-info/METADATA,sha256=1Nv6Xcz7CrpEUrzAYH93bYVX8GfPtHwzj4yofeaoJro,5422
|
8
|
-
ninetoothed-0.1.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|