ninetoothed 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/__init__.py +2 -2
- ninetoothed/jit.py +304 -90
- ninetoothed/naming.py +50 -0
- ninetoothed/symbol.py +39 -44
- ninetoothed/tensor.py +162 -78
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.8.0.dist-info}/METADATA +10 -2
- ninetoothed-0.8.0.dist-info/RECORD +11 -0
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.8.0.dist-info}/WHEEL +1 -1
- ninetoothed-0.6.0.dist-info/RECORD +0 -10
- {ninetoothed-0.6.0.dist-info → ninetoothed-0.8.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/__init__.py
CHANGED
ninetoothed/jit.py
CHANGED
@@ -1,31 +1,51 @@
|
|
1
1
|
import ast
|
2
2
|
import collections
|
3
|
+
import copy
|
3
4
|
import functools
|
4
5
|
import importlib.util
|
5
6
|
import inspect
|
6
7
|
import itertools
|
7
8
|
import math
|
9
|
+
import subprocess
|
8
10
|
import sys
|
9
11
|
import tempfile
|
10
12
|
|
11
13
|
import triton
|
12
14
|
|
15
|
+
import ninetoothed.naming as naming
|
13
16
|
from ninetoothed.language import attribute, call
|
14
17
|
from ninetoothed.symbol import Symbol
|
15
18
|
from ninetoothed.tensor import Tensor
|
16
19
|
from ninetoothed.torchifier import Torchifier
|
17
20
|
|
18
21
|
|
19
|
-
def
|
20
|
-
|
22
|
+
def make(arrangement, application, tensors):
|
23
|
+
params = inspect.signature(application).parameters
|
24
|
+
types = arrangement(*tensors)
|
25
|
+
annotations = {param: type for param, type in zip(params, types)}
|
26
|
+
application.__annotations__ = annotations
|
27
|
+
|
28
|
+
return jit(application)
|
29
|
+
|
30
|
+
|
31
|
+
def jit(_func=None, *, _prettify=False):
|
32
|
+
def wrapper(func):
|
33
|
+
return JIT(func, _prettify=_prettify)()
|
34
|
+
|
35
|
+
if _func is None:
|
36
|
+
return wrapper
|
37
|
+
|
38
|
+
return wrapper(_func)
|
21
39
|
|
22
40
|
|
23
41
|
class JIT:
|
24
42
|
handles = collections.defaultdict(dict)
|
25
43
|
|
26
|
-
def __init__(self, func):
|
44
|
+
def __init__(self, func, _prettify=False):
|
27
45
|
self.func = func
|
28
46
|
|
47
|
+
self._prettify = _prettify
|
48
|
+
|
29
49
|
def __call__(self):
|
30
50
|
source_file = inspect.getsourcefile(self.func)
|
31
51
|
source_line = inspect.getsourcelines(self.func)[1]
|
@@ -40,12 +60,26 @@ class JIT:
|
|
40
60
|
|
41
61
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
42
62
|
Tritonizer().visit(tree)
|
63
|
+
_BinOpSimplifier().visit(tree)
|
43
64
|
ast.fix_missing_locations(tree)
|
44
65
|
|
66
|
+
if self._prettify:
|
67
|
+
name_collector = _SimplifiedNameCollector()
|
68
|
+
name_collector.visit(tree)
|
69
|
+
|
45
70
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
46
71
|
dependencies = self._find_dependencies()
|
47
72
|
source = "\n\n".join((unparsed, dependencies)).strip()
|
48
73
|
|
74
|
+
if self._prettify:
|
75
|
+
for original, simplified in name_collector.simplified_names.items():
|
76
|
+
if simplified not in name_collector.simplified_names:
|
77
|
+
source = source.replace(original, simplified)
|
78
|
+
|
79
|
+
source = subprocess.check_output(
|
80
|
+
["ruff", "format", "-"], input=source, encoding="utf-8"
|
81
|
+
)
|
82
|
+
|
49
83
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
|
50
84
|
temp_file.write(source.encode("utf-8"))
|
51
85
|
temp_file_name = temp_file.name
|
@@ -67,10 +101,12 @@ class JIT:
|
|
67
101
|
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
68
102
|
|
69
103
|
_AliasRestorer().visit(module)
|
104
|
+
collector = _ImportCollector()
|
105
|
+
collector.visit(module)
|
70
106
|
finder = _FunctionDefFinder(self.func.__name__)
|
71
107
|
finder.visit(module)
|
72
108
|
|
73
|
-
return ast.Module(body=[finder.result], type_ignores=[])
|
109
|
+
return ast.Module(body=collector.imports + [finder.result], type_ignores=[])
|
74
110
|
|
75
111
|
def _find_dependencies(self):
|
76
112
|
dependencies = set()
|
@@ -115,30 +151,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
115
151
|
def visit_FunctionDef(self, node):
|
116
152
|
self._func_def = node
|
117
153
|
|
118
|
-
self.
|
154
|
+
self._invariants = {}
|
119
155
|
|
120
|
-
|
121
|
-
if not isinstance(arg, Tensor) or arg.ndim == 0:
|
122
|
-
continue
|
123
|
-
|
124
|
-
offsets = arg.offsets()
|
125
|
-
|
126
|
-
initializations = {
|
127
|
-
type(self)._name_for_offsets(arg, dim): offs
|
128
|
-
for dim, offs in enumerate(offsets)
|
129
|
-
} | {
|
130
|
-
type(self)._name_for_pointers(arg): arg.original.pointer_string()
|
131
|
-
+ sum(
|
132
|
-
type(self)._name_for_offsets(arg, dim)[
|
133
|
-
type(self)._generate_slices(arg, dim)
|
134
|
-
]
|
135
|
-
* stride
|
136
|
-
for dim, stride in enumerate(arg.original.strides)
|
137
|
-
)
|
138
|
-
}
|
156
|
+
self.generic_visit(node)
|
139
157
|
|
140
|
-
|
141
|
-
|
158
|
+
for target, value in reversed(self._invariants.items()):
|
159
|
+
node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
|
142
160
|
|
143
161
|
return node
|
144
162
|
|
@@ -147,12 +165,17 @@ class CodeGenerator(ast.NodeTransformer):
|
|
147
165
|
|
148
166
|
names_of_args = [arg.names() - {"ninetoothed"} for arg in self._args]
|
149
167
|
names = functools.reduce(lambda x, y: x | y, names_of_args)
|
150
|
-
meta_names = {name for name in names if
|
168
|
+
meta_names = {name for name in names if naming.is_meta(name)}
|
151
169
|
non_meta_names = {name for name in names if name not in meta_names}
|
170
|
+
non_meta_names |= {
|
171
|
+
naming.make_next_power_of_2(name)
|
172
|
+
for name in non_meta_names
|
173
|
+
if naming.is_constexpr(name)
|
174
|
+
}
|
152
175
|
|
153
176
|
node.args = [
|
154
177
|
ast.arg(arg=name)
|
155
|
-
if not
|
178
|
+
if not naming.is_constexpr(name)
|
156
179
|
else ast.arg(arg=name, annotation=attribute("constexpr").node)
|
157
180
|
for name in non_meta_names
|
158
181
|
] + [
|
@@ -161,24 +184,20 @@ class CodeGenerator(ast.NodeTransformer):
|
|
161
184
|
]
|
162
185
|
|
163
186
|
autotune = self._generate_autotune(non_meta_names, meta_names)
|
164
|
-
self._func_def.decorator_list
|
187
|
+
self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
|
165
188
|
|
166
189
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
167
190
|
|
168
191
|
return node
|
169
192
|
|
170
193
|
def visit_Subscript(self, node):
|
171
|
-
if (
|
172
|
-
isinstance(node.value, ast.Name)
|
173
|
-
and node.value.id in self._context
|
174
|
-
and isinstance(node.ctx, ast.Load)
|
175
|
-
):
|
194
|
+
if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
|
176
195
|
value = self._context[node.value.id]
|
177
196
|
|
178
197
|
if isinstance(value, Tensor):
|
179
|
-
return
|
198
|
+
return self._generate_load(
|
180
199
|
value,
|
181
|
-
|
200
|
+
indices=node.slice.elts
|
182
201
|
if isinstance(node.slice, ast.Tuple)
|
183
202
|
else (node.slice,),
|
184
203
|
)
|
@@ -188,7 +207,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
188
207
|
return node
|
189
208
|
|
190
209
|
def visit_Attribute(self, node):
|
191
|
-
if
|
210
|
+
if self._in_context(node.value):
|
192
211
|
value = self._context[node.value.id]
|
193
212
|
|
194
213
|
if isinstance(value, Tensor):
|
@@ -203,8 +222,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
203
222
|
def visit_Name(self, node):
|
204
223
|
self.generic_visit(node)
|
205
224
|
|
206
|
-
if
|
207
|
-
return
|
225
|
+
if self._in_context(node) and isinstance(node.ctx, ast.Load):
|
226
|
+
return self._generate_load(self._context[node.id])
|
208
227
|
|
209
228
|
return node
|
210
229
|
|
@@ -212,16 +231,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
212
231
|
if len(node.targets) == 1:
|
213
232
|
target = node.targets[0]
|
214
233
|
|
215
|
-
if
|
234
|
+
if self._in_context(target):
|
216
235
|
self.generic_visit(node)
|
217
236
|
|
218
237
|
return ast.Expr(
|
219
|
-
|
238
|
+
self._generate_store(self._context[target.id], node.value)
|
220
239
|
)
|
221
240
|
elif (
|
222
241
|
isinstance(target, ast.Subscript)
|
223
|
-
and
|
224
|
-
and target.value.id in self._context
|
242
|
+
and self._in_context(target.value)
|
225
243
|
and isinstance(target.ctx, ast.Store)
|
226
244
|
):
|
227
245
|
value = self._context[target.value.id]
|
@@ -230,10 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
230
248
|
self.generic_visit(node)
|
231
249
|
|
232
250
|
return ast.Expr(
|
233
|
-
|
251
|
+
self._generate_store(
|
234
252
|
value,
|
235
253
|
node.value,
|
236
|
-
|
254
|
+
indices=target.slice.elts
|
237
255
|
if isinstance(target.slice, ast.Tuple)
|
238
256
|
else (target.slice,),
|
239
257
|
)
|
@@ -243,6 +261,11 @@ class CodeGenerator(ast.NodeTransformer):
|
|
243
261
|
|
244
262
|
return node
|
245
263
|
|
264
|
+
_NAME_FOR_PID = Symbol("ninetoothed_pid")
|
265
|
+
|
266
|
+
def _in_context(self, node):
|
267
|
+
return isinstance(node, ast.Name) and node.id in self._context
|
268
|
+
|
246
269
|
def _generate_autotune(self, params, meta):
|
247
270
|
device = triton.runtime.driver.active.get_current_device()
|
248
271
|
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
@@ -303,27 +326,54 @@ class CodeGenerator(ast.NodeTransformer):
|
|
303
326
|
)
|
304
327
|
|
305
328
|
def _generate_launch(self, params, meta):
|
306
|
-
|
307
|
-
|
308
|
-
|
329
|
+
non_next_power_of_2_constexpr_params = [
|
330
|
+
param
|
331
|
+
for param in params
|
332
|
+
if naming.is_constexpr(param) and not naming.is_next_power_of_2(param)
|
333
|
+
]
|
334
|
+
non_next_power_of_2_constexpr_params_without_prefixes = [
|
335
|
+
naming.remove_prefixes(param)
|
336
|
+
for param in non_next_power_of_2_constexpr_params
|
337
|
+
]
|
338
|
+
next_power_of_2_params = [
|
339
|
+
param for param in params if naming.is_next_power_of_2(param)
|
340
|
+
]
|
341
|
+
next_power_of_2_params_without_prefixes = [
|
342
|
+
naming.remove_prefixes(param) for param in next_power_of_2_params
|
309
343
|
]
|
310
344
|
|
311
345
|
launch = ast.FunctionDef(
|
312
346
|
name=f"launch_{self._func_def.name}",
|
313
347
|
args=ast.arguments(
|
314
348
|
posonlyargs=[],
|
315
|
-
args=[ast.arg(arg=arg.
|
316
|
-
+ [
|
349
|
+
args=[ast.arg(arg=arg.source.name) for arg in self._args]
|
350
|
+
+ [
|
351
|
+
ast.arg(arg=param)
|
352
|
+
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
353
|
+
],
|
317
354
|
kwonlyargs=[],
|
318
355
|
defaults=[],
|
319
356
|
),
|
320
357
|
body=[
|
321
358
|
ast.Assign(
|
322
359
|
targets=[ast.Name(id=param, ctx=ast.Store())],
|
323
|
-
value=ast.Name(id=
|
360
|
+
value=ast.Name(id=param_without_prefixes, ctx=ast.Load()),
|
361
|
+
)
|
362
|
+
for param, param_without_prefixes in zip(
|
363
|
+
non_next_power_of_2_constexpr_params,
|
364
|
+
non_next_power_of_2_constexpr_params_without_prefixes,
|
365
|
+
)
|
366
|
+
]
|
367
|
+
+ [
|
368
|
+
ast.Assign(
|
369
|
+
targets=[ast.Name(id=param, ctx=ast.Store())],
|
370
|
+
value=Symbol(
|
371
|
+
f"triton.next_power_of_2({param_without_prefixes})"
|
372
|
+
).node,
|
324
373
|
)
|
325
|
-
for param,
|
326
|
-
|
374
|
+
for param, param_without_prefixes in zip(
|
375
|
+
next_power_of_2_params,
|
376
|
+
next_power_of_2_params_without_prefixes,
|
327
377
|
)
|
328
378
|
]
|
329
379
|
+ [
|
@@ -369,51 +419,105 @@ class CodeGenerator(ast.NodeTransformer):
|
|
369
419
|
|
370
420
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
371
421
|
|
372
|
-
|
373
|
-
def _generate_load(tensor, intermediate_indices=()):
|
422
|
+
def _generate_load(self, tensor, indices=()):
|
374
423
|
if tensor.ndim == 0:
|
375
|
-
return Symbol(tensor.
|
424
|
+
return Symbol(tensor.source.name).node
|
376
425
|
|
377
|
-
pointers, mask =
|
378
|
-
|
379
|
-
)
|
380
|
-
other = CodeGenerator._generate_other(tensor)
|
426
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
427
|
+
other = type(self)._generate_other(tensor)
|
381
428
|
|
382
429
|
return call("load", pointers, mask=mask, other=other).node
|
383
430
|
|
384
|
-
|
385
|
-
|
386
|
-
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
387
|
-
tensor, intermediate_indices
|
388
|
-
)
|
431
|
+
def _generate_store(self, tensor, value, indices=()):
|
432
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
389
433
|
|
390
434
|
return call("store", pointers, value, mask=mask).node
|
391
435
|
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
)
|
397
|
-
|
398
|
-
|
399
|
-
for
|
400
|
-
|
401
|
-
|
402
|
-
|
436
|
+
def _generate_pointers_and_mask(self, tensor, indices):
|
437
|
+
invariant_target_dims = type(self)._find_invariant_target_dims(tensor)
|
438
|
+
|
439
|
+
indices = self._complete_indices(tensor, indices)
|
440
|
+
offsets = type(self)._generate_offsets(tensor, indices)
|
441
|
+
|
442
|
+
for source_dim in range(tensor.source.ndim):
|
443
|
+
for target_dim in range(tensor.target.ndim):
|
444
|
+
if target_dim not in invariant_target_dims:
|
445
|
+
continue
|
446
|
+
|
447
|
+
name = type(self)._name_for_offsets(tensor, source_dim, target_dim)
|
448
|
+
self._invariants[name] = offsets[source_dim][target_dim]
|
449
|
+
offsets[source_dim][target_dim] = name
|
450
|
+
|
451
|
+
name_for_pointers = type(self)._name_for_pointers(tensor)
|
452
|
+
self._invariants[name_for_pointers] = Symbol(tensor.source.pointer_string())
|
453
|
+
|
454
|
+
for source_dim in range(tensor.source.ndim):
|
455
|
+
for target_dim in range(tensor.target.ndim):
|
456
|
+
if target_dim not in invariant_target_dims:
|
457
|
+
continue
|
458
|
+
|
459
|
+
self._invariants[name_for_pointers] += (
|
460
|
+
offsets[source_dim][target_dim][
|
461
|
+
type(self)._generate_slices(tensor, target_dim)
|
462
|
+
]
|
463
|
+
* tensor.source.strides[source_dim]
|
464
|
+
)
|
465
|
+
|
466
|
+
pointers = name_for_pointers + sum(
|
467
|
+
offsets[source_dim][target_dim][
|
468
|
+
type(self)._generate_slices(tensor, target_dim)
|
469
|
+
]
|
470
|
+
* tensor.source.strides[source_dim]
|
471
|
+
for source_dim in range(tensor.source.ndim)
|
472
|
+
for target_dim in range(tensor.target.ndim)
|
473
|
+
if target_dim not in invariant_target_dims
|
474
|
+
and offsets[source_dim][target_dim] != 0
|
403
475
|
)
|
404
476
|
mask = functools.reduce(
|
405
477
|
lambda x, y: x & y,
|
406
478
|
(
|
407
|
-
|
408
|
-
|
479
|
+
offsets[source_dim][target_dim][
|
480
|
+
type(self)._generate_slices(tensor, target_dim)
|
481
|
+
]
|
482
|
+
< tensor.source.shape[source_dim]
|
483
|
+
for source_dim in range(tensor.source.ndim)
|
484
|
+
for target_dim in range(tensor.target.ndim)
|
485
|
+
if offsets[source_dim][target_dim] != 0
|
409
486
|
),
|
410
487
|
)
|
411
488
|
|
412
489
|
return pointers, mask
|
413
490
|
|
491
|
+
def _complete_indices(self, tensor, indices):
|
492
|
+
indices = list(self._generate_pid_indices(tensor) + indices)
|
493
|
+
|
494
|
+
for size in tensor.inmost().shape:
|
495
|
+
if Symbol.is_name(size):
|
496
|
+
name = size.node.id
|
497
|
+
if not naming.is_meta(name):
|
498
|
+
size = naming.make_next_power_of_2(name)
|
499
|
+
|
500
|
+
indices.append(call("arange", 0, size))
|
501
|
+
|
502
|
+
return tuple(indices)
|
503
|
+
|
504
|
+
def _generate_pid_indices(self, tensor):
|
505
|
+
self._invariants[type(self)._NAME_FOR_PID] = call("program_id", 0)
|
506
|
+
|
507
|
+
indices = list(
|
508
|
+
type(self)._unravel_index(type(self)._NAME_FOR_PID, tensor.shape)
|
509
|
+
)
|
510
|
+
|
511
|
+
for dim, index in enumerate(indices):
|
512
|
+
name = type(self)._name_for_index(tensor, dim)
|
513
|
+
self._invariants[name] = index
|
514
|
+
indices[dim] = name
|
515
|
+
|
516
|
+
return tuple(indices)
|
517
|
+
|
414
518
|
@staticmethod
|
415
519
|
def _generate_other(tensor):
|
416
|
-
other = tensor.
|
520
|
+
other = tensor.source.other
|
417
521
|
|
418
522
|
if isinstance(other, float) and not math.isfinite(other):
|
419
523
|
return f"float('{other}')"
|
@@ -425,23 +529,86 @@ class CodeGenerator(ast.NodeTransformer):
|
|
425
529
|
return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
|
426
530
|
|
427
531
|
@staticmethod
|
428
|
-
def
|
429
|
-
|
430
|
-
|
431
|
-
for offs in tensor.offsets(
|
432
|
-
[0 for _ in range(tensor.ndim)]
|
433
|
-
+ list(intermediate_indices)
|
434
|
-
+ [0 for _ in range(tensor.inmost().ndim)]
|
435
|
-
)
|
532
|
+
def _generate_offsets(tensor, indices):
|
533
|
+
offsets = collections.defaultdict(
|
534
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
436
535
|
)
|
437
536
|
|
537
|
+
curr = tensor
|
538
|
+
start = 0
|
539
|
+
|
540
|
+
while isinstance(curr, type(tensor)):
|
541
|
+
stop = start + curr.ndim
|
542
|
+
curr_indices = indices[start:stop]
|
543
|
+
|
544
|
+
for index, stride, source_dim, target_dim in zip(
|
545
|
+
curr_indices, curr.strides, curr.source_dims, curr.target_dims
|
546
|
+
):
|
547
|
+
offsets[source_dim][target_dim] += index * stride
|
548
|
+
|
549
|
+
start = stop
|
550
|
+
curr = curr.dtype
|
551
|
+
|
552
|
+
for source_dim in tuple(offsets):
|
553
|
+
for target_dim in tuple(offsets[source_dim]):
|
554
|
+
if not isinstance(source_dim, tuple):
|
555
|
+
continue
|
556
|
+
|
557
|
+
unraveled = CodeGenerator._unravel_index(
|
558
|
+
offsets[source_dim][target_dim],
|
559
|
+
tuple(tensor.source.shape[dim] for dim in source_dim),
|
560
|
+
)
|
561
|
+
|
562
|
+
for offs, dim in zip(unraveled, source_dim):
|
563
|
+
offsets[dim][target_dim] = offs
|
564
|
+
|
565
|
+
for source_dim in range(tensor.source.ndim):
|
566
|
+
for target_dim in range(tensor.target.ndim):
|
567
|
+
offsets[source_dim][target_dim] = copy.deepcopy(
|
568
|
+
offsets[source_dim][target_dim]
|
569
|
+
)
|
570
|
+
offsets[source_dim][target_dim].find_and_replace(
|
571
|
+
Symbol(tensor.source.strides[source_dim]), Symbol(1)
|
572
|
+
)
|
573
|
+
|
574
|
+
return offsets
|
575
|
+
|
576
|
+
@staticmethod
|
577
|
+
def _find_invariant_target_dims(tensor):
|
578
|
+
invariant_target_dims = set()
|
579
|
+
|
580
|
+
curr = tensor.dtype
|
581
|
+
|
582
|
+
while isinstance(curr.dtype, Tensor):
|
583
|
+
for target_dim in range(curr.target.ndim):
|
584
|
+
if target_dim not in curr.target_dims:
|
585
|
+
invariant_target_dims.add(target_dim)
|
586
|
+
|
587
|
+
curr = curr.dtype
|
588
|
+
|
589
|
+
return invariant_target_dims
|
590
|
+
|
438
591
|
@staticmethod
|
439
592
|
def _name_for_pointers(tensor):
|
440
|
-
return Symbol(f"{tensor.
|
593
|
+
return Symbol(f"{tensor.source.name}_pointers")
|
441
594
|
|
442
595
|
@staticmethod
|
443
|
-
def _name_for_offsets(tensor,
|
444
|
-
return Symbol(f"{tensor.
|
596
|
+
def _name_for_offsets(tensor, source_dim, target_dim):
|
597
|
+
return Symbol(f"{tensor.source.name}_offsets_{source_dim}_{target_dim}")
|
598
|
+
|
599
|
+
@staticmethod
|
600
|
+
def _name_for_index(tensor, dim):
|
601
|
+
return Symbol(f"{tensor.source.name}_index_{dim}")
|
602
|
+
|
603
|
+
@staticmethod
|
604
|
+
def _unravel_index(index, shape):
|
605
|
+
indices = []
|
606
|
+
|
607
|
+
for stride in Tensor(shape=shape).strides:
|
608
|
+
indices.append(index // stride)
|
609
|
+
index %= stride
|
610
|
+
|
611
|
+
return tuple(indices)
|
445
612
|
|
446
613
|
|
447
614
|
class Tritonizer(ast.NodeTransformer):
|
@@ -477,6 +644,36 @@ class Tritonizer(ast.NodeTransformer):
|
|
477
644
|
return node
|
478
645
|
|
479
646
|
|
647
|
+
class _BinOpSimplifier(ast.NodeTransformer):
|
648
|
+
def visit_BinOp(self, node):
|
649
|
+
self.generic_visit(node)
|
650
|
+
|
651
|
+
if isinstance(node.op, ast.Mult):
|
652
|
+
left = Symbol(node.left)
|
653
|
+
right = Symbol(node.right)
|
654
|
+
|
655
|
+
if left == 0 or right == 0:
|
656
|
+
return Symbol(0).node
|
657
|
+
|
658
|
+
if left == 1:
|
659
|
+
return node.right
|
660
|
+
|
661
|
+
if right == 1:
|
662
|
+
return node.left
|
663
|
+
|
664
|
+
return node
|
665
|
+
|
666
|
+
|
667
|
+
class _SimplifiedNameCollector(ast.NodeVisitor):
|
668
|
+
def __init__(self):
|
669
|
+
self.simplified_names = {}
|
670
|
+
|
671
|
+
def visit_Name(self, node):
|
672
|
+
self.generic_visit(node)
|
673
|
+
|
674
|
+
self.simplified_names[node.id] = naming.remove_prefixes(node.id)
|
675
|
+
|
676
|
+
|
480
677
|
class _Handle:
|
481
678
|
def __init__(self, kernel, launch, source):
|
482
679
|
self._kernel = kernel
|
@@ -535,6 +732,23 @@ class _AliasRestorer(ast.NodeTransformer):
|
|
535
732
|
return node
|
536
733
|
|
537
734
|
|
735
|
+
class _ImportCollector(ast.NodeVisitor):
|
736
|
+
def __init__(self):
|
737
|
+
super().__init__()
|
738
|
+
|
739
|
+
self.imports = []
|
740
|
+
|
741
|
+
def visit_Import(self, node):
|
742
|
+
self.imports.append(node)
|
743
|
+
|
744
|
+
self.generic_visit(node)
|
745
|
+
|
746
|
+
def visit_ImportFrom(self, node):
|
747
|
+
self.imports.append(node)
|
748
|
+
|
749
|
+
self.generic_visit(node)
|
750
|
+
|
751
|
+
|
538
752
|
class _FunctionDefFinder(ast.NodeVisitor):
|
539
753
|
def __init__(self, name):
|
540
754
|
self._name = name
|
ninetoothed/naming.py
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
|
4
|
+
def make_constexpr(name):
|
5
|
+
return _add_prefix(name, _CONSTEXPR)
|
6
|
+
|
7
|
+
|
8
|
+
def make_meta(name):
|
9
|
+
return _add_prefix(name, _META)
|
10
|
+
|
11
|
+
|
12
|
+
def make_next_power_of_2(name):
|
13
|
+
return _add_prefix(name, _NEXT_POWER_OF_2)
|
14
|
+
|
15
|
+
|
16
|
+
def is_constexpr(name):
|
17
|
+
return _CONSTEXPR in _find_prefixes(name) or is_meta(name)
|
18
|
+
|
19
|
+
|
20
|
+
def is_meta(name):
|
21
|
+
return _META in _find_prefixes(name)
|
22
|
+
|
23
|
+
|
24
|
+
def is_next_power_of_2(name):
|
25
|
+
return _NEXT_POWER_OF_2 in _find_prefixes(name)
|
26
|
+
|
27
|
+
|
28
|
+
def remove_prefixes(name):
|
29
|
+
return _PREFIX_PATTERN.sub("", name)
|
30
|
+
|
31
|
+
|
32
|
+
_CONSTEXPR = "constexpr"
|
33
|
+
|
34
|
+
_META = "meta"
|
35
|
+
|
36
|
+
_NEXT_POWER_OF_2 = "next_power_of_2"
|
37
|
+
|
38
|
+
_PREFIX_PATTERN = re.compile(r"ninetoothed_((?!_).*?)_prefix_")
|
39
|
+
|
40
|
+
|
41
|
+
def _add_prefix(name, string):
|
42
|
+
return f"{_make_prefix(string)}{name}"
|
43
|
+
|
44
|
+
|
45
|
+
def _make_prefix(string):
|
46
|
+
return f"ninetoothed_{string}_prefix_"
|
47
|
+
|
48
|
+
|
49
|
+
def _find_prefixes(name):
|
50
|
+
return set(_PREFIX_PATTERN.findall(name))
|
ninetoothed/symbol.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
import ast
|
2
2
|
import inspect
|
3
|
+
import numbers
|
3
4
|
import types
|
4
5
|
|
6
|
+
import ninetoothed.naming as naming
|
7
|
+
|
5
8
|
|
6
9
|
class Symbol:
|
7
10
|
def __init__(self, expr, constexpr=None, meta=None):
|
@@ -28,18 +31,31 @@ class Symbol:
|
|
28
31
|
if constexpr is False:
|
29
32
|
raise ValueError("Non-constexpr meta symbol is not supported.")
|
30
33
|
|
31
|
-
self._node.id =
|
34
|
+
self._node.id = naming.make_meta(self._node.id)
|
32
35
|
|
33
36
|
if constexpr:
|
34
|
-
self._node.id =
|
37
|
+
self._node.id = naming.make_constexpr(self._node.id)
|
38
|
+
|
39
|
+
def __eq__(self, other):
|
40
|
+
if isinstance(self._node, ast.Constant):
|
41
|
+
if isinstance(other, Symbol) and isinstance(other._node, ast.Constant):
|
42
|
+
return self._node.value == other._node.value
|
43
|
+
|
44
|
+
if isinstance(other, numbers.Number):
|
45
|
+
return self._node.value == other
|
46
|
+
|
47
|
+
return False
|
48
|
+
|
49
|
+
def __hash__(self):
|
50
|
+
return id(self)
|
35
51
|
|
36
52
|
def __add__(self, other):
|
37
53
|
other = type(self)(other)
|
38
54
|
|
39
|
-
if
|
55
|
+
if self == 0:
|
40
56
|
return other
|
41
57
|
|
42
|
-
if
|
58
|
+
if other == 0:
|
43
59
|
return self
|
44
60
|
|
45
61
|
return type(self)(ast.BinOp(left=self._node, op=ast.Add(), right=other._node))
|
@@ -47,19 +63,30 @@ class Symbol:
|
|
47
63
|
def __radd__(self, other):
|
48
64
|
return self.__add__(other)
|
49
65
|
|
50
|
-
def
|
66
|
+
def __sub__(self, other):
|
51
67
|
other = type(self)(other)
|
52
68
|
|
53
|
-
if
|
54
|
-
return
|
69
|
+
if self == 0:
|
70
|
+
return -other
|
55
71
|
|
56
|
-
if
|
72
|
+
if other == 0:
|
73
|
+
return self
|
74
|
+
|
75
|
+
return type(self)(ast.BinOp(left=self._node, op=ast.Sub(), right=other._node))
|
76
|
+
|
77
|
+
def __rsub__(self, other):
|
78
|
+
return type(self)(other).__sub__(self)
|
79
|
+
|
80
|
+
def __mul__(self, other):
|
81
|
+
other = type(self)(other)
|
82
|
+
|
83
|
+
if self == 0 or other == 0:
|
57
84
|
return type(self)(0)
|
58
85
|
|
59
|
-
if
|
86
|
+
if self == 1:
|
60
87
|
return other
|
61
88
|
|
62
|
-
if
|
89
|
+
if other == 1:
|
63
90
|
return self
|
64
91
|
|
65
92
|
return type(self)(ast.BinOp(left=self._node, op=ast.Mult(), right=other._node))
|
@@ -136,40 +163,8 @@ class Symbol:
|
|
136
163
|
return SliceSimplifier().visit(self._node)
|
137
164
|
|
138
165
|
@staticmethod
|
139
|
-
def
|
140
|
-
return
|
141
|
-
|
142
|
-
@staticmethod
|
143
|
-
def is_meta(name):
|
144
|
-
return name.startswith(Symbol._meta_prefix())
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def remove_prefix(name):
|
148
|
-
if name.startswith(Symbol._constexpr_prefix()):
|
149
|
-
return name.removeprefix(Symbol._constexpr_prefix())
|
150
|
-
|
151
|
-
if name.startswith(Symbol._meta_prefix()):
|
152
|
-
return name.removeprefix(Symbol._meta_prefix())
|
153
|
-
|
154
|
-
@staticmethod
|
155
|
-
def _create_constexpr(name):
|
156
|
-
return f"{Symbol._constexpr_prefix()}{name}"
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def _create_meta(name):
|
160
|
-
return f"{Symbol._meta_prefix()}{name}"
|
161
|
-
|
162
|
-
@staticmethod
|
163
|
-
def _constexpr_prefix():
|
164
|
-
return f"{Symbol._ninetoothed_prefix()}constexpr_"
|
165
|
-
|
166
|
-
@staticmethod
|
167
|
-
def _meta_prefix():
|
168
|
-
return f"{Symbol._ninetoothed_prefix()}meta_"
|
169
|
-
|
170
|
-
@staticmethod
|
171
|
-
def _ninetoothed_prefix():
|
172
|
-
return "_ninetoothed_"
|
166
|
+
def is_name(object):
|
167
|
+
return isinstance(object, Symbol) and isinstance(object.node, ast.Name)
|
173
168
|
|
174
169
|
|
175
170
|
class _FindAndReplacer(ast.NodeTransformer):
|
ninetoothed/tensor.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import itertools
|
2
|
+
import math
|
2
3
|
import re
|
3
4
|
|
4
5
|
from ninetoothed.language import call
|
@@ -15,13 +16,18 @@ class Tensor:
|
|
15
16
|
dtype=None,
|
16
17
|
strides=None,
|
17
18
|
other=None,
|
18
|
-
|
19
|
+
name=None,
|
20
|
+
source=None,
|
21
|
+
source_dims=None,
|
22
|
+
target=None,
|
23
|
+
target_dims=None,
|
19
24
|
):
|
20
|
-
type(self).num_instances += 1
|
21
|
-
|
22
25
|
self.dtype = dtype
|
23
26
|
|
24
|
-
|
27
|
+
if name is not None:
|
28
|
+
self.name = name
|
29
|
+
else:
|
30
|
+
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
|
25
31
|
|
26
32
|
if ndim is not None:
|
27
33
|
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
|
@@ -36,34 +42,61 @@ class Tensor:
|
|
36
42
|
|
37
43
|
self.other = other
|
38
44
|
|
39
|
-
if
|
40
|
-
self.
|
45
|
+
if source is not None:
|
46
|
+
self.source = source
|
41
47
|
else:
|
42
|
-
self.
|
48
|
+
self.source = self
|
43
49
|
|
44
|
-
|
45
|
-
|
46
|
-
|
50
|
+
if source_dims is not None:
|
51
|
+
self.source_dims = source_dims
|
52
|
+
else:
|
53
|
+
self.source_dims = (dim for dim in range(self.source.ndim))
|
54
|
+
|
55
|
+
if target is not None:
|
56
|
+
self.target = target
|
57
|
+
else:
|
58
|
+
self.target = self
|
59
|
+
|
60
|
+
if target_dims is not None:
|
61
|
+
self.target_dims = target_dims
|
62
|
+
else:
|
63
|
+
self.target_dims = (dim for dim in range(self.target.ndim))
|
64
|
+
|
65
|
+
type(self).num_instances += 1
|
66
|
+
|
67
|
+
def tile(self, tile_shape, strides=None, dilation=None):
|
68
|
+
if strides is None:
|
69
|
+
strides = [-1 for _ in tile_shape]
|
70
|
+
|
71
|
+
if dilation is None:
|
72
|
+
dilation = [1 for _ in tile_shape]
|
47
73
|
|
48
74
|
outer_shape = []
|
49
75
|
outer_strides = []
|
50
76
|
inner_shape = []
|
51
77
|
inner_strides = []
|
52
78
|
|
53
|
-
for
|
54
|
-
self.shape, self.strides, tile_shape,
|
79
|
+
for self_size, self_stride, tile_size, stride, spacing in zip(
|
80
|
+
self.shape, self.strides, tile_shape, strides, dilation
|
55
81
|
):
|
56
82
|
if tile_size == -1:
|
57
|
-
tile_size =
|
83
|
+
tile_size = self_size
|
84
|
+
|
85
|
+
if stride == -1:
|
86
|
+
stride = tile_size
|
58
87
|
|
59
|
-
new_size =
|
88
|
+
new_size = (
|
89
|
+
call("cdiv", self_size - spacing * (tile_size - 1) - 1, stride) + 1
|
90
|
+
if stride != 0
|
91
|
+
else -1
|
92
|
+
)
|
60
93
|
outer_shape.append(new_size)
|
61
94
|
|
62
|
-
new_stride =
|
95
|
+
new_stride = self_stride * stride // spacing
|
63
96
|
outer_strides.append(new_stride)
|
64
97
|
|
65
98
|
inner_shape.append(tile_size)
|
66
|
-
next_stride =
|
99
|
+
next_stride = self_stride * spacing
|
67
100
|
inner_strides.append(next_stride)
|
68
101
|
|
69
102
|
return type(self)(
|
@@ -72,10 +105,16 @@ class Tensor:
|
|
72
105
|
shape=inner_shape,
|
73
106
|
dtype=self.dtype,
|
74
107
|
strides=inner_strides,
|
75
|
-
|
108
|
+
source=self.source,
|
109
|
+
source_dims=self.source_dims,
|
110
|
+
target=self.target,
|
111
|
+
target_dims=self.target_dims,
|
76
112
|
),
|
77
113
|
strides=outer_strides,
|
78
|
-
|
114
|
+
source=self.source,
|
115
|
+
source_dims=self.source_dims,
|
116
|
+
target=self.target,
|
117
|
+
target_dims=self.target_dims,
|
79
118
|
)
|
80
119
|
|
81
120
|
def expand(self, shape):
|
@@ -90,7 +129,10 @@ class Tensor:
|
|
90
129
|
stride if new_size == -1 else 0
|
91
130
|
for new_size, stride in zip(shape, self.strides)
|
92
131
|
],
|
93
|
-
|
132
|
+
source=self.source,
|
133
|
+
source_dims=self.source_dims,
|
134
|
+
target=self.target,
|
135
|
+
target_dims=self.target_dims,
|
94
136
|
)
|
95
137
|
|
96
138
|
def squeeze(self, dim):
|
@@ -99,73 +141,109 @@ class Tensor:
|
|
99
141
|
shape=[size for i, size in enumerate(self.shape) if dim != i],
|
100
142
|
dtype=self.dtype,
|
101
143
|
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
|
102
|
-
|
144
|
+
source=self.source,
|
145
|
+
source_dims=[
|
146
|
+
source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
|
147
|
+
],
|
148
|
+
target=self.target,
|
149
|
+
target_dims=[
|
150
|
+
target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
|
151
|
+
],
|
103
152
|
)
|
104
153
|
|
105
|
-
def
|
106
|
-
|
107
|
-
|
154
|
+
def permute(self, dims):
|
155
|
+
# TODO: Add error handling.
|
156
|
+
new_shape = [None for _ in range(self.ndim)]
|
157
|
+
new_strides = [None for _ in range(self.ndim)]
|
158
|
+
new_source_dims = [None for _ in range(self.ndim)]
|
108
159
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
160
|
+
for original_dim, permuted_dim in enumerate(dims):
|
161
|
+
new_shape[original_dim] = self.shape[permuted_dim]
|
162
|
+
new_strides[original_dim] = self.strides[permuted_dim]
|
163
|
+
new_source_dims[original_dim] = self.source_dims[permuted_dim]
|
164
|
+
|
165
|
+
return type(self)(
|
166
|
+
shape=new_shape,
|
167
|
+
dtype=self.dtype,
|
168
|
+
strides=new_strides,
|
169
|
+
source=self.source,
|
170
|
+
source_dims=new_source_dims,
|
171
|
+
target=self.target,
|
172
|
+
target_dims=self.target_dims,
|
118
173
|
)
|
119
174
|
|
120
|
-
def
|
121
|
-
|
122
|
-
|
175
|
+
def flatten(self, start_dim=None, end_dim=None):
|
176
|
+
# TODO: Add error handling.
|
177
|
+
if start_dim is None:
|
178
|
+
start_dim = 0
|
179
|
+
if end_dim is None:
|
180
|
+
end_dim = self.ndim
|
123
181
|
|
124
|
-
|
182
|
+
leading_sizes = self.shape[:start_dim]
|
183
|
+
flattening_sizes = self.shape[start_dim:end_dim]
|
184
|
+
trailing_sizes = self.shape[end_dim:]
|
125
185
|
|
126
|
-
|
127
|
-
start = 0
|
186
|
+
new_shape = leading_sizes + (math.prod(flattening_sizes),) + trailing_sizes
|
128
187
|
|
129
|
-
|
130
|
-
|
131
|
-
|
188
|
+
leading_strides = self.strides[:start_dim]
|
189
|
+
flattening_strides = self.strides[start_dim:end_dim]
|
190
|
+
trailing_strides = self.strides[end_dim:]
|
132
191
|
|
133
|
-
|
134
|
-
for dim in self._dims_of(stride):
|
135
|
-
offsets[dim].append(index * stride)
|
192
|
+
new_strides = leading_strides + (flattening_strides[-1],) + trailing_strides
|
136
193
|
|
137
|
-
|
138
|
-
|
194
|
+
leading_source_dims = self.source_dims[:start_dim]
|
195
|
+
flattening_source_dims = self.source_dims[start_dim:end_dim]
|
196
|
+
trailing_source_dims = self.source_dims[end_dim:]
|
139
197
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
return offsets
|
145
|
-
|
146
|
-
def indices(self, index=None):
|
147
|
-
if index is None:
|
148
|
-
index = call("program_id", 0)
|
198
|
+
new_source_dims = (
|
199
|
+
leading_source_dims + (flattening_source_dims,) + trailing_source_dims
|
200
|
+
)
|
149
201
|
|
150
|
-
|
202
|
+
return type(self)(
|
203
|
+
shape=new_shape,
|
204
|
+
dtype=self.dtype,
|
205
|
+
strides=new_strides,
|
206
|
+
source=self.source,
|
207
|
+
source_dims=new_source_dims,
|
208
|
+
target=self.target,
|
209
|
+
target_dims=self.target_dims,
|
210
|
+
)
|
151
211
|
|
152
|
-
|
153
|
-
|
154
|
-
|
212
|
+
def ravel(self):
|
213
|
+
# TODO: Add error handling.
|
214
|
+
new_shape = []
|
215
|
+
new_strides = []
|
155
216
|
|
156
|
-
curr = self
|
217
|
+
curr = self
|
157
218
|
|
158
|
-
while isinstance(curr
|
159
|
-
|
160
|
-
|
219
|
+
while isinstance(curr, type(self)):
|
220
|
+
new_shape.extend(curr.shape)
|
221
|
+
new_strides.extend(curr.strides)
|
161
222
|
|
162
223
|
curr = curr.dtype
|
163
224
|
|
164
|
-
|
165
|
-
|
166
|
-
|
225
|
+
return type(self)(
|
226
|
+
shape=new_shape,
|
227
|
+
strides=new_strides,
|
228
|
+
other=self.source.other,
|
229
|
+
name=self.source.name,
|
230
|
+
)
|
231
|
+
|
232
|
+
def names(self):
|
233
|
+
if self.ndim == 0:
|
234
|
+
return {self.source.name}
|
167
235
|
|
168
|
-
return
|
236
|
+
return (
|
237
|
+
{self.source.pointer_string()}
|
238
|
+
| {
|
239
|
+
name
|
240
|
+
for value in itertools.chain(self.shape, self.strides)
|
241
|
+
if isinstance(value, Symbol)
|
242
|
+
for name in value.names()
|
243
|
+
}
|
244
|
+
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
245
|
+
| (self.source.names() if self.source is not self else set())
|
246
|
+
)
|
169
247
|
|
170
248
|
def inmost(self):
|
171
249
|
if not isinstance(self.dtype, type(self)):
|
@@ -214,6 +292,22 @@ class Tensor:
|
|
214
292
|
def ndim(self):
|
215
293
|
return len(self.shape)
|
216
294
|
|
295
|
+
@property
|
296
|
+
def source_dims(self):
|
297
|
+
return self._source_dims
|
298
|
+
|
299
|
+
@source_dims.setter
|
300
|
+
def source_dims(self, value):
|
301
|
+
self._source_dims = tuple(value)
|
302
|
+
|
303
|
+
@property
|
304
|
+
def target_dims(self):
|
305
|
+
return self._target_dims
|
306
|
+
|
307
|
+
@target_dims.setter
|
308
|
+
def target_dims(self, value):
|
309
|
+
self._target_dims = tuple(value)
|
310
|
+
|
217
311
|
@staticmethod
|
218
312
|
def pointer_pattern():
|
219
313
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
@@ -226,21 +320,11 @@ class Tensor:
|
|
226
320
|
def stride_pattern():
|
227
321
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
|
228
322
|
|
229
|
-
def _dims_of(self, stride):
|
230
|
-
dims = set()
|
231
|
-
names = stride.names() if isinstance(stride, Symbol) else {stride}
|
232
|
-
|
233
|
-
for dim, original_stride in enumerate(self.original.strides):
|
234
|
-
if str(original_stride) in names:
|
235
|
-
dims.add(dim)
|
236
|
-
|
237
|
-
return dims
|
238
|
-
|
239
323
|
@staticmethod
|
240
324
|
def _calculate_default_strides(shape):
|
241
325
|
strides = [1]
|
242
326
|
|
243
|
-
for size in shape[1:]:
|
327
|
+
for size in reversed(shape[1:]):
|
244
328
|
strides.append(size * strides[-1])
|
245
329
|
|
246
330
|
return reversed(strides)
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -51,6 +51,8 @@ def add_kernel(
|
|
51
51
|
|
52
52
|
In this code, we first define `BLOCK_SIZE`, which is a `Symbol`. You can think of `"BLOCK_SIZE"` as its name. We see that `meta` is set to `True`, indicating to the compiler that it is a meta-parameter and its value can be determined by the compiler. The `Tensor(1)` constructs a one-dimensional tensor (vector), and `Tensor(1).tile((BLOCK_SIZE,))` means we want to create a vector and divide it into blocks of size `BLOCK_SIZE`. Suppose the size of this vector is `8192` and `BLOCK_SIZE` is `1024`, then the vector will be divided into `8` blocks, each of size `1024`.
|
53
53
|
|
54
|
+

|
55
|
+
|
54
56
|
By using type annotations, we tell the compiler that we will have three tensor parameters, which will be divided into blocks, and `x`, `y`, and `z` are these blocks. It's important to understand that `x`, `y`, and `z` are the blocks, not the tensors themselves. In the function body, `x`, `y`, and `z` are also the blocks. The rest is straightforward (only one line `z = x + y` left, haha), we add each block of `x` and `y` and store it in `z`. Since each block of the parameter tensors undergoes this operation, the addition is completed for the whole tensors as well.
|
55
57
|
|
56
58
|
### Matrix Multiplication
|
@@ -82,4 +84,10 @@ def matmul_kernel(a: a_tiled, b: b_tiled, c: c_tiled):
|
|
82
84
|
|
83
85
|
For matrix multiplication, we also have three tensor parameters, but the tiling method is more complex than vector addition. We denote the three matrices as $A$, $B$, and $C$, where $A$ and $B$ are inputs, and $C$ is the output. Tiling $C$ is simple; we just need to divide it into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_N)` by rows and columns. Once each block computes its result, the entire $C$ is computed. However, how should we tile $A$ and $B$? The answer is to introduce another meta-parameter `BLOCK_SIZE_K`. This way, we can divide $A$ into blocks of size `(BLOCK_SIZE_M, BLOCK_SIZE_K)` and $B$ into blocks of size `(BLOCK_SIZE_K, BLOCK_SIZE_N)`. However, for matrix multiplication, $A$ and $B$ do not correspond block by block; each row of $A$ needs to correspond to each column of $B$. Therefore, we need to further `tile` $A$ and $B$ by rows and columns, respectively. Up to this point, we have a set of row blocks of $A$ and column blocks of $B$. However, each row block of $A$ must correspond to every column block of $B$. This is where `expand` comes in. We `expand` the row blocks of $A$ along the columns to the number of columns of $C$ and the column blocks of $B$ along the rows to the number of rows of $C$. This way, we successfully tile $A$, $B$, and $C$. In fact, our meta-operations up to this point have already enabled us to write kernel functions. However, we notice that the levels where the row blocks and column blocks reside, which we mentioned earlier, are two-dimensional, and their sizes are of the forms `(1, ...)` and `(..., 1)`. This means that if no other operations are performed, the way we access row blocks and column blocks would have to be `a[0, k]` and `b[k, 0]`. If we want to use `a` to find the range of `k`, we would need to use `a.shape[1]`, but we know that dimensions of size `1` can actually be removed completely. This is why we added two lines of `squeeze`. The `dtype` refers to the data type, which in PyTorch can generally be some integer or floating-point type, such as `torch.float32`. However, since meta-operations like `tile` can be performed in NineToothed, `dtype` can also be a `Tensor`. In other words, there is a concept of "tensors that store tensors" in NineToothed. In summary, these two lines perform operations on the tensors stored in the outmost tensor, removing the dimensions of size `1`. This way, when we access the row and column blocks, we can use `a[k]` and `b[k]`, and when finding the range of `k`, we can use `a.shape[0]`.
|
84
86
|
|
87
|
+

|
88
|
+
|
85
89
|
With tiling done, the rest is simple. In the function body, we define an `accumulator` to accumulate intermediate results. We then iterate through the corresponding row blocks of $A$ and column blocks of $B$, multiplying them and accumulating the results in `accumulator`. Finally, we place the `accumulator` in the corresponding block of $C$. Since each block of the parameter tensors undergoes this operation, the multiplication is completed for the whole tensors as well.
|
90
|
+
|
91
|
+
## License
|
92
|
+
|
93
|
+
This project is distributed under the Apache-2.0 license. See the included [LICENSE](LICENSE) file for details.
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
|
7
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
+
ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
|
9
|
+
ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.8.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=5gNp4HixCkural_Ns3DxwT4LL3OUcG0ECj4NLjb-EYk,16959
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
|
5
|
-
ninetoothed/tensor.py,sha256=L-9LhwnM4uRtRvj3tqrzerUijEfKeTQvFBcmS1hQilI,6656
|
6
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
-
ninetoothed-0.6.0.dist-info/METADATA,sha256=zvY4nvKt7R8kWDYrGnApem_C07trLgOj1-7zXPfqD9U,6785
|
8
|
-
ninetoothed-0.6.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.6.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.6.0.dist-info/RECORD,,
|
File without changes
|