tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py ADDED
@@ -0,0 +1,296 @@
1
+ from __future__ import annotations
2
+ import functools, operator, itertools, math
3
+ from dataclasses import dataclass
4
+ from typing import Tuple, List, Optional, Dict, Set, cast
5
+ from tinygrad.helpers import prod, all_int, argsort
6
+ from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
7
+
8
+ @functools.lru_cache(maxsize=None)
9
+ def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
10
+ return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
11
+
12
+ @functools.lru_cache(maxsize=None)
13
+ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
14
+ if not shape: return ()
15
+ strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
16
+ return canonicalize_strides(shape, strides[::-1])
17
+
18
+ @functools.lru_cache(maxsize=None)
19
+ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
20
+ # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
21
+ if not shape: return tuple()
22
+ assert len(shape) == len(strides)
23
+ ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
24
+ # wrt merging zero strided dimensions
25
+ merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
26
+ for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
27
+ if sh == 1: continue
28
+ if merging or ret[-1][1] == sh * st: # mergeable
29
+ ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
30
+ else: ret.append((sh, st, sh if st else 0)) # begin new
31
+ # merging ends with either non-zero strided dim or zero strided dim with mask range > 1
32
+ merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
33
+ return tuple(ret)
34
+
35
+ @functools.lru_cache(maxsize=None)
36
+ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]:
37
+ if view.mask is None: return view.mask, False
38
+ if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True
39
+ new_mask: List[Tuple[int, int]] = []
40
+
41
+ r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
42
+ curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
43
+ if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
44
+
45
+ while len(new_mask) < len(new_shape):
46
+ (l, r), next_stride = mask, new_dim * curr_stride
47
+
48
+ if old_dim >= next_stride: # need to split mask.
49
+ if old_dim == next_stride: # simply copy the mask and get next batch for merging
50
+ new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
51
+ curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
52
+ if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
53
+
54
+ else: # mask can only be splitted if reshape doesn't cut across the mask.
55
+ if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
56
+ or old_dim % next_stride != 0): return view.mask, True
57
+ new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
58
+ curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
59
+
60
+ else:
61
+ next_mask = next(r_masks, (0, 1))
62
+ # combine if the mask can unfold continuously
63
+ if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True
64
+ mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
65
+
66
+ for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
67
+ if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask
68
+
69
+ return tuple(reversed(new_mask)), False
70
+
71
+ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
72
+ strides = strides_for_shape(shape)
73
+ result = []
74
+ for stride in strides:
75
+ here = offs // stride if stride else 0
76
+ result.append(here)
77
+ offs -= here * stride
78
+ return result
79
+
80
+ @dataclass(frozen=True)
81
+ class View:
82
+ shape:Tuple[sint, ...]
83
+ strides:Tuple[sint, ...]
84
+ offset:sint
85
+ mask:Optional[Tuple[Tuple[sint, sint], ...]]
86
+ contiguous:bool
87
+
88
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
89
+ def size(self) -> int:
90
+ # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
91
+ ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
92
+ assert isinstance(ret, int), f"{ret=} is not int"
93
+ return ret
94
+
95
+ @staticmethod
96
+ @functools.lru_cache(maxsize=None)
97
+ def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
98
+ strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
99
+ # canonicalize empty mask
100
+ if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
101
+ contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
102
+ # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
103
+ # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
104
+ # TODO: assert comparison with LtNode to avoid mis-using symbolic
105
+ if mask and any(elim := [not (b+1 < e) for b,e in mask]):
106
+ if any(not (b < e) for b,e in mask):
107
+ strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
108
+ offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
109
+ strides = tuple(0 if e else st for st,e in zip(strides, elim))
110
+ return View(shape, strides, offset, mask, contiguous)
111
+
112
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
113
+ def vars(self) -> Set[Variable]:
114
+ flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
115
+ return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
116
+
117
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
118
+ def unbind(self) -> Tuple[View, Dict[Variable, int]]:
119
+ var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
120
+ unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
121
+ new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
122
+ new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
123
+ new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
124
+ new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
125
+ b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
126
+ return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
127
+
128
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
129
+ def __add__(self, vm1:View) -> Optional[View]:
130
+ vm2 = self
131
+ if vm2.contiguous: return vm1
132
+ if vm1.contiguous and vm1.shape == vm2.shape: return vm2
133
+ if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
134
+ if vm1.mask:
135
+ for b,e in vm1.mask:
136
+ if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
137
+ return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
138
+
139
+ # Project vm1's offset and strides on to vm2.
140
+ origin = un1d(vm2.shape, vm1.offset)
141
+ terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
142
+ strides: List[sint] = [0] * len(vm1.shape)
143
+ for d1, st in enumerate(vm1.strides):
144
+ if st == 0: continue
145
+ for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
146
+ if (s1 := s1 - o) == 0: continue
147
+ terms[d2].append((d1, s1))
148
+ strides[d1] += s1 * vm2.strides[d2]
149
+
150
+ # Merge dimensions in vm2 if required.
151
+ # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
152
+ idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
153
+ merged_size, merged_term = 1, NumNode(0)
154
+ extents: List[Tuple[sint, Node]] = []
155
+ for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
156
+ merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
157
+ merged_size *= s
158
+ if not (merged_term >= merged_size) and not (merged_term < 0):
159
+ extents.append((merged_size, merged_term))
160
+ merged_size, merged_term = 1, NumNode(0)
161
+ if merged_term: return None
162
+ if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
163
+ return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
164
+
165
+ if vm2.mask:
166
+ # Try to project vm2's mask on to vm1.
167
+ newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
168
+ for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
169
+ if not (t.min < b or t.max >= e): continue
170
+ if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
171
+ bad = True
172
+ continue
173
+ term = terms[d2]
174
+ if len(term) != 1:
175
+ if not term and newe: newe[0] = 0
176
+ else: bad = True
177
+ continue
178
+ d1, s1 = term[0]
179
+ if not isinstance(s1, int) or not isinstance(newe[d1], int):
180
+ bad = True
181
+ continue
182
+ newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
183
+ newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
184
+
185
+ # If any of vm1 was masked off, try again with that mask in place.
186
+ for b, e, s in zip(newb, newe, vm1.shape):
187
+ if b != 0 or e != s:
188
+ return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
189
+ # Otherwise if vm2's mask was violated, then cannot merge.
190
+ if bad: return None
191
+
192
+ return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
193
+
194
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
195
+ def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
196
+ ret = View.create(self.shape)
197
+ if self.mask: ret = ret.shrink(self.mask)
198
+ ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
199
+ return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
200
+
201
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
202
+ def minify(self):
203
+ min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
204
+ return nv if (nv := self.reshape(min_shape)) else self
205
+
206
+ def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
207
+ offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
208
+ if self.mask:
209
+ # move the old mask
210
+ nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
211
+ # merge the masks if we have two
212
+ mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
213
+ shape = [y-x for x,y in arg]
214
+ if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
215
+ return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
216
+
217
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
218
+ def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
219
+ assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
220
+ if any(b or e for b, e in arg):
221
+ zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
222
+ mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
223
+ return self.__unsafe_resize(zvarg, mask=mask)
224
+ return self
225
+
226
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
227
+ def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
228
+ assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
229
+ return self.__unsafe_resize(arg)
230
+
231
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
232
+ def expand(self, new_shape: Tuple[sint, ...]) -> View:
233
+ if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
234
+ if 0 in self.shape:
235
+ assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
236
+ return View.create(new_shape)
237
+ assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
238
+ # NOTE: can the mask ever be (0,0)?
239
+ mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
240
+ return View.create(new_shape, self.strides, self.offset, mask)
241
+
242
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
243
+ def permute(self, axis: Tuple[int, ...]) -> View:
244
+ assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
245
+ assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
246
+ return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
247
+ tuple(self.mask[a] for a in axis) if self.mask is not None else None)
248
+
249
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
250
+ def stride(self, mul: Tuple[int, ...]) -> View:
251
+ # except for the negative case, you can build this from the others. invertible in the negative case
252
+ assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
253
+ strides = tuple([z*m for z,m in zip(self.strides, mul)])
254
+ new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
255
+ offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
256
+ mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
257
+ for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
258
+ return View.create(new_shape, strides, self.offset + offset, mask)
259
+
260
+ @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
261
+ def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
262
+ if self.shape == new_shape: return self
263
+
264
+ assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
265
+ if 0 in self.shape:
266
+ assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
267
+ return View.create(new_shape)
268
+ # check for the same size
269
+ if all_int(self.shape):
270
+ assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
271
+ if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
272
+ raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
273
+
274
+ if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
275
+
276
+ # after the asserts, it's okay to check contiguous
277
+ if self.contiguous: return View.create(new_shape)
278
+
279
+ strides, r_new_shape = [], reversed(new_shape)
280
+ for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
281
+ acc = 1
282
+ # TODO: this <= and != is for symbolic!?
283
+ while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
284
+ strides.append(new_stride)
285
+ if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
286
+ if acc != merged_dim: break
287
+ else:
288
+ strides += [0,] * (len(new_shape) - len(strides))
289
+ new_mask, extra = _reshape_mask(self, new_shape)
290
+ if not extra:
291
+ new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
292
+ extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
293
+ (sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
294
+ return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
295
+
296
+ return None