tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/symbolic.py
CHANGED
@@ -1,80 +1,62 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from abc import abstractmethod
|
3
2
|
import functools
|
4
3
|
from math import gcd
|
5
4
|
from tinygrad.helpers import partition
|
6
|
-
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
5
|
+
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
|
7
6
|
|
8
7
|
# NOTE: Python has different behavior for negative mod and floor div than c
|
9
8
|
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
10
9
|
|
11
|
-
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
12
|
-
|
13
10
|
class Node:
|
14
11
|
b: Union[Node, int]
|
15
12
|
min: int
|
16
|
-
max:
|
17
|
-
def render(self, ops=None, ctx=None
|
13
|
+
max: sint
|
14
|
+
def render(self, ops=None, ctx=None) -> Any:
|
18
15
|
if ops is None: ops = render_python
|
19
16
|
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
def
|
17
|
+
return ops[type(self)](self, ops, ctx)
|
18
|
+
def vars(self) -> Set[Variable]: return set()
|
19
|
+
# substitute Variables with the values in var_vals
|
20
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
|
21
|
+
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
22
|
+
|
24
23
|
@functools.cached_property
|
25
24
|
def key(self) -> str: return self.render(ctx="DEBUG")
|
26
25
|
@functools.cached_property
|
27
26
|
def hash(self) -> int: return hash(self.key)
|
28
|
-
def __repr__(self): return
|
27
|
+
def __repr__(self): return self.render(ctx="REPR")
|
28
|
+
def __str__(self): return "<"+self.key+">"
|
29
29
|
def __hash__(self): return self.hash
|
30
30
|
def __bool__(self): return not (self.max == self.min == 0)
|
31
31
|
def __eq__(self, other:object) -> bool:
|
32
32
|
if not isinstance(other, Node): return NotImplemented
|
33
33
|
return self.key == other.key
|
34
34
|
def __neg__(self): return self*-1
|
35
|
-
def __add__(self, b:Union[Node,int]): return
|
35
|
+
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
|
36
36
|
def __radd__(self, b:int): return self+b
|
37
37
|
def __sub__(self, b:Union[Node,int]): return self+-b
|
38
38
|
def __rsub__(self, b:int): return -self+b
|
39
39
|
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
40
40
|
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
41
41
|
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
42
|
-
def __lt__(self, b:Union[Node,int]):
|
43
|
-
lhs = self
|
44
|
-
if isinstance(lhs, SumNode) and isinstance(b, int):
|
45
|
-
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
46
|
-
if muls:
|
47
|
-
# NOTE: gcd in python 3.8 takes exactly 2 args
|
48
|
-
mul_gcd = muls[0].b
|
49
|
-
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
|
50
|
-
if b%mul_gcd == 0:
|
51
|
-
all_others = Variable.sum(others)
|
52
|
-
#print(mul_gcd, muls, all_others)
|
53
|
-
if all_others.min >= 0 and all_others.max < mul_gcd:
|
54
|
-
# TODO: should we divide both by mul_gcd here?
|
55
|
-
lhs = Variable.sum(muls)
|
56
|
-
return create_node(LtNode(lhs, b))
|
42
|
+
def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
|
57
43
|
def __mul__(self, b:Union[Node, int]):
|
58
44
|
if b == 0: return NumNode(0)
|
59
45
|
if b == 1: return self
|
60
|
-
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
61
46
|
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
62
47
|
def __rmul__(self, b:int): return self*b
|
63
48
|
|
64
49
|
# *** complex ops ***
|
65
50
|
|
66
|
-
def __rfloordiv__(self, b:int):
|
67
|
-
if self.min > b >= 0: return NumNode(0)
|
68
|
-
if isinstance(self, NumNode): return NumNode(b // self.b)
|
69
|
-
raise RuntimeError(f"not supported: {b} // {self}")
|
51
|
+
def __rfloordiv__(self, b:int): return NumNode(b) // self
|
70
52
|
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
71
53
|
if isinstance(b, Node):
|
72
|
-
if b.__class__ is NumNode: return self
|
54
|
+
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
|
73
55
|
if self == b: return NumNode(1)
|
74
56
|
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
75
57
|
raise RuntimeError(f"not supported: {self} // {b}")
|
76
58
|
assert b != 0
|
77
|
-
if b < 0: return (self
|
59
|
+
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
|
78
60
|
if b == 1: return self
|
79
61
|
|
80
62
|
# the numerator of div is not allowed to be negative
|
@@ -84,10 +66,7 @@ class Node:
|
|
84
66
|
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
85
67
|
return create_node(DivNode(self, b))
|
86
68
|
|
87
|
-
def __rmod__(self, b:int):
|
88
|
-
if self.min > b >= 0: return NumNode(b)
|
89
|
-
if isinstance(self, NumNode): return NumNode(b % self.b)
|
90
|
-
raise RuntimeError(f"not supported: {b} % {self}")
|
69
|
+
def __rmod__(self, b:int): return NumNode(b) % self
|
91
70
|
def __mod__(self, b:Union[Node,int]):
|
92
71
|
if isinstance(b, Node):
|
93
72
|
if b.__class__ is NumNode: return self % b.b
|
@@ -96,37 +75,27 @@ class Node:
|
|
96
75
|
raise RuntimeError(f"not supported: {self} % {b}")
|
97
76
|
assert b > 0
|
98
77
|
if b == 1: return NumNode(0)
|
99
|
-
if self.
|
100
|
-
|
78
|
+
if isinstance(self.max, int) and isinstance(self.min, int):
|
79
|
+
if self.min >= 0 and self.max < b: return self
|
80
|
+
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
81
|
+
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
101
82
|
return create_node(ModNode(self, b))
|
102
83
|
|
103
|
-
@staticmethod
|
104
|
-
def num(num:int) -> NumNode: return NumNode(num)
|
105
|
-
|
106
|
-
@staticmethod
|
107
|
-
def factorize(nodes:List[Node]) -> List[Node]:
|
108
|
-
mul_groups: Dict[Node, int] = {}
|
109
|
-
for x in nodes:
|
110
|
-
a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
|
111
|
-
mul_groups[a] = mul_groups.get(a, 0) + b
|
112
|
-
return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
113
|
-
|
114
84
|
@staticmethod
|
115
85
|
def sum(nodes:List[Node]) -> Node:
|
116
86
|
nodes = [x for x in nodes if x.max or x.min]
|
117
87
|
if not nodes: return NumNode(0)
|
118
88
|
if len(nodes) == 1: return nodes[0]
|
119
89
|
|
120
|
-
|
90
|
+
mul_groups: Dict[Node, int] = {}
|
121
91
|
num_node_sum = 0
|
122
92
|
for node in SumNode(nodes).flat_components:
|
123
93
|
if node.__class__ is NumNode: num_node_sum += node.b
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
new_nodes = Node.factorize(new_nodes)
|
94
|
+
elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
|
95
|
+
else: mul_groups[node] = mul_groups.get(node, 0) + 1
|
96
|
+
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
128
97
|
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
129
|
-
return
|
98
|
+
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
130
99
|
|
131
100
|
@staticmethod
|
132
101
|
def ands(nodes:List[Node]) -> Node:
|
@@ -136,48 +105,96 @@ class Node:
|
|
136
105
|
|
137
106
|
# filter 1s
|
138
107
|
nodes = [x for x in nodes if x.min != x.max]
|
139
|
-
return
|
108
|
+
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
140
109
|
|
141
110
|
# 4 basic node types
|
142
111
|
|
143
112
|
class Variable(Node):
|
144
|
-
def __new__(cls,
|
145
|
-
|
113
|
+
def __new__(cls, *args):
|
114
|
+
expr, nmin, nmax = args
|
115
|
+
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
146
116
|
if nmin == nmax: return NumNode(nmin)
|
147
117
|
return super().__new__(cls)
|
148
118
|
|
149
|
-
def
|
119
|
+
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
120
|
+
|
121
|
+
def __init__(self, expr:str, nmin:int, nmax:sint):
|
150
122
|
self.expr, self.min, self.max = expr, nmin, nmax
|
151
|
-
|
123
|
+
self._val: Optional[int] = None
|
124
|
+
@property
|
125
|
+
def val(self):
|
126
|
+
assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
|
127
|
+
return self._val
|
128
|
+
def bind(self, val):
|
129
|
+
assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
|
130
|
+
self._val = val
|
131
|
+
return self
|
132
|
+
def unbind(self) -> Tuple[Variable, int]:
|
133
|
+
assert self.val is not None, f"cannot unbind {self}"
|
134
|
+
return Variable(self.expr, self.min, self.max), self.val
|
135
|
+
def vars(self): return {self}
|
136
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
|
152
137
|
|
153
138
|
class NumNode(Node):
|
154
139
|
def __init__(self, num:int):
|
140
|
+
assert isinstance(num, int), f"{num} is not an int"
|
155
141
|
self.b:int = num
|
156
142
|
self.min, self.max = num, num
|
157
|
-
def
|
158
|
-
|
143
|
+
def bind(self, val):
|
144
|
+
assert self.b == val, f"cannot bind {val} to {self}"
|
145
|
+
return self
|
146
|
+
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
159
147
|
def __eq__(self, other): return self.b == other
|
160
|
-
def __hash__(self): return self.
|
148
|
+
def __hash__(self): return hash(self.b) # needed with __eq__ override
|
149
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
|
161
150
|
|
162
151
|
def create_node(ret:Node):
|
163
152
|
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
164
153
|
if ret.min == ret.max: return NumNode(ret.min)
|
165
154
|
return ret
|
166
155
|
|
156
|
+
def create_lt_node(lhs:Node, b:Union[Node, int]):
|
157
|
+
if isinstance(lhs, SumNode):
|
158
|
+
if isinstance(b, int):
|
159
|
+
new_sum = []
|
160
|
+
for x in lhs.nodes:
|
161
|
+
# TODO: should we just force the last one to always be the number
|
162
|
+
if isinstance(x, NumNode): b -= x.b
|
163
|
+
else: new_sum.append(x)
|
164
|
+
lhs = Node.sum(new_sum)
|
165
|
+
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
166
|
+
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
167
|
+
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
168
|
+
if muls:
|
169
|
+
# NOTE: gcd in python 3.8 takes exactly 2 args
|
170
|
+
mul_gcd = b
|
171
|
+
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
172
|
+
all_others = Node.sum(others)
|
173
|
+
if all_others.min >= 0 and all_others.max < mul_gcd:
|
174
|
+
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
175
|
+
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
|
176
|
+
if isinstance(lhs, MulNode):
|
177
|
+
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
|
178
|
+
sgn = 1 if lhs.b > 0 else -1
|
179
|
+
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
|
180
|
+
return create_node(LtNode(lhs, b))
|
181
|
+
|
182
|
+
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
|
183
|
+
|
167
184
|
class OpNode(Node):
|
168
185
|
def __init__(self, a:Node, b:Union[Node, int]):
|
169
186
|
self.a, self.b = a, b
|
170
187
|
self.min, self.max = self.get_bounds()
|
171
|
-
def vars(self): return self.a.vars()
|
172
|
-
|
173
|
-
def get_bounds(self) -> Tuple[int, int]: pass
|
188
|
+
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
189
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
174
190
|
|
175
191
|
class LtNode(OpNode):
|
176
|
-
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
|
177
|
-
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
178
192
|
def get_bounds(self) -> Tuple[int, int]:
|
179
|
-
if
|
180
|
-
return (1, 1) if self.a.max < self.b
|
193
|
+
if self.a == self.b: return (0, 0)
|
194
|
+
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
195
|
+
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
196
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
197
|
+
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
|
181
198
|
|
182
199
|
class MulNode(OpNode):
|
183
200
|
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
@@ -185,57 +202,72 @@ class MulNode(OpNode):
|
|
185
202
|
if self.b % b == 0: return self.a*(self.b//b)
|
186
203
|
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
187
204
|
return Node.__floordiv__(self, b, factoring_allowed)
|
188
|
-
def __mod__(self, b: Union[Node, int]):
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
205
|
+
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
206
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
207
|
+
assert self.a.min >= 0
|
208
|
+
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
209
|
+
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
|
210
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
211
|
+
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
193
212
|
|
194
213
|
class DivNode(OpNode):
|
195
214
|
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
196
|
-
def get_bounds(self) -> Tuple[int,
|
215
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
197
216
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
198
217
|
return self.a.min//self.b, self.a.max//self.b
|
218
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
|
199
219
|
|
200
220
|
class ModNode(OpNode):
|
221
|
+
def __mod__(self, b: Union[Node, int]):
|
222
|
+
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
|
223
|
+
return Node.__mod__(self, b)
|
201
224
|
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
202
|
-
|
203
|
-
|
204
|
-
def get_bounds(self) -> Tuple[int, int]:
|
225
|
+
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
226
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
205
227
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
206
|
-
|
228
|
+
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
|
229
|
+
return (self.a.min%self.b, self.a.max%self.b)
|
230
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
|
207
231
|
|
208
232
|
class RedNode(Node):
|
209
|
-
def __init__(self, nodes:List[Node]):
|
210
|
-
|
233
|
+
def __init__(self, nodes:List[Node]):
|
234
|
+
self.nodes = nodes
|
235
|
+
self.min, self.max = self.get_bounds()
|
236
|
+
def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
|
237
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
211
238
|
|
212
239
|
class SumNode(RedNode):
|
240
|
+
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
|
241
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
213
242
|
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
214
|
-
|
243
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
244
|
+
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
|
245
|
+
if self == b: return NumNode(1)
|
215
246
|
fully_divided: List[Node] = []
|
216
247
|
rest: List[Node] = []
|
217
|
-
if isinstance(b, SumNode):
|
218
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
219
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
220
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
221
248
|
if isinstance(b, Node):
|
222
249
|
for x in self.flat_components:
|
223
250
|
if x % b == 0: fully_divided.append(x // b)
|
224
251
|
else: rest.append(x)
|
225
|
-
if (sum_fully_divided:=
|
252
|
+
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
|
226
253
|
return Node.__floordiv__(self, b, False)
|
227
254
|
if b == 1: return self
|
228
255
|
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
229
|
-
fully_divided, rest = [], []
|
230
256
|
_gcd = b
|
231
257
|
divisor = 1
|
232
258
|
for x in self.flat_components:
|
233
259
|
if x.__class__ in (NumNode, MulNode):
|
234
|
-
if x.b%b == 0: fully_divided.append(x//b)
|
260
|
+
if x.b % b == 0: fully_divided.append(x // b)
|
235
261
|
else:
|
262
|
+
if x.__class__ is NumNode and (div := x.b // b):
|
263
|
+
fully_divided.append(NumNode(div))
|
264
|
+
x = NumNode(x.b - b * div)
|
236
265
|
rest.append(x)
|
237
|
-
|
238
|
-
|
266
|
+
if isinstance(x.b, int):
|
267
|
+
_gcd = gcd(_gcd, x.b)
|
268
|
+
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
269
|
+
else:
|
270
|
+
_gcd = 1
|
239
271
|
else:
|
240
272
|
rest.append(x)
|
241
273
|
_gcd = 1
|
@@ -243,62 +275,55 @@ class SumNode(RedNode):
|
|
243
275
|
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
244
276
|
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
245
277
|
|
278
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
246
279
|
def __mod__(self, b: Union[Node, int]):
|
247
|
-
if
|
248
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
249
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
250
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
280
|
+
if self == b: return NumNode(0)
|
251
281
|
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
return Node.__mod__(Node.sum(new_nodes), b)
|
258
|
-
|
259
|
-
def __lt__(self, b:Union[Node,int]):
|
260
|
-
if isinstance(b, int):
|
261
|
-
new_sum = []
|
262
|
-
for x in self.nodes:
|
263
|
-
# TODO: should we just force the last one to always be the number
|
264
|
-
if isinstance(x, NumNode): b -= x.b
|
265
|
-
else: new_sum.append(x)
|
266
|
-
return Node.__lt__(Node.sum(new_sum), b)
|
267
|
-
return Node.__lt__(self, b)
|
282
|
+
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
|
283
|
+
return Node.__mod__(new_sum, b)
|
284
|
+
|
285
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
286
|
+
return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
268
287
|
|
288
|
+
# recursively expand sumnode components
|
289
|
+
# TODO: can remove this if there's no SumNode inside SumNode
|
269
290
|
@property
|
270
|
-
def flat_components(self):
|
271
|
-
new_nodes = []
|
272
|
-
for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
|
273
|
-
return new_nodes
|
291
|
+
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
274
292
|
|
275
293
|
class AndNode(RedNode):
|
276
|
-
def
|
277
|
-
def
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
return create_node(ret)
|
284
|
-
|
285
|
-
def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
|
286
|
-
if isinstance(n, (int, NumNode)): return int(n)
|
287
|
-
if isinstance(n, Variable): return var_vals[n]
|
288
|
-
if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
|
289
|
-
if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
|
290
|
-
raise NotImplementedError(n)
|
291
|
-
@functools.lru_cache(maxsize=None)
|
292
|
-
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
293
|
-
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
294
|
+
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
295
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
296
|
+
subed = []
|
297
|
+
for node in self.nodes:
|
298
|
+
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
299
|
+
subed.append(sub)
|
300
|
+
return Node.ands(subed)
|
294
301
|
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
302
|
+
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
303
|
+
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
304
|
+
if isinstance(a, (int, float)): return a
|
305
|
+
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
|
306
|
+
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
307
|
+
return ret.b
|
308
|
+
|
309
|
+
# symbolic int, these are allowed in a Tensor shape
|
310
|
+
sint = Union[int, Variable, MulNode, SumNode]
|
311
|
+
|
312
|
+
def render_mulnode(node:MulNode, ops, ctx):
|
313
|
+
# TODO: add ProdNode and remove this case
|
314
|
+
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
|
315
|
+
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
|
316
|
+
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
|
317
|
+
|
318
|
+
render_python: Dict[Type, Callable[..., str]] = {
|
319
|
+
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
|
320
|
+
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
|
321
|
+
else f"{self.expr}"),
|
322
|
+
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
323
|
+
MulNode: render_mulnode,
|
299
324
|
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
300
325
|
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
301
326
|
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
302
327
|
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
303
|
-
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
328
|
+
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
304
329
|
}
|