tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
CHANGED
@@ -1,35 +1,37 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import functools, operator
|
2
|
+
import functools, operator, itertools, math
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Tuple, List, Optional, Dict, cast
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set, cast
|
5
5
|
from tinygrad.helpers import prod, all_int, argsort
|
6
|
-
from tinygrad.shape.symbolic import Node, NumNode, Variable,
|
6
|
+
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
7
7
|
|
8
8
|
@functools.lru_cache(maxsize=None)
|
9
|
-
def
|
10
|
-
return tuple(
|
9
|
+
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
10
|
+
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
11
11
|
|
12
12
|
@functools.lru_cache(maxsize=None)
|
13
|
-
def strides_for_shape(shape:Tuple[
|
14
|
-
|
15
|
-
|
16
|
-
return
|
13
|
+
def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
14
|
+
if not shape: return ()
|
15
|
+
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
|
16
|
+
return canonicalize_strides(shape, strides)
|
17
17
|
|
18
18
|
@functools.lru_cache(maxsize=None)
|
19
|
-
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]
|
20
|
-
# merge contiguous
|
21
|
-
if not shape: return
|
22
|
-
assert len(shape) == len(strides)
|
19
|
+
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
20
|
+
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
21
|
+
if not shape: return ()
|
22
|
+
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
23
23
|
ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
|
24
|
-
#
|
25
|
-
|
26
|
-
for i, (
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
24
|
+
# merge this dim to next dim if size is 1
|
25
|
+
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
|
26
|
+
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
27
|
+
last_s, last_st, last_pre_expand_s = ret[-1]
|
28
|
+
# always merge 1
|
29
|
+
if s == 1: continue
|
30
|
+
# merge last dim with this dim if merging or strides matched
|
31
|
+
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
|
32
|
+
else: ret.append((s, st, s if st else 0))
|
33
|
+
# merge this dim to next dim if size is 1
|
34
|
+
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
33
35
|
return tuple(ret)
|
34
36
|
|
35
37
|
@functools.lru_cache(maxsize=None)
|
@@ -52,7 +54,8 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
52
54
|
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
53
55
|
|
54
56
|
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
55
|
-
if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
57
|
+
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
58
|
+
or old_dim % next_stride != 0): return view.mask, True
|
56
59
|
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
57
60
|
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
58
61
|
|
@@ -67,6 +70,15 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
67
70
|
|
68
71
|
return tuple(reversed(new_mask)), False
|
69
72
|
|
73
|
+
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
74
|
+
strides = strides_for_shape(shape)
|
75
|
+
result = []
|
76
|
+
for stride in strides:
|
77
|
+
here = offs // stride if stride else 0
|
78
|
+
result.append(here)
|
79
|
+
offs -= here * stride
|
80
|
+
return result
|
81
|
+
|
70
82
|
@dataclass(frozen=True)
|
71
83
|
class View:
|
72
84
|
shape:Tuple[sint, ...]
|
@@ -75,10 +87,29 @@ class View:
|
|
75
87
|
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
76
88
|
contiguous:bool
|
77
89
|
|
90
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
91
|
+
def size(self) -> int:
|
92
|
+
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
93
|
+
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
94
|
+
assert isinstance(ret, int), f"{ret=} is not int"
|
95
|
+
return ret
|
96
|
+
|
78
97
|
@staticmethod
|
79
98
|
@functools.lru_cache(maxsize=None)
|
80
99
|
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
81
|
-
strides =
|
100
|
+
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
101
|
+
# canonicalize 0 in shape
|
102
|
+
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
103
|
+
# canonicalize empty mask
|
104
|
+
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
105
|
+
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
106
|
+
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
107
|
+
# TODO: assert comparison with LtNode to avoid mis-using symbolic
|
108
|
+
if mask and any(elim := [not (b+1 < e) for b,e in mask]):
|
109
|
+
if any(not (b < e) for b,e in mask):
|
110
|
+
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
111
|
+
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
112
|
+
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
82
113
|
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
83
114
|
return View(shape, strides, offset, mask, contiguous)
|
84
115
|
|
@@ -88,13 +119,81 @@ class View:
|
|
88
119
|
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
|
89
120
|
|
90
121
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
91
|
-
def unbind(self) -> View:
|
92
|
-
|
122
|
+
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
123
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
|
124
|
+
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
93
125
|
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
94
126
|
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
95
127
|
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
96
|
-
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
|
97
|
-
|
128
|
+
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
|
129
|
+
b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
130
|
+
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
131
|
+
|
132
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
133
|
+
def __add__(self, vm1:View) -> Optional[View]:
|
134
|
+
vm2 = self
|
135
|
+
if vm2.contiguous: return vm1
|
136
|
+
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
137
|
+
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
138
|
+
if vm1.mask:
|
139
|
+
for b,e in vm1.mask:
|
140
|
+
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
141
|
+
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
142
|
+
|
143
|
+
# Project vm1's offset and strides on to vm2.
|
144
|
+
origin = un1d(vm2.shape, vm1.offset)
|
145
|
+
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
146
|
+
strides: List[sint] = [0] * len(vm1.shape)
|
147
|
+
for d1, st in enumerate(vm1.strides):
|
148
|
+
if st == 0: continue
|
149
|
+
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
150
|
+
if (s1 := s1 - o) == 0: continue
|
151
|
+
terms[d2].append((d1, s1))
|
152
|
+
strides[d1] += s1 * vm2.strides[d2]
|
153
|
+
|
154
|
+
# Merge dimensions in vm2 if required.
|
155
|
+
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
156
|
+
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
157
|
+
merged_size, merged_term = 1, NumNode(0)
|
158
|
+
extents: List[Tuple[sint, Node]] = []
|
159
|
+
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
160
|
+
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
161
|
+
merged_size *= s
|
162
|
+
if not (merged_term >= merged_size) and not (merged_term < 0):
|
163
|
+
extents.append((merged_size, merged_term))
|
164
|
+
merged_size, merged_term = 1, NumNode(0)
|
165
|
+
if merged_term: return None
|
166
|
+
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
167
|
+
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
|
168
|
+
|
169
|
+
if vm2.mask:
|
170
|
+
# Try to project vm2's mask on to vm1.
|
171
|
+
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
172
|
+
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
173
|
+
if not (t.min < b or t.max >= e): continue
|
174
|
+
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
175
|
+
bad = True
|
176
|
+
continue
|
177
|
+
term = terms[d2]
|
178
|
+
if len(term) != 1:
|
179
|
+
if not term and newe: newe[0] = 0
|
180
|
+
else: bad = True
|
181
|
+
continue
|
182
|
+
d1, s1 = term[0]
|
183
|
+
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
184
|
+
bad = True
|
185
|
+
continue
|
186
|
+
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
187
|
+
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
188
|
+
|
189
|
+
# If any of vm1 was masked off, try again with that mask in place.
|
190
|
+
for b, e, s in zip(newb, newe, vm1.shape):
|
191
|
+
if b != 0 or e != s:
|
192
|
+
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
193
|
+
# Otherwise if vm2's mask was violated, then cannot merge.
|
194
|
+
if bad: return None
|
195
|
+
|
196
|
+
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
98
197
|
|
99
198
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
100
199
|
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
|
@@ -103,7 +202,10 @@ class View:
|
|
103
202
|
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
104
203
|
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
105
204
|
|
106
|
-
#
|
205
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
206
|
+
def minify(self):
|
207
|
+
min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
|
208
|
+
return nv if (nv := self.reshape(min_shape)) else self
|
107
209
|
|
108
210
|
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
109
211
|
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
@@ -117,8 +219,8 @@ class View:
|
|
117
219
|
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
118
220
|
|
119
221
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
120
|
-
def pad(self, arg: Tuple[Tuple[
|
121
|
-
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
222
|
+
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
223
|
+
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
|
122
224
|
if any(b or e for b, e in arg):
|
123
225
|
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
124
226
|
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
@@ -143,9 +245,9 @@ class View:
|
|
143
245
|
|
144
246
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
145
247
|
def permute(self, axis: Tuple[int, ...]) -> View:
|
146
|
-
assert
|
147
|
-
|
148
|
-
|
248
|
+
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
249
|
+
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
250
|
+
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
149
251
|
|
150
252
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
151
253
|
def stride(self, mul: Tuple[int, ...]) -> View:
|
@@ -154,7 +256,8 @@ class View:
|
|
154
256
|
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
155
257
|
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
156
258
|
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
157
|
-
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m))
|
259
|
+
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
|
260
|
+
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
158
261
|
return View.create(new_shape, strides, self.offset + offset, mask)
|
159
262
|
|
160
263
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -166,7 +269,7 @@ class View:
|
|
166
269
|
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
167
270
|
return View.create(new_shape)
|
168
271
|
# check for the same size
|
169
|
-
if all_int(self.shape):
|
272
|
+
if (self_all_int := all_int(self.shape)):
|
170
273
|
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
171
274
|
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
172
275
|
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
@@ -176,19 +279,33 @@ class View:
|
|
176
279
|
# after the asserts, it's okay to check contiguous
|
177
280
|
if self.contiguous: return View.create(new_shape)
|
178
281
|
|
282
|
+
# if it's not contiguous and new shape is symbolic, check if it's directly replaceable
|
283
|
+
if self_all_int and not all_int(new_shape):
|
284
|
+
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
285
|
+
for si, so in zip(self.shape, new_shape):
|
286
|
+
if isinstance(so, int):
|
287
|
+
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
288
|
+
else:
|
289
|
+
var_vals = {v: v.unbind()[1] for v in so.vars()}
|
290
|
+
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
291
|
+
# all dimensions matched, return the new view directly
|
292
|
+
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
293
|
+
|
179
294
|
strides, r_new_shape = [], reversed(new_shape)
|
180
|
-
for merged_dim,
|
181
|
-
acc
|
295
|
+
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
296
|
+
acc = 1
|
297
|
+
# TODO: this <= and != is for symbolic!?
|
182
298
|
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
|
183
|
-
strides.append(new_stride
|
184
|
-
if new_dim
|
185
|
-
new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
299
|
+
strides.append(new_stride)
|
300
|
+
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
186
301
|
if acc != merged_dim: break
|
187
302
|
else:
|
188
303
|
strides += [0,] * (len(new_shape) - len(strides))
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
304
|
+
new_mask, extra = _reshape_mask(self, new_shape)
|
305
|
+
if not extra:
|
306
|
+
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
|
307
|
+
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
308
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
|
309
|
+
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
193
310
|
|
194
311
|
return None
|