tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  import functools
3
3
  from math import gcd
4
4
  from tinygrad.helpers import partition
5
- from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
5
+ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
6
6
 
7
7
  # NOTE: Python has different behavior for negative mod and floor div than c
8
8
  # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
@@ -10,14 +10,14 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set
10
10
  class Node:
11
11
  b: Union[Node, int]
12
12
  min: int
13
- max: int
13
+ max: sint
14
14
  def render(self, ops=None, ctx=None) -> Any:
15
15
  if ops is None: ops = render_python
16
16
  assert self.__class__ in (Variable, NumNode) or self.min != self.max
17
17
  return ops[type(self)](self, ops, ctx)
18
18
  def vars(self) -> Set[Variable]: return set()
19
19
  # substitute Variables with the values in var_vals
20
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
20
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
21
21
  def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
22
22
 
23
23
  @functools.cached_property
@@ -32,7 +32,7 @@ class Node:
32
32
  if not isinstance(other, Node): return NotImplemented
33
33
  return self.key == other.key
34
34
  def __neg__(self): return self*-1
35
- def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b, Node) else NumNode(b)])
35
+ def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
36
36
  def __radd__(self, b:int): return self+b
37
37
  def __sub__(self, b:Union[Node,int]): return self+-b
38
38
  def __rsub__(self, b:int): return -self+b
@@ -43,24 +43,20 @@ class Node:
43
43
  def __mul__(self, b:Union[Node, int]):
44
44
  if b == 0: return NumNode(0)
45
45
  if b == 1: return self
46
- if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
47
46
  return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
48
47
  def __rmul__(self, b:int): return self*b
49
48
 
50
49
  # *** complex ops ***
51
50
 
