ninetoothed 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/__init__.py +2 -2
- ninetoothed/jit.py +186 -77
- ninetoothed/tensor.py +134 -73
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.8.0.dist-info}/METADATA +3 -2
- ninetoothed-0.8.0.dist-info/RECORD +11 -0
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.8.0.dist-info}/WHEEL +1 -1
- ninetoothed-0.7.0.dist-info/RECORD +0 -11
- {ninetoothed-0.7.0.dist-info → ninetoothed-0.8.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/__init__.py
CHANGED
ninetoothed/jit.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import ast
|
2
2
|
import collections
|
3
|
+
import copy
|
3
4
|
import functools
|
4
5
|
import importlib.util
|
5
6
|
import inspect
|
@@ -18,6 +19,15 @@ from ninetoothed.tensor import Tensor
|
|
18
19
|
from ninetoothed.torchifier import Torchifier
|
19
20
|
|
20
21
|
|
22
|
+
def make(arrangement, application, tensors):
|
23
|
+
params = inspect.signature(application).parameters
|
24
|
+
types = arrangement(*tensors)
|
25
|
+
annotations = {param: type for param, type in zip(params, types)}
|
26
|
+
application.__annotations__ = annotations
|
27
|
+
|
28
|
+
return jit(application)
|
29
|
+
|
30
|
+
|
21
31
|
def jit(_func=None, *, _prettify=False):
|
22
32
|
def wrapper(func):
|
23
33
|
return JIT(func, _prettify=_prettify)()
|
@@ -141,30 +151,12 @@ class CodeGenerator(ast.NodeTransformer):
|
|
141
151
|
def visit_FunctionDef(self, node):
|
142
152
|
self._func_def = node
|
143
153
|
|
144
|
-
self.
|
145
|
-
|
146
|
-
for arg in self._args:
|
147
|
-
if not isinstance(arg, Tensor) or arg.ndim == 0:
|
148
|
-
continue
|
154
|
+
self._invariants = {}
|
149
155
|
|
150
|
-
|
151
|
-
|
152
|
-
initializations = {
|
153
|
-
type(self)._name_for_offsets(arg, dim): offs
|
154
|
-
for dim, offs in enumerate(offsets)
|
155
|
-
} | {
|
156
|
-
type(self)._name_for_pointers(arg): arg.original.pointer_string()
|
157
|
-
+ sum(
|
158
|
-
type(self)._name_for_offsets(arg, dim)[
|
159
|
-
type(self)._generate_slices(arg, dim)
|
160
|
-
]
|
161
|
-
* stride
|
162
|
-
for dim, stride in enumerate(arg.original.strides)
|
163
|
-
)
|
164
|
-
}
|
156
|
+
self.generic_visit(node)
|
165
157
|
|
166
|
-
|
167
|
-
|
158
|
+
for target, value in reversed(self._invariants.items()):
|
159
|
+
node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
|
168
160
|
|
169
161
|
return node
|
170
162
|
|
@@ -192,24 +184,20 @@ class CodeGenerator(ast.NodeTransformer):
|
|
192
184
|
]
|
193
185
|
|
194
186
|
autotune = self._generate_autotune(non_meta_names, meta_names)
|
195
|
-
self._func_def.decorator_list
|
187
|
+
self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
|
196
188
|
|
197
189
|
self._launch = self._generate_launch(non_meta_names, meta_names)
|
198
190
|
|
199
191
|
return node
|
200
192
|
|
201
193
|
def visit_Subscript(self, node):
|
202
|
-
if (
|
203
|
-
isinstance(node.value, ast.Name)
|
204
|
-
and node.value.id in self._context
|
205
|
-
and isinstance(node.ctx, ast.Load)
|
206
|
-
):
|
194
|
+
if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
|
207
195
|
value = self._context[node.value.id]
|
208
196
|
|
209
197
|
if isinstance(value, Tensor):
|
210
|
-
return
|
198
|
+
return self._generate_load(
|
211
199
|
value,
|
212
|
-
|
200
|
+
indices=node.slice.elts
|
213
201
|
if isinstance(node.slice, ast.Tuple)
|
214
202
|
else (node.slice,),
|
215
203
|
)
|
@@ -219,7 +207,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
219
207
|
return node
|
220
208
|
|
221
209
|
def visit_Attribute(self, node):
|
222
|
-
if
|
210
|
+
if self._in_context(node.value):
|
223
211
|
value = self._context[node.value.id]
|
224
212
|
|
225
213
|
if isinstance(value, Tensor):
|
@@ -234,8 +222,8 @@ class CodeGenerator(ast.NodeTransformer):
|
|
234
222
|
def visit_Name(self, node):
|
235
223
|
self.generic_visit(node)
|
236
224
|
|
237
|
-
if
|
238
|
-
return
|
225
|
+
if self._in_context(node) and isinstance(node.ctx, ast.Load):
|
226
|
+
return self._generate_load(self._context[node.id])
|
239
227
|
|
240
228
|
return node
|
241
229
|
|
@@ -243,16 +231,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
243
231
|
if len(node.targets) == 1:
|
244
232
|
target = node.targets[0]
|
245
233
|
|
246
|
-
if
|
234
|
+
if self._in_context(target):
|
247
235
|
self.generic_visit(node)
|
248
236
|
|
249
237
|
return ast.Expr(
|
250
|
-
|
238
|
+
self._generate_store(self._context[target.id], node.value)
|
251
239
|
)
|
252
240
|
elif (
|
253
241
|
isinstance(target, ast.Subscript)
|
254
|
-
and
|
255
|
-
and target.value.id in self._context
|
242
|
+
and self._in_context(target.value)
|
256
243
|
and isinstance(target.ctx, ast.Store)
|
257
244
|
):
|
258
245
|
value = self._context[target.value.id]
|
@@ -261,10 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
261
248
|
self.generic_visit(node)
|
262
249
|
|
263
250
|
return ast.Expr(
|
264
|
-
|
251
|
+
self._generate_store(
|
265
252
|
value,
|
266
253
|
node.value,
|
267
|
-
|
254
|
+
indices=target.slice.elts
|
268
255
|
if isinstance(target.slice, ast.Tuple)
|
269
256
|
else (target.slice,),
|
270
257
|
)
|
@@ -274,6 +261,11 @@ class CodeGenerator(ast.NodeTransformer):
|
|
274
261
|
|
275
262
|
return node
|
276
263
|
|
264
|
+
_NAME_FOR_PID = Symbol("ninetoothed_pid")
|
265
|
+
|
266
|
+
def _in_context(self, node):
|
267
|
+
return isinstance(node, ast.Name) and node.id in self._context
|
268
|
+
|
277
269
|
def _generate_autotune(self, params, meta):
|
278
270
|
device = triton.runtime.driver.active.get_current_device()
|
279
271
|
properties = triton.runtime.driver.active.utils.get_device_properties(device)
|
@@ -354,7 +346,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
354
346
|
name=f"launch_{self._func_def.name}",
|
355
347
|
args=ast.arguments(
|
356
348
|
posonlyargs=[],
|
357
|
-
args=[ast.arg(arg=arg.
|
349
|
+
args=[ast.arg(arg=arg.source.name) for arg in self._args]
|
358
350
|
+ [
|
359
351
|
ast.arg(arg=param)
|
360
352
|
for param in non_next_power_of_2_constexpr_params_without_prefixes
|
@@ -427,51 +419,105 @@ class CodeGenerator(ast.NodeTransformer):
|
|
427
419
|
|
428
420
|
return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
|
429
421
|
|
430
|
-
|
431
|
-
def _generate_load(tensor, intermediate_indices=()):
|
422
|
+
def _generate_load(self, tensor, indices=()):
|
432
423
|
if tensor.ndim == 0:
|
433
|
-
return Symbol(tensor.
|
424
|
+
return Symbol(tensor.source.name).node
|
434
425
|
|
435
|
-
pointers, mask =
|
436
|
-
|
437
|
-
)
|
438
|
-
other = CodeGenerator._generate_other(tensor)
|
426
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
427
|
+
other = type(self)._generate_other(tensor)
|
439
428
|
|
440
429
|
return call("load", pointers, mask=mask, other=other).node
|
441
430
|
|
442
|
-
|
443
|
-
|
444
|
-
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
445
|
-
tensor, intermediate_indices
|
446
|
-
)
|
431
|
+
def _generate_store(self, tensor, value, indices=()):
|
432
|
+
pointers, mask = self._generate_pointers_and_mask(tensor, indices)
|
447
433
|
|
448
434
|
return call("store", pointers, value, mask=mask).node
|
449
435
|
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
)
|
455
|
-
|
456
|
-
|
457
|
-
for
|
458
|
-
|
459
|
-
|
460
|
-
|
436
|
+
def _generate_pointers_and_mask(self, tensor, indices):
|
437
|
+
invariant_target_dims = type(self)._find_invariant_target_dims(tensor)
|
438
|
+
|
439
|
+
indices = self._complete_indices(tensor, indices)
|
440
|
+
offsets = type(self)._generate_offsets(tensor, indices)
|
441
|
+
|
442
|
+
for source_dim in range(tensor.source.ndim):
|
443
|
+
for target_dim in range(tensor.target.ndim):
|
444
|
+
if target_dim not in invariant_target_dims:
|
445
|
+
continue
|
446
|
+
|
447
|
+
name = type(self)._name_for_offsets(tensor, source_dim, target_dim)
|
448
|
+
self._invariants[name] = offsets[source_dim][target_dim]
|
449
|
+
offsets[source_dim][target_dim] = name
|
450
|
+
|
451
|
+
name_for_pointers = type(self)._name_for_pointers(tensor)
|
452
|
+
self._invariants[name_for_pointers] = Symbol(tensor.source.pointer_string())
|
453
|
+
|
454
|
+
for source_dim in range(tensor.source.ndim):
|
455
|
+
for target_dim in range(tensor.target.ndim):
|
456
|
+
if target_dim not in invariant_target_dims:
|
457
|
+
continue
|
458
|
+
|
459
|
+
self._invariants[name_for_pointers] += (
|
460
|
+
offsets[source_dim][target_dim][
|
461
|
+
type(self)._generate_slices(tensor, target_dim)
|
462
|
+
]
|
463
|
+
* tensor.source.strides[source_dim]
|
464
|
+
)
|
465
|
+
|
466
|
+
pointers = name_for_pointers + sum(
|
467
|
+
offsets[source_dim][target_dim][
|
468
|
+
type(self)._generate_slices(tensor, target_dim)
|
469
|
+
]
|
470
|
+
* tensor.source.strides[source_dim]
|
471
|
+
for source_dim in range(tensor.source.ndim)
|
472
|
+
for target_dim in range(tensor.target.ndim)
|
473
|
+
if target_dim not in invariant_target_dims
|
474
|
+
and offsets[source_dim][target_dim] != 0
|
461
475
|
)
|
462
476
|
mask = functools.reduce(
|
463
477
|
lambda x, y: x & y,
|
464
478
|
(
|
465
|
-
|
466
|
-
|
479
|
+
offsets[source_dim][target_dim][
|
480
|
+
type(self)._generate_slices(tensor, target_dim)
|
481
|
+
]
|
482
|
+
< tensor.source.shape[source_dim]
|
483
|
+
for source_dim in range(tensor.source.ndim)
|
484
|
+
for target_dim in range(tensor.target.ndim)
|
485
|
+
if offsets[source_dim][target_dim] != 0
|
467
486
|
),
|
468
487
|
)
|
469
488
|
|
470
489
|
return pointers, mask
|
471
490
|
|
491
|
+
def _complete_indices(self, tensor, indices):
|
492
|
+
indices = list(self._generate_pid_indices(tensor) + indices)
|
493
|
+
|
494
|
+
for size in tensor.inmost().shape:
|
495
|
+
if Symbol.is_name(size):
|
496
|
+
name = size.node.id
|
497
|
+
if not naming.is_meta(name):
|
498
|
+
size = naming.make_next_power_of_2(name)
|
499
|
+
|
500
|
+
indices.append(call("arange", 0, size))
|
501
|
+
|
502
|
+
return tuple(indices)
|
503
|
+
|
504
|
+
def _generate_pid_indices(self, tensor):
|
505
|
+
self._invariants[type(self)._NAME_FOR_PID] = call("program_id", 0)
|
506
|
+
|
507
|
+
indices = list(
|
508
|
+
type(self)._unravel_index(type(self)._NAME_FOR_PID, tensor.shape)
|
509
|
+
)
|
510
|
+
|
511
|
+
for dim, index in enumerate(indices):
|
512
|
+
name = type(self)._name_for_index(tensor, dim)
|
513
|
+
self._invariants[name] = index
|
514
|
+
indices[dim] = name
|
515
|
+
|
516
|
+
return tuple(indices)
|
517
|
+
|
472
518
|
@staticmethod
|
473
519
|
def _generate_other(tensor):
|
474
|
-
other = tensor.
|
520
|
+
other = tensor.source.other
|
475
521
|
|
476
522
|
if isinstance(other, float) and not math.isfinite(other):
|
477
523
|
return f"float('{other}')"
|
@@ -483,23 +529,86 @@ class CodeGenerator(ast.NodeTransformer):
|
|
483
529
|
return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
|
484
530
|
|
485
531
|
@staticmethod
|
486
|
-
def
|
487
|
-
|
488
|
-
|
489
|
-
for offs in tensor.offsets(
|
490
|
-
[0 for _ in range(tensor.ndim)]
|
491
|
-
+ list(intermediate_indices)
|
492
|
-
+ [0 for _ in range(tensor.inmost().ndim)]
|
493
|
-
)
|
532
|
+
def _generate_offsets(tensor, indices):
|
533
|
+
offsets = collections.defaultdict(
|
534
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
494
535
|
)
|
495
536
|
|
537
|
+
curr = tensor
|
538
|
+
start = 0
|
539
|
+
|
540
|
+
while isinstance(curr, type(tensor)):
|
541
|
+
stop = start + curr.ndim
|
542
|
+
curr_indices = indices[start:stop]
|
543
|
+
|
544
|
+
for index, stride, source_dim, target_dim in zip(
|
545
|
+
curr_indices, curr.strides, curr.source_dims, curr.target_dims
|
546
|
+
):
|
547
|
+
offsets[source_dim][target_dim] += index * stride
|
548
|
+
|
549
|
+
start = stop
|
550
|
+
curr = curr.dtype
|
551
|
+
|
552
|
+
for source_dim in tuple(offsets):
|
553
|
+
for target_dim in tuple(offsets[source_dim]):
|
554
|
+
if not isinstance(source_dim, tuple):
|
555
|
+
continue
|
556
|
+
|
557
|
+
unraveled = CodeGenerator._unravel_index(
|
558
|
+
offsets[source_dim][target_dim],
|
559
|
+
tuple(tensor.source.shape[dim] for dim in source_dim),
|
560
|
+
)
|
561
|
+
|
562
|
+
for offs, dim in zip(unraveled, source_dim):
|
563
|
+
offsets[dim][target_dim] = offs
|
564
|
+
|
565
|
+
for source_dim in range(tensor.source.ndim):
|
566
|
+
for target_dim in range(tensor.target.ndim):
|
567
|
+
offsets[source_dim][target_dim] = copy.deepcopy(
|
568
|
+
offsets[source_dim][target_dim]
|
569
|
+
)
|
570
|
+
offsets[source_dim][target_dim].find_and_replace(
|
571
|
+
Symbol(tensor.source.strides[source_dim]), Symbol(1)
|
572
|
+
)
|
573
|
+
|
574
|
+
return offsets
|
575
|
+
|
576
|
+
@staticmethod
|
577
|
+
def _find_invariant_target_dims(tensor):
|
578
|
+
invariant_target_dims = set()
|
579
|
+
|
580
|
+
curr = tensor.dtype
|
581
|
+
|
582
|
+
while isinstance(curr.dtype, Tensor):
|
583
|
+
for target_dim in range(curr.target.ndim):
|
584
|
+
if target_dim not in curr.target_dims:
|
585
|
+
invariant_target_dims.add(target_dim)
|
586
|
+
|
587
|
+
curr = curr.dtype
|
588
|
+
|
589
|
+
return invariant_target_dims
|
590
|
+
|
496
591
|
@staticmethod
|
497
592
|
def _name_for_pointers(tensor):
|
498
|
-
return Symbol(f"{tensor.
|
593
|
+
return Symbol(f"{tensor.source.name}_pointers")
|
594
|
+
|
595
|
+
@staticmethod
|
596
|
+
def _name_for_offsets(tensor, source_dim, target_dim):
|
597
|
+
return Symbol(f"{tensor.source.name}_offsets_{source_dim}_{target_dim}")
|
499
598
|
|
500
599
|
@staticmethod
|
501
|
-
def
|
502
|
-
return Symbol(f"{tensor.
|
600
|
+
def _name_for_index(tensor, dim):
|
601
|
+
return Symbol(f"{tensor.source.name}_index_{dim}")
|
602
|
+
|
603
|
+
@staticmethod
|
604
|
+
def _unravel_index(index, shape):
|
605
|
+
indices = []
|
606
|
+
|
607
|
+
for stride in Tensor(shape=shape).strides:
|
608
|
+
indices.append(index // stride)
|
609
|
+
index %= stride
|
610
|
+
|
611
|
+
return tuple(indices)
|
503
612
|
|
504
613
|
|
505
614
|
class Tritonizer(ast.NodeTransformer):
|
ninetoothed/tensor.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import itertools
|
2
|
+
import math
|
2
3
|
import re
|
3
4
|
|
4
|
-
import ninetoothed.naming as naming
|
5
5
|
from ninetoothed.language import call
|
6
6
|
from ninetoothed.symbol import Symbol
|
7
7
|
|
@@ -17,7 +17,10 @@ class Tensor:
|
|
17
17
|
strides=None,
|
18
18
|
other=None,
|
19
19
|
name=None,
|
20
|
-
|
20
|
+
source=None,
|
21
|
+
source_dims=None,
|
22
|
+
target=None,
|
23
|
+
target_dims=None,
|
21
24
|
):
|
22
25
|
self.dtype = dtype
|
23
26
|
|
@@ -39,10 +42,25 @@ class Tensor:
|
|
39
42
|
|
40
43
|
self.other = other
|
41
44
|
|
42
|
-
if
|
43
|
-
self.
|
45
|
+
if source is not None:
|
46
|
+
self.source = source
|
44
47
|
else:
|
45
|
-
self.
|
48
|
+
self.source = self
|
49
|
+
|
50
|
+
if source_dims is not None:
|
51
|
+
self.source_dims = source_dims
|
52
|
+
else:
|
53
|
+
self.source_dims = (dim for dim in range(self.source.ndim))
|
54
|
+
|
55
|
+
if target is not None:
|
56
|
+
self.target = target
|
57
|
+
else:
|
58
|
+
self.target = self
|
59
|
+
|
60
|
+
if target_dims is not None:
|
61
|
+
self.target_dims = target_dims
|
62
|
+
else:
|
63
|
+
self.target_dims = (dim for dim in range(self.target.ndim))
|
46
64
|
|
47
65
|
type(self).num_instances += 1
|
48
66
|
|
@@ -87,10 +105,16 @@ class Tensor:
|
|
87
105
|
shape=inner_shape,
|
88
106
|
dtype=self.dtype,
|
89
107
|
strides=inner_strides,
|
90
|
-
|
108
|
+
source=self.source,
|
109
|
+
source_dims=self.source_dims,
|
110
|
+
target=self.target,
|
111
|
+
target_dims=self.target_dims,
|
91
112
|
),
|
92
113
|
strides=outer_strides,
|
93
|
-
|
114
|
+
source=self.source,
|
115
|
+
source_dims=self.source_dims,
|
116
|
+
target=self.target,
|
117
|
+
target_dims=self.target_dims,
|
94
118
|
)
|
95
119
|
|
96
120
|
def expand(self, shape):
|
@@ -105,7 +129,10 @@ class Tensor:
|
|
105
129
|
stride if new_size == -1 else 0
|
106
130
|
for new_size, stride in zip(shape, self.strides)
|
107
131
|
],
|
108
|
-
|
132
|
+
source=self.source,
|
133
|
+
source_dims=self.source_dims,
|
134
|
+
target=self.target,
|
135
|
+
target_dims=self.target_dims,
|
109
136
|
)
|
110
137
|
|
111
138
|
def squeeze(self, dim):
|
@@ -114,81 +141,109 @@ class Tensor:
|
|
114
141
|
shape=[size for i, size in enumerate(self.shape) if dim != i],
|
115
142
|
dtype=self.dtype,
|
116
143
|
strides=[stride for i, stride in enumerate(self.strides) if dim != i],
|
117
|
-
|
144
|
+
source=self.source,
|
145
|
+
source_dims=[
|
146
|
+
source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
|
147
|
+
],
|
148
|
+
target=self.target,
|
149
|
+
target_dims=[
|
150
|
+
target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
|
151
|
+
],
|
118
152
|
)
|
119
153
|
|
120
|
-
def
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
{self.original.pointer_string()}
|
126
|
-
| {
|
127
|
-
name
|
128
|
-
for value in itertools.chain(self.shape, self.strides)
|
129
|
-
if isinstance(value, Symbol)
|
130
|
-
for name in value.names()
|
131
|
-
}
|
132
|
-
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
133
|
-
| (self.original.names() if self.original is not self else set())
|
134
|
-
)
|
154
|
+
def permute(self, dims):
|
155
|
+
# TODO: Add error handling.
|
156
|
+
new_shape = [None for _ in range(self.ndim)]
|
157
|
+
new_strides = [None for _ in range(self.ndim)]
|
158
|
+
new_source_dims = [None for _ in range(self.ndim)]
|
135
159
|
|
136
|
-
|
137
|
-
|
138
|
-
|
160
|
+
for original_dim, permuted_dim in enumerate(dims):
|
161
|
+
new_shape[original_dim] = self.shape[permuted_dim]
|
162
|
+
new_strides[original_dim] = self.strides[permuted_dim]
|
163
|
+
new_source_dims[original_dim] = self.source_dims[permuted_dim]
|
139
164
|
|
140
|
-
|
165
|
+
return type(self)(
|
166
|
+
shape=new_shape,
|
167
|
+
dtype=self.dtype,
|
168
|
+
strides=new_strides,
|
169
|
+
source=self.source,
|
170
|
+
source_dims=new_source_dims,
|
171
|
+
target=self.target,
|
172
|
+
target_dims=self.target_dims,
|
173
|
+
)
|
141
174
|
|
142
|
-
|
143
|
-
|
175
|
+
def flatten(self, start_dim=None, end_dim=None):
|
176
|
+
# TODO: Add error handling.
|
177
|
+
if start_dim is None:
|
178
|
+
start_dim = 0
|
179
|
+
if end_dim is None:
|
180
|
+
end_dim = self.ndim
|
144
181
|
|
145
|
-
|
146
|
-
|
147
|
-
|
182
|
+
leading_sizes = self.shape[:start_dim]
|
183
|
+
flattening_sizes = self.shape[start_dim:end_dim]
|
184
|
+
trailing_sizes = self.shape[end_dim:]
|
148
185
|
|
149
|
-
|
150
|
-
for dim in self._dims_of(stride):
|
151
|
-
offsets[dim].append(index * stride)
|
186
|
+
new_shape = leading_sizes + (math.prod(flattening_sizes),) + trailing_sizes
|
152
187
|
|
153
|
-
|
154
|
-
|
188
|
+
leading_strides = self.strides[:start_dim]
|
189
|
+
flattening_strides = self.strides[start_dim:end_dim]
|
190
|
+
trailing_strides = self.strides[end_dim:]
|
155
191
|
|
156
|
-
|
157
|
-
offsets[dim] = sum(offsets[dim])
|
158
|
-
offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
|
192
|
+
new_strides = leading_strides + (flattening_strides[-1],) + trailing_strides
|
159
193
|
|
160
|
-
|
194
|
+
leading_source_dims = self.source_dims[:start_dim]
|
195
|
+
flattening_source_dims = self.source_dims[start_dim:end_dim]
|
196
|
+
trailing_source_dims = self.source_dims[end_dim:]
|
161
197
|
|
162
|
-
|
163
|
-
|
164
|
-
|
198
|
+
new_source_dims = (
|
199
|
+
leading_source_dims + (flattening_source_dims,) + trailing_source_dims
|
200
|
+
)
|
165
201
|
|
166
|
-
|
202
|
+
return type(self)(
|
203
|
+
shape=new_shape,
|
204
|
+
dtype=self.dtype,
|
205
|
+
strides=new_strides,
|
206
|
+
source=self.source,
|
207
|
+
source_dims=new_source_dims,
|
208
|
+
target=self.target,
|
209
|
+
target_dims=self.target_dims,
|
210
|
+
)
|
167
211
|
|
168
|
-
|
169
|
-
|
170
|
-
|
212
|
+
def ravel(self):
|
213
|
+
# TODO: Add error handling.
|
214
|
+
new_shape = []
|
215
|
+
new_strides = []
|
171
216
|
|
172
|
-
curr = self
|
217
|
+
curr = self
|
173
218
|
|
174
|
-
while isinstance(curr
|
175
|
-
|
176
|
-
|
219
|
+
while isinstance(curr, type(self)):
|
220
|
+
new_shape.extend(curr.shape)
|
221
|
+
new_strides.extend(curr.strides)
|
177
222
|
|
178
223
|
curr = curr.dtype
|
179
224
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
if not naming.is_meta(name):
|
187
|
-
size = naming.make_next_power_of_2(name)
|
225
|
+
return type(self)(
|
226
|
+
shape=new_shape,
|
227
|
+
strides=new_strides,
|
228
|
+
other=self.source.other,
|
229
|
+
name=self.source.name,
|
230
|
+
)
|
188
231
|
|
189
|
-
|
232
|
+
def names(self):
|
233
|
+
if self.ndim == 0:
|
234
|
+
return {self.source.name}
|
190
235
|
|
191
|
-
return
|
236
|
+
return (
|
237
|
+
{self.source.pointer_string()}
|
238
|
+
| {
|
239
|
+
name
|
240
|
+
for value in itertools.chain(self.shape, self.strides)
|
241
|
+
if isinstance(value, Symbol)
|
242
|
+
for name in value.names()
|
243
|
+
}
|
244
|
+
| (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
|
245
|
+
| (self.source.names() if self.source is not self else set())
|
246
|
+
)
|
192
247
|
|
193
248
|
def inmost(self):
|
194
249
|
if not isinstance(self.dtype, type(self)):
|
@@ -237,6 +292,22 @@ class Tensor:
|
|
237
292
|
def ndim(self):
|
238
293
|
return len(self.shape)
|
239
294
|
|
295
|
+
@property
|
296
|
+
def source_dims(self):
|
297
|
+
return self._source_dims
|
298
|
+
|
299
|
+
@source_dims.setter
|
300
|
+
def source_dims(self, value):
|
301
|
+
self._source_dims = tuple(value)
|
302
|
+
|
303
|
+
@property
|
304
|
+
def target_dims(self):
|
305
|
+
return self._target_dims
|
306
|
+
|
307
|
+
@target_dims.setter
|
308
|
+
def target_dims(self, value):
|
309
|
+
self._target_dims = tuple(value)
|
310
|
+
|
240
311
|
@staticmethod
|
241
312
|
def pointer_pattern():
|
242
313
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
@@ -249,16 +320,6 @@ class Tensor:
|
|
249
320
|
def stride_pattern():
|
250
321
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
|
251
322
|
|
252
|
-
def _dims_of(self, stride):
|
253
|
-
dims = set()
|
254
|
-
names = stride.names() if isinstance(stride, Symbol) else {stride}
|
255
|
-
|
256
|
-
for dim, original_stride in enumerate(self.original.strides):
|
257
|
-
if str(original_stride) in names:
|
258
|
-
dims.add(dim)
|
259
|
-
|
260
|
-
return dims
|
261
|
-
|
262
323
|
@staticmethod
|
263
324
|
def _calculate_default_strides(shape):
|
264
325
|
strides = [1]
|
@@ -1,10 +1,11 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
7
7
|
Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
|
8
|
+
License-File: LICENSE
|
8
9
|
Classifier: License :: OSI Approved :: Apache Software License
|
9
10
|
Classifier: Operating System :: OS Independent
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
|
7
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
+
ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
|
9
|
+
ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.8.0.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=y0y3gfcBNeRLhozIzLuDHLLzBAGrL8tLk8Rcvhw_uec,20068
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
|
5
|
-
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
-
ninetoothed/tensor.py,sha256=KL6Iw2nwaRnZsNTxOeghDeXUQUgphnVh_1fsmAObOtI,7391
|
7
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
8
|
-
ninetoothed-0.7.0.dist-info/METADATA,sha256=gRLAzPSdxWYff8SqPuwJ3ji3Z_kPq_vtEupnSe7dIQ0,7032
|
9
|
-
ninetoothed-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
10
|
-
ninetoothed-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
-
ninetoothed-0.7.0.dist-info/RECORD,,
|
File without changes
|