tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/symbolic.py
CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
import functools
|
3
3
|
from math import gcd
|
4
4
|
from tinygrad.helpers import partition
|
5
|
-
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
|
5
|
+
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
|
6
6
|
|
7
7
|
# NOTE: Python has different behavior for negative mod and floor div than c
|
8
8
|
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
@@ -10,14 +10,14 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
|
|
10
10
|
class Node:
|
11
11
|
b: Union[Node, int]
|
12
12
|
min: int
|
13
|
-
max:
|
13
|
+
max: sint
|
14
14
|
def render(self, ops=None, ctx=None) -> Any:
|
15
15
|
if ops is None: ops = render_python
|
16
16
|
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
17
17
|
return ops[type(self)](self, ops, ctx)
|
18
18
|
def vars(self) -> Set[Variable]: return set()
|
19
19
|
# substitute Variables with the values in var_vals
|
20
|
-
def substitute(self, var_vals:
|
20
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
|
21
21
|
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
22
22
|
|
23
23
|
@functools.cached_property
|
@@ -32,7 +32,7 @@ class Node:
|
|
32
32
|
if not isinstance(other, Node): return NotImplemented
|
33
33
|
return self.key == other.key
|
34
34
|
def __neg__(self): return self*-1
|
35
|
-
def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b,
|
35
|
+
def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
|
36
36
|
def __radd__(self, b:int): return self+b
|
37
37
|
def __sub__(self, b:Union[Node,int]): return self+-b
|
38
38
|
def __rsub__(self, b:int): return -self+b
|
@@ -43,24 +43,20 @@ class Node:
|
|
43
43
|
def __mul__(self, b:Union[Node, int]):
|
44
44
|
if b == 0: return NumNode(0)
|
45
45
|
if b == 1: return self
|
46
|
-
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
47
46
|
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
48
47
|
def __rmul__(self, b:int): return self*b
|
49
48
|
|
50
49
|
# *** complex ops ***
|
51
50
|
|
52
|
-
def __rfloordiv__(self, b:int):
|
53
|
-
if self.min > b >= 0: return NumNode(0)
|
54
|
-
if isinstance(self, NumNode): return NumNode(b // self.b)
|
55
|
-
raise RuntimeError(f"not supported: {b} // {self}")
|
51
|
+
def __rfloordiv__(self, b:int): return NumNode(b) // self
|
56
52
|
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
57
53
|
if isinstance(b, Node):
|
58
|
-
if b.__class__ is NumNode: return self
|
54
|
+
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
|
59
55
|
if self == b: return NumNode(1)
|
60
56
|
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
61
57
|
raise RuntimeError(f"not supported: {self} // {b}")
|
62
58
|
assert b != 0
|
63
|
-
if b < 0: return (self
|
59
|
+
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
|
64
60
|
if b == 1: return self
|
65
61
|
|
66
62
|
# the numerator of div is not allowed to be negative
|
@@ -70,10 +66,7 @@ class Node:
|
|
70
66
|
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
71
67
|
return create_node(DivNode(self, b))
|
72
68
|
|
73
|
-
def __rmod__(self, b:int):
|
74
|
-
if self.min > b >= 0: return NumNode(b)
|
75
|
-
if isinstance(self, NumNode): return NumNode(b % self.b)
|
76
|
-
raise RuntimeError(f"not supported: {b} % {self}")
|
69
|
+
def __rmod__(self, b:int): return NumNode(b) % self
|
77
70
|
def __mod__(self, b:Union[Node,int]):
|
78
71
|
if isinstance(b, Node):
|
79
72
|
if b.__class__ is NumNode: return self % b.b
|
@@ -102,7 +95,7 @@ class Node:
|
|
102
95
|
else: mul_groups[node] = mul_groups.get(node, 0) + 1
|
103
96
|
new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
|
104
97
|
if num_node_sum: new_nodes.append(NumNode(num_node_sum))
|
105
|
-
return
|
98
|
+
return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
|
106
99
|
|
107
100
|
@staticmethod
|
108
101
|
def ands(nodes:List[Node]) -> Node:
|
@@ -112,19 +105,20 @@ class Node:
|
|
112
105
|
|
113
106
|
# filter 1s
|
114
107
|
nodes = [x for x in nodes if x.min != x.max]
|
115
|
-
return
|
108
|
+
return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
116
109
|
|
117
110
|
# 4 basic node types
|
118
111
|
|
119
112
|
class Variable(Node):
|
120
113
|
def __new__(cls, *args):
|
121
|
-
if len(args) == 0: return super().__new__(cls) # fix pickle
|
122
114
|
expr, nmin, nmax = args
|
123
115
|
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
124
116
|
if nmin == nmax: return NumNode(nmin)
|
125
117
|
return super().__new__(cls)
|
126
118
|
|
127
|
-
def
|
119
|
+
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
120
|
+
|
121
|
+
def __init__(self, expr:str, nmin:int, nmax:sint):
|
128
122
|
self.expr, self.min, self.max = expr, nmin, nmax
|
129
123
|
self._val: Optional[int] = None
|
130
124
|
@property
|
@@ -139,7 +133,7 @@ class Variable(Node):
|
|
139
133
|
assert self.val is not None, f"cannot unbind {self}"
|
140
134
|
return Variable(self.expr, self.min, self.max), self.val
|
141
135
|
def vars(self): return {self}
|
142
|
-
def substitute(self, var_vals:
|
136
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
|
143
137
|
|
144
138
|
class NumNode(Node):
|
145
139
|
def __init__(self, num:int):
|
@@ -149,97 +143,129 @@ class NumNode(Node):
|
|
149
143
|
def bind(self, val):
|
150
144
|
assert self.b == val, f"cannot bind {val} to {self}"
|
151
145
|
return self
|
146
|
+
def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
152
147
|
def __eq__(self, other): return self.b == other
|
153
|
-
def __hash__(self): return self.
|
154
|
-
def substitute(self, var_vals:
|
148
|
+
def __hash__(self): return hash(self.b) # needed with __eq__ override
|
149
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
|
155
150
|
|
156
151
|
def create_node(ret:Node):
|
157
152
|
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
158
153
|
if ret.min == ret.max: return NumNode(ret.min)
|
159
154
|
return ret
|
160
155
|
|
156
|
+
def create_lt_node(lhs:Node, b:Union[Node, int]):
|
157
|
+
if isinstance(lhs, SumNode):
|
158
|
+
if isinstance(b, int):
|
159
|
+
new_sum = []
|
160
|
+
for x in lhs.nodes:
|
161
|
+
# TODO: should we just force the last one to always be the number
|
162
|
+
if isinstance(x, NumNode): b -= x.b
|
163
|
+
else: new_sum.append(x)
|
164
|
+
lhs = Node.sum(new_sum)
|
165
|
+
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
166
|
+
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
167
|
+
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
168
|
+
if muls:
|
169
|
+
# NOTE: gcd in python 3.8 takes exactly 2 args
|
170
|
+
mul_gcd = b
|
171
|
+
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
172
|
+
all_others = Node.sum(others)
|
173
|
+
if all_others.min >= 0 and all_others.max < mul_gcd:
|
174
|
+
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
175
|
+
return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
|
176
|
+
if isinstance(lhs, MulNode):
|
177
|
+
if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
|
178
|
+
sgn = 1 if lhs.b > 0 else -1
|
179
|
+
return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
|
180
|
+
return create_node(LtNode(lhs, b))
|
181
|
+
|
182
|
+
def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
|
183
|
+
|
161
184
|
class OpNode(Node):
|
162
185
|
def __init__(self, a:Node, b:Union[Node, int]):
|
163
186
|
self.a, self.b = a, b
|
164
187
|
self.min, self.max = self.get_bounds()
|
165
188
|
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
166
|
-
def get_bounds(self) -> Tuple[int,
|
189
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
167
190
|
|
168
191
|
class LtNode(OpNode):
|
169
|
-
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
170
192
|
def get_bounds(self) -> Tuple[int, int]:
|
193
|
+
if self.a == self.b: return (0, 0)
|
171
194
|
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
172
195
|
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
173
|
-
def substitute(self, var_vals:
|
174
|
-
return self.a.substitute(var_vals)
|
196
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
197
|
+
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
|
175
198
|
|
176
199
|
class MulNode(OpNode):
|
177
|
-
def __lt__(self, b: Union[Node, int]):
|
178
|
-
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
|
179
|
-
sgn = 1 if self.b > 0 else -1
|
180
|
-
return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
|
181
200
|
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
182
201
|
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
183
202
|
if self.b % b == 0: return self.a*(self.b//b)
|
184
203
|
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
185
204
|
return Node.__floordiv__(self, b, factoring_allowed)
|
186
205
|
def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
|
187
|
-
def get_bounds(self) -> Tuple[int,
|
188
|
-
|
206
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
207
|
+
assert self.a.min >= 0
|
208
|
+
if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
209
|
+
return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
|
210
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
189
211
|
return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
190
212
|
|
191
213
|
class DivNode(OpNode):
|
192
214
|
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
193
|
-
def get_bounds(self) -> Tuple[int,
|
215
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
194
216
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
195
217
|
return self.a.min//self.b, self.a.max//self.b
|
196
|
-
def substitute(self, var_vals:
|
218
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
|
197
219
|
|
198
220
|
class ModNode(OpNode):
|
199
221
|
def __mod__(self, b: Union[Node, int]):
|
200
|
-
if isinstance(b,
|
201
|
-
return
|
222
|
+
if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
|
223
|
+
return Node.__mod__(self, b)
|
202
224
|
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
203
225
|
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
|
204
|
-
def get_bounds(self) -> Tuple[int,
|
226
|
+
def get_bounds(self) -> Tuple[int, sint]:
|
205
227
|
assert self.a.min >= 0 and isinstance(self.b, int)
|
206
|
-
|
207
|
-
|
228
|
+
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
|
229
|
+
return (self.a.min%self.b, self.a.max%self.b)
|
230
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
|
208
231
|
|
209
232
|
class RedNode(Node):
|
210
|
-
def __init__(self, nodes:List[Node]):
|
233
|
+
def __init__(self, nodes:List[Node]):
|
234
|
+
self.nodes = nodes
|
235
|
+
self.min, self.max = self.get_bounds()
|
211
236
|
def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
|
237
|
+
def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
|
212
238
|
|
213
239
|
class SumNode(RedNode):
|
240
|
+
def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
|
214
241
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
215
242
|
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
216
243
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
217
|
-
def __floordiv__(self, b: Union[Node,
|
244
|
+
def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
|
245
|
+
if self == b: return NumNode(1)
|
218
246
|
fully_divided: List[Node] = []
|
219
247
|
rest: List[Node] = []
|
220
|
-
if isinstance(b, SumNode):
|
221
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
222
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
223
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
224
248
|
if isinstance(b, Node):
|
225
249
|
for x in self.flat_components:
|
226
250
|
if x % b == 0: fully_divided.append(x // b)
|
227
251
|
else: rest.append(x)
|
228
|
-
if (sum_fully_divided:=
|
252
|
+
if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
|
229
253
|
return Node.__floordiv__(self, b, False)
|
230
254
|
if b == 1: return self
|
231
255
|
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
232
|
-
fully_divided, rest = [], []
|
233
256
|
_gcd = b
|
234
257
|
divisor = 1
|
235
258
|
for x in self.flat_components:
|
236
259
|
if x.__class__ in (NumNode, MulNode):
|
237
|
-
if x.b%b == 0: fully_divided.append(x//b)
|
260
|
+
if x.b % b == 0: fully_divided.append(x // b)
|
238
261
|
else:
|
262
|
+
if x.__class__ is NumNode and (div := x.b // b):
|
263
|
+
fully_divided.append(NumNode(div))
|
264
|
+
x = NumNode(x.b - b * div)
|
239
265
|
rest.append(x)
|
240
266
|
if isinstance(x.b, int):
|
241
267
|
_gcd = gcd(_gcd, x.b)
|
242
|
-
if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
|
268
|
+
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
|
243
269
|
else:
|
244
270
|
_gcd = 1
|
245
271
|
else:
|
@@ -251,39 +277,13 @@ class SumNode(RedNode):
|
|
251
277
|
|
252
278
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
253
279
|
def __mod__(self, b: Union[Node, int]):
|
254
|
-
if
|
255
|
-
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
256
|
-
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
257
|
-
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
280
|
+
if self == b: return NumNode(0)
|
258
281
|
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
259
|
-
|
260
|
-
|
261
|
-
if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
|
262
|
-
else: new_nodes.append(x)
|
263
|
-
return Node.__mod__(Node.sum(new_nodes), b)
|
264
|
-
|
265
|
-
def __lt__(self, b:Union[Node,int]):
|
266
|
-
lhs: Node = self
|
267
|
-
if isinstance(b, int):
|
268
|
-
new_sum = []
|
269
|
-
for x in self.nodes:
|
270
|
-
# TODO: should we just force the last one to always be the number
|
271
|
-
if isinstance(x, NumNode): b -= x.b
|
272
|
-
else: new_sum.append(x)
|
273
|
-
lhs = Node.sum(new_sum)
|
274
|
-
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
275
|
-
assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
|
276
|
-
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
277
|
-
if muls:
|
278
|
-
# NOTE: gcd in python 3.8 takes exactly 2 args
|
279
|
-
mul_gcd = b
|
280
|
-
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
281
|
-
all_others = Node.sum(others)
|
282
|
-
if all_others.min >= 0 and all_others.max < mul_gcd:
|
283
|
-
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
284
|
-
return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
|
282
|
+
new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
|
283
|
+
return Node.__mod__(new_sum, b)
|
285
284
|
|
286
|
-
def substitute(self, var_vals:
|
285
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
286
|
+
return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
287
287
|
|
288
288
|
# recursively expand sumnode components
|
289
289
|
# TODO: can remove this if there's no SumNode inside SumNode
|
@@ -291,36 +291,39 @@ class SumNode(RedNode):
|
|
291
291
|
def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
|
292
292
|
|
293
293
|
class AndNode(RedNode):
|
294
|
-
def
|
294
|
+
def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
|
295
|
+
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
|
295
296
|
subed = []
|
296
297
|
for node in self.nodes:
|
297
298
|
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
298
299
|
subed.append(sub)
|
299
300
|
return Node.ands(subed)
|
300
301
|
|
301
|
-
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
302
|
-
ret = typ(nodes)
|
303
|
-
if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
|
304
|
-
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
305
|
-
return create_node(ret)
|
306
|
-
|
307
302
|
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
308
|
-
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
303
|
+
def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
309
304
|
if isinstance(a, (int, float)): return a
|
310
|
-
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()})
|
305
|
+
ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
|
311
306
|
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
312
307
|
return ret.b
|
313
308
|
|
314
|
-
# symbolic int
|
315
|
-
sint = Union[
|
309
|
+
# symbolic int, these are allowed in a Tensor shape
|
310
|
+
sint = Union[int, Variable, MulNode, SumNode]
|
311
|
+
|
312
|
+
def render_mulnode(node:MulNode, ops, ctx):
|
313
|
+
# TODO: add ProdNode and remove this case
|
314
|
+
if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
|
315
|
+
return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
|
316
|
+
return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
|
316
317
|
|
317
|
-
render_python: Dict[Type, Callable] = {
|
318
|
-
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG"
|
318
|
+
render_python: Dict[Type, Callable[..., str]] = {
|
319
|
+
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
|
320
|
+
else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
|
321
|
+
else f"{self.expr}"),
|
319
322
|
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
320
|
-
MulNode:
|
323
|
+
MulNode: render_mulnode,
|
321
324
|
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
322
325
|
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
323
326
|
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
324
327
|
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
325
|
-
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
328
|
+
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
326
329
|
}
|
tinygrad/shape/view.py
CHANGED
@@ -1,35 +1,35 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import functools, operator
|
2
|
+
import functools, operator, itertools, math
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Optional, Dict, cast
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set, cast
|
5
5
|
from tinygrad.helpers import prod, all_int, argsort
|
6
|
-
from tinygrad.shape.symbolic import Node, NumNode, Variable,
|
6
|
+
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
|
7
7
|
|
8
8
|
@functools.lru_cache(maxsize=None)
|
9
|
-
def
|
10
|
-
return tuple(
|
9
|
+
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
10
|
+
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
11
11
|
|
12
12
|
@functools.lru_cache(maxsize=None)
|
13
|
-
def strides_for_shape(shape:Tuple[
|
14
|
-
|
15
|
-
|
16
|
-
return
|
13
|
+
def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
14
|
+
if not shape: return ()
|
15
|
+
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
|
16
|
+
return canonicalize_strides(shape, strides[::-1])
|
17
17
|
|
18
18
|
@functools.lru_cache(maxsize=None)
|
19
|
-
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]
|
19
|
+
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
20
20
|
# merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
|
21
21
|
if not shape: return tuple()
|
22
22
|
assert len(shape) == len(strides)
|
23
23
|
ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
|
24
|
-
#
|
25
|
-
|
24
|
+
# wrt merging zero strided dimensions
|
25
|
+
merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
|
26
26
|
for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
27
27
|
if sh == 1: continue
|
28
|
-
if
|
29
|
-
ret[-1] = (ret[-1][0] * sh, st, (sh if
|
28
|
+
if merging or ret[-1][1] == sh * st: # mergeable
|
29
|
+
ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
|
30
30
|
else: ret.append((sh, st, sh if st else 0)) # begin new
|
31
31
|
# merging ends with either non-zero strided dim or zero strided dim with mask range > 1
|
32
|
-
|
32
|
+
merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
|
33
33
|
return tuple(ret)
|
34
34
|
|
35
35
|
@functools.lru_cache(maxsize=None)
|
@@ -52,7 +52,8 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
52
52
|
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
53
53
|
|
54
54
|
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
55
|
-
if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
55
|
+
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
56
|
+
or old_dim % next_stride != 0): return view.mask, True
|
56
57
|
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
57
58
|
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
58
59
|
|
@@ -67,6 +68,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
67
68
|
|
68
69
|
return tuple(reversed(new_mask)), False
|
69
70
|
|
71
|
+
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
72
|
+
strides = strides_for_shape(shape)
|
73
|
+
result = []
|
74
|
+
for stride in strides:
|
75
|
+
here = offs // stride if stride else 0
|
76
|
+
result.append(here)
|
77
|
+
offs -= here * stride
|
78
|
+
return result
|
79
|
+
|
70
80
|
@dataclass(frozen=True)
|
71
81
|
class View:
|
72
82
|
shape:Tuple[sint, ...]
|
@@ -75,11 +85,28 @@ class View:
|
|
75
85
|
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
76
86
|
contiguous:bool
|
77
87
|
|
88
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
89
|
+
def size(self) -> int:
|
90
|
+
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
91
|
+
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
92
|
+
assert isinstance(ret, int), f"{ret=} is not int"
|
93
|
+
return ret
|
94
|
+
|
78
95
|
@staticmethod
|
79
96
|
@functools.lru_cache(maxsize=None)
|
80
97
|
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
81
|
-
strides =
|
98
|
+
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
99
|
+
# canonicalize empty mask
|
100
|
+
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
82
101
|
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
102
|
+
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
103
|
+
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
104
|
+
# TODO: assert comparison with LtNode to avoid mis-using symbolic
|
105
|
+
if mask and any(elim := [not (b+1 < e) for b,e in mask]):
|
106
|
+
if any(not (b < e) for b,e in mask):
|
107
|
+
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
108
|
+
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
109
|
+
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
83
110
|
return View(shape, strides, offset, mask, contiguous)
|
84
111
|
|
85
112
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
@@ -88,13 +115,81 @@ class View:
|
|
88
115
|
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
|
89
116
|
|
90
117
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
91
|
-
def unbind(self) -> View:
|
92
|
-
|
118
|
+
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
119
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
|
120
|
+
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
93
121
|
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
94
122
|
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
95
123
|
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
96
|
-
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
|
97
|
-
|
124
|
+
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
|
125
|
+
b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
126
|
+
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
127
|
+
|
128
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
129
|
+
def __add__(self, vm1:View) -> Optional[View]:
|
130
|
+
vm2 = self
|
131
|
+
if vm2.contiguous: return vm1
|
132
|
+
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
133
|
+
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
134
|
+
if vm1.mask:
|
135
|
+
for b,e in vm1.mask:
|
136
|
+
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
137
|
+
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
138
|
+
|
139
|
+
# Project vm1's offset and strides on to vm2.
|
140
|
+
origin = un1d(vm2.shape, vm1.offset)
|
141
|
+
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
142
|
+
strides: List[sint] = [0] * len(vm1.shape)
|
143
|
+
for d1, st in enumerate(vm1.strides):
|
144
|
+
if st == 0: continue
|
145
|
+
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
146
|
+
if (s1 := s1 - o) == 0: continue
|
147
|
+
terms[d2].append((d1, s1))
|
148
|
+
strides[d1] += s1 * vm2.strides[d2]
|
149
|
+
|
150
|
+
# Merge dimensions in vm2 if required.
|
151
|
+
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
152
|
+
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
153
|
+
merged_size, merged_term = 1, NumNode(0)
|
154
|
+
extents: List[Tuple[sint, Node]] = []
|
155
|
+
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
156
|
+
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
157
|
+
merged_size *= s
|
158
|
+
if not (merged_term >= merged_size) and not (merged_term < 0):
|
159
|
+
extents.append((merged_size, merged_term))
|
160
|
+
merged_size, merged_term = 1, NumNode(0)
|
161
|
+
if merged_term: return None
|
162
|
+
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
163
|
+
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
|
164
|
+
|
165
|
+
if vm2.mask:
|
166
|
+
# Try to project vm2's mask on to vm1.
|
167
|
+
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
168
|
+
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
169
|
+
if not (t.min < b or t.max >= e): continue
|
170
|
+
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
171
|
+
bad = True
|
172
|
+
continue
|
173
|
+
term = terms[d2]
|
174
|
+
if len(term) != 1:
|
175
|
+
if not term and newe: newe[0] = 0
|
176
|
+
else: bad = True
|
177
|
+
continue
|
178
|
+
d1, s1 = term[0]
|
179
|
+
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
180
|
+
bad = True
|
181
|
+
continue
|
182
|
+
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
183
|
+
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
184
|
+
|
185
|
+
# If any of vm1 was masked off, try again with that mask in place.
|
186
|
+
for b, e, s in zip(newb, newe, vm1.shape):
|
187
|
+
if b != 0 or e != s:
|
188
|
+
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
189
|
+
# Otherwise if vm2's mask was violated, then cannot merge.
|
190
|
+
if bad: return None
|
191
|
+
|
192
|
+
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
98
193
|
|
99
194
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
100
195
|
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
|
@@ -103,7 +198,10 @@ class View:
|
|
103
198
|
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
104
199
|
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
105
200
|
|
106
|
-
#
|
201
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
202
|
+
def minify(self):
|
203
|
+
min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
|
204
|
+
return nv if (nv := self.reshape(min_shape)) else self
|
107
205
|
|
108
206
|
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
109
207
|
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
@@ -117,8 +215,8 @@ class View:
|
|
117
215
|
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
118
216
|
|
119
217
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
120
|
-
def pad(self, arg: Tuple[Tuple[
|
121
|
-
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
218
|
+
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
219
|
+
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
|
122
220
|
if any(b or e for b, e in arg):
|
123
221
|
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
124
222
|
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
@@ -145,7 +243,8 @@ class View:
|
|
145
243
|
def permute(self, axis: Tuple[int, ...]) -> View:
|
146
244
|
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
147
245
|
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
148
|
-
return View.create(tuple(
|
246
|
+
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
247
|
+
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
149
248
|
|
150
249
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
151
250
|
def stride(self, mul: Tuple[int, ...]) -> View:
|
@@ -154,7 +253,8 @@ class View:
|
|
154
253
|
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
155
254
|
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
156
255
|
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
157
|
-
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m))
|
256
|
+
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
|
257
|
+
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
158
258
|
return View.create(new_shape, strides, self.offset + offset, mask)
|
159
259
|
|
160
260
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -177,18 +277,20 @@ class View:
|
|
177
277
|
if self.contiguous: return View.create(new_shape)
|
178
278
|
|
179
279
|
strides, r_new_shape = [], reversed(new_shape)
|
180
|
-
for merged_dim,
|
181
|
-
acc
|
280
|
+
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
281
|
+
acc = 1
|
282
|
+
# TODO: this <= and != is for symbolic!?
|
182
283
|
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
|
183
|
-
strides.append(new_stride
|
184
|
-
if new_dim
|
185
|
-
new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
284
|
+
strides.append(new_stride)
|
285
|
+
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
186
286
|
if acc != merged_dim: break
|
187
287
|
else:
|
188
288
|
strides += [0,] * (len(new_shape) - len(strides))
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
289
|
+
new_mask, extra = _reshape_mask(self, new_shape)
|
290
|
+
if not extra:
|
291
|
+
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
|
292
|
+
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
293
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
|
294
|
+
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
193
295
|
|
194
296
|
return None
|