ninetoothed 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +118 -13
- ninetoothed/naming.py +50 -0
- ninetoothed/symbol.py +39 -44
- ninetoothed/tensor.py +37 -14
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.7.0.dist-info}/METADATA +9 -2
- ninetoothed-0.7.0.dist-info/RECORD +11 -0
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.7.0.dist-info}/WHEEL +1 -1
- ninetoothed-0.6.0.dist-info/RECORD +0 -10
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.7.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -5,27 +5,37 @@ import importlib.util
|
|
5
5
|
import inspect
|
6
6
|
import itertools
|
7
7
|
import math
|
8
|
+
import subprocess
|
8
9
|
import sys
|
9
10
|
import tempfile
|
10
11
|
|
11
12
|
import triton
|
12
13
|
|
14
|
+
import ninetoothed.naming as naming
|
13
15
|
from ninetoothed.language import attribute, call
|
14
16
|
from ninetoothed.symbol import Symbol
|
15
17
|
from ninetoothed.tensor import Tensor
|
16
18
|
from ninetoothed.torchifier import Torchifier
|
17
19
|
|
18
20
|
|
19
|
-
def jit(
|
20
|
-
|
21
|
+
def jit(_func=None, *, _prettify=False):
|
22
|
+
def wrapper(func):
|
23
|
+
return JIT(func, _prettify=_prettify)()
|
24
|
+
|
25
|
+
if _func is None:
|
26
|
+
return wrapper
|
27
|
+
|
28
|
+
return wrapper(_func)
|
21
29
|
|
22
30
|
|
23
31
|
class JIT:
|
24
32
|
handles = collections.defaultdict(dict)
|
25
33
|
|
26
|
-
def __init__(self, func):
|
34
|
+
def __init__(self, func, _prettify=False):
|
27
35
|
self.func = func
|
28
36
|
|
37
|
+
self._prettify = _prettify
|
38
|
+
|
29
39
|
def __call__(self):
|
30
40
|
source_file = inspect.getsourcefile(self.func)
|
31
41
|
source_line = inspect.getsourcelines(self.func)[1]
|
@@ -40,12 +50,26 @@ class JIT:
|
|
40
50
|
|
41
51
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
42
52
|
Tritonizer().visit(tree)
|
53
|
+
_BinOpSimplifier().visit(tree)
|
43
54
|
ast.fix_missing_locations(tree)
|
44
55
|
|
56
|
+
if self._prettify:
|
57
|
+
name_collector = _SimplifiedNameCollector()
|
58
|
+
name_collector.visit(tree)
|
59
|
+
|
45
60
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
46
61
|
dependencies = self._find_dependencies()
|
47
62
|
source = "\n\n".join((unparsed, dependencies)).strip()
|
48
63
|
|
64
|
+
if self._prettify:
|
65
|
+
for original, simplified in name_collector.simplified_names.items():
|
66
|
+
if simplified not in name_collector.simplified_names:
|
67
|
+
source = source.replace(original, simplified)
|
68
|
+
|
69
|
+
source = subprocess.check_output(
|
70
|
+
["ruff", "format", "-"], input=source, encoding="utf-8"
|
71
|
+
)
|
72
|
+
|
49
73
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
|
50
74
|
temp_file.write(source.encode("utf-8"))
|
51
75
|
temp_file_name = temp_file.name
|
@@ -67,10 +91,12 @@ class JIT:
|
|
67
91
|
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
68
92
|
|
69
93
|
_AliasRestorer().visit(module)
|
94
|
+
collector = _ImportCollector()
|
95
|
+
collector.visit(module)
|
70
96
|
finder = _FunctionDefFinder(self.func.__name__)
|
71
97
|
finder.visit(module)
|
72
98
|
|
73
|
-
return ast.Module(body=[finder.result], type_ignores=[])
|
99
|
+
return ast.Module(body=collector.imports + [finder.result], type_ignores=[])
|
74
100
|
|
75
101
|
def _find_dependencies(self):
|
76
102
|
dependencies = set()
|
@@ -147,12 +173,17 @@ class CodeGenerator(ast.NodeTransformer):
|
|
147
173
|
|
148
174
|
names_of_args = [arg.names() - {"ninetoothed"} for arg in self._args]
|
149
175
|
names = functools.reduce(lambda x, y: x | y, names_of_args)
|
150
|
-
meta_names = {name for name in names if
|
176
|
+
meta_names = {name for name in names if naming.is_meta(name)}
|
151
177
|
non_meta_names = {name for name in names if name not in meta_names}
|
178
|
+
non_meta_names |= {
|
179
|
+
naming.make_next_power_of_2(name)
|
180
|
+
for name in non_meta_names
|
181
|
+
if naming.is_constexpr(name)
|
182
|
+
}
|
152
183
|
|
153
184
|
node.args = [
|
154
185
|
ast.arg(arg=name)
|
155
|
-
if not
|
186
|
+
if not naming.is_constexpr(name)
|
156
187
|
else ast.arg(arg=name, annotation=attribute("constexpr").node)
|
157
188
|
for name in non_meta_names
|
158
189
|
] + [
|
@@ -303,9 +334,20 @@ class CodeGenerator(ast.NodeTransformer):
|
|
303
334
|
)
|
304
335
|
|
305
336
|
def _generate_launch(self, params, meta):
|
306
|
-
|
307
|
-
|
308
|
-
|
337
|
+
non_next_power_of_2_constexpr_params = [
|
338
|
+
param
|
339
|
+
for param in params
|
340
|
+
if naming.is_constexpr(param) and not naming.is_next_power_of_2(param)
|
341
|
+
]
|
342
|
+
non_next_power_of_2_constexpr_params_without_prefixes = [
|
343
|
+
naming.remove_prefixes(param)
|
344
|
+
for param in non_next_power_of_2_constexpr_params
|
345
|
+
]
|
346
|
+
next_power_of_2_params = [
|
347
|
+
param for param in params if naming.is_next_power_of_2(param)
|
348
|
+
]
|
349
|
+
next_power_of_2_params_without_prefixes = [
|
350
|
+
naming.remove_prefixes(param) for param in next_power_of_2_params
|
309
351
|
]
|
310
352
|
|
311
353
|
launch = ast.FunctionDef(
|
@@ -313,17 +355,33 @@ class CodeGenerator(ast.NodeTransformer):
|
|
313
355
|
args=ast.arguments(
|
314
356
|
posonlyargs=[],
|
315
357
|
args=[ast.arg(arg=arg.original.name) for arg in self._args]
|
316
|
-
+ [
|
358
|
+
+ [
|
359
|
+
ast.arg(arg=param)
|
360
|
+
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
361
|
+
],
|
317
362
|
kwonlyargs=[],
|
318
363
|
defaults=[],
|
319
364
|
),
|
320
365
|
body=[
|
321
366
|
ast.Assign(
|
322
367
|
targets=[ast.Name(id=param, ctx=ast.Store())],
|
323
|
-
value=ast.Name(id=
|
368
|
+
value=ast.Name(id=param_without_prefixes, ctx=ast.Load()),
|
324
369
|
)
|
325
|
-
for param,
|
326
|
-
|
370
|
+
for param, param_without_prefixes in zip(
|
371
|
+
non_next_power_of_2_constexpr_params,
|
372
|
+
non_next_power_of_2_constexpr_params_without_prefixes,
|
373
|
+
)
|
374
|
+
]
|
375
|
+
+ [
|
376
|
+
ast.Assign(
|
377
|
+
targets=[ast.Name(id=param, ctx=ast.Store())],
|
378
|
+
value=Symbol(
|
379
|
+
f"triton.next_power_of_2({param_without_prefixes})"
|
380
|
+
).node,
|
381
|
+
)
|
382
|
+
for param, param_without_prefixes in zip(
|
383
|
+
next_power_of_2_params,
|
384
|
+
next_power_of_2_params_without_prefixes,
|
327
385
|
)
|
328
386
|
]
|
329
387
|
+ [
|
@@ -477,6 +535,36 @@ class Tritonizer(ast.NodeTransformer):
|
|
477
535
|
return node
|
478
536
|
|
479
537
|
|
538
|
+
class _BinOpSimplifier(ast.NodeTransformer):
|
539
|
+
def visit_BinOp(self, node):
|
540
|
+
self.generic_visit(node)
|
541
|
+
|
542
|
+
if isinstance(node.op, ast.Mult):
|
543
|
+
left = Symbol(node.left)
|
544
|
+
right = Symbol(node.right)
|
545
|
+
|
546
|
+
if left == 0 or right == 0:
|
547
|
+
return Symbol(0).node
|
548
|
+
|
549
|
+
if left == 1:
|
550
|
+
return node.right
|
551
|
+
|
552
|
+
if right == 1:
|
553
|
+
return node.left
|
554
|
+
|
555
|
+
return node
|
556
|
+
|
557
|
+
|
558
|
+
class _SimplifiedNameCollector(ast.NodeVisitor):
|
559
|
+
def __init__(self):
|
560
|
+
self.simplified_names = {}
|
561
|
+
|
562
|
+
def visit_Name(self, node):
|
563
|
+
self.generic_visit(node)
|
564
|
+
|
565
|
+
self.simplified_names[node.id] = naming.remove_prefixes(node.id)
|
566
|
+
|
567
|
+
|
480
568
|
class _Handle:
|
481
569
|
def __init__(self, kernel, launch, source):
|
482
570
|
self._kernel = kernel
|
@@ -535,6 +623,23 @@ class _AliasRestorer(ast.NodeTransformer):
|
|
535
623
|
return node
|
536
624
|
|
537
625
|
|
626
|
+
class _ImportCollector(ast.NodeVisitor):
|
627
|
+
def __init__(self):
|
628
|
+
super().__init__()
|
629
|
+
|
630
|
+
self.imports = []
|
631
|
+
|
632
|
+
def visit_Import(self, node):
|
633
|
+
self.imports.append(node)
|
634
|
+
|
635
|
+
self.generic_visit(node)
|
636
|
+
|
637
|
+
def visit_ImportFrom(self, node):
|
638
|
+
self.imports.append(node)
|
639
|
+
|
640
|
+
self.generic_visit(node)
|
641
|
+
|
642
|
+
|
538
643
|
class _FunctionDefFinder(ast.NodeVisitor):
|
539
644
|
def __init__(self, name):
|
540
645
|
self._name = name
|
ninetoothed/naming.py
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
|
4
|
+
def make_constexpr(name):
|
5
|
+
return _add_prefix(name, _CONSTEXPR)
|
6
|
+
|
7
|
+
|
8
|
+
def make_meta(name):
|
9
|
+
return _add_prefix(name, _META)
|
10
|
+
|
11
|
+
|
12
|
+
def make_next_power_of_2(name):
|
13
|
+
return _add_prefix(name, _NEXT_POWER_OF_2)
|
14
|
+
|
15
|
+
|
16
|
+
def is_constexpr(name):
|
17
|
+
return _CONSTEXPR in _find_prefixes(name) or is_meta(name)
|
18
|
+
|
19
|
+
|
20
|
+
def is_meta(name):
|
21
|
+
return _META in _find_prefixes(name)
|
22
|
+
|
23
|
+
|
24
|
+
def is_next_power_of_2(name):
|
25
|
+
return _NEXT_POWER_OF_2 in _find_prefixes(name)
|
26
|
+
|
27
|
+
|
28
|
+
def remove_prefixes(name):
|
29
|
+
return _PREFIX_PATTERN.sub("", name)
|
30
|
+
|
31
|
+
|
32
|
+
_CONSTEXPR = "constexpr"
|
33
|
+
|
34
|
+
_META = "meta"
|
35
|
+
|
36
|
+
_NEXT_POWER_OF_2 = "next_power_of_2"
|
37
|
+
|
38
|
+
_PREFIX_PATTERN = re.compile(r"ninetoothed_((?!_).*?)_prefix_")
|
39
|
+
|
40
|
+
|
41
|
+
def _add_prefix(name, string):
|
42
|
+
return f"{_make_prefix(string)}{name}"
|
43
|
+
|
44
|
+
|
45
|
+
def _make_prefix(string):
|
46
|
+
return f"ninetoothed_{string}_prefix_"
|
47
|
+
|
48
|
+
|
49
|
+
def _find_prefixes(name):
|
50
|
+
return set(_PREFIX_PATTERN.findall(name))
|
ninetoothed/symbol.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
import ast
|
2
2
|
import inspect
|
3
|
+
import numbers
|
3
4
|
import types
|
4
5
|
|
6
|
+
import ninetoothed.naming as naming
|
7
|
+
|
5
8
|
|
6
9
|
class Symbol:
|
7
10
|
def __init__(self, expr, constexpr=None, meta=None):
|
@@ -28,18 +31,31 @@ class Symbol:
|
|
28
31
|
if constexpr is False:
|
29
32
|
raise ValueError("Non-constexpr meta symbol is not supported.")
|
30
33
|
|
31
|
-
self._node.id =
|
34
|
+
self._node.id = naming.make_meta(self._node.id)
|
32
35
|
|
33
36
|
if constexpr:
|
34
|
-
self._node.id =
|
37
|
+
self._node.id = naming.make_constexpr(self._node.id)
|
38
|
+
|
39
|
+
def __eq__(self, other):
|
40
|
+
if isinstance(self._node, ast.Constant):
|
41
|
+
if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
|
42
|
+
return self._node.value == other._node.value
|
43
|
+
|
44
|
+
if isinstance(other, numbers.Number):
|
45
|
+
return self._node.value == other
|
46
|
+
|
47
|
+
return False
|
48
|
+
|
49
|
+
def __hash__(self):
|
50
|
+
return id(self)
|
35
51
|
|
36
52
|
def __add__(self, other):
|
37
53
|
other = type(self)(other)
|
38
54
|
|
39
|
-
if
|
55
|
+
if self == 0:
|
40
56
|
return other
|
41
57
|
|
42
|
-
if
|
58
|
+
if other == 0:
|
43
59
|
return self
|
44
60
|
|
45
61
|
return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
|
@@ -47,19 +63,30 @@ class Symbol:
|
|
47
63
|
def __radd__(self, other):
|
48
64
|
return self.__add__(other)
|
49
65
|
|
50
|
-
def
|
66
|
+
def __sub__(self, other):
|
51
67
|
other = type(self)(other)
|
52
68
|
|
53
|
-
if
|
54
|
-
return
|
69
|
+
if self == 0:
|
70
|
+
return -other
|
55
71
|
|
56
|
-
if
|
72
|
+
if other == 0:
|
73
|
+
return self
|
74
|
+
|
75
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Sub(), right=other._node))
|
76
|
+
|
77
|
+
def __rsub__(self, other):
|
78
|
+
return type(self)(other).__sub__(self)
|
79
|
+
|
80
|
+
def __mul__(self, other):
|
81
|
+
other = type(self)(other)
|
82
|
+
|
83
|
+
if self == 0 or other == 0:
|
57
84
|
return type(self)(0)
|
58
85
|
|
59
|
-
if
|
86
|
+
if self == 1:
|
60
87
|
return other
|
61
88
|
|
62
|
-
if
|
89
|
+
if other == 1:
|
63
90
|
return self
|
64
91
|
|
65
92
|
return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
|
@@ -136,40 +163,8 @@ class Symbol:
|
|
136
163
|
return SliceSimplifier().visit(self._node)
|
137
164
|
|
138
165
|
@staticmethod
|
139
|
-
def
|
140
|
-
return
|
141
|
-
|
142
|
-
@staticmethod
|
143
|
-
def is_meta(name):
|
144
|
-
return name.startswith(Symbol._meta_prefix())
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def remove_prefix(name):
|
148
|
-
if name.startswith(Symbol._constexpr_prefix()):
|
149
|
-
return name.removeprefix(Symbol._constexpr_prefix())
|
150
|
-
|
151
|
-
if name.startswith(Symbol._meta_prefix()):
|
152
|
-
return name.removeprefix(Symbol._meta_prefix())
|
153
|
-
|
154
|
-
@staticmethod
|
155
|
-
def _create_constexpr(name):
|
156
|
-
return f"{Symbol._constexpr_prefix()}{name}"
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def _create_meta(name):
|
160
|
-
return f"{Symbol._meta_prefix()}{name}"
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def _constexpr_prefix():
|
164
|
-
return f"{Symbol._ninetoothed_prefix()}constexpr_"
|
165
|
-
|
166
|
-
@staticmethod
|
167
|
-
def _meta_prefix():
|
168
|
-
return f"{Symbol._ninetoothed_prefix()}meta_"
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def _ninetoothed_prefix():
|
172
|
-
return "_ninetoothed_"
|
166
|
+
def is_name(object):
|
167
|
+
return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
|
173
168
|
|
174
169
|
|
175
170
|
class _FindAndReplacer(ast.NodeTransformer):
|
ninetoothed/tensor.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import itertools
|
2
2
|
import re
|
3
3
|
|
4
|
+
import ninetoothed.naming as naming
|
4
5
|
from ninetoothed.language import call
|
5
6
|
from ninetoothed.symbol import Symbol
|
6
7
|
|
@@ -15,13 +16,15 @@ class Tensor:
|
|
15
16
|
dtype=None,
|
16
17
|
strides=None,
|
17
18
|
other=None,
|
19
|
+
name=None,
|
18
20
|
original=None,
|
19
21
|
):
|
20
|
-
type(self).num_instances += 1
|
21
|
-
|
22
22
|
self.dtype = dtype
|
23
23
|
|
24
|
-
|
24
|
+
if name is not None:
|
25
|
+
self.name = name
|
26
|
+
else:
|
27
|
+
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
|
25
28
|
|
26
29
|
if ndim is not None:
|
27
30
|
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
|
@@ -41,29 +44,41 @@ class Tensor:
|
|
41
44
|
else:
|
42
45
|
self.original = self
|
43
46
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
+
type(self).num_instances += 1
|
48
|
+
|
49
|
+
def tile(self, tile_shape, strides=None, dilation=None):
|
50
|
+
if strides is None:
|
51
|
+
strides = [-1 for _ in tile_shape]
|
52
|
+
|
53
|
+
if dilation is None:
|
54
|
+
dilation = [1 for _ in tile_shape]
|
47
55
|
|
48
56
|
outer_shape = []
|
49
57
|
outer_strides = []
|
50
58
|
inner_shape = []
|
51
59
|
inner_strides = []
|
52
60
|
|
53
|
-
for
|
54
|
-
self.shape, self.strides, tile_shape,
|
61
|
+
for self_size, self_stride, tile_size, stride, spacing in zip(
|
62
|
+
self.shape, self.strides, tile_shape, strides, dilation
|
55
63
|
):
|
56
64
|
if tile_size == -1:
|
57
|
-
tile_size =
|
65
|
+
tile_size = self_size
|
66
|
+
|
67
|
+
if stride == -1:
|
68
|
+
stride = tile_size
|
58
69
|
|
59
|
-
new_size =
|
70
|
+
new_size = (
|
71
|
+
call("cdiv", self_size - spacing * (tile_size - 1) - 1, stride) + 1
|
72
|
+
if stride != 0
|
73
|
+
else -1
|
74
|
+
)
|
60
75
|
outer_shape.append(new_size)
|
61
76
|
|
62
|
-
new_stride =
|
77
|
+
new_stride = self_stride * stride // spacing
|
63
78
|
outer_strides.append(new_stride)
|
64
79
|
|
65
80
|
inner_shape.append(tile_size)
|
66
|
-
next_stride =
|
81
|
+
next_stride = self_stride * spacing
|
67
82
|
inner_strides.append(next_stride)
|
68
83
|
|
69
84
|
return type(self)(
|
@@ -115,6 +130,7 @@ class Tensor:
|
|
115
130
|
for name in value.names()
|
116
131
|
}
|
117
132
|
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
133
|
+
| (self.original.names() if self.original is not self else set())
|
118
134
|
)
|
119
135
|
|
120
136
|
def offsets(self, indices=None):
|
@@ -163,7 +179,14 @@ class Tensor:
|
|
163
179
|
|
164
180
|
if isinstance(curr, type(self)):
|
165
181
|
for dim in range(curr.ndim):
|
166
|
-
|
182
|
+
size = curr.shape[dim]
|
183
|
+
|
184
|
+
if Symbol.is_name(size):
|
185
|
+
name = size.node.id
|
186
|
+
if not naming.is_meta(name):
|
187
|
+
size = naming.make_next_power_of_2(name)
|
188
|
+
|
189
|
+
indices.append(call("arange", 0, size))
|
167
190
|
|
168
191
|
return tuple(indices)
|
169
192
|
|
@@ -240,7 +263,7 @@ class Tensor:
|
|
240
263
|
def _calculate_default_strides(shape):
|
241
264
|
strides = [1]
|
242
265
|
|
243
|
-
for size in shape[1:]:
|
266
|
+
for size in reversed(shape[1:]):
|
244
267
|
strides.append(size * strides[-1])
|
245
268
|
|
246
269
|
return reversed(strides)
|
@@ -1,11 +1,10 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.7.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
7
7
|
Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
|
8
|
-
License-File: LICENSE
|
9
8
|
Classifier: License :: OSI Approved :: Apache Software License
|
10
9
|
Classifier: Operating System :: OS Independent
|
11
10
|
Classifier: Programming Language :: Python :: 3
|
@@ -51,6 +50,8 @@ def add_kernel(
|
|
51
50
|
|
52
51
|
In this code, we first define `BLOCK_SIZE`, which is a `Symbol`. You can think of `"BLOCK_SIZE"` as its name. We see that `meta` is set to `True`, indicating to the compiler that it is a meta-parameter and its value can be determined by the compiler. The `Tensor(1)` constructs a one-dimensional tensor (vector), and `Tensor(1).tile((BLOCK_SIZE,))` means we want to create a vector and divide it into blocks of size `BLOCK_SIZE`. Suppose the size of this vector is `8192` and `BLOCK_SIZE` is `1024`, then the vector will be divided into `8` blocks, each of size `1024`.
|
53
52
|
|
53
|
+

|
54
|
+
|
54
55
|
By using type annotations, we tell the compiler that we will have three tensor parameters, which will be divided into blocks, and `x`, `y`, and `z` are these blocks. It's important to understand that `x`, `y`, and `z` are the blocks, not the tensors themselves. In the function body, `x`, `y`, and `z` are also the blocks. The rest is straightforward (only one line `z = x + y` left, haha), we add each block of `x` and `y` and store it in `z`. Since each block of the parameter tensors undergoes this operation, the addition is completed for the whole tensors as well.
|
55
56
|
|
56
57
|
### Matrix Multiplication
|
@@ -82,4 +83,10 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
82
83
|
|
83
84
|
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
84
85
|
|
86
|
+

|
87
|
+
|
85
88
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
89
|
+
|
90
|
+
## License
|
91
|
+
|
92
|
+
This project is distributed under the Apache-2.0 license. See the included [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
+
ninetoothed/jit.py,sha256=y0y3gfcBNeRLhozIzLuDHLLzBAGrL8tLk8Rcvhw_uec,20068
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=KL6Iw2nwaRnZsNTxOeghDeXUQUgphnVh_1fsmAObOtI,7391
|
7
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
+
ninetoothed-0.7.0.dist-info/METADATA,sha256=gRLAzPSdxWYff8SqPuwJ3ji3Z_kPq_vtEupnSe7dIQ0,7032
|
9
|
+
ninetoothed-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
10
|
+
ninetoothed-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.7.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=5gNp4HixCkural_Ns3DxwT4LL3OUcG0ECj4NLjb-EYk,16959
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
|
5
|
-
ninetoothed/tensor.py,sha256=L-9LhwnM4uRtRvj3tqrzerUijEfKeTQvFBcmS1hQilI,6656
|
6
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
-
ninetoothed-0.6.0.dist-info/METADATA,sha256=zvY4nvKt7R8kWDYrGnApem_C07trLgOj1-7zXPfqD9U,6785
|
8
|
-
ninetoothed-0.6.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.6.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.6.0.dist-info/RECORD,,
|
File without changes
|