ninetoothed 0.13.0__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -106,13 +106,18 @@ class JIT:
106
106
  def _get_tree(self):
107
107
  module = ast.parse(inspect.getsource(inspect.getmodule(self.func)))
108
108
 
109
- _AliasRestorer().visit(module)
110
109
  collector = _ImportCollector()
111
110
  collector.visit(module)
111
+
112
112
  finder = _FunctionDefFinder(self.func.__name__)
113
113
  finder.visit(module)
114
+ func_def = finder.result
115
+
116
+ inliner = _Inliner(self.func.__globals__)
117
+ inliner.visit(func_def)
118
+ module.body = collector.imports + inliner.imports + [finder.result]
114
119
 
115
- return ast.Module(body=collector.imports + [finder.result], type_ignores=[])
120
+ return _AliasRestorer().visit(module)
116
121
 
117
122
  def _find_dependencies(self):
118
123
  dependencies = set()
@@ -701,6 +706,211 @@ class Tritonizer(ast.NodeTransformer):
701
706
  return node
702
707
 
703
708
 
709
+ class _Inliner(ast.NodeTransformer):
710
+ def __init__(self, globals, imports=[]):
711
+ self._globals = globals
712
+
713
+ self._count = 0
714
+
715
+ self.imports = imports
716
+
717
+ def visit_Expr(self, node):
718
+ value, stmts = self._inline_expr(node.value)
719
+ node.value = value
720
+ node = self.generic_visit(node)
721
+
722
+ if stmts:
723
+ if isinstance(value, ast.Constant) and value.value is None:
724
+ return stmts
725
+
726
+ return stmts + [node]
727
+
728
+ return node
729
+
730
+ def visit_Assign(self, node):
731
+ value, stmts = self._inline_expr(node.value)
732
+ node.value = value
733
+ node = self.generic_visit(node)
734
+
735
+ if stmts:
736
+ return stmts + [node]
737
+
738
+ return node
739
+
740
+ def visit_Return(self, node):
741
+ if node.value:
742
+ value, stmts = self._inline_expr(node.value)
743
+ node.value = value
744
+
745
+ if stmts:
746
+ return stmts + [node]
747
+
748
+ return node
749
+
750
+ def _inline_expr(self, expr):
751
+ def _inline_list(lst):
752
+ new_list = []
753
+ new_stmts = []
754
+
755
+ for expr in lst:
756
+ expr, stmts = self._inline_expr(expr)
757
+
758
+ new_list.append(expr)
759
+ new_stmts.extend(stmts)
760
+
761
+ return new_list, new_stmts
762
+
763
+ def _inline_field(field):
764
+ if isinstance(field, ast.AST):
765
+ return self._inline_expr(field)
766
+
767
+ return field, []
768
+
769
+ if isinstance(expr, ast.Call):
770
+ new_expr, new_stmts = self._inline_call(expr)
771
+
772
+ if new_expr is not None:
773
+ return new_expr, new_stmts
774
+
775
+ new_stmts = []
776
+
777
+ for field, value in ast.iter_fields(expr):
778
+ if isinstance(value, list):
779
+ new_value, new_stmts = _inline_list(value)
780
+ else:
781
+ new_value, new_stmts = _inline_field(value)
782
+
783
+ setattr(expr, field, new_value)
784
+ new_stmts.extend(new_stmts)
785
+
786
+ return expr, new_stmts
787
+
788
+ def _inline_call(self, node):
789
+ class _ParameterReplacer(ast.NodeTransformer):
790
+ def __init__(self, mapping):
791
+ self._mapping = mapping
792
+
793
+ def visit_Name(self, node):
794
+ return self._mapping.get(node.id, node)
795
+
796
+ class _LocalVariableRenamer(ast.NodeTransformer):
797
+ def __init__(self, prefix, local_vars):
798
+ self._prefix = prefix
799
+
800
+ self._local_vars = local_vars
801
+
802
+ def visit_Name(self, node):
803
+ if node.id in self._local_vars:
804
+ node.id = f"{self._prefix}{node.id}"
805
+
806
+ return node
807
+
808
+ def visit_arg(self, node):
809
+ return node
810
+
811
+ def _resolve_function(node, globals):
812
+ if isinstance(node, ast.Name):
813
+ return globals.get(node.id)
814
+
815
+ if isinstance(node, ast.Attribute):
816
+ obj = _resolve_function(node.value, globals)
817
+
818
+ if obj is not None:
819
+ return getattr(obj, node.attr, None)
820
+
821
+ return None
822
+
823
+ def _get_source(func):
824
+ try:
825
+ return inspect.getsource(func)
826
+ except TypeError:
827
+ return None
828
+
829
+ def _find_function_definition(source):
830
+ finder = _FunctionDefFinder(func.__name__)
831
+ finder.visit(ast.parse(source))
832
+
833
+ return finder.result
834
+
835
+ def _find_assigned_names(stmts):
836
+ class _AssignedNameFinder(ast.NodeVisitor):
837
+ def __init__(self):
838
+ self.result = set()
839
+
840
+ def visit_Name(self, node):
841
+ if isinstance(node.ctx, ast.Store):
842
+ self.result.add(node.id)
843
+
844
+ names = set()
845
+
846
+ for stmt in stmts:
847
+ finder = _AssignedNameFinder()
848
+ finder.visit(stmt)
849
+ names |= finder.result
850
+
851
+ return names
852
+
853
+ def _make_temporary():
854
+ prefix = naming.auto_generate(f"temporary_{self._count}")
855
+ self._count += 1
856
+
857
+ return prefix
858
+
859
+ func = _resolve_function(node.func, self._globals)
860
+
861
+ if func is None:
862
+ return None, []
863
+
864
+ source = _get_source(func)
865
+
866
+ if source is None:
867
+ return None, []
868
+
869
+ func_def = _find_function_definition(source)
870
+
871
+ if func_def is None:
872
+ return None, []
873
+
874
+ collector = _ImportCollector()
875
+ collector.visit(ast.parse(inspect.getsource(inspect.getmodule(func))))
876
+ self.imports.extend(collector.imports)
877
+
878
+ param_names = [arg.arg for arg in func_def.args.args]
879
+
880
+ mapping = {param: arg for param, arg in zip(param_names, node.args)}
881
+ param_replacer = _ParameterReplacer(mapping)
882
+ body = [param_replacer.visit(stmt) for stmt in func_def.body]
883
+
884
+ local_vars = _find_assigned_names(body) - set(param_names)
885
+ prefix = _make_temporary()
886
+ local_var_renamer = _LocalVariableRenamer(prefix, local_vars)
887
+ body = [local_var_renamer.visit(stmt) for stmt in body]
888
+
889
+ inlined_body = []
890
+
891
+ inliner = _Inliner(func.__globals__)
892
+
893
+ for stmt in body:
894
+ inlined_stmt = inliner.visit(stmt)
895
+
896
+ if isinstance(inlined_stmt, list):
897
+ inlined_body.extend(inlined_stmt)
898
+ else:
899
+ inlined_body.append(inlined_stmt)
900
+
901
+ if not inlined_body or not isinstance(inlined_body[-1], ast.Return):
902
+ return ast.Constant(value=None), inlined_body
903
+
904
+ ret = inlined_body.pop()
905
+ temp = _make_temporary()
906
+ assignment = ast.Assign(
907
+ targets=[ast.Name(id=temp, ctx=ast.Store())], value=ret.value
908
+ )
909
+ inlined_body.append(assignment)
910
+
911
+ return ast.Name(id=temp, ctx=ast.Load()), inlined_body
912
+
913
+
704
914
  class _BinOpSimplifier(ast.NodeTransformer):
