tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  import functools
3
3
  from math import gcd
4
4
  from tinygrad.helpers import partition
5
- from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
5
+ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
6
6
 
7
7
  # NOTE: Python has different behavior for negative mod and floor div than c
8
8
  # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@@ -10,29 +10,27 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
10
10
  class Node:
11
11
  b: Union[Node, int]
12
12
  min: int
13
- max: int
13
+ max: sint
14
14
  def render(self, ops=None, ctx=None) -> Any:
15
15
  if ops is None: ops = render_python
16
16
  assert self.__class__ in (Variable, NumNode) or self.min != self.max
17
17
  return ops[type(self)](self, ops, ctx)
18
18
  def vars(self) -> Set[Variable]: return set()
19
19
  # substitute Variables with the values in var_vals
20
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
20
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
21
21
  def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
22
22
 
23
23
  @functools.cached_property
24
24
  def key(self) -> str: return self.render(ctx="DEBUG")
25
- @functools.cached_property
26
- def hash(self) -> int: return hash(self.key)
27
25
  def __repr__(self): return self.render(ctx="REPR")
28
26
  def __str__(self): return "<"+self.key+">"
29
- def __hash__(self): return self.hash
27
+ def __hash__(self): return hash(self.key)
30
28
  def __bool__(self): return not (self.max == self.min == 0)
31
29
  def __eq__(self, other:object) -> bool:
32
30
  if not isinstance(other, Node): return NotImplemented
33
31
  return self.key == other.key
34
32
  def __neg__(self): return self*-1
35
- def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b, Node) else NumNode(b)])
33
+ def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
36
34
  def __radd__(self, b:int): return self+b
37
35
  def __sub__(self, b:Union[Node,int]): return self+-b
38
36
  def __rsub__(self, b:int): return -self+b
@@ -43,24 +41,20 @@ class Node:
43
41
  def __mul__(self, b:Union[Node, int]):
44
42
  if b == 0: return NumNode(0)
45
43
  if b == 1: return self
46
- if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
47
44
  return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
48
45
  def __rmul__(self, b:int): return self*b
49
46
 
50
47
  # *** complex ops ***
51
48
 
