tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/shape/view.py
CHANGED
@@ -1,23 +1,48 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import functools, operator, itertools
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import
|
4
|
+
from typing import cast, Sequence
|
5
5
|
from tinygrad.dtype import dtypes
|
6
|
-
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
|
6
|
+
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
|
7
7
|
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
8
8
|
|
9
|
-
|
9
|
+
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
10
|
+
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
11
|
+
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
12
|
+
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
13
|
+
except ValueError: return None
|
14
|
+
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
15
|
+
|
16
|
+
def get_contraction_with_reduce(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...], reduce_axis:tuple[int, ...]) -> list[list[int]]|None:
|
17
|
+
if (contraction:=get_contraction(old_shape, new_shape)) is None: return None
|
18
|
+
# contraction returns the 1s as right justified as possible
|
19
|
+
# normally this contraction is good, but sometimes the reduce dim is empty. borrow from the next one, leaving one
|
20
|
+
# this ensures there's always ones available in the reduce dimension. this is also a valid contraction
|
21
|
+
for i in range(len(contraction)):
|
22
|
+
if i in reduce_axis and len(contraction[i]) == 0:
|
23
|
+
take_from = i+1
|
24
|
+
while take_from < len(contraction) and len(contraction[take_from]) == 0:
|
25
|
+
assert new_shape[take_from] == 1
|
26
|
+
take_from += 1
|
27
|
+
if take_from == len(contraction) or new_shape[take_from] != 1: return None # nothing to take
|
28
|
+
for j in range(take_from, i, -1):
|
29
|
+
assert len(contraction[j]) > 0
|
30
|
+
contraction[j-1] = contraction[j][:-1]
|
31
|
+
contraction[j] = contraction[j][-1:]
|
32
|
+
return contraction
|
33
|
+
|
34
|
+
@functools.cache
|
10
35
|
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
11
36
|
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
12
37
|
|
13
|
-
@functools.
|
38
|
+
@functools.cache
|
14
39
|
def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
|
15
40
|
if not shape: return ()
|
16
41
|
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
|
17
42
|
return canonicalize_strides(shape, strides)
|
18
43
|
|
19
|
-
@functools.
|
20
|
-
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:
|
44
|
+
@functools.cache
|
45
|
+
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:tuple[tuple[int, int], ...]|None=None) -> tuple[tuple[int, int, int], ...]:
|
21
46
|
# merge contiguous sub-parts or zero strided dims
|
22
47
|
# any stride 0, masked from dim=1, or contiguous part is merged into next dim.
|
23
48
|
# stride != 0 to stride == 0 starts a new merging block
|
@@ -38,9 +63,9 @@ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tup
|
|
38
63
|
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
39
64
|
return tuple(ret)
|
40
65
|
|
41
|
-
@functools.
|
42
|
-
def _reshape_mask(_mask:
|
43
|
-
->
|
66
|
+
@functools.cache
|
67
|
+
def _reshape_mask(_mask:tuple[tuple[sint, sint], ...]|None, old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
|
68
|
+
-> tuple[tuple[sint, sint], ...]|None:
|
44
69
|
"""Returns the new mask if reshape is possible, and None if not possible."""
|
45
70
|
if _mask is None: return tuple((0, s) for s in new_shape)
|
46
71
|
if not all_int(flatten(_mask)): return None
|
@@ -51,7 +76,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
|
|
51
76
|
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
52
77
|
|
53
78
|
while len(new_mask) < len(new_shape):
|
54
|
-
(l, r), next_stride = mask, new_dim * curr_stride
|
79
|
+
(l, r), next_stride = mask, ssimplify(new_dim * curr_stride)
|
55
80
|
|
56
81
|
# need to split mask
|
57
82
|
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
@@ -66,7 +91,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
|
|
66
91
|
next_mask = next(r_masks, (0, 1))
|
67
92
|
# combine if the mask can unfold continuously
|
68
93
|
if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
|
69
|
-
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
94
|
+
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1))
|
70
95
|
|
71
96
|
return tuple(reversed(new_mask))
|
72
97
|
|
@@ -84,12 +109,12 @@ class View:
|
|
84
109
|
shape:tuple[sint, ...]
|
85
110
|
strides:tuple[sint, ...]
|
86
111
|
offset:sint
|
87
|
-
mask:
|
112
|
+
mask:tuple[tuple[sint, sint], ...]|None
|
88
113
|
contiguous:bool
|
89
114
|
|
90
|
-
def to_indexed_uops(self:View, idxs:
|
115
|
+
def to_indexed_uops(self:View, idxs:Sequence[UOp]|None=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
|
91
116
|
"""(idx, valid)"""
|
92
|
-
if idxs is None: idxs = [UOp.range(dtypes.int,
|
117
|
+
if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
|
93
118
|
iexpr = sint_to_uop(self.offset)
|
94
119
|
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
95
120
|
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
@@ -98,16 +123,17 @@ class View:
|
|
98
123
|
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
|
99
124
|
return iexpr, vexpr
|
100
125
|
|
101
|
-
@functools.
|
126
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
102
127
|
def size(self) -> int:
|
103
128
|
ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
|
104
129
|
assert isinstance(ret, int), f"{ret=} is not int"
|
105
130
|
return ret
|
106
131
|
|
107
132
|
@staticmethod
|
108
|
-
@functools.
|
109
|
-
def create(shape:tuple[sint, ...], strides:
|
110
|
-
|
133
|
+
@functools.cache
|
134
|
+
def create(shape:tuple[sint, ...], strides:tuple[sint, ...]|None=None, offset:sint=0, mask:tuple[tuple[sint, sint], ...]|None=None):
|
135
|
+
# TODO: resolve shouldn't be needed here
|
136
|
+
if not all(resolve(s >= 0) for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
111
137
|
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
112
138
|
# canonicalize 0 in shape
|
113
139
|
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
@@ -131,43 +157,50 @@ class View:
|
|
131
157
|
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
132
158
|
return View(shape, strides, offset, mask, contiguous)
|
133
159
|
|
134
|
-
@functools.
|
160
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
135
161
|
def vars(self) -> set[Variable]:
|
136
162
|
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
137
163
|
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
|
138
164
|
|
139
|
-
@functools.
|
165
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
140
166
|
def unbind(self) -> tuple[View, dict[Variable, int]]:
|
141
|
-
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
167
|
+
var_unboundvar_val = [(v, v.unbind()) for v in self.vars() if v.op is Ops.BIND]
|
142
168
|
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
169
|
+
return self.substitute(unbound_vars), dict(x[1] for x in var_unboundvar_val)
|
170
|
+
|
171
|
+
def substitute(self, dvars:dict[UOp, UOp]):
|
172
|
+
def _substitute(x:sint): return x if isinstance(x, int) else x.substitute(dvars)
|
173
|
+
new_shape = tuple(map(_substitute, self.shape))
|
174
|
+
new_strides = tuple(map(_substitute, self.strides))
|
175
|
+
new_offset = _substitute(self.offset)
|
176
|
+
new_mask = tuple((_substitute(x[0]), _substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
177
|
+
return View.create(new_shape, new_strides, new_offset, new_mask)
|
178
|
+
|
179
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
180
|
+
def __add__(self, vm1:View) -> View|None:
|
152
181
|
vm2 = self
|
153
|
-
if vm2.contiguous: return vm1
|
182
|
+
if vm2.contiguous or vm1.size() == 0: return vm1
|
154
183
|
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
155
184
|
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
|
156
185
|
if vm1.mask:
|
157
186
|
if (new_vm1 := vm1.shrink(vm1.mask)) == vm1 or (merged := vm2 + new_vm1) is None: return None
|
158
187
|
return merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))
|
159
|
-
if not all_int(vm1.shape):
|
188
|
+
if not all_int(vm1.shape):
|
189
|
+
# if all strides are 0 and vm2 is unmasked, return vm1
|
190
|
+
if all(x == 0 for x in vm2.strides+vm1.strides) and vm2.mask is None: return vm1
|
191
|
+
# TODO: handle more cases
|
192
|
+
return None
|
160
193
|
|
161
194
|
# Project vm1's offset and strides on to vm2.
|
162
|
-
origin = unravel(vm2.shape, vm1.offset)
|
195
|
+
origin = [ssimplify(o) for o in unravel(vm2.shape, vm1.offset)]
|
163
196
|
terms: list[list[tuple[int, sint]]] = [[] for _ in vm2.shape]
|
164
197
|
strides: list[sint] = [0] * len(vm1.shape)
|
165
198
|
for d1, st in enumerate(vm1.strides):
|
166
199
|
if st == 0: continue
|
167
200
|
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
|
168
|
-
if (s1 := s1 - o)
|
201
|
+
if not resolve((s1 := s1 - o)!=0): continue # if s1 can possibly be 0
|
169
202
|
terms[d2].append((d1, s1))
|
170
|
-
strides[d1] += s1 * vm2.strides[d2]
|
203
|
+
strides[d1] += ssimplify(s1 * vm2.strides[d2])
|
171
204
|
|
172
205
|
# Merge dimensions in vm2 if required.
|
173
206
|
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
@@ -190,14 +223,17 @@ class View:
|
|
190
223
|
# Try to project vm2's mask on to vm1.
|
191
224
|
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
192
225
|
for (b, e), o, term, (_, t) in zip(vm2.mask, origin, terms, reversed(extents)):
|
193
|
-
if resolve(b <= t.vmin and t.vmax < e, False): continue
|
226
|
+
if resolve(b <= (t := t.simplify()).vmin and t.vmax < e, False): continue
|
194
227
|
if len(term) != 1:
|
195
|
-
if not term and newe:
|
228
|
+
if not term and newe:
|
229
|
+
# t should be a constant if no terms contribute to this dimension, but it might not be simplified
|
230
|
+
if t.vmin != t.vmax: return None
|
231
|
+
newe[0] = 0
|
196
232
|
else: bad = True
|
197
233
|
continue
|
198
234
|
d1, s1 = term[0]
|
199
|
-
newb[d1] =
|
200
|
-
newe[d1] =
|
235
|
+
newb[d1] = smax(newb[d1], ceildiv(b - o if s1 > 0 else e - o - 1, s1))
|
236
|
+
newe[d1] = smin(newe[d1], (b - o if s1 < 0 else e - o - 1) // s1 + 1)
|
201
237
|
|
202
238
|
# If any of vm1 was masked off, try again with that mask in place.
|
203
239
|
if any((b, e) != (0, s) for b, e, s in zip(newb, newe, vm1.shape)):
|
@@ -205,16 +241,16 @@ class View:
|
|
205
241
|
# Otherwise if vm2's mask was violated, then cannot merge.
|
206
242
|
if bad: return None
|
207
243
|
|
208
|
-
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
244
|
+
return View.create(vm1.shape, tuple(strides), ssimplify(sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset))
|
209
245
|
|
210
|
-
@functools.
|
211
|
-
def invert(self, out_shape:tuple[sint, ...]) ->
|
246
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
247
|
+
def invert(self, out_shape:tuple[sint, ...]) -> View|None:
|
212
248
|
ret = View.create(self.shape)
|
213
249
|
if self.mask: ret = ret.shrink(self.mask)
|
214
250
|
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
215
251
|
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
216
252
|
|
217
|
-
@functools.
|
253
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
218
254
|
def minify(self):
|
219
255
|
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
220
256
|
return nv if (nv := self.reshape(min_shape)) else self
|
@@ -228,7 +264,7 @@ class View:
|
|
228
264
|
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
229
265
|
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
230
266
|
|
231
|
-
@functools.
|
267
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
232
268
|
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
233
269
|
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
|
234
270
|
# NOTE: not checking for symbolic arg
|
@@ -239,38 +275,38 @@ class View:
|
|
239
275
|
return self.__unsafe_resize(zvarg, mask=mask)
|
240
276
|
return self
|
241
277
|
|
242
|
-
@functools.
|
278
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
243
279
|
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
244
280
|
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
245
281
|
# NOTE: not checking for symbolic arg
|
246
282
|
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
|
247
283
|
return self.__unsafe_resize(arg)
|
248
284
|
|
249
|
-
@functools.
|
285
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
250
286
|
def expand(self, new_shape: tuple[sint, ...]) -> View:
|
251
287
|
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
252
288
|
# NOTE: does not check multiple of symbolic shape
|
253
289
|
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
254
290
|
if 0 in self.shape: return View.create(new_shape)
|
255
|
-
# TODO:
|
256
|
-
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
291
|
+
# TODO: resolve may not be needed, but it's hard because vars need to be canonicalized
|
292
|
+
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns) and resolve(s == 1, False) else m) \
|
257
293
|
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
258
294
|
return View.create(new_shape, self.strides, self.offset, mask)
|
259
295
|
|
260
|
-
@functools.
|
296
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
261
297
|
def permute(self, axis: tuple[int, ...]) -> View:
|
262
298
|
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
263
299
|
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
264
300
|
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
265
301
|
|
266
|
-
@functools.
|
302
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
267
303
|
def flip(self, arg: tuple[bool, ...]) -> View:
|
268
304
|
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
|
269
305
|
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
|
270
306
|
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
|
271
307
|
|
272
|
-
@functools.
|
273
|
-
def reshape(self, new_shape: tuple[sint, ...]) ->
|
308
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
309
|
+
def reshape(self, new_shape: tuple[sint, ...]) -> View|None:
|
274
310
|
if self.shape == new_shape: return self
|
275
311
|
|
276
312
|
if not all(x >= 0 for x in new_shape): raise ValueError(f"shape can't contain negative numbers {new_shape}")
|