ninetoothed 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +206 -43
- ninetoothed/language.py +4 -1
- ninetoothed/symbol.py +63 -8
- ninetoothed/tensor.py +105 -39
- ninetoothed/torchifier.py +15 -11
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.3.0.dist-info}/METADATA +9 -5
- ninetoothed-0.3.0.dist-info/RECORD +10 -0
- ninetoothed-0.1.1.dist-info/RECORD +0 -10
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.3.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.1.1.dist-info → ninetoothed-0.3.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -5,7 +5,8 @@ import inspect
|
|
5
5
|
import itertools
|
6
6
|
import math
|
7
7
|
import tempfile
|
8
|
-
|
8
|
+
|
9
|
+
import triton
|
9
10
|
|
10
11
|
from ninetoothed.language import attribute, call
|
11
12
|
from ninetoothed.symbol import Symbol
|
@@ -33,8 +34,7 @@ class JIT:
|
|
33
34
|
):
|
34
35
|
return type(self).handles[source_file][source_line]
|
35
36
|
|
36
|
-
|
37
|
-
tree = ast.parse(source)
|
37
|
+
tree = self._get_tree()
|
38
38
|
|
39
39
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
40
40
|
Tritonizer().visit(tree)
|
@@ -56,15 +56,7 @@ class JIT:
|
|
56
56
|
namespace = {}
|
57
57
|
exec(code, namespace)
|
58
58
|
|
59
|
-
|
60
|
-
def __init__(self, kernel, launch):
|
61
|
-
self._kernel = kernel
|
62
|
-
self._launch = launch
|
63
|
-
|
64
|
-
def __call__(self, *args, **kwargs):
|
65
|
-
return self._launch(*args, **kwargs)
|
66
|
-
|
67
|
-
handle = Handle(
|
59
|
+
handle = _Handle(
|
68
60
|
namespace[self.func.__name__],
|
69
61
|
namespace[f"launch_{self.func.__name__}"],
|
70
62
|
)
|
@@ -73,6 +65,15 @@ class JIT:
|
|
73
65
|
|
74
66
|
return handle
|
75
67
|
|
68
|
+
def _get_tree(self):
|
69
|
+
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
70
|
+
|
71
|
+
_AliasRestorer().visit(module)
|
72
|
+
finder = _FunctionDefFinder(self.func.__name__)
|
73
|
+
finder.visit(module)
|
74
|
+
|
75
|
+
return ast.Module(body=[finder.result], type_ignores=[])
|
76
|
+
|
76
77
|
|
77
78
|
class CodeGenerator(ast.NodeTransformer):
|
78
79
|
def __init__(self, context):
|
@@ -100,6 +101,29 @@ class CodeGenerator(ast.NodeTransformer):
|
|
100
101
|
|
101
102
|
self.generic_visit(node)
|
102
103
|
|
104
|
+
for arg in self._args:
|
105
|
+
if not isinstance(arg, Tensor):
|
106
|
+
continue
|
107
|
+
|
108
|
+
offsets = arg.offsets()
|
109
|
+
|
110
|
+
initializations = {
|
111
|
+
type(self)._name_for_offsets(arg, dim): offs
|
112
|
+
for dim, offs in enumerate(offsets)
|
113
|
+
} | {
|
114
|
+
type(self)._name_for_pointers(arg): arg.original.pointer_string()
|
115
|
+
+ sum(
|
116
|
+
type(self)._name_for_offsets(arg, dim)[
|
117
|
+
type(self)._generate_slices(arg, dim)
|
118
|
+
]
|
119
|
+
* stride
|
120
|
+
for dim, stride in enumerate(arg.original.strides)
|
121
|
+
)
|
122
|
+
}
|
123
|
+
|
124
|
+
for target, value in reversed(initializations.items()):
|
125
|
+
node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
|
126
|
+
|
103
127
|
return node
|
104
128
|
|
105
129
|
def visit_arguments(self, node):
|
@@ -136,14 +160,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
136
160
|
value = self._context[node.value.id]
|
137
161
|
|
138
162
|
if isinstance(value, Tensor):
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
return call("load", pointers).node
|
163
|
+
return type(self)._generate_load(
|
164
|
+
value,
|
165
|
+
intermediate_indices=node.slice.elts
|
166
|
+
if isinstance(node.slice, ast.Tuple)
|
167
|
+
else (node.slice,),
|
168
|
+
)
|
147
169
|
|
148
170
|
self.generic_visit(node)
|
149
171
|
|
@@ -166,7 +188,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
166
188
|
self.generic_visit(node)
|
167
189
|
|
168
190
|
if node.id in self._context and isinstance(node.ctx, ast.Load):
|
169
|
-
return
|
191
|
+
return type(self)._generate_load(self._context[node.id])
|
170
192
|
|
171
193
|
return node
|
172
194
|
|
@@ -178,11 +200,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
178
200
|
self.generic_visit(node)
|
179
201
|
|
180
202
|
return ast.Expr(
|
181
|
-
|
182
|
-
"store",
|
183
|
-
self._context[target.id].pointers().node,
|
184
|
-
node.value,
|
185
|
-
).node
|
203
|
+
type(self)._generate_store(self._context[target.id], node.value)
|
186
204
|
)
|
187
205
|
elif (
|
188
206
|
isinstance(target, ast.Subscript)
|
@@ -195,20 +213,14 @@ class CodeGenerator(ast.NodeTransformer):
|
|
195
213
|
if isinstance(value, Tensor):
|
196
214
|
self.generic_visit(node)
|
197
215
|
|
198
|
-
indices = value.indices() + tuple(
|
199
|
-
target.slice.elts
|
200
|
-
if isinstance(target.slice, ast.Tuple)
|
201
|
-
else target.slice
|
202
|
-
)
|
203
|
-
offsets = value.offsets(indices)
|
204
|
-
pointers = value.pointers(offsets)
|
205
|
-
|
206
216
|
return ast.Expr(
|
207
|
-
|
208
|
-
|
209
|
-
pointers.node,
|
217
|
+
type(self)._generate_store(
|
218
|
+
value,
|
210
219
|
node.value,
|
211
|
-
|
220
|
+
intermediate_indices=target.slice.elts
|
221
|
+
if isinstance(target.slice, ast.Tuple)
|
222
|
+
else (target.slice,),
|
223
|
+
)
|
212
224
|
)
|
213
225
|
|
214
226
|
self.generic_visit(node)
|
@@ -216,6 +228,13 @@ class CodeGenerator(ast.NodeTransformer):
|
|
216
228
|
return node
|
217
229
|
|
218
230
|
def _generate_autotune(self, params, meta):
|
231
|
+
device = triton.runtime.driver.active.get_current_device()
|
232
|
+
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
233
|
+
max_shared_mem = properties["max_shared_mem"]
|
234
|
+
|
235
|
+
num_warps = 8
|
236
|
+
num_stages = max_shared_mem // 2**15
|
237
|
+
|
219
238
|
configs = [
|
220
239
|
ast.Call(
|
221
240
|
func=ast.Attribute(
|
@@ -229,7 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
229
248
|
values=[ast.Constant(value=value) for value in values],
|
230
249
|
)
|
231
250
|
],
|
232
|
-
keywords=[
|
251
|
+
keywords=[
|
252
|
+
ast.keyword(arg="num_warps", value=ast.Constant(value=num_warps)),
|
253
|
+
ast.keyword(arg="num_stages", value=ast.Constant(value=num_stages)),
|
254
|
+
],
|
233
255
|
)
|
234
256
|
for values in itertools.product(self._POWER_OF_TWOS, repeat=len(meta))
|
235
257
|
if self._MIN_PRODUCT <= math.prod(values) <= self._MAX_PRODUCT
|
@@ -256,7 +278,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
256
278
|
elts=[
|
257
279
|
ast.Constant(value=param)
|
258
280
|
for param in params
|
259
|
-
if not Tensor.
|
281
|
+
if not Tensor.pointer_pattern().fullmatch(param)
|
260
282
|
],
|
261
283
|
ctx=ast.Load(),
|
262
284
|
),
|
@@ -269,7 +291,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
269
291
|
name=f"launch_{self._func_def.name}",
|
270
292
|
args=ast.arguments(
|
271
293
|
posonlyargs=[],
|
272
|
-
args=[ast.arg(arg.name) for arg in self._args],
|
294
|
+
args=[ast.arg(arg.original.name) for arg in self._args],
|
273
295
|
kwonlyargs=[],
|
274
296
|
defaults=[],
|
275
297
|
),
|
@@ -316,6 +338,77 @@ class CodeGenerator(ast.NodeTransformer):
|
|
316
338
|
|
317
339
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
318
340
|
|
341
|
+
@staticmethod
|
342
|
+
def _generate_load(tensor, intermediate_indices=()):
|
343
|
+
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
344
|
+
tensor, intermediate_indices
|
345
|
+
)
|
346
|
+
other = CodeGenerator._generate_other(tensor)
|
347
|
+
|
348
|
+
return call("load", pointers, mask=mask, other=other).node
|
349
|
+
|
350
|
+
@staticmethod
|
351
|
+
def _generate_store(tensor, value, intermediate_indices=()):
|
352
|
+
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
353
|
+
tensor, intermediate_indices
|
354
|
+
)
|
355
|
+
|
356
|
+
return call("store", pointers, value, mask=mask).node
|
357
|
+
|
358
|
+
@staticmethod
|
359
|
+
def _generate_pointers_and_mask(tensor, intermediate_indices):
|
360
|
+
intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
|
361
|
+
tensor, intermediate_indices
|
362
|
+
)
|
363
|
+
offsets = [
|
364
|
+
CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
|
365
|
+
for dim in range(tensor.original.ndim)
|
366
|
+
]
|
367
|
+
pointers = CodeGenerator._name_for_pointers(tensor) + sum(
|
368
|
+
map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
|
369
|
+
)
|
370
|
+
mask = functools.reduce(
|
371
|
+
lambda x, y: x & y,
|
372
|
+
(
|
373
|
+
offs[CodeGenerator._generate_slices(tensor, dim)] < size
|
374
|
+
for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
|
375
|
+
),
|
376
|
+
)
|
377
|
+
|
378
|
+
return pointers, mask
|
379
|
+
|
380
|
+
@staticmethod
|
381
|
+
def _generate_other(tensor):
|
382
|
+
other = tensor.original.other
|
383
|
+
|
384
|
+
if isinstance(other, float) and not math.isfinite(other):
|
385
|
+
return f"float('{other}')"
|
386
|
+
|
387
|
+
return other
|
388
|
+
|
389
|
+
@staticmethod
|
390
|
+
def _generate_slices(tensor, dim):
|
391
|
+
return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _generate_intermediate_offsets(tensor, intermediate_indices):
|
395
|
+
return tuple(
|
396
|
+
offs
|
397
|
+
for offs in tensor.offsets(
|
398
|
+
[0 for _ in range(tensor.ndim)]
|
399
|
+
+ list(intermediate_indices)
|
400
|
+
+ [0 for _ in range(tensor.inmost().ndim)]
|
401
|
+
)
|
402
|
+
)
|
403
|
+
|
404
|
+
@staticmethod
|
405
|
+
def _name_for_pointers(tensor):
|
406
|
+
return Symbol(f"{tensor.original.name}_pointers")
|
407
|
+
|
408
|
+
@staticmethod
|
409
|
+
def _name_for_offsets(tensor, dim):
|
410
|
+
return Symbol(f"{tensor.original.name}_offsets_{dim}")
|
411
|
+
|
319
412
|
|
320
413
|
class Tritonizer(ast.NodeTransformer):
|
321
414
|
def visit_Module(self, node):
|
@@ -329,8 +422,8 @@ class Tritonizer(ast.NodeTransformer):
|
|
329
422
|
def visit_Name(self, node):
|
330
423
|
self.generic_visit(node)
|
331
424
|
|
332
|
-
if node.id == "ninetoothed":
|
333
|
-
node.id = "triton"
|
425
|
+
if node.id == "ninetoothed" or "ninetoothed." in node.id:
|
426
|
+
node.id = node.id.replace("ninetoothed", "triton")
|
334
427
|
|
335
428
|
return node
|
336
429
|
|
@@ -348,3 +441,73 @@ class Tritonizer(ast.NodeTransformer):
|
|
348
441
|
)
|
349
442
|
|
350
443
|
return node
|
444
|
+
|
445
|
+
|
446
|
+
class _Handle:
|
447
|
+
def __init__(self, kernel, launch):
|
448
|
+
self._kernel = kernel
|
449
|
+
self._launch = launch
|
450
|
+
|
451
|
+
def __call__(self, *args, **kwargs):
|
452
|
+
return self._launch(*args, **kwargs)
|
453
|
+
|
454
|
+
|
455
|
+
class _AliasRestorer(ast.NodeTransformer):
|
456
|
+
def __init__(self):
|
457
|
+
super().__init__()
|
458
|
+
|
459
|
+
self._aliases = {}
|
460
|
+
self._redefined = set()
|
461
|
+
|
462
|
+
def visit_Import(self, node):
|
463
|
+
for alias in node.names:
|
464
|
+
if alias.asname:
|
465
|
+
self._aliases[alias.asname] = alias.name
|
466
|
+
|
467
|
+
return node
|
468
|
+
|
469
|
+
def visit_ImportFrom(self, node):
|
470
|
+
for alias in node.names:
|
471
|
+
full_name = f"{node.module}.{alias.name}"
|
472
|
+
if alias.asname:
|
473
|
+
self._aliases[alias.asname] = full_name
|
474
|
+
|
475
|
+
return node
|
476
|
+
|
477
|
+
def visit_Assign(self, node):
|
478
|
+
for target in node.targets:
|
479
|
+
if isinstance(target, ast.Name):
|
480
|
+
self._redefined.add(target.id)
|
481
|
+
|
482
|
+
return self.generic_visit(node)
|
483
|
+
|
484
|
+
def visit_FunctionDef(self, node):
|
485
|
+
original_redefined = self._redefined.copy()
|
486
|
+
|
487
|
+
self.generic_visit(node)
|
488
|
+
|
489
|
+
self._redefined = original_redefined
|
490
|
+
|
491
|
+
return node
|
492
|
+
|
493
|
+
def visit_Name(self, node):
|
494
|
+
if node.id in self._redefined:
|
495
|
+
return node
|
496
|
+
|
497
|
+
if node.id in self._aliases:
|
498
|
+
return ast.Name(id=self._aliases[node.id], ctx=node.ctx)
|
499
|
+
|
500
|
+
return node
|
501
|
+
|
502
|
+
|
503
|
+
class _FunctionDefFinder(ast.NodeVisitor):
|
504
|
+
def __init__(self, name):
|
505
|
+
self._name = name
|
506
|
+
|
507
|
+
self.result = None
|
508
|
+
|
509
|
+
def visit_FunctionDef(self, node):
|
510
|
+
if node.name == self._name:
|
511
|
+
self.result = node
|
512
|
+
|
513
|
+
self.generic_visit(node)
|
ninetoothed/language.py
CHANGED
@@ -10,7 +10,10 @@ def call(func, *args, **kwargs):
|
|
10
10
|
ast.Call(
|
11
11
|
func=attribute(func).node,
|
12
12
|
args=[Symbol(arg).node for arg in args],
|
13
|
-
keywords=[
|
13
|
+
keywords=[
|
14
|
+
ast.keyword(arg=kwarg, value=Symbol(kwargs[kwarg]).node)
|
15
|
+
for kwarg in kwargs
|
16
|
+
],
|
14
17
|
)
|
15
18
|
)
|
16
19
|
|
ninetoothed/symbol.py
CHANGED
@@ -34,37 +34,80 @@ class Symbol:
|
|
34
34
|
self._node.id = type(self)._create_constexpr(self._node.id)
|
35
35
|
|
36
36
|
def __add__(self, other):
|
37
|
-
|
38
|
-
|
39
|
-
)
|
37
|
+
other = type(self)(other)
|
38
|
+
|
39
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 0:
|
40
|
+
return other
|
41
|
+
|
42
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 0:
|
43
|
+
return self
|
44
|
+
|
45
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
|
40
46
|
|
41
47
|
def __radd__(self, other):
|
42
48
|
return self.__add__(other)
|
43
49
|
|
44
50
|
def __mul__(self, other):
|
45
|
-
|
46
|
-
|
47
|
-
)
|
51
|
+
other = type(self)(other)
|
52
|
+
|
53
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 0:
|
54
|
+
return type(self)(0)
|
55
|
+
|
56
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 0:
|
57
|
+
return type(self)(0)
|
58
|
+
|
59
|
+
if isinstance(self._node, ast.Constant) and self._node.value == 1:
|
60
|
+
return other
|
61
|
+
|
62
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 1:
|
63
|
+
return self
|
64
|
+
|
65
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
|
48
66
|
|
49
67
|
def __rmul__(self, other):
|
50
68
|
return self.__mul__(other)
|
51
69
|
|
52
70
|
def __floordiv__(self, other):
|
71
|
+
other = type(self)(other)
|
72
|
+
|
73
|
+
if isinstance(other._node, ast.Constant) and other._node.value == 1:
|
74
|
+
return self
|
75
|
+
|
53
76
|
return type(self)(
|
54
|
-
ast.BinOp(left=self._node, op=ast.FloorDiv(), right=
|
77
|
+
ast.BinOp(left=self._node, op=ast.FloorDiv(), right=other._node)
|
55
78
|
)
|
56
79
|
|
57
80
|
def __mod__(self, other):
|
81
|
+
other = type(self)(other)
|
82
|
+
|
83
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Mod(), right=other._node))
|
84
|
+
|
85
|
+
def __lt__(self, other):
|
86
|
+
other = type(self)(other)
|
87
|
+
|
88
|
+
return type(self)(
|
89
|
+
ast.Compare(left=self._node, ops=[ast.Lt()], comparators=[other._node])
|
90
|
+
)
|
91
|
+
|
92
|
+
def __and__(self, other):
|
93
|
+
other = type(self)(other)
|
94
|
+
|
58
95
|
return type(self)(
|
59
|
-
ast.BinOp(left=self._node, op=ast.
|
96
|
+
ast.BinOp(left=self._node, op=ast.BitAnd(), right=other._node)
|
60
97
|
)
|
61
98
|
|
99
|
+
def __rand__(self, other):
|
100
|
+
return self.__and__(other)
|
101
|
+
|
62
102
|
def __getitem__(self, key):
|
63
103
|
return type(self)(ast.Subscript(value=self._node, slice=type(self)(key)._node))
|
64
104
|
|
65
105
|
def __repr__(self):
|
66
106
|
return ast.unparse(self._node)
|
67
107
|
|
108
|
+
def find_and_replace(self, target, replacement):
|
109
|
+
_FindAndReplacer(target.node, replacement.node).visit(self._node)
|
110
|
+
|
68
111
|
def names(self):
|
69
112
|
class NameCollector(ast.NodeVisitor):
|
70
113
|
def __init__(self):
|
@@ -107,3 +150,15 @@ class Symbol:
|
|
107
150
|
@staticmethod
|
108
151
|
def _create_meta(name):
|
109
152
|
return f"_ninetoothed_meta_{name}"
|
153
|
+
|
154
|
+
|
155
|
+
class _FindAndReplacer(ast.NodeTransformer):
|
156
|
+
def __init__(self, target, replacement):
|
157
|
+
self._target_id = target.id
|
158
|
+
self._replacement = replacement
|
159
|
+
|
160
|
+
def visit_Name(self, node):
|
161
|
+
if node.id == self._target_id:
|
162
|
+
return self._replacement
|
163
|
+
|
164
|
+
return self.generic_visit(node)
|
ninetoothed/tensor.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import itertools
|
2
|
+
import re
|
2
3
|
|
3
4
|
from ninetoothed.language import call
|
4
5
|
from ninetoothed.symbol import Symbol
|
@@ -7,19 +8,24 @@ from ninetoothed.symbol import Symbol
|
|
7
8
|
class Tensor:
|
8
9
|
num_instances = 0
|
9
10
|
|
10
|
-
def __init__(
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
ndim=None,
|
14
|
+
shape=None,
|
15
|
+
dtype=None,
|
16
|
+
strides=None,
|
17
|
+
other=None,
|
18
|
+
original=None,
|
19
|
+
):
|
11
20
|
type(self).num_instances += 1
|
12
21
|
|
13
22
|
self.dtype = dtype
|
14
23
|
|
15
|
-
|
16
|
-
self.name = name
|
17
|
-
else:
|
18
|
-
self.name = f"tensor_{type(self).num_instances}"
|
24
|
+
self.name = f"tensor_{type(self).num_instances}"
|
19
25
|
|
20
26
|
if ndim is not None:
|
21
|
-
self.shape = [Symbol(
|
22
|
-
self.strides = [Symbol(
|
27
|
+
self.shape = [Symbol(self.size_string(i)) for i in range(ndim)]
|
28
|
+
self.strides = [Symbol(self.stride_string(i)) for i in range(ndim)]
|
23
29
|
else:
|
24
30
|
self.shape = shape
|
25
31
|
|
@@ -28,6 +34,13 @@ class Tensor:
|
|
28
34
|
else:
|
29
35
|
self.strides = self._calculate_default_strides(shape)
|
30
36
|
|
37
|
+
self.other = other
|
38
|
+
|
39
|
+
if original is not None:
|
40
|
+
self.original = original
|
41
|
+
else:
|
42
|
+
self.original = self
|
43
|
+
|
31
44
|
def tile(self, tile_shape, tile_strides=None):
|
32
45
|
if tile_strides is None:
|
33
46
|
tile_strides = [1 for _ in tile_shape]
|
@@ -59,10 +72,10 @@ class Tensor:
|
|
59
72
|
shape=inner_shape,
|
60
73
|
dtype=self.dtype,
|
61
74
|
strides=inner_strides,
|
62
|
-
|
75
|
+
original=self.original,
|
63
76
|
),
|
64
77
|
strides=outer_strides,
|
65
|
-
|
78
|
+
original=self.original,
|
66
79
|
)
|
67
80
|
|
68
81
|
def expand(self, shape):
|
@@ -77,12 +90,21 @@ class Tensor:
|
|
77
90
|
stride if new_size == -1 else 0
|
78
91
|
for new_size, stride in zip(shape, self.strides)
|
79
92
|
],
|
80
|
-
|
93
|
+
original=self.original,
|
94
|
+
)
|
95
|
+
|
96
|
+
def squeeze(self, dim):
|
97
|
+
# TODO: Add error handling.
|
98
|
+
return type(self)(
|
99
|
+
shape=[size for i, size in enumerate(self.shape) if dim != i],
|
100
|
+
dtype=self.dtype,
|
101
|
+
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
|
102
|
+
original=self.original,
|
81
103
|
)
|
82
104
|
|
83
105
|
def names(self):
|
84
106
|
return (
|
85
|
-
{self.
|
107
|
+
{self.original.pointer_string()}
|
86
108
|
| {
|
87
109
|
name
|
88
110
|
for value in itertools.chain(self.shape, self.strides)
|
@@ -92,34 +114,31 @@ class Tensor:
|
|
92
114
|
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
93
115
|
)
|
94
116
|
|
95
|
-
def pointers(self, offsets=None):
|
96
|
-
if offsets is None:
|
97
|
-
offsets = self.offsets()
|
98
|
-
|
99
|
-
return self._pointer() + offsets
|
100
|
-
|
101
117
|
def offsets(self, indices=None):
|
102
118
|
if indices is None:
|
103
119
|
indices = self.indices()
|
104
120
|
|
105
|
-
|
106
|
-
|
107
|
-
|
121
|
+
offsets = [[] for _ in range(self.original.ndim)]
|
122
|
+
|
123
|
+
curr = self
|
124
|
+
start = 0
|
125
|
+
|
126
|
+
while isinstance(curr, type(self)):
|
127
|
+
stop = start + curr.ndim
|
128
|
+
curr_indices = indices[start:stop]
|
129
|
+
|
130
|
+
for index, stride in zip(curr_indices, curr.strides):
|
131
|
+
for dim in self._dims_of(stride):
|
132
|
+
offsets[dim].append(index * stride)
|
108
133
|
|
109
|
-
|
110
|
-
|
111
|
-
* call("arange", 0, self.size(idx))[
|
112
|
-
tuple(slice(None) if i == idx else None for i in range(self.ndim()))
|
113
|
-
]
|
114
|
-
for idx in range(self.ndim())
|
115
|
-
)
|
134
|
+
start = stop
|
135
|
+
curr = curr.dtype
|
116
136
|
|
117
|
-
|
118
|
-
|
137
|
+
for dim in range(self.original.ndim):
|
138
|
+
offsets[dim] = sum(offsets[dim])
|
139
|
+
offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
|
119
140
|
|
120
|
-
return
|
121
|
-
index * stride for index, stride in zip(outer_indices, self.strides)
|
122
|
-
) + self.dtype.offsets(inner_indices)
|
141
|
+
return offsets
|
123
142
|
|
124
143
|
def indices(self, index=None):
|
125
144
|
if index is None:
|
@@ -127,14 +146,38 @@ class Tensor:
|
|
127
146
|
|
128
147
|
indices = []
|
129
148
|
|
130
|
-
for stride in type(self)(shape=self.shape,
|
149
|
+
for stride in type(self)(shape=self.shape, original=self.original).strides:
|
131
150
|
indices.append(index // stride)
|
132
151
|
index %= stride
|
133
152
|
|
153
|
+
curr = self.dtype
|
154
|
+
|
155
|
+
while isinstance(curr.dtype, type(self)):
|
156
|
+
for _ in range(curr.ndim):
|
157
|
+
indices.append(0)
|
158
|
+
|
159
|
+
curr = curr.dtype
|
160
|
+
|
161
|
+
if isinstance(curr, type(self)):
|
162
|
+
for dim in range(curr.ndim):
|
163
|
+
indices.append(call("arange", 0, curr.shape[dim]))
|
164
|
+
|
134
165
|
return tuple(indices)
|
135
166
|
|
136
|
-
def
|
137
|
-
|
167
|
+
def inmost(self):
|
168
|
+
if not isinstance(self.dtype, type(self)):
|
169
|
+
return self
|
170
|
+
|
171
|
+
return self.dtype.inmost()
|
172
|
+
|
173
|
+
def pointer_string(self):
|
174
|
+
return f"{self.name}_pointer"
|
175
|
+
|
176
|
+
def size_string(self, dim):
|
177
|
+
return f"{self.name}_size_{dim}"
|
178
|
+
|
179
|
+
def stride_string(self, dim):
|
180
|
+
return f"{self.name}_stride_{dim}"
|
138
181
|
|
139
182
|
def size(self, dim=None):
|
140
183
|
if dim is None:
|
@@ -148,12 +191,31 @@ class Tensor:
|
|
148
191
|
|
149
192
|
return self.strides[dim]
|
150
193
|
|
194
|
+
@property
|
195
|
+
def ndim(self):
|
196
|
+
return len(self.shape)
|
197
|
+
|
151
198
|
@staticmethod
|
152
|
-
def
|
153
|
-
return
|
199
|
+
def pointer_pattern():
|
200
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
154
201
|
|
155
|
-
|
156
|
-
|
202
|
+
@staticmethod
|
203
|
+
def size_pattern():
|
204
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(size)_(.+)")
|
205
|
+
|
206
|
+
@staticmethod
|
207
|
+
def stride_pattern():
|
208
|
+
return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
|
209
|
+
|
210
|
+
def _dims_of(self, stride):
|
211
|
+
dims = set()
|
212
|
+
names = stride.names() if isinstance(stride, Symbol) else {stride}
|
213
|
+
|
214
|
+
for dim, original_stride in enumerate(self.original.strides):
|
215
|
+
if str(original_stride) in names:
|
216
|
+
dims.add(dim)
|
217
|
+
|
218
|
+
return dims
|
157
219
|
|
158
220
|
@staticmethod
|
159
221
|
def _calculate_default_strides(shape):
|
@@ -163,3 +225,7 @@ class Tensor:
|
|
163
225
|
strides.append(size * strides[-1])
|
164
226
|
|
165
227
|
return reversed(strides)
|
228
|
+
|
229
|
+
|
230
|
+
def _identifier_pattern_raw_string():
|
231
|
+
return r"[a-zA-Z_][a-zA-Z0-9_]*"
|
ninetoothed/torchifier.py
CHANGED
@@ -1,23 +1,27 @@
|
|
1
1
|
import ast
|
2
|
-
|
2
|
+
|
3
|
+
from ninetoothed.tensor import Tensor
|
3
4
|
|
4
5
|
|
5
6
|
class Torchifier(ast.NodeTransformer):
|
6
7
|
def visit_Name(self, node):
|
7
8
|
self.generic_visit(node)
|
8
9
|
|
9
|
-
|
10
|
+
source = node.id
|
11
|
+
|
12
|
+
def repl(match):
|
13
|
+
return f"{match.group(1)}"
|
14
|
+
|
15
|
+
source = Tensor.pointer_pattern().sub(repl, source)
|
16
|
+
|
17
|
+
def repl(match):
|
18
|
+
return f"{match.group(1)}.{match.group(2)}({match.group(3)})"
|
10
19
|
|
11
|
-
|
20
|
+
source = Tensor.size_pattern().sub(repl, source)
|
21
|
+
source = Tensor.stride_pattern().sub(repl, source)
|
12
22
|
|
13
|
-
if
|
14
|
-
return ast.parse(
|
15
|
-
pattern.sub(
|
16
|
-
lambda match: f"{match.group(1)}.{match.group(2)}({match.group(3)})",
|
17
|
-
node.id,
|
18
|
-
),
|
19
|
-
mode="eval",
|
20
|
-
).body
|
23
|
+
if source != node.id:
|
24
|
+
return ast.parse(source, mode="eval").body
|
21
25
|
|
22
26
|
return node
|
23
27
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -10,6 +10,7 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
|
+
Requires-Dist: triton>=3.0.0
|
13
14
|
Description-Content-Type: text/markdown
|
14
15
|
|
15
16
|
# NineToothed
|
@@ -26,7 +27,7 @@ We can use `pip` to install `ninetoothed`.
|
|
26
27
|
pip install ninetoothed
|
27
28
|
```
|
28
29
|
|
29
|
-
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install
|
30
|
+
After successfully running the above command, `ninetoothed` will be installed. However, to fully utilize its capabilities, you also need to install a deep learning framework supported by `ninetoothed`. For trial purposes, we recommend installing `torch`.
|
30
31
|
|
31
32
|
## Usage
|
32
33
|
|
@@ -64,16 +65,19 @@ c_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
|
|
64
65
|
a_tiled = a_tiled.expand((-1, c_tiled.shape[1]))
|
65
66
|
b_tiled = b_tiled.expand((c_tiled.shape[0], -1))
|
66
67
|
|
68
|
+
a_tiled.dtype = a_tiled.dtype.squeeze(0)
|
69
|
+
b_tiled.dtype = b_tiled.dtype.squeeze(1)
|
70
|
+
|
67
71
|
@ninetoothed.jit
|
68
72
|
def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
69
73
|
accumulator = ninetoothed.language.zeros(
|
70
74
|
c.shape, dtype=ninetoothed.language.float32
|
71
75
|
)
|
72
|
-
for k in range(a.shape[
|
73
|
-
accumulator
|
76
|
+
for k in range(a.shape[0]):
|
77
|
+
accumulator += ninetoothed.language.dot(a[k], b[k])
|
74
78
|
c = accumulator.to(ninetoothed.language.float16)
|
75
79
|
```
|
76
80
|
|
77
|
-
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$.
|
81
|
+
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
78
82
|
|
79
83
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of B, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
@@ -0,0 +1,10 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
+
ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
|
5
|
+
ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
|
6
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
+
ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
|
8
|
+
ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
+
ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
+
ninetoothed-0.3.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=DdRdZ7DhfZwJeS7AcO_RhD9TZcCebKI55V4_6UHs3bo,10523
|
3
|
-
ninetoothed/language.py,sha256=cSuTgi5OwmLFy-dy_AHGZzRm18wz01ByHQ2vioP1vTg,437
|
4
|
-
ninetoothed/symbol.py,sha256=8BI4ekeLuUdHTEREvMMlAzwrJ93pqiCdSHGc38clBFA,3034
|
5
|
-
ninetoothed/tensor.py,sha256=o_HLEuaBzojmbMLnbPGLcw4iqBI34TNdES3YLTagztE,4590
|
6
|
-
ninetoothed/torchifier.py,sha256=JmIVQE8r0zr_RLExsRDOGNsMu0F7v6J_o22aWqlw81k,841
|
7
|
-
ninetoothed-0.1.1.dist-info/METADATA,sha256=1Nv6Xcz7CrpEUrzAYH93bYVX8GfPtHwzj4yofeaoJro,5422
|
8
|
-
ninetoothed-0.1.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|