tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py CHANGED
@@ -1,35 +1,37 @@
1
1
  from __future__ import annotations
2
- import functools, operator
2
+ import functools, operator, itertools, math
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Optional, Dict, cast
4
+ from typing import Tuple, List, Optional, Dict, Set, cast
5
5
  from tinygrad.helpers import prod, all_int, argsort
6
- from tinygrad.shape.symbolic import Node, NumNode, Variable, Set, sint
6
+ from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
7
7
 
8
8
  @functools.lru_cache(maxsize=None)
9
- def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
10
- return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
9
+ def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
10
+ return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
11
11
 
12
12
  @functools.lru_cache(maxsize=None)
13
- def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
14
- strides = [1] if shape else []
15
- for d in reversed(shape[1:]): strides.append(d*strides[-1])
16
- return filter_strides(shape, tuple(reversed(strides)))
13
+ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
14
+ if not shape: return ()
15
+ strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
16
+ return canonicalize_strides(shape, strides)
17
17
 
18
18
  @functools.lru_cache(maxsize=None)
19
- def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]] = None) -> Tuple[Tuple[int, int, int], ...]: # noqa: E501
20
- # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
21
- if not shape: return tuple()
22
- assert len(shape) == len(strides)
19
+ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
20
+ # merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
21
+ if not shape: return ()
22
+ assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
23
23
  ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
24
- # state (0, 1, 2) -> (none, in-progress, done). wrt merging zero strided dimensions
25
- state = 1 if mask and strides[0] == 0 and shape[0] != 1 and mask[0][1] - mask[0][0] == 1 else 0
26
- for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
- if sh == 1: continue
28
- if state == 1 or ret[-1][1] == sh * st: # mergeable
29
- ret[-1] = (ret[-1][0] * sh, st, (sh if state == 1 else ret[-1][2] * sh) if st else 0)
30
- else: ret.append((sh, st, sh if st else 0)) # begin new
31
- # merging ends with either non-zero strided dim or zero strided dim with mask range > 1
32
- state = 1 if (st == 0 and mask and mask[i][1] - mask[i][0] == 1) else (2 if state != 0 else 0)
24
+ # merge this dim to next dim if size is 1
25
+ merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
26
+ for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
+ last_s, last_st, last_pre_expand_s = ret[-1]
28
+ # always merge 1
29
+ if s == 1: continue
30
+ # merge last dim with this dim if merging or strides matched
31
+ if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
32
+ else: ret.append((s, st, s if st else 0))
33
+ # merge this dim to next dim if size is 1
34
+ merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
33
35
  return tuple(ret)
34
36
 
35
37
  @functools.lru_cache(maxsize=None)
@@ -52,7 +54,8 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
52
54
  if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
53
55
 
54
56
  else: # mask can only be splitted if reshape doesn't cut across the mask.
55
- if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride): return view.mask, True
57
+ if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
58
+ or old_dim % next_stride != 0): return view.mask, True
56
59
  new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
57
60
  curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
58
61
 
@@ -67,6 +70,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
67
70
 
68
71
  return tuple(reversed(new_mask)), False
69
72
 
73
+ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
74
+ strides = strides_for_shape(shape)
75
+ result = []
76
+ for stride in strides:
77
+ here = offs // stride if stride else 0
78
+ result.append(here)
79
+ offs -= here * stride
80
+ return result
81
+
70
82
  @dataclass(frozen=True)
71
83
  class View:
72
84
  shape:Tuple[sint, ...]
@@ -75,10 +87,29 @@ class View:
75
87
  mask:Optional[Tuple[Tuple[sint, sint], ...]]
76
88
  contiguous:bool
77
89
 
90
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
91
+ def size(self) -> int:
92
+ # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
93
+ ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
94
+ assert isinstance(ret, int), f"{ret=} is not int"
95
+ return ret
96
+
78
97
  @staticmethod
79
98
  @functools.lru_cache(maxsize=None)
80
99
  def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
81
- strides = filter_strides(shape, strides) if strides else strides_for_shape(shape)
100
+ strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
101
+ # canonicalize 0 in shape
102
+ if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
103
+ # canonicalize empty mask
104
+ if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
105
+ # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
106
+ # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
107
+ # TODO: assert comparison with LtNode to avoid mis-using symbolic
108
+ if mask and any(elim := [not (b+1 < e) for b,e in mask]):
109
+ if any(not (b < e) for b,e in mask):
110
+ strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
111
+ offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
112
+ strides = tuple(0 if e else st for st,e in zip(strides, elim))
82
113
  contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
83
114
  return View(shape, strides, offset, mask, contiguous)
84
115
 
@@ -88,13 +119,81 @@ class View:
88
119
  return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
89
120
 
90
121
  @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
