ninetoothed 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/__init__.py +2 -2
- ninetoothed/jit.py +187 -90
- ninetoothed/naming.py +4 -0
- ninetoothed/tensor.py +140 -74
- ninetoothed/torchifier.py +4 -0
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.9.0.dist-info}/METADATA +3 -2
- ninetoothed-0.9.0.dist-info/RECORD +11 -0
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.9.0.dist-info}/WHEEL +1 -1
- ninetoothed-0.7.0.dist-info/RECORD +0 -11
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.9.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/__init__.py
CHANGED
ninetoothed/jit.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import ast
|
2
2
|
import collections
|
3
|
+
import copy
|
3
4
|
import functools
|
4
5
|
import importlib.util
|
5
6
|
import inspect
|
@@ -18,6 +19,15 @@ from ninetoothed.tensor import Tensor
|
|
18
19
|
from ninetoothed.torchifier import Torchifier
|
19
20
|
|
20
21
|
|
22
|
+
def make(arrangement, application, tensors):
|
23
|
+
params = inspect.signature(application).parameters
|
24
|
+
types = arrangement(*tensors)
|
25
|
+
annotations = {param: type for param, type in zip(params, types)}
|
26
|
+
application.__annotations__ = annotations
|
27
|
+
|
28
|
+
return jit(application)
|
29
|
+
|
30
|
+
|
21
31
|
def jit(_func=None, *, _prettify=False):
|
22
32
|
def wrapper(func):
|
23
33
|
return JIT(func, _prettify=_prettify)()
|
@@ -29,23 +39,12 @@ def jit(_func=None, *, _prettify=False):
|
|
29
39
|
|
30
40
|
|
31
41
|
class JIT:
|
32
|
-
handles = collections.defaultdict(dict)
|
33
|
-
|
34
42
|
def __init__(self, func, _prettify=False):
|
35
43
|
self.func = func
|
36
44
|
|
37
45
|
self._prettify = _prettify
|
38
46
|
|
39
47
|
def __call__(self):
|
40
|
-
source_file = inspect.getsourcefile(self.func)
|
41
|
-
source_line = inspect.getsourcelines(self.func)[1]
|
42
|
-
|
43
|
-
if (
|
44
|
-
source_file in type(self).handles
|
45
|
-
and source_line in type(self).handles[source_file]
|
46
|
-
):
|
47
|
-
return type(self).handles[source_file][source_line]
|
48
|
-
|
49
48
|
tree = self._get_tree()
|
50
49
|
|
51
50
|
CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
|
@@ -83,8 +82,6 @@ class JIT:
|
|
83
82
|
source,
|
84
83
|
)
|
85
84
|
|
86
|
-
type(self).handles[source_file][source_line] = handle
|
87
|
-
|
88
85
|
return handle
|
89
86
|
|
90
87
|
def _get_tree(self):
|
@@ -141,30 +138,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
141
138
|
def visit_FunctionDef(self, node):
|
142
139
|
self._func_def = node
|
143
140
|
|
144
|
-
self.
|
145
|
-
|
146
|
-
for arg in self._args:
|
147
|
-
if not isinstance(arg, Tensor) or arg.ndim == 0:
|
148
|
-
continue
|
149
|
-
|
150
|
-
offsets = arg.offsets()
|
141
|
+
self._invariants = {}
|
151
142
|
|
152
|
-
|
153
|
-
type(self)._name_for_offsets(arg, dim): offs
|
154
|
-
for dim, offs in enumerate(offsets)
|
155
|
-
} | {
|
156
|
-
type(self)._name_for_pointers(arg): arg.original.pointer_string()
|
157
|
-
+ sum(
|
158
|
-
type(self)._name_for_offsets(arg, dim)[
|
159
|
-
type(self)._generate_slices(arg, dim)
|
160
|
-
]
|
161
|
-
* stride
|
162
|
-
for dim, stride in enumerate(arg.original.strides)
|
163
|
-
)
|
164
|
-
}
|
143
|
+
self.generic_visit(node)
|
165
144
|
|
166
|
-
|
167
|
-
|
145
|
+
for target, value in reversed(self._invariants.items()):
|
146
|
+
node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
|
168
147
|
|
169
148
|
return node
|
170
149
|
|
@@ -192,24 +171,20 @@ class CodeGenerator(ast.NodeTransformer):
|
|
192
171
|
]
|
193
172
|
|
194
173
|
autotune = self._generate_autotune(non_meta_names, meta_names)
|
195
|
-
self._func_def.decorator_list
|
174
|
+
self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
|
196
175
|
|
197
176
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
198
177
|
|
199
178
|
return node
|
200
179
|
|
201
180
|
def visit_Subscript(self, node):
|
202
|
-
if (
|
203
|
-
isinstance(node.value, ast.Name)
|
204
|
-
and node.value.id in self._context
|
205
|
-
and isinstance(node.ctx, ast.Load)
|
206
|
-
):
|
181
|
+
if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
|
207
182
|
value = self._context[node.value.id]
|
208
183
|
|
209
184
|
if isinstance(value, Tensor):
|
210
|
-
return
|
185
|
+
return self._generate_load(
|
211
186
|
value,
|
212
|
-
|
187
|
+
indices=node.slice.elts
|
213
188
|
if isinstance(node.slice, ast.Tuple)
|
214
189
|
else (node.slice,),
|
215
190
|
)
|
@@ -219,7 +194,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
219
194
|
return node
|
220
195
|
|
221
196
|
def visit_Attribute(self, node):
|
222
|
-
if
|
197
|
+
if self._in_context(node.value):
|
223
198
|
value = self._context[node.value.id]
|
224
199
|
|
225
200
|
if isinstance(value, Tensor):
|
@@ -234,8 +209,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
234
209
|
def visit_Name(self, node):
|
235
210
|
self.generic_visit(node)
|
236
211
|
|
237
|
-
if
|
238
|
-
return
|
212
|
+
if self._in_context(node) and isinstance(node.ctx, ast.Load):
|
213
|
+
return self._generate_load(self._context[node.id])
|
239
214
|
|
240
215
|
return node
|
241
216
|
|
@@ -243,16 +218,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
243
218
|
if len(node.targets) == 1:
|
244
219
|
target = node.targets[0]
|
245
220
|
|
246
|
-
if
|
221
|
+
if self._in_context(target):
|
247
222
|
self.generic_visit(node)
|
248
223
|
|
249
224
|
return ast.Expr(
|
250
|
-
|
225
|
+
self._generate_store(self._context[target.id], node.value)
|
251
226
|
)
|
252
227
|
elif (
|
253
228
|
isinstance(target, ast.Subscript)
|
254
|
-
and
|
255
|
-
and target.value.id in self._context
|
229
|
+
and self._in_context(target.value)
|
256
230
|
and isinstance(target.ctx, ast.Store)
|
257
231
|
):
|
258
232
|
value = self._context[target.value.id]
|
@@ -261,10 +235,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
261
235
|
self.generic_visit(node)
|
262
236
|
|
263
237
|
return ast.Expr(
|
264
|
-
|
238
|
+
self._generate_store(
|
265
239
|
value,
|
266
240
|
node.value,
|
267
|
-
|
241
|
+
indices=target.slice.elts
|
268
242
|
if isinstance(target.slice, ast.Tuple)
|
269
243
|
else (target.slice,),
|
270
244
|
)
|
@@ -274,6 +248,11 @@ class CodeGenerator(ast.NodeTransformer):
|
|
274
248
|
|
275
249
|
return node
|
276
250
|
|
251
|
+
_NAME_FOR_PID = Symbol("ninetoothed_pid")
|
252
|
+
|
253
|
+
def _in_context(self, node):
|
254
|
+
return isinstance(node, ast.Name) and node.id in self._context
|
255
|
+
|
277
256
|
def _generate_autotune(self, params, meta):
|
278
257
|
device = triton.runtime.driver.active.get_current_device()
|
279
258
|
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
@@ -354,10 +333,11 @@ class CodeGenerator(ast.NodeTransformer):
|
|
354
333
|
name=f"launch_{self._func_def.name}",
|
355
334
|
args=ast.arguments(
|
356
335
|
posonlyargs=[],
|
357
|
-
args=[ast.arg(arg=arg.
|
336
|
+
args=[ast.arg(arg=arg.source.name) for arg in self._args]
|
358
337
|
+ [
|
359
338
|
ast.arg(arg=param)
|
360
339
|
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
340
|
+
if not Tensor.size_pattern().fullmatch(param)
|
361
341
|
],
|
362
342
|
kwonlyargs=[],
|
363
343
|
defaults=[],
|
@@ -427,51 +407,105 @@ class CodeGenerator(ast.NodeTransformer):
|
|
427
407
|
|
428
408
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
429
409
|
|
430
|
-
|
431
|
-
def _generate_load(tensor, intermediate_indices=()):
|
410
|
+
def _generate_load(self, tensor, indices=()):
|
432
411
|
if tensor.ndim == 0:
|
433
|
-
return Symbol(tensor.
|
412
|
+
return Symbol(tensor.source.name).node
|
434
413
|
|
435
|
-
pointers, mask =
|
436
|
-
|
437
|
-
)
|
438
|
-
other = CodeGenerator._generate_other(tensor)
|
414
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
415
|
+
other = type(self)._generate_other(tensor)
|
439
416
|
|
440
417
|
return call("load", pointers, mask=mask, other=other).node
|
441
418
|
|
442
|
-
|
443
|
-
|
444
|
-
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
445
|
-
tensor, intermediate_indices
|
446
|
-
)
|
419
|
+
def _generate_store(self, tensor, value, indices=()):
|
420
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
447
421
|
|
448
422
|
return call("store", pointers, value, mask=mask).node
|
449
423
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
)
|
455
|
-
|
456
|
-
|
457
|
-
for
|
458
|
-
|
459
|
-
|
460
|
-
|
424
|
+
def _generate_pointers_and_mask(self, tensor, indices):
|
425
|
+
invariant_target_dims = type(self)._find_invariant_target_dims(tensor)
|
426
|
+
|
427
|
+
indices = self._complete_indices(tensor, indices)
|
428
|
+
offsets = type(self)._generate_offsets(tensor, indices)
|
429
|
+
|
430
|
+
for source_dim in range(tensor.source.ndim):
|
431
|
+
for target_dim in range(tensor.target.ndim):
|
432
|
+
if target_dim not in invariant_target_dims:
|
433
|
+
continue
|
434
|
+
|
435
|
+
name = type(self)._name_for_offsets(tensor, source_dim, target_dim)
|
436
|
+
self._invariants[name] = offsets[source_dim][target_dim]
|
437
|
+
offsets[source_dim][target_dim] = name
|
438
|
+
|
439
|
+
name_for_pointers = type(self)._name_for_pointers(tensor)
|
440
|
+
self._invariants[name_for_pointers] = Symbol(tensor.source.pointer_string())
|
441
|
+
|
442
|
+
for source_dim in range(tensor.source.ndim):
|
443
|
+
for target_dim in range(tensor.target.ndim):
|
444
|
+
if target_dim not in invariant_target_dims:
|
445
|
+
continue
|
446
|
+
|
447
|
+
self._invariants[name_for_pointers] += (
|
448
|
+
offsets[source_dim][target_dim][
|
449
|
+
type(self)._generate_slices(tensor, target_dim)
|
450
|
+
]
|
451
|
+
* tensor.source.strides[source_dim]
|
452
|
+
)
|
453
|
+
|
454
|
+
pointers = name_for_pointers + sum(
|
455
|
+
offsets[source_dim][target_dim][
|
456
|
+
type(self)._generate_slices(tensor, target_dim)
|
457
|
+
]
|
458
|
+
* tensor.source.strides[source_dim]
|
459
|
+
for source_dim in range(tensor.source.ndim)
|
460
|
+
for target_dim in range(tensor.target.ndim)
|
461
|
+
if target_dim not in invariant_target_dims
|
462
|
+
and offsets[source_dim][target_dim] != 0
|
461
463
|
)
|
462
464
|
mask = functools.reduce(
|
463
465
|
lambda x, y: x & y,
|
464
466
|
(
|
465
|
-
|
466
|
-
|
467
|
+
offsets[source_dim][target_dim][
|
468
|
+
type(self)._generate_slices(tensor, target_dim)
|
469
|
+
]
|
470
|
+
< tensor.source.shape[source_dim]
|
471
|
+
for source_dim in range(tensor.source.ndim)
|
472
|
+
for target_dim in range(tensor.target.ndim)
|
473
|
+
if offsets[source_dim][target_dim] != 0
|
467
474
|
),
|
468
475
|
)
|
469
476
|
|
470
477
|
return pointers, mask
|
471
478
|
|
479
|
+
def _complete_indices(self, tensor, indices):
|
480
|
+
indices = list(self._generate_pid_indices(tensor) + tuple(indices))
|
481
|
+
|
482
|
+
for size in tensor.inmost().shape:
|
483
|
+
if Symbol.is_name(size):
|
484
|
+
name = size.node.id
|
485
|
+
if not naming.is_meta(name):
|
486
|
+
size = naming.make_next_power_of_2(name)
|
487
|
+
|
488
|
+
indices.append(call("arange", 0, size))
|
489
|
+
|
490
|
+
return tuple(indices)
|
491
|
+
|
492
|
+
def _generate_pid_indices(self, tensor):
|
493
|
+
self._invariants[type(self)._NAME_FOR_PID] = call("program_id", 0)
|
494
|
+
|
495
|
+
indices = list(
|
496
|
+
type(self)._unravel_index(type(self)._NAME_FOR_PID, tensor.shape)
|
497
|
+
)
|
498
|
+
|
499
|
+
for dim, index in enumerate(indices):
|
500
|
+
name = type(self)._name_for_index(tensor, dim)
|
501
|
+
self._invariants[name] = index
|
502
|
+
indices[dim] = name
|
503
|
+
|
504
|
+
return tuple(indices)
|
505
|
+
|
472
506
|
@staticmethod
|
473
507
|
def _generate_other(tensor):
|
474
|
-
other = tensor.
|
508
|
+
other = tensor.source.other
|
475
509
|
|
476
510
|
if isinstance(other, float) and not math.isfinite(other):
|
477
511
|
return f"float('{other}')"
|
@@ -483,23 +517,86 @@ class CodeGenerator(ast.NodeTransformer):
|
|
483
517
|
return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
|
484
518
|
|
485
519
|
@staticmethod
|
486
|
-
def
|
487
|
-
|
488
|
-
|
489
|
-
for offs in tensor.offsets(
|
490
|
-
[0 for _ in range(tensor.ndim)]
|
491
|
-
+ list(intermediate_indices)
|
492
|
-
+ [0 for _ in range(tensor.inmost().ndim)]
|
493
|
-
)
|
520
|
+
def _generate_offsets(tensor, indices):
|
521
|
+
offsets = collections.defaultdict(
|
522
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
494
523
|
)
|
495
524
|
|
525
|
+
curr = tensor
|
526
|
+
start = 0
|
527
|
+
|
528
|
+
while isinstance(curr, type(tensor)):
|
529
|
+
stop = start + curr.ndim
|
530
|
+
curr_indices = indices[start:stop]
|
531
|
+
|
532
|
+
for index, stride, source_dim, target_dim in zip(
|
533
|
+
curr_indices, curr.strides, curr.source_dims, curr.target_dims
|
534
|
+
):
|
535
|
+
offsets[source_dim][target_dim] += index * stride
|
536
|
+
|
537
|
+
start = stop
|
538
|
+
curr = curr.dtype
|
539
|
+
|
540
|
+
for source_dim in tuple(offsets):
|
541
|
+
for target_dim in tuple(offsets[source_dim]):
|
542
|
+
if not isinstance(source_dim, tuple):
|
543
|
+
continue
|
544
|
+
|
545
|
+
unraveled = CodeGenerator._unravel_index(
|
546
|
+
offsets[source_dim][target_dim],
|
547
|
+
tuple(tensor.source.shape[dim] for dim in source_dim),
|
548
|
+
)
|
549
|
+
|
550
|
+
for offs, dim in zip(unraveled, source_dim):
|
551
|
+
offsets[dim][target_dim] = offs
|
552
|
+
|
553
|
+
for source_dim in range(tensor.source.ndim):
|
554
|
+
for target_dim in range(tensor.target.ndim):
|
555
|
+
offsets[source_dim][target_dim] = copy.deepcopy(
|
556
|
+
offsets[source_dim][target_dim]
|
557
|
+
)
|
558
|
+
offsets[source_dim][target_dim].find_and_replace(
|
559
|
+
Symbol(tensor.source.strides[source_dim]), Symbol(1)
|
560
|
+
)
|
561
|
+
|
562
|
+
return offsets
|
563
|
+
|
564
|
+
@staticmethod
|
565
|
+
def _find_invariant_target_dims(tensor):
|
566
|
+
invariant_target_dims = set()
|
567
|
+
|
568
|
+
curr = tensor.dtype
|
569
|
+
|
570
|
+
while isinstance(curr.dtype, Tensor):
|
571
|
+
for target_dim in range(curr.target.ndim):
|
572
|
+
if target_dim not in curr.target_dims:
|
573
|
+
invariant_target_dims.add(target_dim)
|
574
|
+
|
575
|
+
curr = curr.dtype
|
576
|
+
|
577
|
+
return invariant_target_dims
|
578
|
+
|
496
579
|
@staticmethod
|
497
580
|
def _name_for_pointers(tensor):
|
498
|
-
return Symbol(f"{tensor.
|
581
|
+
return Symbol(f"{tensor.source.name}_pointers")
|
582
|
+
|
583
|
+
@staticmethod
|
584
|
+
def _name_for_offsets(tensor, source_dim, target_dim):
|
585
|
+
return Symbol(f"{tensor.source.name}_offsets_{source_dim}_{target_dim}")
|
499
586
|
|
500
587
|
@staticmethod
|
501
|
-
def
|
502
|
-
return Symbol(f"{tensor.
|
588
|
+
def _name_for_index(tensor, dim):
|
589
|
+
return Symbol(f"{tensor.source.name}_index_{dim}")
|
590
|
+
|
591
|
+
@staticmethod
|
592
|
+
def _unravel_index(index, shape):
|
593
|
+
indices = []
|
594
|
+
|
595
|
+
for stride in Tensor(shape=shape).strides:
|
596
|
+
indices.append(index // stride)
|
597
|
+
index %= stride
|
598
|
+
|
599
|
+
return tuple(indices)
|
503
600
|
|
504
601
|
|
505
602
|
class Tritonizer(ast.NodeTransformer):
|
ninetoothed/naming.py
CHANGED
ninetoothed/tensor.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import itertools
|
2
|
+
import math
|
2
3
|
import re
|
3
4
|
|
4
5
|
import ninetoothed.naming as naming
|
@@ -16,18 +17,25 @@ class Tensor:
|
|
16
17
|
dtype=None,
|
17
18
|
strides=None,
|
18
19
|
other=None,
|
20
|
+
constexpr_shape=None,
|
19
21
|
name=None,
|
20
|
-
|
22
|
+
source=None,
|
23
|
+
source_dims=None,
|
24
|
+
target=None,
|
25
|
+
target_dims=None,
|
21
26
|
):
|
22
27
|
self.dtype = dtype
|
23
28
|
|
24
29
|
if name is not None:
|
25
30
|
self.name = name
|
26
31
|
else:
|
27
|
-
self.name = f"
|
32
|
+
self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
|
28
33
|
|
29
34
|
if ndim is not None:
|
30
|
-
self.shape = (
|
35
|
+
self.shape = (
|
36
|
+
Symbol(self.size_string(i), constexpr=constexpr_shape)
|
37
|
+
for i in range(ndim)
|
38
|
+
)
|
31
39
|
self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
|
32
40
|
else:
|
33
41
|
self.shape = shape
|
@@ -39,10 +47,25 @@ class Tensor:
|
|
39
47
|
|
40
48
|
self.other = other
|
41
49
|
|
42
|
-
if
|
43
|
-
self.
|
50
|
+
if source is not None:
|
51
|
+
self.source = source
|
52
|
+
else:
|
53
|
+
self.source = self
|
54
|
+
|
55
|
+
if source_dims is not None:
|
56
|
+
self.source_dims = source_dims
|
57
|
+
else:
|
58
|
+
self.source_dims = (dim for dim in range(self.source.ndim))
|
59
|
+
|
60
|
+
if target is not None:
|
61
|
+
self.target = target
|
62
|
+
else:
|
63
|
+
self.target = self
|
64
|
+
|
65
|
+
if target_dims is not None:
|
66
|
+
self.target_dims = target_dims
|
44
67
|
else:
|
45
|
-
self.
|
68
|
+
self.target_dims = (dim for dim in range(self.target.ndim))
|
46
69
|
|
47
70
|
type(self).num_instances += 1
|
48
71
|
|
@@ -87,10 +110,16 @@ class Tensor:
|
|
87
110
|
shape=inner_shape,
|
88
111
|
dtype=self.dtype,
|
89
112
|
strides=inner_strides,
|
90
|
-
|
113
|
+
source=self.source,
|
114
|
+
source_dims=self.source_dims,
|
115
|
+
target=self.target,
|
116
|
+
target_dims=self.target_dims,
|
91
117
|
),
|
92
118
|
strides=outer_strides,
|
93
|
-
|
119
|
+
source=self.source,
|
120
|
+
source_dims=self.source_dims,
|
121
|
+
target=self.target,
|
122
|
+
target_dims=self.target_dims,
|
94
123
|
)
|
95
124
|
|
96
125
|
def expand(self, shape):
|
@@ -105,7 +134,10 @@ class Tensor:
|
|
105
134
|
stride if new_size == -1 else 0
|
106
135
|
for new_size, stride in zip(shape, self.strides)
|
107
136
|
],
|
108
|
-
|
137
|
+
source=self.source,
|
138
|
+
source_dims=self.source_dims,
|
139
|
+
target=self.target,
|
140
|
+
target_dims=self.target_dims,
|
109
141
|
)
|
110
142
|
|
111
143
|
def squeeze(self, dim):
|
@@ -114,81 +146,109 @@ class Tensor:
|
|
114
146
|
shape=[size for i, size in enumerate(self.shape) if dim != i],
|
115
147
|
dtype=self.dtype,
|
116
148
|
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
|
117
|
-
|
149
|
+
source=self.source,
|
150
|
+
source_dims=[
|
151
|
+
source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
|
152
|
+
],
|
153
|
+
target=self.target,
|
154
|
+
target_dims=[
|
155
|
+
target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
|
156
|
+
],
|
118
157
|
)
|
119
158
|
|
120
|
-
def
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
{self.original.pointer_string()}
|
126
|
-
| {
|
127
|
-
name
|
128
|
-
for value in itertools.chain(self.shape, self.strides)
|
129
|
-
if isinstance(value, Symbol)
|
130
|
-
for name in value.names()
|
131
|
-
}
|
132
|
-
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
133
|
-
| (self.original.names() if self.original is not self else set())
|
134
|
-
)
|
159
|
+
def permute(self, dims):
|
160
|
+
# TODO: Add error handling.
|
161
|
+
new_shape = [None for _ in range(self.ndim)]
|
162
|
+
new_strides = [None for _ in range(self.ndim)]
|
163
|
+
new_source_dims = [None for _ in range(self.ndim)]
|
135
164
|
|
136
|
-
|
137
|
-
|
138
|
-
|
165
|
+
for original_dim, permuted_dim in enumerate(dims):
|
166
|
+
new_shape[original_dim] = self.shape[permuted_dim]
|
167
|
+
new_strides[original_dim] = self.strides[permuted_dim]
|
168
|
+
new_source_dims[original_dim] = self.source_dims[permuted_dim]
|
139
169
|
|
140
|
-
|
170
|
+
return type(self)(
|
171
|
+
shape=new_shape,
|
172
|
+
dtype=self.dtype,
|
173
|
+
strides=new_strides,
|
174
|
+
source=self.source,
|
175
|
+
source_dims=new_source_dims,
|
176
|
+
target=self.target,
|
177
|
+
target_dims=self.target_dims,
|
178
|
+
)
|
141
179
|
|
142
|
-
|
143
|
-
|
180
|
+
def flatten(self, start_dim=None, end_dim=None):
|
181
|
+
# TODO: Add error handling.
|
182
|
+
if start_dim is None:
|
183
|
+
start_dim = 0
|
184
|
+
if end_dim is None:
|
185
|
+
end_dim = self.ndim
|
144
186
|
|
145
|
-
|
146
|
-
|
147
|
-
|
187
|
+
leading_sizes = self.shape[:start_dim]
|
188
|
+
flattening_sizes = self.shape[start_dim:end_dim]
|
189
|
+
trailing_sizes = self.shape[end_dim:]
|
148
190
|
|
149
|
-
|
150
|
-
for dim in self._dims_of(stride):
|
151
|
-
offsets[dim].append(index * stride)
|
191
|
+
new_shape = leading_sizes + (math.prod(flattening_sizes),) + trailing_sizes
|
152
192
|
|
153
|
-
|
154
|
-
|
193
|
+
leading_strides = self.strides[:start_dim]
|
194
|
+
flattening_strides = self.strides[start_dim:end_dim]
|
195
|
+
trailing_strides = self.strides[end_dim:]
|
155
196
|
|
156
|
-
|
157
|
-
offsets[dim] = sum(offsets[dim])
|
158
|
-
offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
|
197
|
+
new_strides = leading_strides + (flattening_strides[-1],) + trailing_strides
|
159
198
|
|
160
|
-
|
199
|
+
leading_source_dims = self.source_dims[:start_dim]
|
200
|
+
flattening_source_dims = self.source_dims[start_dim:end_dim]
|
201
|
+
trailing_source_dims = self.source_dims[end_dim:]
|
161
202
|
|
162
|
-
|
163
|
-
|
164
|
-
|
203
|
+
new_source_dims = (
|
204
|
+
leading_source_dims + (flattening_source_dims,) + trailing_source_dims
|
205
|
+
)
|
165
206
|
|
166
|
-
|
207
|
+
return type(self)(
|
208
|
+
shape=new_shape,
|
209
|
+
dtype=self.dtype,
|
210
|
+
strides=new_strides,
|
211
|
+
source=self.source,
|
212
|
+
source_dims=new_source_dims,
|
213
|
+
target=self.target,
|
214
|
+
target_dims=self.target_dims,
|
215
|
+
)
|
167
216
|
|
168
|
-
|
169
|
-
|
170
|
-
|
217
|
+
def ravel(self):
|
218
|
+
# TODO: Add error handling.
|
219
|
+
new_shape = []
|
220
|
+
new_strides = []
|
171
221
|
|
172
|
-
curr = self
|
222
|
+
curr = self
|
173
223
|
|
174
|
-
while isinstance(curr
|
175
|
-
|
176
|
-
|
224
|
+
while isinstance(curr, type(self)):
|
225
|
+
new_shape.extend(curr.shape)
|
226
|
+
new_strides.extend(curr.strides)
|
177
227
|
|
178
228
|
curr = curr.dtype
|
179
229
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
if not naming.is_meta(name):
|
187
|
-
size = naming.make_next_power_of_2(name)
|
230
|
+
return type(self)(
|
231
|
+
shape=new_shape,
|
232
|
+
strides=new_strides,
|
233
|
+
other=self.source.other,
|
234
|
+
name=self.source.name,
|
235
|
+
)
|
188
236
|
|
189
|
-
|
237
|
+
def names(self):
|
238
|
+
if self.ndim == 0:
|
239
|
+
return {self.source.name}
|
190
240
|
|
191
|
-
return
|
241
|
+
return (
|
242
|
+
{self.source.pointer_string()}
|
243
|
+
| {
|
244
|
+
name
|
245
|
+
for value in itertools.chain(self.shape, self.strides)
|
246
|
+
if isinstance(value, Symbol)
|
247
|
+
for name in value.names()
|
248
|
+
}
|
249
|
+
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
250
|
+
| (self.source.names() if self.source is not self else set())
|
251
|
+
)
|
192
252
|
|
193
253
|
def inmost(self):
|
194
254
|
if not isinstance(self.dtype, type(self)):
|
@@ -237,6 +297,22 @@ class Tensor:
|
|
237
297
|
def ndim(self):
|
238
298
|
return len(self.shape)
|
239
299
|
|
300
|
+
@property
|
301
|
+
def source_dims(self):
|
302
|
+
return self._source_dims
|
303
|
+
|
304
|
+
@source_dims.setter
|
305
|
+
def source_dims(self, value):
|
306
|
+
self._source_dims = tuple(value)
|
307
|
+
|
308
|
+
@property
|
309
|
+
def target_dims(self):
|
310
|
+
return self._target_dims
|
311
|
+
|
312
|
+
@target_dims.setter
|
313
|
+
def target_dims(self, value):
|
314
|
+
self._target_dims = tuple(value)
|
315
|
+
|
240
316
|
@staticmethod
|
241
317
|
def pointer_pattern():
|
242
318
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
@@ -249,16 +325,6 @@ class Tensor:
|
|
249
325
|
def stride_pattern():
|
250
326
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
|
251
327
|
|
252
|
-
def _dims_of(self, stride):
|
253
|
-
dims = set()
|
254
|
-
names = stride.names() if isinstance(stride, Symbol) else {stride}
|
255
|
-
|
256
|
-
for dim, original_stride in enumerate(self.original.strides):
|
257
|
-
if str(original_stride) in names:
|
258
|
-
dims.add(dim)
|
259
|
-
|
260
|
-
return dims
|
261
|
-
|
262
328
|
@staticmethod
|
263
329
|
def _calculate_default_strides(shape):
|
264
330
|
strides = [1]
|
ninetoothed/torchifier.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import ast
|
2
2
|
|
3
|
+
import ninetoothed.naming as naming
|
3
4
|
from ninetoothed.tensor import Tensor
|
4
5
|
|
5
6
|
|
@@ -9,6 +10,9 @@ class Torchifier(ast.NodeTransformer):
|
|
9
10
|
|
10
11
|
source = node.id
|
11
12
|
|
13
|
+
if naming.is_constexpr(source):
|
14
|
+
return node
|
15
|
+
|
12
16
|
def repl(match):
|
13
17
|
return f"{match.group(1)}"
|
14
18
|
|
@@ -1,10 +1,11 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
7
7
|
Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
|
8
|
+
License-File: LICENSE
|
8
9
|
Classifier: License :: OSI Approved :: Apache Software License
|
9
10
|
Classifier: Operating System :: OS Independent
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
|
9
|
+
ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.9.0.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=y0y3gfcBNeRLhozIzLuDHLLzBAGrL8tLk8Rcvhw_uec,20068
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
-
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
-
ninetoothed/tensor.py,sha256=KL6Iw2nwaRnZsNTxOeghDeXUQUgphnVh_1fsmAObOtI,7391
|
7
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
-
ninetoothed-0.7.0.dist-info/METADATA,sha256=gRLAzPSdxWYff8SqPuwJ3ji3Z_kPq_vtEupnSe7dIQ0,7032
|
9
|
-
ninetoothed-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
10
|
-
ninetoothed-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
-
ninetoothed-0.7.0.dist-info/RECORD,,
|
File without changes
|