52
- def __rfloordiv__(self, b:int):
53
- if self.min > b >= 0: return NumNode(0)
54
- if isinstance(self, NumNode): return NumNode(b // self.b)
55
- raise RuntimeError(f"not supported: {b} // {self}")
51
+ def __rfloordiv__(self, b:int): return NumNode(b) // self
56
52
  def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
57
53
  if isinstance(b, Node):
58
- if b.__class__ is NumNode: return self // b.b
54
+ if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
59
55
  if self == b: return NumNode(1)
60
56
  if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
61
57
  raise RuntimeError(f"not supported: {self} // {b}")
62
58
  assert b != 0
63
- if b < 0: return (self//-b)*-1
59
+ if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
64
60
  if b == 1: return self
65
61
 
66
62
  # the numerator of div is not allowed to be negative
@@ -70,10 +66,7 @@ class Node:
70
66
  return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
71
67
  return create_node(DivNode(self, b))
72
68
 
73
- def __rmod__(self, b:int):
74
- if self.min > b >= 0: return NumNode(b)
75
- if isinstance(self, NumNode): return NumNode(b % self.b)
76
- raise RuntimeError(f"not supported: {b} % {self}")
69
+ def __rmod__(self, b:int): return NumNode(b) % self
77
70
  def __mod__(self, b:Union[Node,int]):
78
71
  if isinstance(b, Node):
79
72
  if b.__class__ is NumNode: return self % b.b
@@ -102,7 +95,7 @@ class Node:
102
95
  else: mul_groups[node] = mul_groups.get(node, 0) + 1
103
96
  new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
104
97
  if num_node_sum: new_nodes.append(NumNode(num_node_sum))
105
- return create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
98
+ return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
106
99
 
107
100
  @staticmethod
108
101
  def ands(nodes:List[Node]) -> Node:
@@ -112,19 +105,20 @@ class Node:
112
105
 
113
106
  # filter 1s
114
107
  nodes = [x for x in nodes if x.min != x.max]
115
- return create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
108
+ return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
116
109
 
117
110
  # 4 basic node types
118
111
 
119
112
  class Variable(Node):
120
113
  def __new__(cls, *args):
121
- if len(args) == 0: return super().__new__(cls) # fix pickle
122
114
  expr, nmin, nmax = args
123
115
  assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
124
116
  if nmin == nmax: return NumNode(nmin)
125
117
  return super().__new__(cls)
126
118
 
127
- def __init__(self, expr:Optional[str], nmin:int, nmax:int):
119
+ def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
120
+
121
+ def __init__(self, expr:str, nmin:int, nmax:sint):
128
122
  self.expr, self.min, self.max = expr, nmin, nmax
129
123
  self._val: Optional[int] = None
130
124
  @property
@@ -139,7 +133,7 @@ class Variable(Node):
139
133
  assert self.val is not None, f"cannot unbind {self}"
140
134
  return Variable(self.expr, self.min, self.max), self.val
141
135
  def vars(self): return {self}
142
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return var_vals.get(self, self)
136
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
143
137
 
144
138
  class NumNode(Node):
145
139
  def __init__(self, num:int):
@@ -149,97 +143,129 @@ class NumNode(Node):
149
143
  def bind(self, val):
150
144
  assert self.b == val, f"cannot bind {val} to {self}"
151
145
  return self
146
+ def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
152
147
  def __eq__(self, other): return self.b == other
153
- def __hash__(self): return self.hash # needed with __eq__ override
154
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self
148
+ def __hash__(self): return hash(self.b) # needed with __eq__ override
149
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
155
150
 
156
151
  def create_node(ret:Node):
157
152
  assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
158
153
  if ret.min == ret.max: return NumNode(ret.min)
159
154
  return ret
160
155
 
156
+ def create_lt_node(lhs:Node, b:Union[Node, int]):
157
+ if isinstance(lhs, SumNode):
158
+ if isinstance(b, int):
159
+ new_sum = []
160
+ for x in lhs.nodes:
161
+ # TODO: should we just force the last one to always be the number
162
+ if isinstance(x, NumNode): b -= x.b
163
+ else: new_sum.append(x)
164
+ lhs = Node.sum(new_sum)
165
+ nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
166
+ assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
167
+ muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
168
+ if muls:
169
+ # NOTE: gcd in python 3.8 takes exactly 2 args
170
+ mul_gcd = b
171
+ for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
172
+ all_others = Node.sum(others)
173
+ if all_others.min >= 0 and all_others.max < mul_gcd:
174
+ lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
175
+ return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
176
+ if isinstance(lhs, MulNode):
177
+ if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
178
+ sgn = 1 if lhs.b > 0 else -1
179
+ return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
180
+ return create_node(LtNode(lhs, b))
181
+
182
+ def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
183
+
161
184
  class OpNode(Node):
162
185
  def __init__(self, a:Node, b:Union[Node, int]):
163
186
  self.a, self.b = a, b
164
187
  self.min, self.max = self.get_bounds()
165
188
  def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
166
- def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
189
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
167
190
 
168
191
  class LtNode(OpNode):
169
- def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
170
192
  def get_bounds(self) -> Tuple[int, int]:
193
+ if self.a == self.b: return (0, 0)
171
194
  if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
172
195
  return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
173
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
174
- return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
196
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
197
+ return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
175
198
 
176
199
  class MulNode(OpNode):
177
- def __lt__(self, b: Union[Node, int]):
178
- if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b)
179
- sgn = 1 if self.b > 0 else -1
180
- return Node.__lt__(self.a*sgn, (b + abs(self.b) - 1)//abs(self.b))
181
200
  def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
182
201
  def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
183
202
  if self.b % b == 0: return self.a*(self.b//b)
184
203
  if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
185
204
  return Node.__floordiv__(self, b, factoring_allowed)
186
205
  def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
187
- def get_bounds(self) -> Tuple[int, int]: return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
188
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
206
+ def get_bounds(self) -> Tuple[int, sint]:
207
+ assert self.a.min >= 0
208
+ if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
209
+ return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
210
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
189
211
  return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
190
212
 
191
213
  class DivNode(OpNode):
192
214
  def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
193
- def get_bounds(self) -> Tuple[int, int]:
215
+ def get_bounds(self) -> Tuple[int, sint]:
194
216
  assert self.a.min >= 0 and isinstance(self.b, int)
195
217
  return self.a.min//self.b, self.a.max//self.b
196
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self.a.substitute(var_vals) // self.b
218
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
197
219
 
198
220
  class ModNode(OpNode):
199
221
  def __mod__(self, b: Union[Node, int]):
200
- if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b)
201
- return self.a % b if self.b % b == 0 else Node.__mod__(self, b)
222
+ if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
223
+ return Node.__mod__(self, b)
202
224
  def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
203
225
  return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
204
- def get_bounds(self) -> Tuple[int, int]:
226
+ def get_bounds(self) -> Tuple[int, sint]:
205
227
  assert self.a.min >= 0 and isinstance(self.b, int)
206
- return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b) # noqa: E501
207
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return self.a.substitute(var_vals) % self.b
228
+ if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
229
+ return (self.a.min%self.b, self.a.max%self.b)
230
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
208
231
 
209
232
  class RedNode(Node):
210
- def __init__(self, nodes:List[Node]): self.nodes = nodes
233
+ def __init__(self, nodes:List[Node]):
234
+ self.nodes = nodes
235
+ self.min, self.max = self.get_bounds()
211
236
  def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
237
+ def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
212
238
 
213
239
  class SumNode(RedNode):
240
+ def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
214
241
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
215
242
  def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
216
243
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
217
- def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
244
+ def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
245
+ if self == b: return NumNode(1)
218
246
  fully_divided: List[Node] = []
219
247
  rest: List[Node] = []
220
- if isinstance(b, SumNode):
221
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
222
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
223
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
224
248
  if isinstance(b, Node):
225
249
  for x in self.flat_components:
226
250
  if x % b == 0: fully_divided.append(x // b)
227
251
  else: rest.append(x)
228
- if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
252
+ if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
229
253
  return Node.__floordiv__(self, b, False)
230
254
  if b == 1: return self
231
255
  if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
232
- fully_divided, rest = [], []
233
256
  _gcd = b
234
257
  divisor = 1
235
258
  for x in self.flat_components:
236
259
  if x.__class__ in (NumNode, MulNode):
237
- if x.b%b == 0: fully_divided.append(x//b)
260
+ if x.b % b == 0: fully_divided.append(x // b)
238
261
  else:
262
+ if x.__class__ is NumNode and (div := x.b // b):
263
+ fully_divided.append(NumNode(div))
264
+ x = NumNode(x.b - b * div)
239
265
  rest.append(x)
240
266
  if isinstance(x.b, int):
241
267
  _gcd = gcd(_gcd, x.b)
242
- if x.__class__ == MulNode and divisor == 1 and b%x.b == 0: divisor = x.b
268
+ if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
243
269
  else:
244
270
  _gcd = 1
245
271
  else:
@@ -251,39 +277,13 @@ class SumNode(RedNode):
251
277
 
252
278
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
253
279
  def __mod__(self, b: Union[Node, int]):
254
- if isinstance(b, SumNode):
255
- nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
256
- de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
257
- if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
280
+ if self == b: return NumNode(0)
258
281
  if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
259
- new_nodes: List[Node] = []
260
- for x in self.nodes:
261
- if x.__class__ in (NumNode, MulNode): new_nodes.append(x%b) # might simplify
262
- else: new_nodes.append(x)
263
- return Node.__mod__(Node.sum(new_nodes), b)
264
-
265
- def __lt__(self, b:Union[Node,int]):
266
- lhs: Node = self
267
- if isinstance(b, int):
268
- new_sum = []
269
- for x in self.nodes:
270
- # TODO: should we just force the last one to always be the number
271
- if isinstance(x, NumNode): b -= x.b
272
- else: new_sum.append(x)
273
- lhs = Node.sum(new_sum)
274
- nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
275
- assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
276
- muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
277
- if muls:
278
- # NOTE: gcd in python 3.8 takes exactly 2 args
279
- mul_gcd = b
280
- for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
281
- all_others = Node.sum(others)
282
- if all_others.min >= 0 and all_others.max < mul_gcd:
283
- lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
284
- return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
282
+ new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
283
+ return Node.__mod__(new_sum, b)
285
284
 
286
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
285
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
286
+ return Node.sum([node.substitute(var_vals) for node in self.nodes])
287
287
 
288
288
  # recursively expand sumnode components
289
289
  # TODO: can remove this if there's no SumNode inside SumNode
@@ -291,36 +291,39 @@ class SumNode(RedNode):
291
291
  def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
292
292
 
293
293
  class AndNode(RedNode):
294
- def substitute(self, var_vals: Dict[Variable, Node]) -> Node:
294
+ def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
295
+ def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
295
296
  subed = []
296
297
  for node in self.nodes:
297
298
  if not (sub:=node.substitute(var_vals)): return NumNode(0)
298
299
  subed.append(sub)
299
300
  return Node.ands(subed)
300
301
 
301
- def create_rednode(typ:Type[RedNode], nodes:List[Node]):
302
- ret = typ(nodes)
303
- if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
304
- elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
305
- return create_node(ret)
306
-
307
302
  def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
308
- def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
303
+ def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
309
304
  if isinstance(a, (int, float)): return a
310
- ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()})
305
+ ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
311
306
  assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
312
307
  return ret.b
313
308
 
314
- # symbolic int
315
- sint = Union[Node, int]
309
+ # symbolic int, these are allowed in a Tensor shape
310
+ sint = Union[int, Variable, MulNode, SumNode]
311
+
312
+ def render_mulnode(node:MulNode, ops, ctx):
313
+ # TODO: add ProdNode and remove this case
314
+ if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
315
+ return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
316
+ return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
316
317
 
317
- render_python: Dict[Type, Callable] = {
318
- Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"), # noqa: E501
318
+ render_python: Dict[Type, Callable[..., str]] = {
319
+ Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
320
+ else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
321
+ else f"{self.expr}"),
319
322
  NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
320
- MulNode: lambda self,ops,ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a,Variable) and isinstance(self.b,Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", # noqa: E501
323
+ MulNode: render_mulnode,
321
324
  DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
322
325
  ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
323
326
  LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
324
327
  SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
325
- AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
328
+ AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
326
329
  }
tinygrad/shape/view.py CHANGED
@@ -1,35 +1,35 @@
1
1
  from __future__ import annotations
2
- import functools, operator
2
+ import functools, operator, itertools, math
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, cast
4
+ from typing import Tuple, List, Optional, Dict, Set, cast
5
5
  from tinygrad.helpers import prod, all_int, argsort
6
- from tinygrad.shape.symbolic import Node, NumNode, Variable, Set, sint
6
+ from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
7
7
 
8
8
  @functools.lru_cache(maxsize=None)
9
- def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
10
- return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
9
+ def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
10
+ return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
11
11
 
12
12
  @functools.lru_cache(maxsize=None)
13
- def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
14
- strides = [1] if shape else []
15
- for d in reversed(shape[1:]): strides.append(d*strides[-1])
16
- return filter_strides(shape, tuple(reversed(strides)))
13
+ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
14
+ if not shape: return ()
15
+ strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
16
+ return canonicalize_strides(shape, strides[::-1])
17
17
 
18
18
  @functools.lru_cache(maxsize=None)
19
- def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: # noqa: E501
19
+ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
20
20
  # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
21
21
  if not shape: return tuple()
22
22
  assert len(shape) == len(strides)
23
23
  ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
24
- # state (0, 1, 2) -> (none, in-progress, done). wrt merging zero strided dimensions
25
- state = 1 if mask and strides[0] == 0 and shape[0] != 1 and mask[0][1] - mask[0][0] == 1 else 0
24
+ # wrt merging zero strided dimensions
25
+ merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
26
26
  for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
27
  if sh == 1: continue
28
- if state == 1 or ret[-1][1] == sh * st: # mergeable
29
- ret[-1] = (ret[-1][0] * sh, st, (sh if state == 1 else ret[-1][2] * sh) if st else 0)
28
+ if merging or ret[-1][1] == sh * st: # mergeable
29
+ ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
30
30
  else: ret.append((sh, st, sh if st else 0)) # begin new
31
31
  # merging ends with either non-zero strided dim or zero strided dim with mask range > 1
32
- state = 1 if (st == 0 and mask and mask[i][1] - mask[i][0] == 1) else (2 if state != 0 else 0)
32
+ merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
33
33
  return tuple(ret)
34
34
 
35
35
  @functools.lru_cache(maxsize=None)
@@ -52,7 +52,8 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
52
52
  if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
53
53
 
54
54
  else: # mask can only be splitted if reshape doesn't cut across the mask.
55
- if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride): return view.mask, True
55
+ if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
56
+ or old_dim % next_stride != 0): return view.mask, True
56
57
  new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
57
58
  curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
58
59
 
@@ -67,6 +68,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
67
68
 
68
69
  return tuple(reversed(new_mask)), False
69
70
 
71
+ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
72
+ strides = strides_for_shape(shape)
73
+ result = []
74
+ for stride in strides:
75
+ here = offs // stride if stride else 0
76
+ result.append(here)
77
+ offs -= here * stride
78
+ return result
79
+
70
80
  @dataclass(frozen=True)
71
81
  class View:
72
82
  shape:Tuple[sint, ...]
@@ -75,11 +85,28 @@ class View:
75
85
  mask:Optional[Tuple[Tuple[sint, sint], ...]]
76
86
  contiguous:bool
77
87
 
88
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
89
+ def size(self) -> int:
90
+ # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
91
+ ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
92
+ assert isinstance(ret, int), f"{ret=} is not int"
93
+ return ret
94
+
78
95
  @staticmethod
79
96
  @functools.lru_cache(maxsize=None)
80
97
  def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
81
- strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
98
+ strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
99
+ # canonicalize empty mask
100
+ if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
82
101
  contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
102
+ # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
103
+ # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
104
+ # TODO: assert comparison with LtNode to avoid mis-using symbolic
105
+ if mask and any(elim := [not (b+1 < e) for b,e in mask]):
106
+ if any(not (b < e) for b,e in mask):
107
+ strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
108
+ offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
109
+ strides = tuple(0 if e else st for st,e in zip(strides, elim))
83
110
  return View(shape, strides, offset, mask, contiguous)
84
111
 
85
112
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@@ -88,13 +115,81 @@ class View:
88
115
  return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
89
116
 
90
117
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
91
- def unbind(self) -> View:
92
- unbound_vars:Dict[Variable,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
118
+ def unbind(self) -> Tuple[View, Dict[Variable, int]]:
119
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
120
+ unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
93
121
  new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
94
122
  new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
95
123
  new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
96
- new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None # noqa: E501
97
- return View.create(new_shape, new_strides, new_offset, new_mask)
124
+ new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
125
+ b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
126
+ return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
127
+
128
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
129
+ def __add__(self, vm1:View) -> Optional[View]:
130
+ vm2 = self
131
+ if vm2.contiguous: return vm1
132
+ if vm1.contiguous and vm1.shape == vm2.shape: return vm2
133
+ if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
134
+ if vm1.mask:
135
+ for b,e in vm1.mask:
136
+ if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
137
+ return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
138
+
139
+ # Project vm1's offset and strides on to vm2.
140
+ origin = un1d(vm2.shape, vm1.offset)
141
+ terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
142
+ strides: List[sint] = [0] * len(vm1.shape)
143
+ for d1, st in enumerate(vm1.strides):
144
+ if st == 0: continue
145
+ for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
146
+ if (s1 := s1 - o) == 0: continue
147
+ terms[d2].append((d1, s1))
148
+ strides[d1] += s1 * vm2.strides[d2]
149
+
150
+ # Merge dimensions in vm2 if required.
151
+ # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
152
+ idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
153
+ merged_size, merged_term = 1, NumNode(0)
154
+ extents: List[Tuple[sint, Node]] = []
155
+ for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
156
+ merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
157
+ merged_size *= s
158
+ if not (merged_term >= merged_size) and not (merged_term < 0):
159
+ extents.append((merged_size, merged_term))
160
+ merged_size, merged_term = 1, NumNode(0)
161
+ if merged_term: return None
162
+ if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
163
+ return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
164
+
165
+ if vm2.mask:
166
+ # Try to project vm2's mask on to vm1.
167
+ newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
168
+ for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
169
+ if not (t.min < b or t.max >= e): continue
170
+ if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
171
+ bad = True
172
+ continue
173
+ term = terms[d2]
174
+ if len(term) != 1:
175
+ if not term and newe: newe[0] = 0
176
+ else: bad = True
177
+ continue
178
+ d1, s1 = term[0]
179
+ if not isinstance(s1, int) or not isinstance(newe[d1], int):
180
+ bad = True
181
+ continue
182
+ newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
183
+ newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
184
+
185
+ # If any of vm1 was masked off, try again with that mask in place.
186
+ for b, e, s in zip(newb, newe, vm1.shape):
187
+ if b != 0 or e != s:
188
+ return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
189
+ # Otherwise if vm2's mask was violated, then cannot merge.
190
+ if bad: return None
191
+
192
+ return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
98
193
 
99
194
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
100
195
  def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
@@ -103,7 +198,10 @@ class View:
103
198
  ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
104
199
  return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
105
200
 
106
- # MovementOps live here now
201
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
202
+ def minify(self):
203
+ min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
204
+ return nv if (nv := self.reshape(min_shape)) else self
107
205
 
108
206
  def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
109
207
  offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
@@ -117,8 +215,8 @@ class View:
117
215
  return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
118
216
 
119
217
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
120
- def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
121
- assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
218
+ def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
219
+ assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
122
220
  if any(b or e for b, e in arg):
123
221
  zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
124
222
  mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
@@ -145,7 +243,8 @@ class View:
145
243
  def permute(self, axis: Tuple[int, ...]) -> View:
146
244
  assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
147
245
  assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
148
- return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) # noqa: E501
246
+ return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
247
+ tuple(self.mask[a] for a in axis) if self.mask is not None else None)
149
248
 
