tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/shape/symbolic.py
CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
import functools
|
3
3
|
from math import gcd
|
4
4
|
from tinygrad.helpers import partition
|
5
|
-
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
|
5
|
+
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
|
6
6
|
|
7
7
|
# NOTE: Python has different behavior for negative mod and floor div than c
|
8
8
|
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
@@ -10,29 +10,27 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
|
|
10
10
|
class Node:
|
11
11
|
b: Union[Node, int]
|
12
12
|
min: int
|
13
|
-
max:
|
13
|
+
max: sint
|
14
14
|
def render(self, ops=None, ctx=None) -> Any:
|
15
15
|
if ops is None: ops = render_python
|
16
16
|
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
17
17
|
return ops[type(self)](self, ops, ctx)
|
18
18
|
def vars(self) -> Set[Variable]: return set()
|
19
19
|
# substitute Variables with the values in var_vals
|
20
|
-
def substitute(self, var_vals:
|
20
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
|
21
21
|
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
22
22
|
|
23
23
|
@functools.cached_property
|
24
24
|
def key(self) -> str: return self.render(ctx="DEBUG")
|
25
|
-
@functools.cached_property
|
26
|
-
def hash(self) -> int: return hash(self.key)
|
27
25
|
def __repr__(self): return self.render(ctx="REPR")
|
28
26
|
def __str__(self): return "<"+self.key+">"
|
29
|
-
def __hash__(self): return self.
|
27
|
+
def __hash__(self): return hash(self.key)
|
30
28
|
def __bool__(self): return not (self.max == self.min == 0)
|
31
29
|
def __eq__(self, other:object) -> bool:
|
32
30
|
if not isinstance(other, Node): return NotImplemented
|
33
31
|
return self.key == other.key
|
34
32
|
def __neg__(self): return self*-1
|
35
|
-
def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b,
|
33
|
+
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
|
36
34
|
def __radd__(self, b:int): return self+b
|
37
35
|
def __sub__(self, b:Union[Node,int]): return self+-b
|
38
36
|
def __rsub__(self, b:int): return -self+b
|
@@ -43,24 +41,20 @@ class Node:
|
|
43
41
|
def __mul__(self, b:Union[Node, int]):
|
44
42
|
if b == 0: return NumNode(0)
|
45
43
|
if b == 1: return self
|
46
|
-
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
47
44
|
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
48
45
|
def __rmul__(self, b:int): return self*b
|
49
46
|
|
50
47
|
# *** complex ops ***
|
51
48
|
|
52
|
-
def __rfloordiv__(self, b:int):
|
53
|
-
if self.min > b >= 0: return NumNode(0)
|
54
|
-
if isinstance(self, NumNode): return NumNode(b // self.b)
|
55
|
-
raise RuntimeError(f"not supported: {b} // {self}")
|
49
|
+
def __rfloordiv__(self, b:int): return NumNode(b) // self
|
56
50
|
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
57
51
|
if isinstance(b, Node):
|
58
|
-
if b.__class__ is NumNode: return self
|
52
|
+
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
|
59
53
|
if self == b: return NumNode(1)
|
60
54
|
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
61
55
|
raise RuntimeError(f"not supported: {self} // {b}")
|
62
56
|
assert b != 0
|
63
|
-
if b < 0: return (self
|
57
|
+
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
|
64
58
|
if b == 1: return self
|
65
59
|
|
66
60
|
# the numerator of div is not allowed to be negative
|
@@ -70,10 +64,7 @@ class Node:
|
|
70
64
|
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
71
65
|
return create_node(DivNode(self, b))
|
72
66
|
|
73
|
-
def __rmod__(self, b:int):
|
74
|
-
if self.min > b >= 0: return NumNode(b)
|
75
|
-
if isinstance(self, NumNode): return NumNode(b % self.b)
|
76
|
-
raise RuntimeError(f"not supported: {b} % {self}")
|
67
|
+
def __rmod__(self, b:int): return NumNode(b) % self
|
77
68
|
def __mod__(self, b:Union[Node,int]):
|
78
69
|
if isinstance(b, Node):
|
79
70
|
if b.__class__ is NumNode: return self % b.b
|
@@ -102,7 +93,7 @@ class Node:
|
|
102
93
|
else: mul_groups[node] = mul_groups.get(node, 0) + 1
|
103
94
|
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
104
95
|
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
105
|
-
return
|
96
|
+
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
106
97
|
|
107
98
|
@staticmethod
|
108
99
|
def ands(nodes:List[Node]) -> Node:
|
@@ -112,19 +103,20 @@ class Node:
|
|
112
103
|
|
113
104
|
# filter 1s
|
114
105
|
nodes = [x for x in nodes if x.min != x.max]
|
115
|
-
return
|
106
|
+
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
116
107
|
|
117
108
|
# 4 basic node types
|
118
109
|
|
119
110
|
class Variable(Node):
|
120
111
|
def __new__(cls, *args):
|
121
|
-
if len(args) == 0: return super().__new__(cls) # fix pickle
|
122
112
|
expr, nmin, nmax = args
|
123
113
|
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
124
114
|
if nmin == nmax: return NumNode(nmin)
|
125
115
|
return super().__new__(cls)
|
126
116
|
|
127
|
-
def
|
117
|
+
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
118
|
+
|
119
|
+
def __init__(self, expr:str, nmin:int, nmax:sint):
|
128
120
|
self.expr, self.min, self.max = expr, nmin, nmax
|
129
121
|
self._val: Optional[int] = None
|
130
122
|
@property
|
@@ -139,7 +131,7 @@ class Variable(Node):
|
|
139
131
|
assert self.val is not None, f"cannot unbind {self}"
|
140
132
|
return Variable(self.expr, self.min, self.max), self.val
|
141
133
|
def vars(self): return {self}
|
142
|
-
def substitute(self, var_vals:
|
134
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
|
143
135
|
|
144
136
|
class NumNode(Node):
|
145
137
|
def __init__(self, num:int):
|
@@ -149,97 +141,129 @@ class NumNode(Node):
|
|
149
141
|
def bind(self, val):
|
150
142
|
assert self.b == val, f"cannot bind {val} to {self}"
|
151
143
|
return self
|
144
|
+
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
152
145
|
def __eq__(self, other): return self.b == other
|
153
|
-
def __hash__(self): return self.
|
154
|
-
def substitute(self, var_vals:
|
146
|
+
def __hash__(self): return hash(self.b) # needed with __eq__ override
|
147
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
|
155
148
|
|
156
149
|
def create_node(ret:Node):
|
157
150
|
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
158
151
|
if ret.min == ret.max: return NumNode(ret.min)
|
159
152
|
return ret
|
160
153
|
|
154
|
+
def create_lt_node(lhs:Node, b:Union[Node, int]):
|
155
|
+
if isinstance(lhs, SumNode):
|
156
|
+
if isinstance(b, int):
|
157
|
+
new_sum = []
|
158
|
+
for x in lhs.nodes:
|
159
|
+
# TODO: should we just force the last one to always be the number
|
160
|
+
if isinstance(x, NumNode): b -= x.b
|
161
|
+
else: new_sum.append(x)
|
162
|
+
lhs = Node.sum(new_sum)
|
163
|
+
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
164
|
+
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
165
|
+
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
166
|
+
if muls:
|
167
|
+
# NOTE: gcd in python 3.8 takes exactly 2 args
|
168
|
+
mul_gcd = b
|
169
|
+
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
170
|
+
all_others = Node.sum(others)
|
171
|
+
if all_others.min >= 0 and all_others.max < mul_gcd:
|
172
|
+
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
173
|
+
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
|
174
|
+
if isinstance(lhs, MulNode):
|
175
|
+
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
|
176
|
+
sgn = 1 if lhs.b > 0 else -1
|
177
|
+
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
|
178
|
+
return create_node(LtNode(lhs, b))
|
179
|
+
|
180
|
+
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
|
181
|
+
|
161
182
|
class OpNode(Node):
|
162
183
|
def __init__(self, a:Node, b:Union[Node, int]):
|
163
184
|
self.a, self.b = a, b
|
164
185
|
self.min, self.max = self.get_bounds()
|
165
186
|
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
166
|
-
def get_bounds(self) -> Tuple[int,
|
187
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
167
188
|
|
168
189
|
class LtNode(OpNode):
|
169
|
-
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
170
190
|
def get_bounds(self) -> Tuple[int, int]:
|
191
|
+
if self.a == self.b: return (0, 0)
|
171
192
|
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
172
193
|
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
173
|
-
def substitute(self, var_vals:
|
174
|
-
return self.a.substitute(var_vals)
|
194
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
195
|
+
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
|
175
196
|
|
176
197
|
class MulNode(OpNode):
|
177
|
-
def __lt__(self, b: Union[Node, int]):
|
178
|
-
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
|
179
|
-
sgn = 1 if self.b > 0 else -1
|
180
|
-
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
|
181
198
|
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
182
199
|
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
183
200
|
if self.b % b == 0: return self.a*(self.b//b)
|
184
201
|
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
185
202
|
return Node.__floordiv__(self, b, factoring_allowed)
|
186
203
|
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
187
|
-
def get_bounds(self) -> Tuple[int,
|
188
|
-
|
204
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
205
|
+
assert self.a.min >= 0
|
206
|
+
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
207
|
+
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
|
208
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
189
209
|
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
190
210
|
|
191
211
|
class DivNode(OpNode):
|
192
212
|
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
193
|
-
def get_bounds(self) -> Tuple[int,
|
213
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
194
214
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
195
215
|
return self.a.min//self.b, self.a.max//self.b
|
196
|
-
def substitute(self, var_vals:
|
216
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
|
197
217
|
|
198
218
|
class ModNode(OpNode):
|
199
219
|
def __mod__(self, b: Union[Node, int]):
|
200
|
-
if isinstance(b,
|
201
|
-
return
|
220
|
+
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
|
221
|
+
return Node.__mod__(self, b)
|
202
222
|
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
203
223
|
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
204
|
-
def get_bounds(self) -> Tuple[int,
|
224
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
205
225
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
206
|
-
|
207
|
-
|
226
|
+
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
|
227
|
+
return (self.a.min%self.b, self.a.max%self.b)
|
228
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
|
208
229
|
|
209
230
|
class RedNode(Node):
|
210
|
-
def __init__(self, nodes:List[Node]):
|
231
|
+
def __init__(self, nodes:List[Node]):
|
232
|
+
self.nodes = nodes
|
233
|
+
self.min, self.max = self.get_bounds()
|
211
234
|
def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
|
235
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
212
236
|
|
213
237
|
class SumNode(RedNode):
|
238
|
+
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
|
214
239
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
215
240
|
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
216
241
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
217
|
-
def __floordiv__(self, b: Union[Node,
|
242
|
+
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
|
243
|
+
if self == b: return NumNode(1)
|
218
244
|
fully_divided: List[Node] = []
|
219
245
|
rest: List[Node] = []
|
220
|
-
if isinstance(b, SumNode):
|
221
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
222
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
223
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
224
246
|
if isinstance(b, Node):
|
225
247
|
for x in self.flat_components:
|
226
248
|
if x % b == 0: fully_divided.append(x // b)
|
227
249
|
else: rest.append(x)
|
228
|
-
if (sum_fully_divided:=
|
250
|
+
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
|
229
251
|
return Node.__floordiv__(self, b, False)
|
230
252
|
if b == 1: return self
|
231
253
|
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
232
|
-
fully_divided, rest = [], []
|
233
254
|
_gcd = b
|
234
255
|
divisor = 1
|
235
256
|
for x in self.flat_components:
|
236
257
|
if x.__class__ in (NumNode, MulNode):
|
237
|
-
if x.b%b == 0: fully_divided.append(x//b)
|
258
|
+
if x.b % b == 0: fully_divided.append(x // b)
|
238
259
|
else:
|
260
|
+
if x.__class__ is NumNode and (div := x.b // b):
|
261
|
+
fully_divided.append(NumNode(div))
|
262
|
+
x = NumNode(x.b - b * div)
|
239
263
|
rest.append(x)
|
240
264
|
if isinstance(x.b, int):
|
241
265
|
_gcd = gcd(_gcd, x.b)
|
242
|
-
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
266
|
+
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
243
267
|
else:
|
244
268
|
_gcd = 1
|
245
269
|
else:
|
@@ -251,39 +275,13 @@ class SumNode(RedNode):
|
|
251
275
|
|
252
276
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
253
277
|
def __mod__(self, b: Union[Node, int]):
|
254
|
-
if
|
255
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
256
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
257
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
278
|
+
if self == b: return NumNode(0)
|
258
279
|
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
259
|
-
|
260
|
-
|
261
|
-
if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
|
262
|
-
else: new_nodes.append(x)
|
263
|
-
return Node.__mod__(Node.sum(new_nodes), b)
|
264
|
-
|
265
|
-
def __lt__(self, b:Union[Node,int]):
|
266
|
-
lhs: Node = self
|
267
|
-
if isinstance(b, int):
|
268
|
-
new_sum = []
|
269
|
-
for x in self.nodes:
|
270
|
-
# TODO: should we just force the last one to always be the number
|
271
|
-
if isinstance(x, NumNode): b -= x.b
|
272
|
-
else: new_sum.append(x)
|
273
|
-
lhs = Node.sum(new_sum)
|
274
|
-
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
275
|
-
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
276
|
-
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
277
|
-
if muls:
|
278
|
-
# NOTE: gcd in python 3.8 takes exactly 2 args
|
279
|
-
mul_gcd = b
|
280
|
-
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
281
|
-
all_others = Node.sum(others)
|
282
|
-
if all_others.min >= 0 and all_others.max < mul_gcd:
|
283
|
-
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
284
|
-
return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
|
280
|
+
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
|
281
|
+
return Node.__mod__(new_sum, b)
|
285
282
|
|
286
|
-
def substitute(self, var_vals:
|
283
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
284
|
+
return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
287
285
|
|
288
286
|
# recursively expand sumnode components
|
289
287
|
# TODO: can remove this if there's no SumNode inside SumNode
|
@@ -291,36 +289,39 @@ class SumNode(RedNode):
|
|
291
289
|
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
292
290
|
|
293
291
|
class AndNode(RedNode):
|
294
|
-
def
|
292
|
+
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
293
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
295
294
|
subed = []
|
296
295
|
for node in self.nodes:
|
297
296
|
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
298
297
|
subed.append(sub)
|
299
298
|
return Node.ands(subed)
|
300
299
|
|
301
|
-
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
302
|
-
ret = typ(nodes)
|
303
|
-
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
|
304
|
-
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
305
|
-
return create_node(ret)
|
306
|
-
|
307
300
|
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
308
|
-
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
301
|
+
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
309
302
|
if isinstance(a, (int, float)): return a
|
310
|
-
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()})
|
303
|
+
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
|
311
304
|
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
312
305
|
return ret.b
|
313
306
|
|
314
|
-
# symbolic int
|
315
|
-
sint = Union[
|
307
|
+
# symbolic int, these are allowed in a Tensor shape
|
308
|
+
sint = Union[int, Variable, MulNode, SumNode]
|
309
|
+
|
310
|
+
def render_mulnode(node:MulNode, ops, ctx):
|
311
|
+
# TODO: add ProdNode and remove this case
|
312
|
+
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
|
313
|
+
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
|
314
|
+
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
|
316
315
|
|
317
|
-
render_python: Dict[Type, Callable] = {
|
318
|
-
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG"
|
316
|
+
render_python: Dict[Type, Callable[..., str]] = {
|
317
|
+
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
|
318
|
+
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
|
319
|
+
else f"{self.expr}"),
|
319
320
|
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
320
|
-
MulNode:
|
321
|
+
MulNode: render_mulnode,
|
321
322
|
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
322
323
|
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
323
324
|
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
324
325
|
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
325
|
-
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
326
|
+
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
326
327
|
}
|