52
- def __rfloordiv__(self, b:int):
53
- if self.min > b >= 0: return NumNode(0)
54
- if isinstance(self, NumNode): return NumNode(b // self.b)
55
- raise RuntimeError(f"not supported: {b} // {self}")
49
+ def __rfloordiv__(self, b:int): return NumNode(b) // self
56
50
  def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
57
51
  if isinstance(b, Node):
58
- if b.__class__ is NumNode: return self // b.b
52
+ if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
59
53
  if self == b: return NumNode(1)
60
54
  if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
61
55
  raise RuntimeError(f"not supported: {self} // {b}")
62
56
  assert b != 0
63
- if b < 0: return (self//-b)*-1
57
+ if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
64
58
  if b == 1: return self
65
59
 
66
60
  # the numerator of div is not allowed to be negative
@@ -70,10 +64,7 @@ class Node:
70
64
  return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
71
65
  return create_node(DivNode(self, b))
72
66
 
73
- def __rmod__(self, b:int):
74
- if self.min > b >= 0: return NumNode(b)
75
- if isinstance(self, NumNode): return NumNode(b % self.b)
76
- raise RuntimeError(f"not supported: {b} % {self}")
67
+ def __rmod__(self, b:int): return NumNode(b) % self
77
68
  def __mod__(self, b:Union[Node,int]):
78
69
  if isinstance(b, Node):
79
70
  if b.__class__ is NumNode: return self % b.b
@@ -102,7 +93,7 @@ class Node:
102
93
  else: mul_groups[node] = mul_groups.get(node, 0) + 1
103
94
  new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
104
95
  if num_node_sum: new_nodes.append(NumNode(num_node_sum))
105
- return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
96
+ return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
106
97
 
107
98
  @staticmethod
108
99
  def ands(nodes:List[Node]) -> Node:
@@ -112,19 +103,20 @@ class Node:
112
103
 
113
104
  # filter 1s
114
105
  nodes = [x for x in nodes if x.min != x.max]
115
- return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
106
+ return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
116
107
 
117
108
  # 4 basic node types
118
109
 
119
110
  class Variable(Node):
120
111
  def __new__(cls, *args):
121
- if len(args) == 0: return super().__new__(cls) # fix pickle
122
112
  expr, nmin, nmax = args
123
113
  assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
124
114
  if nmin == nmax: return NumNode(nmin)
125
115
  return super().__new__(cls)
126
116
 
127
- def __init__(self, expr:Optional[str], nmin:int, nmax:int):
117
+ def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
118
+
119
+ def __init__(self, expr:str, nmin:int, nmax:sint):
128
120
  self.expr, self.min, self.max = expr, nmin, nmax
129
121
  self._val: Optional[int] = None
130
122
  @property
@@ -139,7 +131,7 @@ class Variable(Node):
139
131
  assert self.val is not None, f"cannot unbind {self}"
140
132
  return Variable(self.expr, self.min, self.max), self.val
141
133
  def vars(self): return {self}
142
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return var_vals.get(self, self)
134
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
143
135
 
144
136
  class NumNode(Node):
145
137
  def __init__(self, num:int):
@@ -149,97 +141,129 @@ class NumNode(Node):
149
141
  def bind(self, val):
150
142
  assert self.b == val, f"cannot bind {val} to {self}"
151
143
  return self
144
+ def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
152
145
  def __eq__(self, other): return self.b == other
153
- def __hash__(self): return self.hash # needed with __eq__ override
154
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self
146
+ def __hash__(self): return hash(self.b) # needed with __eq__ override
147
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
155
148
 
156
149
  def create_node(ret:Node):
157
150
  assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
158
151
  if ret.min == ret.max: return NumNode(ret.min)
159
152
  return ret
160
153
 
154
+ def create_lt_node(lhs:Node, b:Union[Node, int]):
155
+ if isinstance(lhs, SumNode):
156
+ if isinstance(b, int):
157
+ new_sum = []
158
+ for x in lhs.nodes:
159
+ # TODO: should we just force the last one to always be the number
160
+ if isinstance(x, NumNode): b -= x.b
161
+ else: new_sum.append(x)
162
+ lhs = Node.sum(new_sum)
163
+ nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
164
+ assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
165
+ muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
166
+ if muls:
167
+ # NOTE: gcd in python 3.8 takes exactly 2 args
168
+ mul_gcd = b
169
+ for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
170
+ all_others = Node.sum(others)
171
+ if all_others.min >= 0 and all_others.max < mul_gcd:
172
+ lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
173
+ return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
174
+ if isinstance(lhs, MulNode):
175
+ if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
176
+ sgn = 1 if lhs.b > 0 else -1
177
+ return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
178
+ return create_node(LtNode(lhs, b))
179
+
180
+ def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
181
+
161
182
  class OpNode(Node):
162
183
  def __init__(self, a:Node, b:Union[Node, int]):
163
184
  self.a, self.b = a, b
164
185
  self.min, self.max = self.get_bounds()
165
186
  def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
166
- def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
187
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
167
188
 
168
189
  class LtNode(OpNode):
169
- def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
170
190
  def get_bounds(self) -> Tuple[int, int]:
191
+ if self.a == self.b: return (0, 0)
171
192
  if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
172
193
  return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
173
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
174
- return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
194
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
195
+ return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
175
196
 
176
197
  class MulNode(OpNode):
177
- def __lt__(self, b: Union[Node, int]):
178
- if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
179
- sgn = 1 if self.b > 0 else -1
180
- return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
181
198
  def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
182
199
  def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
183
200
  if self.b % b == 0: return self.a*(self.b//b)
184
201
  if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
185
202
  return Node.__floordiv__(self, b, factoring_allowed)
186
203
  def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
187
- def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
188
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
204
+ def get_bounds(self) -> Tuple[int, sint]:
205
+ assert self.a.min >= 0
206
+ if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
207
+ return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
208
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
189
209
  return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
190
210
 
191
211
  class DivNode(OpNode):
192
212
  def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
193
- def get_bounds(self) -> Tuple[int, int]:
213
+ def get_bounds(self) -> Tuple[int, sint]:
194
214
  assert self.a.min >= 0 and isinstance(self.b, int)
195
215
  return self.a.min//self.b, self.a.max//self.b
196
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self.a.substitute(var_vals) // self.b
216
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
197
217
 
198
218
  class ModNode(OpNode):
199
219
  def __mod__(self, b: Union[Node, int]):
200
- if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
201
- return self.a % b if self.b % b == 0 else Node.__mod__(self, b)
220
+ if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
221
+ return Node.__mod__(self, b)
202
222
  def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
203
223
  return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
204
- def get_bounds(self) -> Tuple[int, int]:
224
+ def get_bounds(self) -> Tuple[int, sint]:
205
225
  assert self.a.min >= 0 and isinstance(self.b, int)
206
- return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501
207
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self.a.substitute(var_vals) % self.b
226
+ if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
227
+ return (self.a.min%self.b, self.a.max%self.b)
228
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
208
229
 
209
230
  class RedNode(Node):
210
- def __init__(self, nodes:List[Node]): self.nodes = nodes
231
+ def __init__(self, nodes:List[Node]):
232
+ self.nodes = nodes
233
+ self.min, self.max = self.get_bounds()
211
234
  def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
235
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
212
236
 
213
237
  class SumNode(RedNode):
238
+ def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
214
239
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
215
240
  def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
216
241
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
217
- def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
242
+ def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
243
+ if self == b: return NumNode(1)
218
244
  fully_divided: List[Node] = []
219
245
  rest: List[Node] = []
220
- if isinstance(b, SumNode):
221
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
222
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
223
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
224
246
  if isinstance(b, Node):
225
247
  for x in self.flat_components:
226
248
  if x % b == 0: fully_divided.append(x // b)
227
249
  else: rest.append(x)
228
- if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
250
+ if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
229
251
  return Node.__floordiv__(self, b, False)
230
252
  if b == 1: return self
231
253
  if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
232
- fully_divided, rest = [], []
233
254
  _gcd = b
234
255
  divisor = 1
235
256
  for x in self.flat_components:
236
257
  if x.__class__ in (NumNode, MulNode):
237
- if x.b%b == 0: fully_divided.append(x//b)
258
+ if x.b % b == 0: fully_divided.append(x // b)
238
259
  else:
260
+ if x.__class__ is NumNode and (div := x.b // b):
261
+ fully_divided.append(NumNode(div))
262
+ x = NumNode(x.b - b * div)
239
263
  rest.append(x)
240
264
  if isinstance(x.b, int):
241
265
  _gcd = gcd(_gcd, x.b)
242
- if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
266
+ if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
243
267
  else:
244
268
  _gcd = 1
245
269
  else:
@@ -251,39 +275,13 @@ class SumNode(RedNode):
251
275
 
252
276
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
253
277
  def __mod__(self, b: Union[Node, int]):
254
- if isinstance(b, SumNode):
255
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
256
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
257
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
278
+ if self == b: return NumNode(0)
258
279
  if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
259
- new_nodes: List[Node] = []
260
- for x in self.nodes:
261
- if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
262
- else: new_nodes.append(x)
263
- return Node.__mod__(Node.sum(new_nodes), b)
264
-
265
- def __lt__(self, b:Union[Node,int]):
266
- lhs: Node = self
267
- if isinstance(b, int):
268
- new_sum = []
269
- for x in self.nodes:
270
- # TODO: should we just force the last one to always be the number
271
- if isinstance(x, NumNode): b -= x.b
272
- else: new_sum.append(x)
273
- lhs = Node.sum(new_sum)
274
- nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
275
- assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
276
- muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
277
- if muls:
278
- # NOTE: gcd in python 3.8 takes exactly 2 args
279
- mul_gcd = b
280
- for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
281
- all_others = Node.sum(others)
282
- if all_others.min >= 0 and all_others.max < mul_gcd:
283
- lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
284
- return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
280
+ new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
281
+ return Node.__mod__(new_sum, b)
285
282
 
286
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
283
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
284
+ return Node.sum([node.substitute(var_vals) for node in self.nodes])
287
285
 
288
286
  # recursively expand sumnode components
289
287
  # TODO: can remove this if there's no SumNode inside SumNode
@@ -291,36 +289,39 @@ class SumNode(RedNode):
291
289
  def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
292
290
 
293
291
  class AndNode(RedNode):
294
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
292
+ def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
293
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
295
294
  subed = []
296
295
  for node in self.nodes:
297
296
  if not (sub:=node.substitute(var_vals)): return NumNode(0)
298
297
  subed.append(sub)
299
298
  return Node.ands(subed)
300
299
 
301
- def create_rednode(typ:Type[RedNode], nodes:List[Node]):
302
- ret = typ(nodes)
303
- if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
304
- elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
305
- return create_node(ret)
306
-
307
300
  def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
308
- def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
301
+ def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
309
302
  if isinstance(a, (int, float)): return a
310
- ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()})
303
+ ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
311
304
  assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
312
305
  return ret.b
313
306
 
314
- # symbolic int
315
- sint = Union[Node, int]
307
+ # symbolic int, these are allowed in a Tensor shape
308
+ sint = Union[int, Variable, MulNode, SumNode]
309
+
310
+ def render_mulnode(node:MulNode, ops, ctx):
311
+ # TODO: add ProdNode and remove this case
312
+ if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
313
+ return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
314
+ return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
316
315
 
317
- render_python: Dict[Type, Callable] = {
318
- Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), # noqa: E501
316
+ render_python: Dict[Type, Callable[..., str]] = {
317
+ Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
318
+ else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
319
+ else f"{self.expr}"),
319
320
  NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
320
- MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", # noqa: E501
321
+ MulNode: render_mulnode,
321
322
  DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
322
323
  ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
323
324
  LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
324
325
  SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
325
- AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
326
+ AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
326
327
  }