tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
CHANGED
@@ -2,8 +2,9 @@ from __future__ import annotations
|
|
2
2
|
import functools, operator, itertools, math
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Tuple, List, Optional, Dict, Set, cast
|
5
|
-
from tinygrad.
|
6
|
-
from tinygrad.
|
5
|
+
from tinygrad.dtype import dtypes
|
6
|
+
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin
|
7
|
+
from tinygrad.helpers import prod, all_int, argsort, flatten
|
7
8
|
|
8
9
|
@functools.lru_cache(maxsize=None)
|
9
10
|
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
@@ -20,7 +21,7 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
|
|
20
21
|
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
21
22
|
if not shape: return ()
|
22
23
|
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
23
|
-
ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)]
|
24
|
+
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
|
24
25
|
# merge this dim to next dim if size is 1
|
25
26
|
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
|
26
27
|
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
@@ -28,21 +29,24 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
|
|
28
29
|
# always merge 1
|
29
30
|
if s == 1: continue
|
30
31
|
# merge last dim with this dim if merging or strides matched
|
31
|
-
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
|
32
|
-
else: ret.append((s, st, s if st else 0))
|
32
|
+
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
|
33
|
+
else: ret.append((s, st, s if st != 0 else 0))
|
33
34
|
# merge this dim to next dim if size is 1
|
34
35
|
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
35
36
|
return tuple(ret)
|
36
37
|
|
37
38
|
@functools.lru_cache(maxsize=None)
|
38
|
-
def _reshape_mask(
|
39
|
-
|
40
|
-
|
41
|
-
|
39
|
+
def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
|
40
|
+
-> Optional[Tuple[Tuple[sint, sint], ...]]:
|
41
|
+
"""Returns the new mask if reshape is possible, and None if not possible."""
|
42
|
+
if _mask is None: return tuple((0, s) for s in new_shape)
|
43
|
+
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
|
44
|
+
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
|
42
45
|
|
43
|
-
|
46
|
+
new_mask: List[Tuple[int, int]] = []
|
47
|
+
# _mask is all int here
|
48
|
+
r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
|
44
49
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
45
|
-
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
46
50
|
|
47
51
|
while len(new_mask) < len(new_shape):
|
48
52
|
(l, r), next_stride = mask, new_dim * curr_stride
|
@@ -51,34 +55,35 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
|
51
55
|
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
52
56
|
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
53
57
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
54
|
-
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask
|
55
58
|
|
56
59
|
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
57
60
|
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
58
|
-
or old_dim % next_stride != 0): return
|
61
|
+
or old_dim % next_stride != 0): return None
|
59
62
|
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
60
63
|
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
61
64
|
|
62
65
|
else:
|
63
66
|
next_mask = next(r_masks, (0, 1))
|
64
67
|
# combine if the mask can unfold continuously
|
65
|
-
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return
|
68
|
+
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
|
66
69
|
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
67
70
|
|
68
71
|
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
69
|
-
if mask != (0, 1): return ((0, 0),) * len(new_shape)
|
72
|
+
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
|
70
73
|
|
71
|
-
return tuple(reversed(new_mask))
|
74
|
+
return tuple(reversed(new_mask))
|
72
75
|
|
73
76
|
def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
74
77
|
strides = strides_for_shape(shape)
|
75
78
|
result = []
|
76
79
|
for stride in strides:
|
77
|
-
here = offs // stride if stride else 0
|
80
|
+
here = offs // stride if stride != 0 else 0
|
78
81
|
result.append(here)
|
79
82
|
offs -= here * stride
|
80
83
|
return result
|
81
84
|
|
85
|
+
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
86
|
+
|
82
87
|
@dataclass(frozen=True)
|
83
88
|
class View:
|
84
89
|
shape:Tuple[sint, ...]
|
@@ -87,16 +92,33 @@ class View:
|
|
87
92
|
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
88
93
|
contiguous:bool
|
89
94
|
|
95
|
+
@functools.cached_property
|
96
|
+
def t(self):
|
97
|
+
return tuple(x.tuplize if isinstance(x, UOp) else (x,) \
|
98
|
+
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
|
99
|
+
def __lt__(self, o:View): return self.t < o.t
|
100
|
+
|
101
|
+
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
102
|
+
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
103
|
+
iexpr = variable_to_uop(self.offset)
|
104
|
+
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
105
|
+
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
106
|
+
if m is not None:
|
107
|
+
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
|
108
|
+
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
|
109
|
+
return iexpr, vexpr
|
110
|
+
|
90
111
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
91
112
|
def size(self) -> int:
|
92
|
-
|
93
|
-
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
113
|
+
ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
|
94
114
|
assert isinstance(ret, int), f"{ret=} is not int"
|
95
115
|
return ret
|
96
116
|
|
97
117
|
@staticmethod
|
98
118
|
@functools.lru_cache(maxsize=None)
|
99
119
|
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
120
|
+
# TODO: this resolve shouldn't be needed
|
121
|
+
if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
100
122
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
101
123
|
# canonicalize 0 in shape
|
102
124
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
@@ -104,29 +126,36 @@ class View:
|
|
104
126
|
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
|
105
127
|
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
|
106
128
|
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
|
107
|
-
|
108
|
-
|
109
|
-
if any(not (b < e) for b,e in mask):
|
129
|
+
if mask and any(elim := [not resolve(b+1 < e) for b,e in mask]):
|
130
|
+
if any(not resolve(b < e) for b,e in mask):
|
110
131
|
strides, offset, mask = (0,) * len(shape), 0, ((0,0),) * len(shape)
|
111
132
|
offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim))
|
112
133
|
strides = tuple(0 if e else st for st,e in zip(strides, elim))
|
134
|
+
# simplify as we go
|
135
|
+
if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify())
|
136
|
+
shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape)
|
137
|
+
# TODO: enabling stride simplification breaks it
|
138
|
+
"""
|
139
|
+
strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides)
|
140
|
+
if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask)
|
141
|
+
"""
|
113
142
|
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
114
143
|
return View(shape, strides, offset, mask, contiguous)
|
115
144
|
|
116
145
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
117
146
|
def vars(self) -> Set[Variable]:
|
118
147
|
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
119
|
-
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x,
|
148
|
+
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
|
120
149
|
|
121
150
|
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
122
151
|
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
|
123
|
-
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()
|
152
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
124
153
|
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
154
|
+
def substitute(x): return x if isinstance(x, int) else x.substitute(unbound_vars)
|
155
|
+
new_shape = tuple(map(substitute, self.shape))
|
156
|
+
new_strides = tuple(map(substitute, self.strides))
|
157
|
+
new_offset = substitute(self.offset)
|
158
|
+
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
130
159
|
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
131
160
|
|
132
161
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -137,7 +166,7 @@ class View:
|
|
137
166
|
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
138
167
|
if vm1.mask:
|
139
168
|
for b,e in vm1.mask:
|
140
|
-
if
|
169
|
+
if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
|
141
170
|
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
142
171
|
|
143
172
|
# Project vm1's offset and strides on to vm2.
|
@@ -153,28 +182,30 @@ class View:
|
|
153
182
|
|
154
183
|
# Merge dimensions in vm2 if required.
|
155
184
|
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
156
|
-
|
157
|
-
|
158
|
-
|
185
|
+
if not all_int(vm1.shape): return None
|
186
|
+
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
187
|
+
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
188
|
+
extents: List[Tuple[sint, UOp]] = []
|
159
189
|
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
160
|
-
merged_term +=
|
190
|
+
merged_term += (sum([idxs[d1] * s1 for d1, s1 in term]) + o) * merged_size
|
161
191
|
merged_size *= s
|
162
|
-
if
|
192
|
+
if resolve(merged_term < merged_size, False) and resolve(0 <= merged_term, False):
|
163
193
|
extents.append((merged_size, merged_term))
|
164
|
-
merged_size, merged_term = 1,
|
165
|
-
if merged_term: return None
|
194
|
+
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
195
|
+
if resolve(merged_term != 0): return None
|
166
196
|
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
167
|
-
|
197
|
+
reshaped_vm2 = vm2.reshape(vm2_shape)
|
198
|
+
if reshaped_vm2 is None: return None
|
199
|
+
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1
|
168
200
|
|
169
201
|
if vm2.mask:
|
170
202
|
# Try to project vm2's mask on to vm1.
|
171
203
|
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
172
|
-
for
|
173
|
-
if
|
174
|
-
if not
|
204
|
+
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
|
205
|
+
if resolve(b <= t.vmin and t.vmax < e, False): continue
|
206
|
+
if not all_int([o, b, e]):
|
175
207
|
bad = True
|
176
208
|
continue
|
177
|
-
term = terms[d2]
|
178
209
|
if len(term) != 1:
|
179
210
|
if not term and newe: newe[0] = 0
|
180
211
|
else: bad = True
|
@@ -188,7 +219,7 @@ class View:
|
|
188
219
|
|
189
220
|
# If any of vm1 was masked off, try again with that mask in place.
|
190
221
|
for b, e, s in zip(newb, newe, vm1.shape):
|
191
|
-
if b != 0
|
222
|
+
if (b, e) != (0, s):
|
192
223
|
return vm2 + View.create(vm1.shape, vm1.strides, vm1.offset, tuple(zip(newb, newe)))
|
193
224
|
# Otherwise if vm2's mask was violated, then cannot merge.
|
194
225
|
if bad: return None
|
@@ -211,17 +242,19 @@ class View:
|
|
211
242
|
offset = sum([s * x[0] for s, x in zip(self.strides,arg)])
|
212
243
|
if self.mask:
|
213
244
|
# move the old mask
|
214
|
-
nmask = tuple([(
|
245
|
+
nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)])
|
215
246
|
# merge the masks if we have two
|
216
|
-
mask = tuple([(
|
247
|
+
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
217
248
|
shape = [y-x for x,y in arg]
|
218
249
|
if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None
|
219
|
-
return View.create(tuple(s.
|
250
|
+
return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask)
|
220
251
|
|
221
252
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
222
253
|
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
223
|
-
assert
|
224
|
-
|
254
|
+
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
|
255
|
+
# NOTE: not checking for symbolic arg
|
256
|
+
for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
|
257
|
+
if any(resolve(b!=0) or resolve(e!=0) for b, e in arg):
|
225
258
|
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
|
226
259
|
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
|
227
260
|
return self.__unsafe_resize(zvarg, mask=mask)
|
@@ -229,7 +262,9 @@ class View:
|
|
229
262
|
|
230
263
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
231
264
|
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
|
232
|
-
assert
|
265
|
+
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
266
|
+
# NOTE: not checking for symbolic arg
|
267
|
+
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
|
233
268
|
return self.__unsafe_resize(arg)
|
234
269
|
|
235
270
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -238,9 +273,12 @@ class View:
|
|
238
273
|
if 0 in self.shape:
|
239
274
|
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
240
275
|
return View.create(new_shape)
|
241
|
-
|
276
|
+
# TODO: this resolve might be wrong
|
277
|
+
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
242
278
|
# NOTE: can the mask ever be (0,0)?
|
243
|
-
|
279
|
+
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
280
|
+
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
281
|
+
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
244
282
|
return View.create(new_shape, self.strides, self.offset, mask)
|
245
283
|
|
246
284
|
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
@@ -264,14 +302,15 @@ class View:
|
|
264
302
|
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
265
303
|
if self.shape == new_shape: return self
|
266
304
|
|
267
|
-
|
305
|
+
# TODO: this resolve shouldn't be needed
|
306
|
+
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
268
307
|
if 0 in self.shape:
|
269
308
|
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
270
309
|
return View.create(new_shape)
|
271
310
|
# check for the same size
|
272
311
|
if (self_all_int := all_int(self.shape)):
|
273
|
-
assert all(isinstance(s, (int,
|
274
|
-
if prod(self.shape) != prod(
|
312
|
+
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
313
|
+
if resolve(prod(self.shape) != prod(new_shape), False):
|
275
314
|
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
276
315
|
|
277
316
|
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
@@ -286,7 +325,7 @@ class View:
|
|
286
325
|
if isinstance(so, int):
|
287
326
|
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
288
327
|
else:
|
289
|
-
var_vals =
|
328
|
+
var_vals = dict([v.unbind() for v in so.vars()])
|
290
329
|
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
291
330
|
# all dimensions matched, return the new view directly
|
292
331
|
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
@@ -294,18 +333,18 @@ class View:
|
|
294
333
|
strides, r_new_shape = [], reversed(new_shape)
|
295
334
|
for merged_dim, new_stride, real_dim in reversed(_merge_dims(self.shape, self.strides, self.mask)):
|
296
335
|
acc = 1
|
297
|
-
# TODO:
|
298
|
-
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape,
|
336
|
+
# TODO: third resolve shouldn't be needed
|
337
|
+
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
299
338
|
strides.append(new_stride)
|
300
|
-
if new_dim != 1: new_stride *= (new_dim if (acc :=
|
301
|
-
if acc != merged_dim: break
|
339
|
+
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
|
340
|
+
if resolve(acc != merged_dim): break
|
302
341
|
else:
|
303
342
|
strides += [0,] * (len(new_shape) - len(strides))
|
304
|
-
new_mask
|
305
|
-
if not
|
306
|
-
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask)
|
343
|
+
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
|
344
|
+
if new_mask is not None:
|
345
|
+
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
|
307
346
|
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
308
|
-
(sum(m[0] * s for m,s in zip(new_mask, new_strides))
|
347
|
+
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
309
348
|
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
310
349
|
|
311
350
|
return None
|