150
249
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
151
250
  def stride(self, mul: Tuple[int, ...]) -> View:
@@ -154,7 +253,8 @@ class View:
154
253
  strides = tuple([z*m for z,m in zip(self.strides, mul)])
155
254
  new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
156
255
  offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
157
- mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None # noqa: E501
256
+ mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
257
+ for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
158
258
  return View.create(new_shape, strides, self.offset + offset, mask)
159
259
 
160
260
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -177,18 +277,20 @@ class View:
177
277
  if self.contiguous: return View.create(new_shape)
178
278
 
179
279
  strides, r_new_shape = [], reversed(new_shape)
180
- for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
181
- acc, new_stride = 1, s
280
+ for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
281
+ acc = 1
282
+ # TODO: this <= and != is for symbolic!?
182
283
  while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
183
- strides.append(new_stride if new_dim != 1 else 0)
184
- if new_dim == 1: continue
185
- new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
284
+ strides.append(new_stride)
285
+ if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
186
286
  if acc != merged_dim: break
187
287
  else:
188
288
  strides += [0,] * (len(new_shape) - len(strides))
189
- mask, extra = _reshape_mask(self, new_shape)
190
- fstrides = filter_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides)))
191
- extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, fstrides)) if mask else 0) # noqa: E501
192
- if not extra: return View.create(new_shape, fstrides, self.offset + extra_offset, mask)
289
+ new_mask, extra = _reshape_mask(self, new_shape)
290
+ if not extra:
291
+ new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
292
+ extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
293
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
294
+ return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
193
295
 
194
296
  return None