tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
ADDED
@@ -0,0 +1,296 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import functools, operator, itertools, math
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set, cast
|
5
|
+
from tinygrad.helpers import prod, all_int, argsort
|
6
|
+
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint
|
7
|
+
|
8
|
+
@functools.lru_cache(maxsize=None)
|
9
|
+
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
10
|
+
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
11
|
+
|
12
|
+
@functools.lru_cache(maxsize=None)
|
13
|
+
def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
14
|
+
if not shape: return ()
|
15
|
+
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))
|
16
|
+
return canonicalize_strides(shape, strides[::-1])
|
17
|
+
|
18
|
+
@functools.lru_cache(maxsize=None)
|
19
|
+
def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
20
|
+
# merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...]
|
21
|
+
if not shape: return tuple()
|
22
|
+
assert len(shape) == len(strides)
|
23
|
+
ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
|
24
|
+
# wrt merging zero strided dimensions
|
25
|
+
merging = strides[0] == 0 and (mask[0][1] - mask[0][0] == 1 if mask else shape[0] == 1)
|
26
|
+
for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
27
|
+
if sh == 1: continue
|
28
|
+
if merging or ret[-1][1] == sh * st: # mergeable
|
29
|
+
ret[-1] = (ret[-1][0] * sh, st, (sh if merging else ret[-1][2] * sh) if st else 0)
|
30
|
+
else: ret.append((sh, st, sh if st else 0)) # begin new
|
31
|
+
# merging ends with either non-zero strided dim or zero strided dim with mask range > 1
|
32
|
+
merging = st == 0 and (mask[i][1] - mask[i][0] == 1 if mask else sh == 1)
|
33
|
+
return tuple(ret)
|
34
|
+
|
35
|
+
@functools.lru_cache(maxsize=None)
|
36
|
+
def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]:
|
37
|
+
if view.mask is None: return view.mask, False
|
38
|
+
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True
|
39
|
+
new_mask: List[Tuple[int, int]] = []
|
40
|
+
|
41
|
+
r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
|
42
|
+
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
43
|
+
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
44
|
+
|
45
|
+
while len(new_mask) < len(new_shape):
|
46
|
+
(l, r), next_stride = mask, new_dim * curr_stride
|
47
|
+
|
48
|
+
if old_dim >= next_stride: # need to split mask.
|
49
|
+
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
50
|
+
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
51
|
+
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
52
|
+
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
53
|
+
|
54
|
+
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
55
|
+
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
56
|
+
or old_dim % next_stride != 0): return view.mask, True
|
57
|
+
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
58
|
+
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
59
|
+
|
60
|
+
else:
|
61
|
+
next_mask = next(r_masks, (0, 1))
|
62
|
+
# combine if the mask can unfold continuously
|
63
|
+
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True
|
64
|
+
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
65
|
+
|
66
|
+
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
67
|
+
if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask
|
68
|
+
|
69
|
+
return tuple(reversed(new_mask)), False
|
70
|
+
|
71
|
+
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
72
|
+
strides = strides_for_shape(shape)
|
73
|
+
result = []
|
74
|
+
for stride in strides:
|
75
|
+
here = offs // stride if stride else 0
|
76
|
+
result.append(here)
|
77
|
+
offs -= here * stride
|
78
|
+
return result
|
79
|
+
|
80
|
+
@dataclass(frozen=True)
|
81
|
+
class View:
|
82
|
+
shape:Tuple[sint, ...]
|
83
|
+
strides:Tuple[sint, ...]
|
84
|
+
offset:sint
|
85
|
+
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
86
|
+
contiguous:bool
|
87
|
+
|
88
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
89
|
+
def size(self) -> int:
|
90
|
+
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
91
|
+
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
92
|
+
assert isinstance(ret, int), f"{ret=} is not int"
|
93
|
+
return ret
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
@functools.lru_cache(maxsize=None)
|
97
|
+
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
98
|
+
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
99
|
+
# canonicalize empty mask
|
100
|
+
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
101
|
+
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
102
|
+
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
103
|
+
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
104
|
+
# TODO: assert comparison with LtNode to avoid mis-using symbolic
|
105
|
+
if mask and any(elim := [not (b+1 < e) for b,e in mask]):
|
106
|
+
if any(not (b < e) for b,e in mask):
|
107
|
+
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
108
|
+
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
109
|
+
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
110
|
+
return View(shape, strides, offset, mask, contiguous)
|
111
|
+
|
112
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
113
|
+
def vars(self) -> Set[Variable]:
|
114
|
+
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
115
|
+
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
|
116
|
+
|
117
|
+
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
118
|
+
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
119
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.val is not None]
|
120
|
+
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
121
|
+
new_shape = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape])
|
122
|
+
new_strides = tuple([s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides])
|
123
|
+
new_offset = self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars)
|
124
|
+
new_mask = tuple((a if isinstance(a, int) else a.substitute(unbound_vars),
|
125
|
+
b if isinstance(b, int) else b.substitute(unbound_vars)) for (a, b) in self.mask) if self.mask is not None else None
|
126
|
+
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
127
|
+
|
128
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
129
|
+
def __add__(self, vm1:View) -> Optional[View]:
|
130
|
+
vm2 = self
|
131
|
+
if vm2.contiguous: return vm1
|
132
|
+
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
133
|
+
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
134
|
+
if vm1.mask:
|
135
|
+
for b,e in vm1.mask:
|
136
|
+
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
137
|
+
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
138
|
+
|
139
|
+
# Project vm1's offset and strides on to vm2.
|
140
|
+
origin = un1d(vm2.shape, vm1.offset)
|
141
|
+
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
|
142
|
+
strides: List[sint] = [0] * len(vm1.shape)
|
143
|
+
for d1, st in enumerate(vm1.strides):
|
144
|
+
if st == 0: continue
|
145
|
+
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
|
146
|
+
if (s1 := s1 - o) == 0: continue
|
147
|
+
terms[d2].append((d1, s1))
|
148
|
+
strides[d1] += s1 * vm2.strides[d2]
|
149
|
+
|
150
|
+
# Merge dimensions in vm2 if required.
|
151
|
+
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
152
|
+
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
153
|
+
merged_size, merged_term = 1, NumNode(0)
|
154
|
+
extents: List[Tuple[sint, Node]] = []
|
155
|
+
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
156
|
+
merged_term += Variable.sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
157
|
+
merged_size *= s
|
158
|
+
if not (merged_term >= merged_size) and not (merged_term < 0):
|
159
|
+
extents.append((merged_size, merged_term))
|
160
|
+
merged_size, merged_term = 1, NumNode(0)
|
161
|
+
if merged_term: return None
|
162
|
+
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
163
|
+
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
|
164
|
+
|
165
|
+
if vm2.mask:
|
166
|
+
# Try to project vm2's mask on to vm1.
|
167
|
+
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
168
|
+
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
169
|
+
if not (t.min < b or t.max >= e): continue
|
170
|
+
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
171
|
+
bad = True
|
172
|
+
continue
|
173
|
+
term = terms[d2]
|
174
|
+
if len(term) != 1:
|
175
|
+
if not term and newe: newe[0] = 0
|
176
|
+
else: bad = True
|
177
|
+
continue
|
178
|
+
d1, s1 = term[0]
|
179
|
+
if not isinstance(s1, int) or not isinstance(newe[d1], int):
|
180
|
+
bad = True
|
181
|
+
continue
|
182
|
+
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
|
183
|
+
newe[d1] = min(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
184
|
+
|
185
|
+
# If any of vm1 was masked off, try again with that mask in place.
|
186
|
+
for b, e, s in zip(newb, newe, vm1.shape):
|
187
|
+
if b != 0 or e != s:
|
188
|
+
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
189
|
+
# Otherwise if vm2's mask was violated, then cannot merge.
|
190
|
+
if bad: return None
|
191
|
+
|
192
|
+
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
193
|
+
|
194
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
195
|
+
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[View]:
|
196
|
+
ret = View.create(self.shape)
|
197
|
+
if self.mask: ret = ret.shrink(self.mask)
|
198
|
+
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
199
|
+
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
200
|
+
|
201
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
202
|
+
def minify(self):
|
203
|
+
min_shape = tuple(x[0] for x in _merge_dims(self.shape, self.strides, self.mask))
|
204
|
+
return nv if (nv := self.reshape(min_shape)) else self
|
205
|
+
|
206
|
+
def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View:
|
207
|
+
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
208
|
+
if self.mask:
|
209
|
+
# move the old mask
|
210
|
+
nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
211
|
+
# merge the masks if we have two
|
212
|
+
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
213
|
+
shape = [y-x for x,y in arg]
|
214
|
+
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
215
|
+
return View.create(tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset+offset, mask)
|
216
|
+
|
217
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
218
|
+
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
219
|
+
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
|
220
|
+
if any(b or e for b, e in arg):
|
221
|
+
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
222
|
+
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
223
|
+
return self.__unsafe_resize(zvarg, mask=mask)
|
224
|
+
return self
|
225
|
+
|
226
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
227
|
+
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
228
|
+
assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
229
|
+
return self.__unsafe_resize(arg)
|
230
|
+
|
231
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
232
|
+
def expand(self, new_shape: Tuple[sint, ...]) -> View:
|
233
|
+
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
234
|
+
if 0 in self.shape:
|
235
|
+
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
236
|
+
return View.create(new_shape)
|
237
|
+
assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
|
238
|
+
# NOTE: can the mask ever be (0,0)?
|
239
|
+
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
240
|
+
return View.create(new_shape, self.strides, self.offset, mask)
|
241
|
+
|
242
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
243
|
+
def permute(self, axis: Tuple[int, ...]) -> View:
|
244
|
+
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
245
|
+
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
246
|
+
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
247
|
+
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
248
|
+
|
249
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
250
|
+
def stride(self, mul: Tuple[int, ...]) -> View:
|
251
|
+
# except for the negative case, you can build this from the others. invertible in the negative case
|
252
|
+
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
253
|
+
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
254
|
+
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)])
|
255
|
+
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
256
|
+
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) \
|
257
|
+
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
258
|
+
return View.create(new_shape, strides, self.offset + offset, mask)
|
259
|
+
|
260
|
+
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
261
|
+
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
262
|
+
if self.shape == new_shape: return self
|
263
|
+
|
264
|
+
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
265
|
+
if 0 in self.shape:
|
266
|
+
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
267
|
+
return View.create(new_shape)
|
268
|
+
# check for the same size
|
269
|
+
if all_int(self.shape):
|
270
|
+
assert all(isinstance(s, (int, Variable)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
271
|
+
if prod(self.shape) != prod([s if isinstance(s, int) else cast(Variable,s).val for s in new_shape]):
|
272
|
+
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
273
|
+
|
274
|
+
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
275
|
+
|
276
|
+
# after the asserts, it's okay to check contiguous
|
277
|
+
if self.contiguous: return View.create(new_shape)
|
278
|
+
|
279
|
+
strides, r_new_shape = [], reversed(new_shape)
|
280
|
+
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
281
|
+
acc = 1
|
282
|
+
# TODO: this <= and != is for symbolic!?
|
283
|
+
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)):
|
284
|
+
strides.append(new_stride)
|
285
|
+
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
286
|
+
if acc != merged_dim: break
|
287
|
+
else:
|
288
|
+
strides += [0,] * (len(new_shape) - len(strides))
|
289
|
+
new_mask, extra = _reshape_mask(self, new_shape)
|
290
|
+
if not extra:
|
291
|
+
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask) if new_mask else new_shape, tuple(reversed(strides)))
|
292
|
+
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
293
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)) if new_mask else 0)
|
294
|
+
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
295
|
+
|
296
|
+
return None
|