ninetoothed 0.12.0__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +280 -32
- ninetoothed/symbol.py +13 -6
- ninetoothed/tensor.py +67 -3
- ninetoothed/visualization.py +6 -0
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.14.0.dist-info}/METADATA +4 -1
- ninetoothed-0.14.0.dist-info/RECORD +12 -0
- ninetoothed-0.12.0.dist-info/RECORD +0 -12
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.14.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.12.0.dist-info → ninetoothed-0.14.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -106,13 +106,18 @@ class JIT:
|
|
106
106
|
def _get_tree(self):
|
107
107
|
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
108
108
|
|
109
|
-
_AliasRestorer().visit(module)
|
110
109
|
collector = _ImportCollector()
|
111
110
|
collector.visit(module)
|
111
|
+
|
112
112
|
finder = _FunctionDefFinder(self.func.__name__)
|
113
113
|
finder.visit(module)
|
114
|
+
func_def = finder.result
|
115
|
+
|
116
|
+
inliner = _Inliner(self.func.__globals__)
|
117
|
+
inliner.visit(func_def)
|
118
|
+
module.body = collector.imports + inliner.imports + [finder.result]
|
114
119
|
|
115
|
-
return
|
120
|
+
return _AliasRestorer().visit(module)
|
116
121
|
|
117
122
|
def _find_dependencies(self):
|
118
123
|
dependencies = set()
|
@@ -483,13 +488,15 @@ class CodeGenerator(ast.NodeTransformer):
|
|
483
488
|
mask = functools.reduce(
|
484
489
|
lambda x, y: x & y,
|
485
490
|
(
|
486
|
-
|
487
|
-
|
488
|
-
|
491
|
+
sum(
|
492
|
+
offsets[source_dim][target_dim][
|
493
|
+
type(self)._generate_slices(tensor, target_dim)
|
494
|
+
]
|
495
|
+
for target_dim in range(tensor.target.ndim)
|
496
|
+
if offsets[source_dim][target_dim] != 0
|
497
|
+
)
|
489
498
|
< tensor.source.shape[source_dim]
|
490
499
|
for source_dim in range(tensor.source.ndim)
|
491
|
-
for target_dim in range(tensor.target.ndim)
|
492
|
-
if offsets[source_dim][target_dim] != 0
|
493
500
|
),
|
494
501
|
) & functools.reduce(
|
495
502
|
lambda x, y: x & y,
|
@@ -505,13 +512,21 @@ class CodeGenerator(ast.NodeTransformer):
|
|
505
512
|
return pointers, mask
|
506
513
|
|
507
514
|
def _complete_indices(self, tensor, indices):
|
515
|
+
class _NextPowerOfTwoMaker(ast.NodeTransformer):
|
516
|
+
def visit_Name(self, node):
|
517
|
+
name = node.id
|
518
|
+
|
519
|
+
if not naming.is_meta(name):
|
520
|
+
next_power_of_2_name = naming.make_next_power_of_2(name)
|
521
|
+
|
522
|
+
return ast.Name(id=next_power_of_2_name, ctx=ast.Load())
|
523
|
+
|
524
|
+
return self.generic_visit(node)
|
525
|
+
|
508
526
|
indices = list(self._generate_pid_indices(tensor) + tuple(indices))
|
509
527
|
|
510
528
|
for size in tensor.innermost().shape:
|
511
|
-
|
512
|
-
name = size.node.id
|
513
|
-
if not naming.is_meta(name):
|
514
|
-
size = naming.make_next_power_of_2(name)
|
529
|
+
size = _NextPowerOfTwoMaker().visit(Symbol(copy.deepcopy(size)).node)
|
515
530
|
|
516
531
|
indices.append(call("arange", 0, size))
|
517
532
|
|
@@ -549,8 +564,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
549
564
|
|
550
565
|
@staticmethod
|
551
566
|
def _generate_offsets(tensor, indices):
|
552
|
-
|
553
|
-
lambda: collections.defaultdict(
|
567
|
+
raw_offsets = collections.defaultdict(
|
568
|
+
lambda: collections.defaultdict(
|
569
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
570
|
+
)
|
554
571
|
)
|
555
572
|
|
556
573
|
curr = tensor
|
@@ -560,36 +577,62 @@ class CodeGenerator(ast.NodeTransformer):
|
|
560
577
|
stop = start + curr.ndim
|
561
578
|
curr_indices = indices[start:stop]
|
562
579
|
|
563
|
-
for index, stride, source_dim, target_dim in zip(
|
564
|
-
curr_indices,
|
580
|
+
for index, stride, source_dim, target_dim, unflattened_dim in zip(
|
581
|
+
curr_indices,
|
582
|
+
curr.strides,
|
583
|
+
curr.source_dims,
|
584
|
+
curr.target_dims,
|
585
|
+
curr.unflattened_dims,
|
565
586
|
):
|
566
|
-
|
587
|
+
raw_offsets[source_dim][target_dim][unflattened_dim] += index * stride
|
567
588
|
|
568
589
|
start = stop
|
569
590
|
curr = curr.dtype
|
570
591
|
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
continue
|
592
|
+
offsets = collections.defaultdict(
|
593
|
+
lambda: collections.defaultdict(lambda: Symbol(0))
|
594
|
+
)
|
575
595
|
|
576
|
-
|
577
|
-
offsets[source_dim][target_dim],
|
578
|
-
tuple(tensor.source.shape[dim] for dim in source_dim),
|
579
|
-
)
|
596
|
+
source_strides = tuple(Symbol(stride) for stride in tensor.source.strides)
|
580
597
|
|
581
|
-
|
582
|
-
|
598
|
+
unflattened_strides = tuple(
|
599
|
+
Symbol(stride) for stride in tensor.unflattened.strides
|
600
|
+
)
|
583
601
|
|
584
|
-
|
585
|
-
|
586
|
-
offsets[source_dim][target_dim]
|
587
|
-
|
602
|
+
def _add_unraveled_offsets(raw_offs, source_dim, target_dim, unflattened_dim):
|
603
|
+
if not isinstance(unflattened_dim, tuple):
|
604
|
+
offsets[source_dim][target_dim] += copy.deepcopy(
|
605
|
+
raw_offs
|
606
|
+
).find_and_replace(
|
607
|
+
unflattened_strides, Symbol(1)
|
608
|
+
) * unflattened_strides[unflattened_dim].find_and_replace(
|
609
|
+
source_strides, Symbol(1)
|
588
610
|
)
|
589
|
-
|
590
|
-
|
611
|
+
|
612
|
+
return
|
613
|
+
|
614
|
+
unraveled_offs = CodeGenerator._unravel_index(
|
615
|
+
raw_offs,
|
616
|
+
tuple(tensor.unflattened.shape[dim] for dim in unflattened_dim),
|
617
|
+
)
|
618
|
+
|
619
|
+
for raw_offs, source_dim, unflattened_dim in zip(
|
620
|
+
unraveled_offs, source_dim, unflattened_dim
|
621
|
+
):
|
622
|
+
_add_unraveled_offsets(
|
623
|
+
raw_offs, source_dim, target_dim, unflattened_dim
|
591
624
|
)
|
592
625
|
|
626
|
+
for source_dim in tuple(raw_offsets):
|
627
|
+
for target_dim in tuple(raw_offsets[source_dim]):
|
628
|
+
for unflattened_dim in tuple(raw_offsets[source_dim][target_dim]):
|
629
|
+
_add_unraveled_offsets(
|
630
|
+
raw_offsets[source_dim][target_dim][unflattened_dim],
|
631
|
+
source_dim,
|
632
|
+
target_dim,
|
633
|
+
unflattened_dim,
|
634
|
+
)
|
635
|
+
|
593
636
|
return offsets
|
594
637
|
|
595
638
|
@staticmethod
|
@@ -663,6 +706,211 @@ class Tritonizer(ast.NodeTransformer):
|
|
663
706
|
return node
|
664
707
|
|
665
708
|
|
709
|
+
class _Inliner(ast.NodeTransformer):
|
710
|
+
def __init__(self, globals, imports=[]):
|
711
|
+
self._globals = globals
|
712
|
+
|
713
|
+
self._count = 0
|
714
|
+
|
715
|
+
self.imports = imports
|
716
|
+
|
717
|
+
def visit_Expr(self, node):
|
718
|
+
value, stmts = self._inline_expr(node.value)
|
719
|
+
node.value = value
|
720
|
+
node = self.generic_visit(node)
|
721
|
+
|
722
|
+
if stmts:
|
723
|
+
if isinstance(value, ast.Constant) and value.value is None:
|
724
|
+
return stmts
|
725
|
+
|
726
|
+
return stmts + [node]
|
727
|
+
|
728
|
+
return node
|
729
|
+
|
730
|
+
def visit_Assign(self, node):
|
731
|
+
value, stmts = self._inline_expr(node.value)
|
732
|
+
node.value = value
|
733
|
+
node = self.generic_visit(node)
|
734
|
+
|
735
|
+
if stmts:
|
736
|
+
return stmts + [node]
|
737
|
+
|
738
|
+
return node
|
739
|
+
|
740
|
+
def visit_Return(self, node):
|
741
|
+
if node.value:
|
742
|
+
value, stmts = self._inline_expr(node.value)
|
743
|
+
node.value = value
|
744
|
+
|
745
|
+
if stmts:
|
746
|
+
return stmts + [node]
|
747
|
+
|
748
|
+
return node
|
749
|
+
|
750
|
+
def _inline_expr(self, expr):
|
751
|
+
def _inline_list(lst):
|
752
|
+
new_list = []
|
753
|
+
new_stmts = []
|
754
|
+
|
755
|
+
for expr in lst:
|
756
|
+
expr, stmts = self._inline_expr(expr)
|
757
|
+
|
758
|
+
new_list.append(expr)
|
759
|
+
new_stmts.extend(stmts)
|
760
|
+
|
761
|
+
return new_list, new_stmts
|
762
|
+
|
763
|
+
def _inline_field(field):
|
764
|
+
if isinstance(field, ast.AST):
|
765
|
+
return self._inline_expr(field)
|
766
|
+
|
767
|
+
return field, []
|
768
|
+
|
769
|
+
if isinstance(expr, ast.Call):
|
770
|
+
new_expr, new_stmts = self._inline_call(expr)
|
771
|
+
|
772
|
+
if new_expr is not None:
|
773
|
+
return new_expr, new_stmts
|
774
|
+
|
775
|
+
new_stmts = []
|
776
|
+
|
777
|
+
for field, value in ast.iter_fields(expr):
|
778
|
+
if isinstance(value, list):
|
779
|
+
new_value, new_stmts = _inline_list(value)
|
780
|
+
else:
|
781
|
+
new_value, new_stmts = _inline_field(value)
|
782
|
+
|
783
|
+
setattr(expr, field, new_value)
|
784
|
+
new_stmts.extend(new_stmts)
|
785
|
+
|
786
|
+
return expr, new_stmts
|
787
|
+
|
788
|
+
def _inline_call(self, node):
|
789
|
+
class _ParameterReplacer(ast.NodeTransformer):
|
790
|
+
def __init__(self, mapping):
|
791
|
+
self._mapping = mapping
|
792
|
+
|
793
|
+
def visit_Name(self, node):
|
794
|
+
return self._mapping.get(node.id, node)
|
795
|
+
|
796
|
+
class _LocalVariableRenamer(ast.NodeTransformer):
|
797
|
+
def __init__(self, prefix, local_vars):
|
798
|
+
self._prefix = prefix
|
799
|
+
|
800
|
+
self._local_vars = local_vars
|
801
|
+
|
802
|
+
def visit_Name(self, node):
|
803
|
+
if node.id in self._local_vars:
|
804
|
+
node.id = f"{self._prefix}{node.id}"
|
805
|
+
|
806
|
+
return node
|
807
|
+
|
808
|
+
def visit_arg(self, node):
|
809
|
+
return node
|
810
|
+
|
811
|
+
def _resolve_function(node, globals):
|
812
|
+
if isinstance(node, ast.Name):
|
813
|
+
return globals.get(node.id)
|
814
|
+
|
815
|
+
if isinstance(node, ast.Attribute):
|
816
|
+
obj = _resolve_function(node.value, globals)
|
817
|
+
|
818
|
+
if obj is not None:
|
819
|
+
return getattr(obj, node.attr, None)
|
820
|
+
|
821
|
+
return None
|
822
|
+
|
823
|
+
def _get_source(func):
|
824
|
+
try:
|
825
|
+
return inspect.getsource(func)
|
826
|
+
except TypeError:
|
827
|
+
return None
|
828
|
+
|
829
|
+
def _find_function_definition(source):
|
830
|
+
finder = _FunctionDefFinder(func.__name__)
|
831
|
+
finder.visit(ast.parse(source))
|
832
|
+
|
833
|
+
return finder.result
|
834
|
+
|
835
|
+
def _find_assigned_names(stmts):
|
836
|
+
class _AssignedNameFinder(ast.NodeVisitor):
|
837
|
+
def __init__(self):
|
838
|
+
self.result = set()
|
839
|
+
|
840
|
+
def visit_Name(self, node):
|
841
|
+
if isinstance(node.ctx, ast.Store):
|
842
|
+
self.result.add(node.id)
|
843
|
+
|
844
|
+
names = set()
|
845
|
+
|
846
|
+
for stmt in stmts:
|
847
|
+
finder = _AssignedNameFinder()
|
848
|
+
finder.visit(stmt)
|
849
|
+
names |= finder.result
|
850
|
+
|
851
|
+
return names
|
852
|
+
|
853
|
+
def _make_temporary():
|
854
|
+
prefix = naming.auto_generate(f"temporary_{self._count}")
|
855
|
+
self._count += 1
|
856
|
+
|
857
|
+
return prefix
|
858
|
+
|
859
|
+
func = _resolve_function(node.func, self._globals)
|
860
|
+
|
861
|
+
if func is None:
|
862
|
+
return None, []
|
863
|
+
|
864
|
+
source = _get_source(func)
|
865
|
+
|
866
|
+
if source is None:
|
867
|
+
return None, []
|
868
|
+
|
869
|
+
func_def = _find_function_definition(source)
|
870
|
+
|
871
|
+
if func_def is None:
|
872
|
+
return None, []
|
873
|
+
|
874
|
+
collector = _ImportCollector()
|
875
|
+
collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
|
876
|
+
self.imports.extend(collector.imports)
|
877
|
+
|
878
|
+
param_names = [arg.arg for arg in func_def.args.args]
|
879
|
+
|
880
|
+
mapping = {param: arg for param, arg in zip(param_names, node.args)}
|
881
|
+
param_replacer = _ParameterReplacer(mapping)
|
882
|
+
body = [param_replacer.visit(stmt) for stmt in func_def.body]
|
883
|
+
|
884
|
+
local_vars = _find_assigned_names(body) - set(param_names)
|
885
|
+
prefix = _make_temporary()
|
886
|
+
local_var_renamer = _LocalVariableRenamer(prefix, local_vars)
|
887
|
+
body = [local_var_renamer.visit(stmt) for stmt in body]
|
888
|
+
|
889
|
+
inlined_body = []
|
890
|
+
|
891
|
+
inliner = _Inliner(func.__globals__)
|
892
|
+
|
893
|
+
for stmt in body:
|
894
|
+
inlined_stmt = inliner.visit(stmt)
|
895
|
+
|
896
|
+
if isinstance(inlined_stmt, list):
|
897
|
+
inlined_body.extend(inlined_stmt)
|
898
|
+
else:
|
899
|
+
inlined_body.append(inlined_stmt)
|
900
|
+
|
901
|
+
if not inlined_body or not isinstance(inlined_body[-1], ast.Return):
|
902
|
+
return ast.Constant(value=None), inlined_body
|
903
|
+
|
904
|
+
ret = inlined_body.pop()
|
905
|
+
temp = _make_temporary()
|
906
|
+
assignment = ast.Assign(
|
907
|
+
targets=[ast.Name(id=temp, ctx=ast.Store())], value=ret.value
|
908
|
+
)
|
909
|
+
inlined_body.append(assignment)
|
910
|
+
|
911
|
+
return ast.Name(id=temp, ctx=ast.Load()), inlined_body
|
912
|
+
|
913
|
+
|
666
914
|
class _BinOpSimplifier(ast.NodeTransformer):
|
667
915
|
def visit_BinOp(self, node):
|
668
916
|
self.generic_visit(node)
|
ninetoothed/symbol.py
CHANGED
@@ -140,7 +140,12 @@ class Symbol:
|
|
140
140
|
return ast.unparse(self._node)
|
141
141
|
|
142
142
|
def find_and_replace(self, target, replacement):
|
143
|
-
|
143
|
+
if isinstance(target, tuple):
|
144
|
+
targets = tuple(item.node for item in target)
|
145
|
+
else:
|
146
|
+
targets = (target.node,)
|
147
|
+
|
148
|
+
return Symbol(_FindAndReplacer(targets, replacement.node).visit(self._node))
|
144
149
|
|
145
150
|
def names(self):
|
146
151
|
class NameCollector(ast.NodeVisitor):
|
@@ -175,12 +180,14 @@ class Symbol:
|
|
175
180
|
|
176
181
|
|
177
182
|
class _FindAndReplacer(ast.NodeTransformer):
|
178
|
-
def __init__(self,
|
179
|
-
self.
|
183
|
+
def __init__(self, targets, replacement):
|
184
|
+
self._targets_unparsed = tuple(
|
185
|
+
sorted({ast.unparse(target) for target in targets}, key=len, reverse=True)
|
186
|
+
)
|
180
187
|
self._replacement = replacement
|
181
188
|
|
182
|
-
def
|
183
|
-
if node
|
189
|
+
def visit(self, node):
|
190
|
+
if ast.unparse(node) in self._targets_unparsed:
|
184
191
|
return self._replacement
|
185
192
|
|
186
|
-
return
|
193
|
+
return super().visit(node)
|
ninetoothed/tensor.py
CHANGED
@@ -37,6 +37,8 @@ class Tensor:
|
|
37
37
|
source_dims=None,
|
38
38
|
target=None,
|
39
39
|
target_dims=None,
|
40
|
+
unflattened=None,
|
41
|
+
unflattened_dims=None,
|
40
42
|
):
|
41
43
|
self.dtype = dtype
|
42
44
|
|
@@ -81,14 +83,26 @@ class Tensor:
|
|
81
83
|
else:
|
82
84
|
self.target_dims = (dim for dim in range(self.target.ndim))
|
83
85
|
|
86
|
+
if unflattened is not None:
|
87
|
+
self.unflattened = unflattened
|
88
|
+
else:
|
89
|
+
self.unflattened = self
|
90
|
+
|
91
|
+
if unflattened_dims is not None:
|
92
|
+
self.unflattened_dims = unflattened_dims
|
93
|
+
else:
|
94
|
+
self.unflattened_dims = (dim for dim in range(self.unflattened.ndim))
|
95
|
+
|
84
96
|
type(self).num_instances += 1
|
85
97
|
|
86
|
-
def tile(self, tile_shape, strides=None, dilation=None):
|
98
|
+
def tile(self, tile_shape, strides=None, dilation=None, floor_mode=False):
|
87
99
|
"""Tiles the tensor into a hierarchical tensor.
|
88
100
|
|
89
101
|
:param tile_shape: The shape of a tile.
|
90
102
|
:param strides: The interval at which each tile is generated.
|
91
103
|
:param dilation: The spacing between tiles.
|
104
|
+
:param floor_mode: If ``True``, will use floor division to
|
105
|
+
compute the outer shape.
|
92
106
|
:return: A hierarchical tensor.
|
93
107
|
"""
|
94
108
|
|
@@ -112,11 +126,21 @@ class Tensor:
|
|
112
126
|
if stride == -1:
|
113
127
|
stride = tile_size
|
114
128
|
|
115
|
-
def
|
129
|
+
def _div(x, y, floor_mode=False):
|
130
|
+
if floor_mode:
|
131
|
+
return x // y
|
132
|
+
|
116
133
|
return (x + y - 1) // y
|
117
134
|
|
118
135
|
new_size = (
|
119
|
-
(
|
136
|
+
(
|
137
|
+
_div(
|
138
|
+
self_size - spacing * (tile_size - 1) - 1,
|
139
|
+
stride,
|
140
|
+
floor_mode=floor_mode,
|
141
|
+
)
|
142
|
+
+ 1
|
143
|
+
)
|
120
144
|
if stride != 0
|
121
145
|
else -1
|
122
146
|
)
|
@@ -137,10 +161,14 @@ class Tensor:
|
|
137
161
|
strides=inner_strides,
|
138
162
|
source=self.source,
|
139
163
|
source_dims=self.source_dims,
|
164
|
+
unflattened=self.unflattened,
|
165
|
+
unflattened_dims=self.unflattened_dims,
|
140
166
|
),
|
141
167
|
strides=outer_strides,
|
142
168
|
source=self.source,
|
143
169
|
source_dims=self.source_dims,
|
170
|
+
unflattened=self.unflattened,
|
171
|
+
unflattened_dims=self.unflattened_dims,
|
144
172
|
)
|
145
173
|
|
146
174
|
def expand(self, shape):
|
@@ -164,6 +192,8 @@ class Tensor:
|
|
164
192
|
source=self.source,
|
165
193
|
source_dims=self.source_dims,
|
166
194
|
target_dims=self.target_dims,
|
195
|
+
unflattened=self.unflattened,
|
196
|
+
unflattened_dims=self.unflattened_dims,
|
167
197
|
)
|
168
198
|
|
169
199
|
def squeeze(self, dim):
|
@@ -192,6 +222,12 @@ class Tensor:
|
|
192
222
|
for i, target_dim in enumerate(self.target_dims)
|
193
223
|
if i not in dim
|
194
224
|
],
|
225
|
+
unflattened=self.unflattened,
|
226
|
+
unflattened_dims=[
|
227
|
+
unflattened_dim
|
228
|
+
for i, unflattened_dim in enumerate(self.unflattened_dims)
|
229
|
+
if i not in dim
|
230
|
+
],
|
195
231
|
)
|
196
232
|
|
197
233
|
def permute(self, dims):
|
@@ -205,11 +241,13 @@ class Tensor:
|
|
205
241
|
new_shape = [None for _ in range(self.ndim)]
|
206
242
|
new_strides = [None for _ in range(self.ndim)]
|
207
243
|
new_source_dims = [None for _ in range(self.ndim)]
|
244
|
+
new_unflattened_dims = [None for _ in range(self.ndim)]
|
208
245
|
|
209
246
|
for original_dim, permuted_dim in enumerate(dims):
|
210
247
|
new_shape[original_dim] = self.shape[permuted_dim]
|
211
248
|
new_strides[original_dim] = self.strides[permuted_dim]
|
212
249
|
new_source_dims[original_dim] = self.source_dims[permuted_dim]
|
250
|
+
new_unflattened_dims[original_dim] = self.unflattened_dims[permuted_dim]
|
213
251
|
|
214
252
|
return type(self)(
|
215
253
|
shape=new_shape,
|
@@ -218,6 +256,8 @@ class Tensor:
|
|
218
256
|
source=self.source,
|
219
257
|
source_dims=new_source_dims,
|
220
258
|
target_dims=self.target_dims,
|
259
|
+
unflattened=self.unflattened,
|
260
|
+
unflattened_dims=new_unflattened_dims,
|
221
261
|
)
|
222
262
|
|
223
263
|
def flatten(self, start_dim=None, end_dim=None):
|
@@ -265,6 +305,16 @@ class Tensor:
|
|
265
305
|
leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
|
266
306
|
)
|
267
307
|
|
308
|
+
leading_unflattened_dims = self.unflattened_dims[:start_dim]
|
309
|
+
flattening_unflattened_dims = self.unflattened_dims[start_dim:end_dim]
|
310
|
+
trailing_unflattened_dims = self.unflattened_dims[end_dim:]
|
311
|
+
|
312
|
+
new_unflattened_dims = (
|
313
|
+
leading_unflattened_dims
|
314
|
+
+ (flattening_unflattened_dims,)
|
315
|
+
+ trailing_unflattened_dims
|
316
|
+
)
|
317
|
+
|
268
318
|
return type(self)(
|
269
319
|
shape=new_shape,
|
270
320
|
dtype=self.dtype,
|
@@ -272,6 +322,8 @@ class Tensor:
|
|
272
322
|
source=self.source,
|
273
323
|
source_dims=new_source_dims,
|
274
324
|
target_dims=new_target_dims,
|
325
|
+
unflattened=self.unflattened,
|
326
|
+
unflattened_dims=new_unflattened_dims,
|
275
327
|
)
|
276
328
|
|
277
329
|
def ravel(self):
|
@@ -290,12 +342,14 @@ class Tensor:
|
|
290
342
|
# TODO: Add error handling.
|
291
343
|
new_shape = []
|
292
344
|
new_strides = []
|
345
|
+
new_source_dims = []
|
293
346
|
|
294
347
|
curr = self
|
295
348
|
|
296
349
|
while isinstance(curr, type(self)):
|
297
350
|
new_shape.extend(curr.shape)
|
298
351
|
new_strides.extend(curr.strides)
|
352
|
+
new_source_dims.extend(curr.source_dims)
|
299
353
|
|
300
354
|
curr = curr.dtype
|
301
355
|
|
@@ -304,6 +358,8 @@ class Tensor:
|
|
304
358
|
strides=new_strides,
|
305
359
|
other=self.source.other,
|
306
360
|
name=self.source.name,
|
361
|
+
source=self.source,
|
362
|
+
source_dims=new_source_dims,
|
307
363
|
)
|
308
364
|
|
309
365
|
def names(self):
|
@@ -385,6 +441,14 @@ class Tensor:
|
|
385
441
|
def target_dims(self, value):
|
386
442
|
self._target_dims = tuple(value)
|
387
443
|
|
444
|
+
@property
|
445
|
+
def unflattened_dims(self):
|
446
|
+
return self._unflattened_dims
|
447
|
+
|
448
|
+
@unflattened_dims.setter
|
449
|
+
def unflattened_dims(self, value):
|
450
|
+
self._unflattened_dims = tuple(value)
|
451
|
+
|
388
452
|
@staticmethod
|
389
453
|
def pointer_pattern():
|
390
454
|
return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
|
ninetoothed/visualization.py
CHANGED
@@ -4,6 +4,12 @@ from mpl_toolkits.axes_grid1 import Divider, Size
|
|
4
4
|
|
5
5
|
|
6
6
|
def visualize(tensor, color=None, save_path=None):
|
7
|
+
"""Visualize a tensor as a structured grid representation.
|
8
|
+
|
9
|
+
:param tensor: The tensor to be visualized.
|
10
|
+
:param color: The color to be used for visualization.
|
11
|
+
:param save_path: The path where the visualization should be saved.
|
12
|
+
"""
|
7
13
|
outline_width = 0.1
|
8
14
|
plt.rcParams["lines.linewidth"] = 72 * outline_width
|
9
15
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.14.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -11,6 +11,9 @@ Classifier: Operating System :: OS Independent
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.10
|
13
13
|
Requires-Dist: triton>=3.0.0
|
14
|
+
Provides-Extra: all
|
15
|
+
Requires-Dist: matplotlib>=3.9.0; extra == 'all'
|
16
|
+
Requires-Dist: numpy>=2.1.0; extra == 'all'
|
14
17
|
Provides-Extra: visualization
|
15
18
|
Requires-Dist: matplotlib>=3.9.0; extra == 'visualization'
|
16
19
|
Requires-Dist: numpy>=2.1.0; extra == 'visualization'
|
@@ -0,0 +1,12 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=wsj9RmaQ1dKPJXtLL4zcWpinwh-7km7-wA7pexy7vq4,31675
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
|
6
|
+
ninetoothed/tensor.py,sha256=W1XY8_vaYmszX4lIWuas-ZKGbbdEZU7Z5h1A4FBXDXg,14358
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
|
9
|
+
ninetoothed-0.14.0.dist-info/METADATA,sha256=XAI4ibJdwJDCPCpoOeFOT0XWEDQGUZi6OvJwyjgcEC8,7311
|
10
|
+
ninetoothed-0.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
+
ninetoothed-0.14.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
+
ninetoothed-0.14.0.dist-info/RECORD,,
|
@@ -1,12 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=U3Nen5vyx69ulW7_hnRuATW86Ag9NgVgd3U02NVB20c,24430
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
-
ninetoothed/symbol.py,sha256=mN96tp-2eUxbiNfxuxtKWNSxOSdYqlcmpY2MYQ-FiEg,4993
|
6
|
-
ninetoothed/tensor.py,sha256=E63sq3jh7ZLiLwFYTtavztKEZx7kRX-UVa2ZXSP2X0s,12008
|
7
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
-
ninetoothed/visualization.py,sha256=VPPh__Bral_Z9hKj9D4UOo8HvRFQidCWSe9cS-D5QfY,3351
|
9
|
-
ninetoothed-0.12.0.dist-info/METADATA,sha256=gb5zeAxwYQRm-dFNmdX3uDIn_U-abEZ-VU9bSBBxioA,7198
|
10
|
-
ninetoothed-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
-
ninetoothed-0.12.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
-
ninetoothed-0.12.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|