ninetoothed 0.13.0__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +212 -2
- ninetoothed/tensor.py +15 -3
- {ninetoothed-0.13.0.dist-info → ninetoothed-0.14.0.dist-info}/METADATA +1 -1
- {ninetoothed-0.13.0.dist-info → ninetoothed-0.14.0.dist-info}/RECORD +6 -6
- {ninetoothed-0.13.0.dist-info → ninetoothed-0.14.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.13.0.dist-info → ninetoothed-0.14.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -106,13 +106,18 @@ class JIT:
|
|
106
106
|
def _get_tree(self):
|
107
107
|
module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
|
108
108
|
|
109
|
-
_AliasRestorer().visit(module)
|
110
109
|
collector = _ImportCollector()
|
111
110
|
collector.visit(module)
|
111
|
+
|
112
112
|
finder = _FunctionDefFinder(self.func.__name__)
|
113
113
|
finder.visit(module)
|
114
|
+
func_def = finder.result
|
115
|
+
|
116
|
+
inliner = _Inliner(self.func.__globals__)
|
117
|
+
inliner.visit(func_def)
|
118
|
+
module.body = collector.imports + inliner.imports + [finder.result]
|
114
119
|
|
115
|
-
return
|
120
|
+
return _AliasRestorer().visit(module)
|
116
121
|
|
117
122
|
def _find_dependencies(self):
|
118
123
|
dependencies = set()
|
@@ -701,6 +706,211 @@ class Tritonizer(ast.NodeTransformer):
|
|
701
706
|
return node
|
702
707
|
|
703
708
|
|
709
|
+
class _Inliner(ast.NodeTransformer):
|
710
|
+
def __init__(self, globals, imports=[]):
|
711
|
+
self._globals = globals
|
712
|
+
|
713
|
+
self._count = 0
|
714
|
+
|
715
|
+
self.imports = imports
|
716
|
+
|
717
|
+
def visit_Expr(self, node):
|
718
|
+
value, stmts = self._inline_expr(node.value)
|
719
|
+
node.value = value
|
720
|
+
node = self.generic_visit(node)
|
721
|
+
|
722
|
+
if stmts:
|
723
|
+
if isinstance(value, ast.Constant) and value.value is None:
|
724
|
+
return stmts
|
725
|
+
|
726
|
+
return stmts + [node]
|
727
|
+
|
728
|
+
return node
|
729
|
+
|
730
|
+
def visit_Assign(self, node):
|
731
|
+
value, stmts = self._inline_expr(node.value)
|
732
|
+
node.value = value
|
733
|
+
node = self.generic_visit(node)
|
734
|
+
|
735
|
+
if stmts:
|
736
|
+
return stmts + [node]
|
737
|
+
|
738
|
+
return node
|
739
|
+
|
740
|
+
def visit_Return(self, node):
|
741
|
+
if node.value:
|
742
|
+
value, stmts = self._inline_expr(node.value)
|
743
|
+
node.value = value
|
744
|
+
|
745
|
+
if stmts:
|
746
|
+
return stmts + [node]
|
747
|
+
|
748
|
+
return node
|
749
|
+
|
750
|
+
def _inline_expr(self, expr):
|
751
|
+
def _inline_list(lst):
|
752
|
+
new_list = []
|
753
|
+
new_stmts = []
|
754
|
+
|
755
|
+
for expr in lst:
|
756
|
+
expr, stmts = self._inline_expr(expr)
|
757
|
+
|
758
|
+
new_list.append(expr)
|
759
|
+
new_stmts.extend(stmts)
|
760
|
+
|
761
|
+
return new_list, new_stmts
|
762
|
+
|
763
|
+
def _inline_field(field):
|
764
|
+
if isinstance(field, ast.AST):
|
765
|
+
return self._inline_expr(field)
|
766
|
+
|
767
|
+
return field, []
|
768
|
+
|
769
|
+
if isinstance(expr, ast.Call):
|
770
|
+
new_expr, new_stmts = self._inline_call(expr)
|
771
|
+
|
772
|
+
if new_expr is not None:
|
773
|
+
return new_expr, new_stmts
|
774
|
+
|
775
|
+
new_stmts = []
|
776
|
+
|
777
|
+
for field, value in ast.iter_fields(expr):
|
778
|
+
if isinstance(value, list):
|
779
|
+
new_value, new_stmts = _inline_list(value)
|
780
|
+
else:
|
781
|
+
new_value, new_stmts = _inline_field(value)
|
782
|
+
|
783
|
+
setattr(expr, field, new_value)
|
784
|
+
new_stmts.extend(new_stmts)
|
785
|
+
|
786
|
+
return expr, new_stmts
|
787
|
+
|
788
|
+
def _inline_call(self, node):
|
789
|
+
class _ParameterReplacer(ast.NodeTransformer):
|
790
|
+
def __init__(self, mapping):
|
791
|
+
self._mapping = mapping
|
792
|
+
|
793
|
+
def visit_Name(self, node):
|
794
|
+
return self._mapping.get(node.id, node)
|
795
|
+
|
796
|
+
class _LocalVariableRenamer(ast.NodeTransformer):
|
797
|
+
def __init__(self, prefix, local_vars):
|
798
|
+
self._prefix = prefix
|
799
|
+
|
800
|
+
self._local_vars = local_vars
|
801
|
+
|
802
|
+
def visit_Name(self, node):
|
803
|
+
if node.id in self._local_vars:
|
804
|
+
node.id = f"{self._prefix}{node.id}"
|
805
|
+
|
806
|
+
return node
|
807
|
+
|
808
|
+
def visit_arg(self, node):
|
809
|
+
return node
|
810
|
+
|
811
|
+
def _resolve_function(node, globals):
|
812
|
+
if isinstance(node, ast.Name):
|
813
|
+
return globals.get(node.id)
|
814
|
+
|
815
|
+
if isinstance(node, ast.Attribute):
|
816
|
+
obj = _resolve_function(node.value, globals)
|
817
|
+
|
818
|
+
if obj is not None:
|
819
|
+
return getattr(obj, node.attr, None)
|
820
|
+
|
821
|
+
return None
|
822
|
+
|
823
|
+
def _get_source(func):
|
824
|
+
try:
|
825
|
+
return inspect.getsource(func)
|
826
|
+
except TypeError:
|
827
|
+
return None
|
828
|
+
|
829
|
+
def _find_function_definition(source):
|
830
|
+
finder = _FunctionDefFinder(func.__name__)
|
831
|
+
finder.visit(ast.parse(source))
|
832
|
+
|
833
|
+
return finder.result
|
834
|
+
|
835
|
+
def _find_assigned_names(stmts):
|
836
|
+
class _AssignedNameFinder(ast.NodeVisitor):
|
837
|
+
def __init__(self):
|
838
|
+
self.result = set()
|
839
|
+
|
840
|
+
def visit_Name(self, node):
|
841
|
+
if isinstance(node.ctx, ast.Store):
|
842
|
+
self.result.add(node.id)
|
843
|
+
|
844
|
+
names = set()
|
845
|
+
|
846
|
+
for stmt in stmts:
|
847
|
+
finder = _AssignedNameFinder()
|
848
|
+
finder.visit(stmt)
|
849
|
+
names |= finder.result
|
850
|
+
|
851
|
+
return names
|
852
|
+
|
853
|
+
def _make_temporary():
|
854
|
+
prefix = naming.auto_generate(f"temporary_{self._count}")
|
855
|
+
self._count += 1
|
856
|
+
|
857
|
+
return prefix
|
858
|
+
|
859
|
+
func = _resolve_function(node.func, self._globals)
|
860
|
+
|
861
|
+
if func is None:
|
862
|
+
return None, []
|
863
|
+
|
864
|
+
source = _get_source(func)
|
865
|
+
|
866
|
+
if source is None:
|
867
|
+
return None, []
|
868
|
+
|
869
|
+
func_def = _find_function_definition(source)
|
870
|
+
|
871
|
+
if func_def is None:
|
872
|
+
return None, []
|
873
|
+
|
874
|
+
collector = _ImportCollector()
|
875
|
+
collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
|
876
|
+
self.imports.extend(collector.imports)
|
877
|
+
|
878
|
+
param_names = [arg.arg for arg in func_def.args.args]
|
879
|
+
|
880
|
+
mapping = {param: arg for param, arg in zip(param_names, node.args)}
|
881
|
+
param_replacer = _ParameterReplacer(mapping)
|
882
|
+
body = [param_replacer.visit(stmt) for stmt in func_def.body]
|
883
|
+
|
884
|
+
local_vars = _find_assigned_names(body) - set(param_names)
|
885
|
+
prefix = _make_temporary()
|
886
|
+
local_var_renamer = _LocalVariableRenamer(prefix, local_vars)
|
887
|
+
body = [local_var_renamer.visit(stmt) for stmt in body]
|
888
|
+
|
889
|
+
inlined_body = []
|
890
|
+
|
891
|
+
inliner = _Inliner(func.__globals__)
|
892
|
+
|
893
|
+
for stmt in body:
|
894
|
+
inlined_stmt = inliner.visit(stmt)
|
895
|
+
|
896
|
+
if isinstance(inlined_stmt, list):
|
897
|
+
inlined_body.extend(inlined_stmt)
|
898
|
+
else:
|
899
|
+
inlined_body.append(inlined_stmt)
|
900
|
+
|
901
|
+
if not inlined_body or not isinstance(inlined_body[-1], ast.Return):
|
902
|
+
return ast.Constant(value=None), inlined_body
|
903
|
+
|
904
|
+
ret = inlined_body.pop()
|
905
|
+
temp = _make_temporary()
|
906
|
+
assignment = ast.Assign(
|
907
|
+
targets=[ast.Name(id=temp, ctx=ast.Store())], value=ret.value
|
908
|
+
)
|
909
|
+
inlined_body.append(assignment)
|
910
|
+
|
911
|
+
return ast.Name(id=temp, ctx=ast.Load()), inlined_body
|
912
|
+
|
913
|
+
|
704
914
|
class _BinOpSimplifier(ast.NodeTransformer):
|
705
915
|
def visit_BinOp(self, node):
|
706
916
|
self.generic_visit(node)
|
ninetoothed/tensor.py
CHANGED
@@ -95,12 +95,14 @@ class Tensor:
|
|
95
95
|
|
96
96
|
type(self).num_instances += 1
|
97
97
|
|
98
|
-
def tile(self, tile_shape, strides=None, dilation=None):
|
98
|
+
def tile(self, tile_shape, strides=None, dilation=None, floor_mode=False):
|
99
99
|
"""Tiles the tensor into a hierarchical tensor.
|
100
100
|
|
101
101
|
:param tile_shape: The shape of a tile.
|
102
102
|
:param strides: The interval at which each tile is generated.
|
103
103
|
:param dilation: The spacing between tiles.
|
104
|
+
:param floor_mode: If ``True``, will use floor division to
|
105
|
+
compute the outer shape.
|
104
106
|
:return: A hierarchical tensor.
|
105
107
|
"""
|
106
108
|
|
@@ -124,11 +126,21 @@ class Tensor:
|
|
124
126
|
if stride == -1:
|
125
127
|
stride = tile_size
|
126
128
|
|
127
|
-
def
|
129
|
+
def _div(x, y, floor_mode=False):
|
130
|
+
if floor_mode:
|
131
|
+
return x // y
|
132
|
+
|
128
133
|
return (x + y - 1) // y
|
129
134
|
|
130
135
|
new_size = (
|
131
|
-
(
|
136
|
+
(
|
137
|
+
_div(
|
138
|
+
self_size - spacing * (tile_size - 1) - 1,
|
139
|
+
stride,
|
140
|
+
floor_mode=floor_mode,
|
141
|
+
)
|
142
|
+
+ 1
|
143
|
+
)
|
132
144
|
if stride != 0
|
133
145
|
else -1
|
134
146
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.14.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -1,12 +1,12 @@
|
|
1
1
|
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=
|
2
|
+
ninetoothed/jit.py,sha256=wsj9RmaQ1dKPJXtLL4zcWpinwh-7km7-wA7pexy7vq4,31675
|
3
3
|
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
4
|
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
5
|
ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
|
6
|
-
ninetoothed/tensor.py,sha256=
|
6
|
+
ninetoothed/tensor.py,sha256=W1XY8_vaYmszX4lIWuas-ZKGbbdEZU7Z5h1A4FBXDXg,14358
|
7
7
|
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
8
|
ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
|
9
|
-
ninetoothed-0.
|
10
|
-
ninetoothed-0.
|
11
|
-
ninetoothed-0.
|
12
|
-
ninetoothed-0.
|
9
|
+
ninetoothed-0.14.0.dist-info/METADATA,sha256=XAI4ibJdwJDCPCpoOeFOT0XWEDQGUZi6OvJwyjgcEC8,7311
|
10
|
+
ninetoothed-0.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
11
|
+
ninetoothed-0.14.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
12
|
+
ninetoothed-0.14.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|