705
915
  def visit_BinOp(self, node):
706
916
  self.generic_visit(node)
ninetoothed/tensor.py CHANGED
@@ -95,12 +95,14 @@ class Tensor:
95
95
 
96
96
  type(self).num_instances += 1
97
97
 
98
- def tile(self, tile_shape, strides=None, dilation=None):
98
+ def tile(self, tile_shape, strides=None, dilation=None, floor_mode=False):
99
99
  """Tiles the tensor into a hierarchical tensor.
100
100
 
101
101
  :param tile_shape: The shape of a tile.
102
102
  :param strides: The interval at which each tile is generated.
103
103
  :param dilation: The spacing between tiles.
104
+ :param floor_mode: If ``True``, will use floor division to
105
+ compute the outer shape.
104
106
  :return: A hierarchical tensor.
105
107
  """
106
108
 
@@ -124,11 +126,21 @@ class Tensor:
124
126
  if stride == -1:
125
127
  stride = tile_size
126
128
 
127
- def cdiv(x, y):
129
+ def _div(x, y, floor_mode=False):
130
+ if floor_mode:
131
+ return x // y
132
+
128
133
  return (x + y - 1) // y
129
134
 
130
135
  new_size = (
131
- (cdiv(self_size - spacing * (tile_size - 1) - 1, stride) + 1)
136
+ (
137
+ _div(
138
+ self_size - spacing * (tile_size - 1) - 1,
139
+ stride,
140
+ floor_mode=floor_mode,
141
+ )
142
+ + 1
143
+ )
132
144
  if stride != 0
133
145
  else -1
134
146
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.13.0
3
+ Version: 0.14.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -1,12 +1,12 @@
1
1
  ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=fuMfb4gEvz78n-7--GS1Ud14g33SLT-vI1kqt9BPAIQ,25759
2
+ ninetoothed/jit.py,sha256=wsj9RmaQ1dKPJXtLL4zcWpinwh-7km7-wA7pexy7vq4,31675
3
3
  ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
4
  ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
5
  ninetoothed/symbol.py,sha256=UpGmx_jvaDtowADnp1DwYC3fvBXSiaMiYpU-ewkVo50,5261
6
- ninetoothed/tensor.py,sha256=Q7WPigNyGOgP5lYYac39pF6zlFbCyXYrICPgGOuuyr4,13976
6
+ ninetoothed/tensor.py,sha256=W1XY8_vaYmszX4lIWuas-ZKGbbdEZU7Z5h1A4FBXDXg,14358
7
7
  ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
8
  ninetoothed/visualization.py,sha256=IZ7iTT4dl5_JFbO-WfSWPFWpgkyPr4nylwhSZVy8gss,3601
9
- ninetoothed-0.13.0.dist-info/METADATA,sha256=Vp_huL1YVgNAa9q6C0J5hZnCR7SRyMFRALGD21ikrjc,7311
10
- ninetoothed-0.13.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
- ninetoothed-0.13.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
- ninetoothed-0.13.0.dist-info/RECORD,,
9
+ ninetoothed-0.14.0.dist-info/METADATA,sha256=XAI4ibJdwJDCPCpoOeFOT0XWEDQGUZi6OvJwyjgcEC8,7311
10
+ ninetoothed-0.14.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
+ ninetoothed-0.14.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
12
+ ninetoothed-0.14.0.dist-info/RECORD,,