tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,80 +1,62 @@
1
1
  from __future__ import annotations
2
- from abc import abstractmethod
3
2
  import functools
4
3
  from math import gcd
5
4
  from tinygrad.helpers import partition
6
- from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
5
+ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
7
6
 
8
7
  # NOTE: Python has different behavior for negative mod and floor div than c
9
8
  # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
10
9
 
11
- def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
12
-
13
10
  class Node:
14
11
  b: Union[Node, int]
15
12
  min: int
16
- max: int
17
- def render(self, ops=None, ctx=None, strip_parens=False) -> str:
13
+ max: sint
14
+ def render(self, ops=None, ctx=None) -> Any:
18
15
  if ops is None: ops = render_python
19
16
  assert self.__class__ in (Variable, NumNode) or self.min != self.max
20
- ret = ops[type(self)](self, ops, ctx)
21
- if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
22
- return ret
23
- def vars(self): return []
17
+ return ops[type(self)](self, ops, ctx)
18
+ def vars(self) -> Set[Variable]: return set()
19
+ # substitute Variables with the values in var_vals
20
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
21
+ def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
22
+
24
23
  @functools.cached_property
25
24
  def key(self) -> str: return self.render(ctx="DEBUG")
26
25
  @functools.cached_property
27
26
  def hash(self) -> int: return hash(self.key)
28
- def __repr__(self): return "<"+self.key+">"
27
+ def __repr__(self): return self.render(ctx="REPR")
28
+ def __str__(self): return "<"+self.key+">"
29
29
  def __hash__(self): return self.hash
30
30
  def __bool__(self): return not (self.max == self.min == 0)
31
31
  def __eq__(self, other:object) -> bool:
32
32
  if not isinstance(other, Node): return NotImplemented
33
33
  return self.key == other.key
34
34
  def __neg__(self): return self*-1
35
- def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
35
+ def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
36
36
  def __radd__(self, b:int): return self+b
37
37
  def __sub__(self, b:Union[Node,int]): return self+-b
38
38
  def __rsub__(self, b:int): return -self+b
39
39
  def __le__(self, b:Union[Node,int]): return self < (b+1)
40
40
  def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
41
41
  def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
42
- def __lt__(self, b:Union[Node,int]):
43
- lhs = self
44
- if isinstance(lhs, SumNode) and isinstance(b, int):
45
- muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
46
- if muls:
47
- # NOTE: gcd in python 3.8 takes exactly 2 args
48
- mul_gcd = muls[0].b
49
- for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
50
- if b%mul_gcd == 0:
51
- all_others = Variable.sum(others)
52
- #print(mul_gcd, muls, all_others)
53
- if all_others.min >= 0 and all_others.max < mul_gcd:
54
- # TODO: should we divide both by mul_gcd here?
55
- lhs = Variable.sum(muls)
56
- return create_node(LtNode(lhs, b))
42
+ def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
57
43
  def __mul__(self, b:Union[Node, int]):
58
44
  if b == 0: return NumNode(0)
59
45
  if b == 1: return self
60
- if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
61
46
  return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
62
47
  def __rmul__(self, b:int): return self*b
63
48
 
64
49
  # *** complex ops ***
65
50
 