91
- def unbind(self) -> View:
92
- unbound_vars:Dict[Variable,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
122
+ def unbind(self) -> Tuple[View, Dict[Variable, int]]:
123
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
124
+ unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
93
125
  new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
94
126
  new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
95
127
  new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
96
- new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None # noqa: E501
97
- return View.create(new_shape, new_strides, new_offset, new_mask)
128
+ new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
129
+ b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
130
+ return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
131
+
132
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
133
+ def __add__(self, vm1:View) -> Optional[View]:
134
+ vm2 = self
135
+ if vm2.contiguous: return vm1
136
+ if vm1.contiguous and vm1.shape == vm2.shape: return vm2
137
+ if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
138
+ if vm1.mask:
139
+ for b,e in vm1.mask:
140
+ if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
141
+ return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
142
+
143
+ # Project vm1's offset and strides on to vm2.
144
+ origin = un1d(vm2.shape, vm1.offset)
145
+ terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
146
+ strides: List[sint] = [0] * len(vm1.shape)
147
+ for d1, st in enumerate(vm1.strides):
148
+ if st == 0: continue
149
+ for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
150
+ if (s1 := s1 - o) == 0: continue
151
+ terms[d2].append((d1, s1))
152
+ strides[d1] += s1 * vm2.strides[d2]
153
+
154
+ # Merge dimensions in vm2 if required.
155
+ # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
156
+ idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
157
+ merged_size, merged_term = 1, NumNode(0)
158
+ extents: List[Tuple[sint, Node]] = []
159
+ for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
160
+ merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
161
+ merged_size *= s
162
+ if not (merged_term >= merged_size) and not (merged_term < 0):
163
+ extents.append((merged_size, merged_term))
164
+ merged_size, merged_term = 1, NumNode(0)
165
+ if merged_term: return None
166
+ if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
167
+ return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
168
+
169
+ if vm2.mask:
170
+ # Try to project vm2's mask on to vm1.
171
+ newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
172
+ for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
173
+ if not (t.min < b or t.max >= e): continue
174
+ if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
175
+ bad = True
176
+ continue
177
+ term = terms[d2]
178
+ if len(term) != 1:
179
+ if not term and newe: newe[0] = 0
180
+ else: bad = True
181
+ continue
182
+ d1, s1 = term[0]
183
+ if not isinstance(s1, int) or not isinstance(newe[d1], int):
184
+ bad = True
185
+ continue
186
+ newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
187
+ newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
188
+
189
+ # If any of vm1 was masked off, try again with that mask in place.
190
+ for b, e, s in zip(newb, newe, vm1.shape):
191
+ if b != 0 or e != s:
192
+ return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
193
+ # Otherwise if vm2's mask was violated, then cannot merge.
194
+ if bad: return None
195
+
196
+ return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
98
197
 
99
198
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
100
199
  def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
@@ -103,7 +202,10 @@ class View:
103
202
  ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
104
203
  return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
105
204
 
106
- # MovementOps live here now
205
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
206
+ def minify(self):
207
+ min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
208
+ return nv if (nv := self.reshape(min_shape)) else self
107
209
 
108
210
  def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
109
211
  offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
@@ -117,8 +219,8 @@ class View:
117
219
  return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
118
220
 
119
221
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
120
- def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View:
121
- assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
222
+ def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
223
+ assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
122
224
  if any(b or e for b, e in arg):
123
225
  zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
124
226
  mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
@@ -143,9 +245,9 @@ class View:
143
245
 
144
246
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
145
247
  def permute(self, axis: Tuple[int, ...]) -> View:
146
- assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
147
- assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
148
- return View.create(tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None) # noqa: E501
248
+ assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
249
+ return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
250
+ tuple(self.mask[a] for a in axis) if self.mask is not None else None)
149
251
 
150
252
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
151
253
  def stride(self, mul: Tuple[int, ...]) -> View:
@@ -154,7 +256,8 @@ class View:
154
256
  strides = tuple([z*m for z,m in zip(self.strides, mul)])
155
257
  new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
156
258
  offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
157
- mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None # noqa: E501
259
+ mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
260
+ for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
158
261
  return View.create(new_shape, strides, self.offset + offset, mask)
159
262
 
160
263
  @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -166,7 +269,7 @@ class View:
166
269
  assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
167
270
  return View.create(new_shape)
168
271
  # check for the same size
169
- if all_int(self.shape):
272
+ if (self_all_int := all_int(self.shape)):
170
273
  assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
171
274
  if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
172
275
  raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
@@ -176,19 +279,33 @@ class View:
176
279
  # after the asserts, it's okay to check contiguous
177
280
  if self.contiguous: return View.create(new_shape)
178
281
 
282
+ # if it's not contiguous and new shape is symbolic, check if it's directly replaceable
283
+ if self_all_int and not all_int(new_shape):
284
+ if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
285
+ for si, so in zip(self.shape, new_shape):
286
+ if isinstance(so, int):
287
+ if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
288
+ else:
289
+ var_vals = {v: v.unbind()[1] for v in so.vars()}
290
+ if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
291
+ # all dimensions matched, return the new view directly
292
+ return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
293
+
179
294
  strides, r_new_shape = [], reversed(new_shape)
180
- for merged_dim, s, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
181
- acc, new_stride = 1, s
295
+ for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
296
+ acc = 1
297
+ # TODO: this <= and != is for symbolic!?
182
298
  while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
183
- strides.append(new_stride if new_dim != 1 else 0)
184
- if new_dim == 1: continue
185
- new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
299
+ strides.append(new_stride)
300
+ if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
186
301
  if acc != merged_dim: break
187
302
  else:
188
303
  strides += [0,] * (len(new_shape) - len(strides))
189
- mask, extra = _reshape_mask(self, new_shape)
190
- fstrides = filter_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides)))
191
- extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, fstrides)) if mask else 0) # noqa: E501
192
- if not extra: return View.create(new_shape, fstrides, self.offset + extra_offset, mask)
304
+ new_mask, extra = _reshape_mask(self, new_shape)
305
+ if not extra:
306
+ new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
307
+ extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
308
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
309
+ return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
193
310
 
194
311
  return None