tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
|
|
1
|
+
# a python uops emulator
|
2
|
+
# works to test the tensor cores, and all the uops in general
|
3
|
+
# this is the (living) definition of uops
|
4
|
+
from typing import Tuple, List, Optional, Any, Dict
|
5
|
+
import pickle, base64, itertools, time, struct
|
6
|
+
from tinygrad.dtype import DType, dtypes, ImageDType
|
7
|
+
from tinygrad.helpers import all_same, getenv, flatten
|
8
|
+
from tinygrad.device import Compiled, Compiler, Allocator
|
9
|
+
from tinygrad.codegen.uops import UOpGraph, UOps
|
10
|
+
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
11
|
+
from tinygrad.renderer import Renderer
|
12
|
+
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
|
13
|
+
|
14
|
+
def _load(m, i):
|
15
|
+
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
16
|
+
return m[i]
|
17
|
+
|
18
|
+
def load(inp, j=0):
|
19
|
+
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
|
20
|
+
else: return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
21
|
+
|
22
|
+
def _store(m, i, v):
|
23
|
+
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
24
|
+
m[i] = v
|
25
|
+
|
26
|
+
class PythonProgram:
|
27
|
+
def __init__(self, name:str, lib:bytes):
|
28
|
+
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
29
|
+
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
30
|
+
st = time.perf_counter()
|
31
|
+
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
32
|
+
warp_size = len(warp)
|
33
|
+
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
|
34
|
+
ul: Dict[int, Any] = {}
|
35
|
+
dl: Dict[int, DType] = {}
|
36
|
+
pbufs: List[memoryview] = list(bufs)
|
37
|
+
pvals: List[int] = list(vals)
|
38
|
+
i = 0
|
39
|
+
loop_ends: Dict[int, int] = {}
|
40
|
+
while i < len(self.uops):
|
41
|
+
uop, dtype, idp, arg = self.uops[i]
|
42
|
+
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
43
|
+
if uop is UOps.DEFINE_ACC: idp.clear()
|
44
|
+
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
45
|
+
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
46
|
+
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
47
|
+
if uop is UOps.STORE:
|
48
|
+
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
|
49
|
+
if isinstance(dtp[0], ImageDType):
|
50
|
+
# image store
|
51
|
+
assert dtp[2].count == 4
|
52
|
+
for j,val in enumerate(inp[2]):
|
53
|
+
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
|
54
|
+
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
55
|
+
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
56
|
+
elif dtp[2].count > 1:
|
57
|
+
for j,val in enumerate(inp[2]):
|
58
|
+
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
|
59
|
+
if g: _store(m, o+j, v)
|
60
|
+
else:
|
61
|
+
for m,o,v,g in zip(*inp):
|
62
|
+
if g: _store(m, o, v)
|
63
|
+
i += 1
|
64
|
+
continue
|
65
|
+
elif uop is UOps.ENDRANGE:
|
66
|
+
loop_ends[idp[0]] = i
|
67
|
+
i = idp[0]
|
68
|
+
continue
|
69
|
+
elif uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
|
70
|
+
# in the python emulator, the warp is always in sync
|
71
|
+
i += 1
|
72
|
+
continue
|
73
|
+
assert dtype is not None, f"{uop} is missing a dtype"
|
74
|
+
dl[i] = dtype
|
75
|
+
if uop is UOps.DEFINE_GLOBAL:
|
76
|
+
assert dtype.fmt is not None
|
77
|
+
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
78
|
+
elif uop is UOps.DEFINE_LOCAL:
|
79
|
+
assert dtype.fmt is not None
|
80
|
+
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
81
|
+
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
82
|
+
elif uop is UOps.DEFINE_VAR:
|
83
|
+
ul[i] = [pvals.pop(0)] * warp_size
|
84
|
+
elif uop is UOps.SPECIAL:
|
85
|
+
if arg[1][0] == 'g':
|
86
|
+
ul[i] = [idxs[2-arg[0]]] * warp_size
|
87
|
+
elif arg[1][0] == 'l':
|
88
|
+
ul[i] = [x[2-arg[0]] for x in warp]
|
89
|
+
elif uop is UOps.CONST:
|
90
|
+
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
91
|
+
elif uop is UOps.DEFINE_ACC:
|
92
|
+
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
|
93
|
+
elif uop is UOps.RANGE:
|
94
|
+
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
95
|
+
else:
|
96
|
+
for j in range(len(ul[i])):
|
97
|
+
ul[i][j] += 1
|
98
|
+
if ul[i][0] == inp[1][0]:
|
99
|
+
del ul[i]
|
100
|
+
i = loop_ends[i] + 1
|
101
|
+
continue
|
102
|
+
elif uop in {UOps.CAST, UOps.BITCAST}:
|
103
|
+
if dtype.count > 1: ul[i] = inp
|
104
|
+
else:
|
105
|
+
assert dtp[0].fmt and dtype.fmt
|
106
|
+
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
107
|
+
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
108
|
+
else:
|
109
|
+
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
110
|
+
overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
|
111
|
+
overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
|
112
|
+
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
|
113
|
+
elif uop is UOps.LOAD:
|
114
|
+
if isinstance(dtp[0], ImageDType):
|
115
|
+
assert dtype.count == 4
|
116
|
+
ul[i] = []
|
117
|
+
for j in range(dtype.count):
|
118
|
+
ret = []
|
119
|
+
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
120
|
+
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
121
|
+
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
122
|
+
ul[i].append(ret)
|
123
|
+
elif dtype.count > 1:
|
124
|
+
ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
125
|
+
else:
|
126
|
+
ul[i] = load(inp)
|
127
|
+
elif uop is UOps.PHI:
|
128
|
+
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
129
|
+
ul[i] = inp[0]
|
130
|
+
elif uop is UOps.GEP:
|
131
|
+
ul[i] = inp[0][arg]
|
132
|
+
elif uop is UOps.WMMA:
|
133
|
+
# here are the models for the WMMA instruction on the different hardware
|
134
|
+
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
135
|
+
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
|
136
|
+
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
|
137
|
+
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
|
138
|
+
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
139
|
+
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
140
|
+
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
141
|
+
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
|
142
|
+
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
|
143
|
+
for goff in range(0, warp_size, WARP_THREADS):
|
144
|
+
for lane_id in range(WARP_THREADS):
|
145
|
+
for elem_idx in range(NUM_C): # calculate new muls and add to acc
|
146
|
+
(c_i, c_j) = c_map(lane_id, elem_idx)
|
147
|
+
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
148
|
+
return out
|
149
|
+
|
150
|
+
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
151
|
+
if arg[5] == "METAL":
|
152
|
+
# A (2 elements on 32 threads): row major
|
153
|
+
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
154
|
+
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
155
|
+
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
156
|
+
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
157
|
+
elif arg[5] == "HSA":
|
158
|
+
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
159
|
+
def a_elem(x, i, j, goff):
|
160
|
+
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
161
|
+
return x[i][goff+j]
|
162
|
+
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
|
163
|
+
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
|
164
|
+
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
165
|
+
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
166
|
+
elif arg[5] == "CUDA":
|
167
|
+
# A (8 elements on 32 threads)
|
168
|
+
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
169
|
+
# B (4 elements on 32 threads)
|
170
|
+
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
|
171
|
+
# (i, j), C, D (4 elements on 32 threads)
|
172
|
+
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
|
173
|
+
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
174
|
+
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
175
|
+
elif uop is UOps.ALU:
|
176
|
+
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
177
|
+
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
178
|
+
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
|
179
|
+
assert i in ul, (uop, dtype, idp, arg)
|
180
|
+
i += 1
|
181
|
+
return time.perf_counter() - st
|
182
|
+
|
183
|
+
class PythonRenderer(Renderer):
|
184
|
+
device = "PYTHON"
|
185
|
+
def __init__(self):
|
186
|
+
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
187
|
+
if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
|
188
|
+
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
189
|
+
|
190
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
191
|
+
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
192
|
+
return base64.b64encode(pickle.dumps(lops)).decode()
|
193
|
+
|
194
|
+
class PythonCompiler(Compiler):
|
195
|
+
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
196
|
+
|
197
|
+
class PythonAllocator(Allocator):
|
198
|
+
def _alloc(self, size, options): return memoryview(bytearray(size))
|
199
|
+
def copyin(self, dest, src:memoryview): dest[:] = src
|
200
|
+
def copyout(self, dest:memoryview, src): dest[:] = src
|
201
|
+
|
202
|
+
class PythonDevice(Compiled):
|
203
|
+
def __init__(self, device:str):
|
204
|
+
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -1,286 +1,118 @@
|
|
1
1
|
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
2
2
|
from __future__ import annotations
|
3
|
-
import
|
4
|
-
from typing import
|
5
|
-
from tinygrad.helpers import
|
6
|
-
from tinygrad.shape.symbolic import Variable, MulNode, NumNode,
|
7
|
-
|
8
|
-
|
9
|
-
def
|
10
|
-
assert len(
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
@functools.lru_cache(maxsize=None)
|
22
|
-
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
23
|
-
|
24
|
-
@functools.lru_cache(maxsize=None)
|
25
|
-
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
26
|
-
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
27
|
-
|
28
|
-
class ViewInternal(NamedTuple):
|
29
|
-
shape:Tuple[int, ...]
|
30
|
-
strides:Tuple[int, ...]
|
31
|
-
offset:int
|
32
|
-
mask:Optional[Tuple[Tuple[int, int]]]
|
33
|
-
contiguous:bool
|
34
|
-
shape_strides:Tuple[Tuple[int, int], ...]
|
35
|
-
|
36
|
-
@functools.lru_cache(maxsize=None)
|
37
|
-
class View(ViewInternal):
|
38
|
-
def __new__(cls, shape, strides=None, offset=0, mask=None):
|
39
|
-
strides_from_shape = strides_for_shape(shape)
|
40
|
-
strides = strides_from_shape if not strides else filter_strides(shape, strides)
|
41
|
-
contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None
|
42
|
-
return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides))
|
43
|
-
def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__()
|
44
|
-
|
45
|
-
def expr_node_mask(self, idx, valid=None) -> Node:
|
46
|
-
expr = [valid] if valid is not None else []
|
47
|
-
if self.mask is not None:
|
48
|
-
acc = 1
|
49
|
-
for ns,(x,y) in reversed(list(zip(self.shape, self.mask))):
|
50
|
-
base = ((idx//acc) % ns)
|
51
|
-
expr += [base >= x, base < y]
|
52
|
-
acc *= ns
|
53
|
-
return Variable.ands(expr)
|
54
|
-
|
55
|
-
# generate an expression if you have a single idx variable
|
56
|
-
def expr_node(self, idx=None) -> Node:
|
57
|
-
if idx is None: idx = Variable('idx', 0, prod(self.shape)-1)
|
58
|
-
ret: List[Node] = [Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] if self.offset else []
|
59
|
-
acc = 1
|
60
|
-
for d,s in reversed(self.shape_strides):
|
61
|
-
ret.append(((idx//acc)%d)*s)
|
62
|
-
acc *= d
|
63
|
-
return Variable.sum(ret)
|
64
|
-
|
65
|
-
# generate an expression if you have a variable or expression for each index
|
66
|
-
def expr_idxs(self, idxs) -> Node:
|
67
|
-
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
68
|
-
return Variable.sum([Variable.num(self.offset) if isinstance(self.offset, int) else self.offset] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
69
|
-
|
70
|
-
@functools.lru_cache(maxsize=None)
|
71
|
-
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
72
|
-
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
73
|
-
acc = 1
|
74
|
-
ret = []
|
75
|
-
for tidx,d in reversed(list(zip(idxs, shape))):
|
76
|
-
ret.append(tidx * acc)
|
77
|
-
acc *= d
|
78
|
-
return Variable.sum(ret)
|
79
|
-
|
80
|
-
@functools.lru_cache(maxsize=None)
|
81
|
-
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
82
|
-
strides = [1] if shape else []
|
83
|
-
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
84
|
-
return filter_strides(shape, tuple(strides))
|
85
|
-
|
86
|
-
@functools.lru_cache(maxsize=None)
|
87
|
-
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
88
|
-
if vm2.mask: return None # this isn't supported yet
|
89
|
-
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
90
|
-
strides = mst.real_strides()
|
91
|
-
if None in strides: return None
|
92
|
-
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
|
93
|
-
|
94
|
-
@functools.lru_cache(maxsize=None)
|
95
|
-
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
96
|
-
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
|
97
|
-
# check if this is adding or removing 1s (only)
|
98
|
-
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
99
|
-
if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
|
100
|
-
new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
|
101
|
-
new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
102
|
-
new_mask_tuple = None
|
103
|
-
if mask:
|
104
|
-
for x,y in zip(shape, mask):
|
105
|
-
if x == 1 and y != (0, 1):
|
106
|
-
new_mask_tuple = ((0,0),) * len(new_shape)
|
107
|
-
break
|
108
|
-
else:
|
109
|
-
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
|
110
|
-
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
111
|
-
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
5
|
+
from tinygrad.helpers import merge_dicts, getenv
|
6
|
+
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
|
7
|
+
from tinygrad.shape.view import View, strides_for_shape
|
8
|
+
|
9
|
+
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
10
|
+
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
11
|
+
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
|
12
|
+
vexpr: List[Node] = [valid] if valid is not None else []
|
13
|
+
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
14
|
+
if sh != 1 and st != 0: iexpr.append(idx*st)
|
15
|
+
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
16
|
+
return Node.sum(iexpr), Node.ands(vexpr)
|
17
|
+
|
18
|
+
@dataclass(frozen=True)
|
19
|
+
class ShapeTracker:
|
20
|
+
views: Tuple[View, ...]
|
112
21
|
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
return new_view, True
|
22
|
+
def __add__(self, st:ShapeTracker) -> ShapeTracker:
|
23
|
+
ret = self
|
24
|
+
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
25
|
+
return ret
|
118
26
|
|
119
|
-
|
120
|
-
|
121
|
-
|
27
|
+
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
|
28
|
+
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
|
29
|
+
return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
|
122
30
|
|
123
|
-
@
|
124
|
-
def
|
125
|
-
return sum([s * x[0] for s, x in zip(strides,arg)])
|
126
|
-
|
127
|
-
class ShapeTracker:
|
128
|
-
__slots__ = "views", "var_vals"
|
129
|
-
def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None):
|
130
|
-
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)])
|
131
|
-
self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {}
|
132
|
-
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})"
|
133
|
-
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
31
|
+
@staticmethod
|
32
|
+
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
134
33
|
|
135
34
|
@property
|
136
35
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
137
36
|
|
138
37
|
@property
|
139
|
-
def
|
38
|
+
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
39
|
+
|
40
|
+
@property
|
41
|
+
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
140
42
|
|
141
43
|
@property
|
142
|
-
def
|
44
|
+
def size(self) -> int: return self.views[-1].size()
|
45
|
+
|
46
|
+
def real_size(self) -> int:
|
47
|
+
if 0 in self.shape: return 0
|
48
|
+
idx, valid = self.expr_idxs()
|
49
|
+
if not valid: return 0
|
50
|
+
# TODO: it's possible that the real_size is smaller condition on valid being true
|
51
|
+
ret = idx.max
|
52
|
+
if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
|
53
|
+
assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
|
54
|
+
return ret+1
|
143
55
|
|
144
|
-
|
145
|
-
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
56
|
+
def vars(self) -> Set[Variable]: return set.union(*[v.vars() for v in self.views], set())
|
146
57
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
return
|
58
|
+
@property
|
59
|
+
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
60
|
+
|
61
|
+
def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
|
62
|
+
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
63
|
+
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
153
64
|
|
154
65
|
# NOTE: if a stride is not always valid, it will be None
|
155
|
-
def real_strides(self, ignore_valid=False) -> Tuple[Optional[
|
66
|
+
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
156
67
|
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
157
|
-
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
68
|
+
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
158
69
|
idx, valid = self.expr_idxs(idxs)
|
159
|
-
ret: List[Optional[
|
70
|
+
ret: List[Optional[sint]] = [None] * len(self.views[-1].shape)
|
71
|
+
bad_idx_vars: Set[Variable] = set()
|
160
72
|
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
ret[idxs.index(this_dim)] = 1
|
73
|
+
idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
|
74
|
+
try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
|
75
|
+
except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
|
165
76
|
idx_vars, valid_vars = idx.vars(), valid.vars()
|
166
77
|
for i,tidx in enumerate(idxs):
|
167
|
-
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
78
|
+
if tidx in bad_idx_vars or (tidx in valid_vars and not ignore_valid): ret[i] = None
|
168
79
|
elif tidx not in idx_vars: ret[i] = 0
|
169
80
|
return tuple(ret)
|
81
|
+
|
170
82
|
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
171
83
|
|
172
|
-
def
|
173
|
-
for
|
174
|
-
|
175
|
-
|
84
|
+
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
85
|
+
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
86
|
+
idx, valid = _expr_view(self.views[-1], idxs)
|
87
|
+
for view in reversed(self.views[0:-1]):
|
88
|
+
if valid.max == 0: return NumNode(-1), valid
|
89
|
+
view = view.minify()
|
90
|
+
acc, idxs = 1, []
|
91
|
+
for d in reversed(view.shape):
|
92
|
+
idxs.append((idx//acc)%d)
|
93
|
+
acc *= d
|
94
|
+
idx, valid = _expr_view(view, idxs[::-1], valid)
|
95
|
+
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
96
|
+
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
176
97
|
return idx, valid
|
177
98
|
|
178
|
-
def
|
179
|
-
|
180
|
-
|
181
|
-
if new_view:
|
182
|
-
if DEBUG >= 4: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
183
|
-
self.views = self.views[:-2] + [new_view]
|
184
|
-
self.simplify()
|
185
|
-
|
186
|
-
def expr_idxs(self, idxs=None):
|
187
|
-
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
188
|
-
idx = self.views[-1].expr_idxs(tuple(idxs))
|
189
|
-
valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
190
|
-
return self._expr_idx(idx, valid)
|
191
|
-
|
192
|
-
def expr_node(self, idx='idx'):
|
193
|
-
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
|
194
|
-
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
|
195
|
-
|
196
|
-
def needs_valid(self) -> bool:
|
197
|
-
return any(v.mask is not None for v in self.views)
|
198
|
-
|
199
|
-
# *** under this line are the movement ops ***
|
200
|
-
|
201
|
-
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):
|
202
|
-
offset = get_unsafe_resize_offset(self.views[-1].strides, arg)
|
203
|
-
if self.views[-1].mask:
|
204
|
-
# move the old mask
|
205
|
-
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)])
|
206
|
-
# merge the masks if we have two
|
207
|
-
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
208
|
-
self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask)
|
99
|
+
def axis_is_masked(self, axis:int) -> bool:
|
100
|
+
_, valid = self.expr_idxs()
|
101
|
+
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
209
102
|
|
210
|
-
def
|
211
|
-
|
212
|
-
|
213
|
-
zvarg, mask = get_pad_args(self.shape, arg)
|
214
|
-
self.__unsafe_resize(zvarg, mask=mask)
|
103
|
+
def simplify(self) -> ShapeTracker:
|
104
|
+
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
105
|
+
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
215
106
|
return self
|
216
107
|
|
217
|
-
|
218
|
-
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
219
|
-
self.__unsafe_resize(arg)
|
220
|
-
return self
|
221
|
-
|
222
|
-
def expand(self, new_shape: Tuple[Union[Node,int], ...]) -> ShapeTracker:
|
223
|
-
assert len(new_shape) == len(self.views[-1].shape)
|
224
|
-
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
|
225
|
-
# NOTE: can the mask ever be (0,0)?
|
226
|
-
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
|
227
|
-
self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)
|
228
|
-
return self
|
229
|
-
|
230
|
-
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
|
231
|
-
if self.views[-1].shape == new_shape: return self
|
232
|
-
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
|
233
|
-
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
234
|
-
# reshape from all int shape into shape with a variable, update the variable value
|
235
|
-
assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape"
|
236
|
-
new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints)
|
237
|
-
if new_var not in self.var_vals:
|
238
|
-
assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]"
|
239
|
-
self.var_vals[new_var] = new_val
|
240
|
-
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
241
|
-
elif not new_nodes: self.var_vals = {}
|
242
|
-
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
243
|
-
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
244
|
-
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
|
245
|
-
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # type: ignore # mypy cannot resolve, all ints here
|
246
|
-
new_view, extra = _reshape(self.views[-1], new_shape)
|
247
|
-
if extra: self.views.append(new_view)
|
248
|
-
else: self.views[-1] = new_view
|
249
|
-
return self
|
250
|
-
|
251
|
-
def permute(self, axis: Tuple[int, ...]):
|
252
|
-
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
253
|
-
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
254
|
-
self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None)
|
255
|
-
return self
|
108
|
+
# *** under this line are the movement ops ***
|
256
109
|
|
257
|
-
|
258
|
-
def
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0])
|
263
|
-
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None
|
264
|
-
self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask)
|
265
|
-
return self
|
110
|
+
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg), ))
|
111
|
+
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
112
|
+
def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
113
|
+
def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
114
|
+
def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
|
266
115
|
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
271
|
-
# Index for new_shape and axis_groups.
|
272
|
-
i: int = 0
|
273
|
-
old_shape_i: int = 0
|
274
|
-
while old_shape_i < len(old_shape):
|
275
|
-
# 1s exist in new_shape only will lead to empty axes group creations.
|
276
|
-
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
277
|
-
if i < len(new_shape) - 1: i += 1
|
278
|
-
else:
|
279
|
-
axis_groups[i].append(old_shape_i)
|
280
|
-
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
|
281
|
-
# Move to next axes group if total size of all dimensions match.
|
282
|
-
if axis_group_size == new_shape[i]:
|
283
|
-
if i < len(new_shape) - 1: i += 1
|
284
|
-
elif axis_group_size > new_shape[i]: return None
|
285
|
-
old_shape_i += 1
|
286
|
-
return axis_groups
|
116
|
+
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
117
|
+
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
118
|
+
return ShapeTracker(self.views + (View.create(new_shape), ))
|