66
- def __rfloordiv__(self, b:int):
67
- if self.min > b >= 0: return NumNode(0)
68
- if isinstance(self, NumNode): return NumNode(b // self.b)
69
- raise RuntimeError(f"not supported: {b} // {self}")
51
+ def __rfloordiv__(self, b:int): return NumNode(b) // self
70
52
  def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
71
53
  if isinstance(b, Node):
72
- if b.__class__ is NumNode: return self // b.b
54
+ if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
73
55
  if self == b: return NumNode(1)
74
56
  if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
75
57
  raise RuntimeError(f"not supported: {self} // {b}")
76
58
  assert b != 0
77
- if b < 0: return (self//-b)*-1
59
+ if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
78
60
  if b == 1: return self
79
61
 
80
62
  # the numerator of div is not allowed to be negative
@@ -84,10 +66,7 @@ class Node:
84
66
  return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
85
67
  return create_node(DivNode(self, b))
86
68
 
87
- def __rmod__(self, b:int):
88
- if self.min > b >= 0: return NumNode(b)
89
- if isinstance(self, NumNode): return NumNode(b % self.b)
90
- raise RuntimeError(f"not supported: {b} % {self}")
69
+ def __rmod__(self, b:int): return NumNode(b) % self
91
70
  def __mod__(self, b:Union[Node,int]):
92
71
  if isinstance(b, Node):
93
72
  if b.__class__ is NumNode: return self % b.b
@@ -96,37 +75,27 @@ class Node:
96
75
  raise RuntimeError(f"not supported: {self} % {b}")
97
76
  assert b > 0
98
77
  if b == 1: return NumNode(0)
99
- if self.min >= 0 and self.max < b: return self
100
- if self.min < 0: return (self - ((self.min//b)*b)) % b
78
+ if isinstance(self.max, int) and isinstance(self.min, int):
79
+ if self.min >= 0 and self.max < b: return self
80
+ if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
81
+ if self.min < 0: return (self - ((self.min//b)*b)) % b
101
82
  return create_node(ModNode(self, b))
102
83
 
103
- @staticmethod
104
- def num(num:int) -> NumNode: return NumNode(num)
105
-
106
- @staticmethod
107
- def factorize(nodes:List[Node]) -> List[Node]:
108
- mul_groups: Dict[Node, int] = {}
109
- for x in nodes:
110
- a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
111
- mul_groups[a] = mul_groups.get(a, 0) + b
112
- return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
113
-
114
84
  @staticmethod
115
85
  def sum(nodes:List[Node]) -> Node:
116
86
  nodes = [x for x in nodes if x.max or x.min]
117
87
  if not nodes: return NumNode(0)
118
88
  if len(nodes) == 1: return nodes[0]
119
89
 
120
- new_nodes: List[Node] = []
90
+ mul_groups: Dict[Node, int] = {}
121
91
  num_node_sum = 0
122
92
  for node in SumNode(nodes).flat_components:
123
93
  if node.__class__ is NumNode: num_node_sum += node.b
124
- else: new_nodes.append(node)
125
-
126
- if len(new_nodes) > 1 and len(set([x.a if isinstance(x, MulNode) else x for x in new_nodes])) < len(new_nodes):
127
- new_nodes = Node.factorize(new_nodes)
94
+ elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
95
+ else: mul_groups[node] = mul_groups.get(node, 0) + 1
96
+ new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
128
97
  if num_node_sum: new_nodes.append(NumNode(num_node_sum))
129
- return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
98
+ return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
130
99
 
131
100
  @staticmethod
132
101
  def ands(nodes:List[Node]) -> Node:
@@ -136,48 +105,96 @@ class Node:
136
105
 
137
106
  # filter 1s
138
107
  nodes = [x for x in nodes if x.min != x.max]
139
- return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
108
+ return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
140
109
 
141
110
  # 4 basic node types
142
111
 
143
112
  class Variable(Node):
144
- def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
145
- assert nmin >= 0 and nmin <= nmax
113
+ def __new__(cls, *args):
114
+ expr, nmin, nmax = args
115
+ assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
146
116
  if nmin == nmax: return NumNode(nmin)
147
117
  return super().__new__(cls)
148
118
 
149
- def __init__(self, expr:Optional[str], nmin:int, nmax:int):
119
+ def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
120
+
121
+ def __init__(self, expr:str, nmin:int, nmax:sint):
150
122
  self.expr, self.min, self.max = expr, nmin, nmax
151
- def vars(self): return [self]
123
+ self._val: Optional[int] = None
124
+ @property
125
+ def val(self):
126
+ assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
127
+ return self._val
128
+ def bind(self, val):
129
+ assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
130
+ self._val = val
131
+ return self
132
+ def unbind(self) -> Tuple[Variable, int]:
133
+ assert self.val is not None, f"cannot unbind {self}"
134
+ return Variable(self.expr, self.min, self.max), self.val
135
+ def vars(self): return {self}
136
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
152
137
 
153
138
  class NumNode(Node):
154
139
  def __init__(self, num:int):
140
+ assert isinstance(num, int), f"{num} is not an int"
155
141
  self.b:int = num
156
142
  self.min, self.max = num, num
157
- def __int__(self): return self.b
158
- def __index__(self): return self.b
143
+ def bind(self, val):
144
+ assert self.b == val, f"cannot bind {val} to {self}"
145
+ return self
146
+ def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
159
147
  def __eq__(self, other): return self.b == other
160
- def __hash__(self): return self.hash # needed with __eq__ override
148
+ def __hash__(self): return hash(self.b) # needed with __eq__ override
149
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
161
150
 
162
151
  def create_node(ret:Node):
163
152
  assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
164
153
  if ret.min == ret.max: return NumNode(ret.min)
165
154
  return ret
166
155
 
156
+ def create_lt_node(lhs:Node, b:Union[Node, int]):
157
+ if isinstance(lhs, SumNode):
158
+ if isinstance(b, int):
159
+ new_sum = []
160
+ for x in lhs.nodes:
161
+ # TODO: should we just force the last one to always be the number
162
+ if isinstance(x, NumNode): b -= x.b
163
+ else: new_sum.append(x)
164
+ lhs = Node.sum(new_sum)
165
+ nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
166
+ assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
167
+ muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
168
+ if muls:
169
+ # NOTE: gcd in python 3.8 takes exactly 2 args
170
+ mul_gcd = b
171
+ for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
172
+ all_others = Node.sum(others)
173
+ if all_others.min >= 0 and all_others.max < mul_gcd:
174
+ lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
175
+ return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
176
+ if isinstance(lhs, MulNode):
177
+ if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
178
+ sgn = 1 if lhs.b > 0 else -1
179
+ return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
180
+ return create_node(LtNode(lhs, b))
181
+
182
+ def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
183
+
167
184
  class OpNode(Node):
168
185
  def __init__(self, a:Node, b:Union[Node, int]):
169
186
  self.a, self.b = a, b
170
187
  self.min, self.max = self.get_bounds()
171
- def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
172
- @abstractmethod
173
- def get_bounds(self) -> Tuple[int, int]: pass
188
+ def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
189
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
174
190
 
175
191
  class LtNode(OpNode):
176
- def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
177
- def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
178
192
  def get_bounds(self) -> Tuple[int, int]:
179
- if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
180
- return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
193
+ if self.a == self.b: return (0, 0)
194
+ if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
195
+ return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
196
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
197
+ return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
181
198
 
182
199
  class MulNode(OpNode):
183
200
  def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
@@ -185,57 +202,72 @@ class MulNode(OpNode):
185
202
  if self.b % b == 0: return self.a*(self.b//b)
186
203
  if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
187
204
  return Node.__floordiv__(self, b, factoring_allowed)
188
- def __mod__(self, b: Union[Node, int]):
189
- a = (self.a * (self.b%b))
190
- return Node.__mod__(a, b)
191
- def get_bounds(self) -> Tuple[int, int]:
192
- return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
205
+ def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
206
+ def get_bounds(self) -> Tuple[int, sint]:
207
+ assert self.a.min >= 0
208
+ if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
209
+ return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
210
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
211
+ return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
193
212
 
194
213
  class DivNode(OpNode):
195
214
  def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
196
- def get_bounds(self) -> Tuple[int, int]:
215
+ def get_bounds(self) -> Tuple[int, sint]:
197
216
  assert self.a.min >= 0 and isinstance(self.b, int)
198
217
  return self.a.min//self.b, self.a.max//self.b
218
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
199
219
 
200
220
  class ModNode(OpNode):
221
+ def __mod__(self, b: Union[Node, int]):
222
+ if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
223
+ return Node.__mod__(self, b)
201
224
  def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
202
- if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
203
- return Node.__floordiv__(self, b, factoring_allowed)
204
- def get_bounds(self) -> Tuple[int, int]:
225
+ return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
226
+ def get_bounds(self) -> Tuple[int, sint]:
205
227
  assert self.a.min >= 0 and isinstance(self.b, int)
206
- return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
228
+ if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
229
+ return (self.a.min%self.b, self.a.max%self.b)
230
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
207
231
 
208
232
  class RedNode(Node):
209
- def __init__(self, nodes:List[Node]): self.nodes = nodes
210
- def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
233
+ def __init__(self, nodes:List[Node]):
234
+ self.nodes = nodes
235
+ self.min, self.max = self.get_bounds()
236
+ def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
237
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
211
238
 
212
239
  class SumNode(RedNode):
240
+ def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
241
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
213
242
  def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
214
- def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
243
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
244
+ def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
245
+ if self == b: return NumNode(1)
215
246
  fully_divided: List[Node] = []
216
247
  rest: List[Node] = []
217
- if isinstance(b, SumNode):
218
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
219
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
220
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
221
248
  if isinstance(b, Node):
222
249
  for x in self.flat_components:
223
250
  if x % b == 0: fully_divided.append(x // b)
224
251
  else: rest.append(x)
225
- if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
252
+ if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
226
253
  return Node.__floordiv__(self, b, False)
227
254
  if b == 1: return self
228
255
  if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
229
- fully_divided, rest = [], []
230
256
  _gcd = b
231
257
  divisor = 1
232
258
  for x in self.flat_components:
233
259
  if x.__class__ in (NumNode, MulNode):
234
- if x.b%b == 0: fully_divided.append(x//b)
260
+ if x.b % b == 0: fully_divided.append(x // b)
235
261
  else:
262
+ if x.__class__ is NumNode and (div := x.b // b):
263
+ fully_divided.append(NumNode(div))
264
+ x = NumNode(x.b - b * div)
236
265
  rest.append(x)
237
- _gcd = gcd(_gcd, x.b)
238
- if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
266
+ if isinstance(x.b, int):
267
+ _gcd = gcd(_gcd, x.b)
268
+ if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
269
+ else:
270
+ _gcd = 1
239
271
  else:
240
272
  rest.append(x)
241
273
  _gcd = 1
@@ -243,62 +275,55 @@ class SumNode(RedNode):
243
275
  if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
244
276
  return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
245
277
 
278
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
246
279
  def __mod__(self, b: Union[Node, int]):
247
- if isinstance(b, SumNode):
248
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
249
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
250
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
280
+ if self == b: return NumNode(0)
251
281
  if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
252
- new_nodes: List[Node] = []
253
- for x in self.nodes:
254
- if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
255
- elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
256
- else: new_nodes.append(x)
257
- return Node.__mod__(Node.sum(new_nodes), b)
258
-
259
- def __lt__(self, b:Union[Node,int]):
260
- if isinstance(b, int):
261
- new_sum = []
262
- for x in self.nodes:
263
- # TODO: should we just force the last one to always be the number
264
- if isinstance(x, NumNode): b -= x.b
265
- else: new_sum.append(x)
266
- return Node.__lt__(Node.sum(new_sum), b)
267
- return Node.__lt__(self, b)
282
+ new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
283
+ return Node.__mod__(new_sum, b)
284
+
285
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
286
+ return Node.sum([node.substitute(var_vals) for node in self.nodes])
268
287
 
288
+ # recursively expand sumnode components
289
+ # TODO: can remove this if there's no SumNode inside SumNode
269
290
  @property
270
- def flat_components(self): # recursively expand sumnode components
271
- new_nodes = []
272
- for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
273
- return new_nodes
291
+ def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
274
292
 
275
293
  class AndNode(RedNode):
276
- def __mul__(self, b: Union[Node, int]): Variable.ands([x*b for x in self.nodes])
277
- def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
278
-
279
- def create_rednode(typ:Type[RedNode], nodes:List[Node]):
280
- ret = typ(nodes)
281
- if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
282
- elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
283
- return create_node(ret)
284
-
285
- def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
286
- if isinstance(n, (int, NumNode)): return int(n)
287
- if isinstance(n, Variable): return var_vals[n]
288
- if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
289
- if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
290
- raise NotImplementedError(n)
291
- @functools.lru_cache(maxsize=None)
292
- def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
293
- def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
294
+ def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
295
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
296
+ subed = []
297
+ for node in self.nodes:
298
+ if not (sub:=node.substitute(var_vals)): return NumNode(0)
299
+ subed.append(sub)
300
+ return Node.ands(subed)
294
301
 
295
- render_python: Dict[Type, Callable] = {
296
- Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}",
297
- NumNode: lambda self,ops,ctx: f"{self.b}",
298
- MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
302
+ def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
303
+ def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
304
+ if isinstance(a, (int, float)): return a
305
+ ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
306
+ assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
307
+ return ret.b
308
+
309
+ # symbolic int, these are allowed in a Tensor shape
310
+ sint = Union[int, Variable, MulNode, SumNode]
311
+
312
+ def render_mulnode(node:MulNode, ops, ctx):
313
+ # TODO: add ProdNode and remove this case
314
+ if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
315
+ return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
316
+ return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
317
+
318
+ render_python: Dict[Type, Callable[..., str]] = {
319
+ Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
320
+ else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
321
+ else f"{self.expr}"),
322
+ NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
323
+ MulNode: render_mulnode,
299
324
  DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
300
325
  ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
301
326
  LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
302
327
  SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
303
- AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
328
+ AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
304
329
  }