tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
# pylint: disable=cell-var-from-loop
|
2
|
+
# a python uops emulator
|
3
|
+
# works to test the tensor cores, and all the uops in general
|
4
|
+
# this is the (living) definition of uops
|
5
|
+
from typing import Tuple, List, Optional, Any, Dict
|
6
|
+
import pickle, base64, itertools, time, struct
|
7
|
+
from tinygrad.dtype import DType, dtypes, ImageDType
|
8
|
+
from tinygrad.helpers import all_same, getenv, flatten
|
9
|
+
from tinygrad.device import Compiled, Compiler, Allocator
|
10
|
+
from tinygrad.codegen.uops import UOpGraph, UOps
|
11
|
+
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
12
|
+
from tinygrad.renderer import Renderer
|
13
|
+
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
14
|
+
|
15
|
+
def _load(m, i):
|
16
|
+
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
17
|
+
return m[i]
|
18
|
+
|
19
|
+
def load(inp, j=0):
|
20
|
+
if len(inp) == 4: return [_load(m, x+j) if gate else default for m,x,gate,default in zip(*inp)]
|
21
|
+
return [_load(m, x+j) for m,x in zip(inp[0], inp[1])]
|
22
|
+
|
23
|
+
def _store(m, i, v):
|
24
|
+
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
25
|
+
m[i] = v
|
26
|
+
|
27
|
+
class PythonProgram:
|
28
|
+
def __init__(self, name:str, lib:bytes):
|
29
|
+
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
30
|
+
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
31
|
+
st = time.perf_counter()
|
32
|
+
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
33
|
+
warp_size = len(warp)
|
34
|
+
for idxs in itertools.product(*[range(x) for x in global_size[::-1]]):
|
35
|
+
ul: Dict[int, Any] = {}
|
36
|
+
dl: Dict[int, DType] = {}
|
37
|
+
pbufs: List[memoryview] = list(bufs)
|
38
|
+
pvals: List[int] = list(vals)
|
39
|
+
i = 0
|
40
|
+
loop_ends: Dict[int, int] = {}
|
41
|
+
while i < len(self.uops):
|
42
|
+
uop, dtype, idp, arg = self.uops[i]
|
43
|
+
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
44
|
+
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
|
45
|
+
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
46
|
+
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
47
|
+
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
48
|
+
if uop is UOps.STORE:
|
49
|
+
if len(inp) == 3: inp.append([True] * len(inp[0])) # set the gate to True
|
50
|
+
if isinstance(dtp[0], ImageDType):
|
51
|
+
# image store
|
52
|
+
assert dtp[2].count == 4
|
53
|
+
for j,val in enumerate(inp[2]):
|
54
|
+
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
|
55
|
+
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
56
|
+
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
57
|
+
elif dtp[2].count > 1:
|
58
|
+
for j,val in enumerate(inp[2]):
|
59
|
+
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
|
60
|
+
if g: _store(m, o+j, v)
|
61
|
+
else:
|
62
|
+
for m,o,v,g in zip(*inp):
|
63
|
+
if g: _store(m, o, v)
|
64
|
+
i += 1
|
65
|
+
continue
|
66
|
+
if uop is UOps.ENDRANGE:
|
67
|
+
loop_ends[idp[0]] = i
|
68
|
+
i = idp[0]
|
69
|
+
continue
|
70
|
+
if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
|
71
|
+
# in the python emulator, the warp is always in sync
|
72
|
+
i += 1
|
73
|
+
continue
|
74
|
+
assert dtype is not None, f"{uop} is missing a dtype"
|
75
|
+
dl[i] = dtype
|
76
|
+
if uop is UOps.DEFINE_GLOBAL:
|
77
|
+
assert dtype.fmt is not None
|
78
|
+
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
79
|
+
elif uop is UOps.DEFINE_LOCAL:
|
80
|
+
assert dtype.fmt is not None
|
81
|
+
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
82
|
+
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
83
|
+
elif uop is UOps.DEFINE_VAR:
|
84
|
+
ul[i] = [pvals.pop(0)] * warp_size
|
85
|
+
elif uop is UOps.SPECIAL:
|
86
|
+
if arg[1][0] == 'g':
|
87
|
+
ul[i] = [idxs[2-arg[0]]] * warp_size
|
88
|
+
elif arg[1][0] == 'l':
|
89
|
+
ul[i] = [x[2-arg[0]] for x in warp]
|
90
|
+
elif uop is UOps.CONST:
|
91
|
+
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
92
|
+
elif uop is UOps.DEFINE_ACC:
|
93
|
+
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
94
|
+
elif uop is UOps.RANGE:
|
95
|
+
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
96
|
+
else:
|
97
|
+
for j in range(len(ul[i])):
|
98
|
+
ul[i][j] += 1
|
99
|
+
if ul[i][0] == inp[1][0]:
|
100
|
+
del ul[i]
|
101
|
+
i = loop_ends[i] + 1
|
102
|
+
continue
|
103
|
+
elif uop in (UOps.CAST, UOps.BITCAST):
|
104
|
+
if dtype.count > 1: ul[i] = inp
|
105
|
+
else:
|
106
|
+
assert dtp[0].fmt and dtype.fmt
|
107
|
+
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
108
|
+
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
109
|
+
else:
|
110
|
+
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
111
|
+
if dtypes.is_int(dtype):
|
112
|
+
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
113
|
+
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
114
|
+
elif dtypes.is_float(dtype):
|
115
|
+
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
116
|
+
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
117
|
+
elif uop is UOps.LOAD:
|
118
|
+
if isinstance(dtp[0], ImageDType):
|
119
|
+
assert dtype.count == 4
|
120
|
+
ul[i] = []
|
121
|
+
for j in range(dtype.count):
|
122
|
+
ret = []
|
123
|
+
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
124
|
+
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
125
|
+
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
126
|
+
ul[i].append(ret)
|
127
|
+
elif dtype.count > 1:
|
128
|
+
ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
129
|
+
else:
|
130
|
+
ul[i] = load(inp)
|
131
|
+
elif uop is UOps.PHI:
|
132
|
+
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
133
|
+
ul[i] = inp[0]
|
134
|
+
elif uop is UOps.GEP:
|
135
|
+
ul[i] = inp[0][arg]
|
136
|
+
elif uop is UOps.WMMA:
|
137
|
+
# here are the models for the WMMA instruction on the different hardware
|
138
|
+
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
139
|
+
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
|
140
|
+
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
|
141
|
+
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
|
142
|
+
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
143
|
+
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
144
|
+
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
145
|
+
assert warp_size > 0 and warp_size % WARP_THREADS == 0, f"must have multiples of {WARP_THREADS} warp threads"
|
146
|
+
out = [inp[2][elem_idx][:] for elem_idx in range(NUM_C)]
|
147
|
+
for goff in range(0, warp_size, WARP_THREADS):
|
148
|
+
for lane_id in range(WARP_THREADS):
|
149
|
+
for elem_idx in range(NUM_C): # calculate new muls and add to acc
|
150
|
+
(c_i, c_j) = c_map(lane_id, elem_idx)
|
151
|
+
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
152
|
+
return out
|
153
|
+
|
154
|
+
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
155
|
+
if arg[5] == "METAL":
|
156
|
+
# A (2 elements on 32 threads): row major
|
157
|
+
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
158
|
+
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
159
|
+
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
160
|
+
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
161
|
+
elif arg[5] == "AMD":
|
162
|
+
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
163
|
+
def a_elem(x, i, j, goff):
|
164
|
+
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
165
|
+
return x[i][goff+j]
|
166
|
+
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
|
167
|
+
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
|
168
|
+
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
169
|
+
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
170
|
+
elif arg[5] == "CUDA":
|
171
|
+
# A (8 elements on 32 threads)
|
172
|
+
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
173
|
+
# B (4 elements on 32 threads)
|
174
|
+
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
|
175
|
+
# (i, j), C, D (4 elements on 32 threads)
|
176
|
+
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
|
177
|
+
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
178
|
+
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
179
|
+
elif uop is UOps.ALU:
|
180
|
+
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
181
|
+
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
182
|
+
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
|
183
|
+
assert i in ul, (uop, dtype, idp, arg)
|
184
|
+
i += 1
|
185
|
+
return time.perf_counter() - st
|
186
|
+
|
187
|
+
class PythonRenderer(Renderer):
|
188
|
+
device = "PYTHON"
|
189
|
+
def __init__(self):
|
190
|
+
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
191
|
+
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
192
|
+
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
193
|
+
|
194
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
195
|
+
lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
|
196
|
+
return base64.b64encode(pickle.dumps(lops)).decode()
|
197
|
+
|
198
|
+
class PythonCompiler(Compiler):
|
199
|
+
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
200
|
+
|
201
|
+
class PythonAllocator(Allocator):
|
202
|
+
def _alloc(self, size, options): return memoryview(bytearray(size))
|
203
|
+
def copyin(self, dest, src:memoryview): dest[:] = src
|
204
|
+
def copyout(self, dest:memoryview, src): dest[:] = src
|
205
|
+
|
206
|
+
class PythonDevice(Compiled):
|
207
|
+
def __init__(self, device:str):
|
208
|
+
super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
File without changes
|
tinygrad/shape/shapetracker.py
CHANGED
@@ -1,70 +1,35 @@
|
|
1
1
|
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
2
2
|
from __future__ import annotations
|
3
|
-
import functools, itertools, operator
|
4
3
|
from dataclasses import dataclass
|
5
|
-
from typing import Tuple, List, Optional, Dict, Set,
|
6
|
-
from tinygrad.
|
7
|
-
from tinygrad.
|
8
|
-
from tinygrad.shape.
|
9
|
-
|
10
|
-
|
11
|
-
def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
|
12
|
-
expr = [valid] if valid is not None else []
|
13
|
-
if view.mask is not None:
|
14
|
-
acc = 1
|
15
|
-
for d,(x,y) in zip(reversed(view.shape), reversed(view.mask)):
|
16
|
-
if (x,y) != (0,d):
|
17
|
-
base = ((idx//acc)%d)
|
18
|
-
expr += [base >= x, base < y]
|
19
|
-
acc *= d
|
20
|
-
return Node.ands(expr)
|
21
|
-
|
22
|
-
# generate an expression if you have a single idx variable
|
23
|
-
def expr_node(view:View, idx:Optional[Node]=None) -> Node:
|
24
|
-
if idx is None: idx = Variable('idx', 0, prod(view.shape)-1)
|
25
|
-
ret: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else []
|
26
|
-
acc = 1
|
27
|
-
for d,s,_ in reversed(_merge_dims(view.shape, view.strides)):
|
28
|
-
ret.append(((idx//acc)%d)*s)
|
29
|
-
acc *= d
|
30
|
-
return Node.sum(ret)
|
31
|
-
|
32
|
-
# generate an expression if you have a variable or expression for each index
|
33
|
-
def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
|
4
|
+
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
5
|
+
from tinygrad.helpers import merge_dicts, getenv
|
6
|
+
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
|
7
|
+
from tinygrad.shape.view import View, strides_for_shape
|
8
|
+
|
9
|
+
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
34
10
|
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
|
42
|
-
if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
|
43
|
-
return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
|
44
|
-
|
45
|
-
@functools.lru_cache(maxsize=None)
|
46
|
-
def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
|
47
|
-
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
48
|
-
acc, ret = 1, []
|
49
|
-
for tidx,d in zip(reversed(idxs), reversed(shape)):
|
50
|
-
ret.append(tidx * acc)
|
51
|
-
acc *= d
|
52
|
-
return Node.sum(ret)
|
11
|
+
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
|
12
|
+
vexpr: List[Node] = [valid] if valid is not None else []
|
13
|
+
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
14
|
+
if sh != 1 and st != 0: iexpr.append(idx*st)
|
15
|
+
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
16
|
+
return Node.sum(iexpr), Node.ands(vexpr)
|
53
17
|
|
54
18
|
@dataclass(frozen=True)
|
55
19
|
class ShapeTracker:
|
56
20
|
views: Tuple[View, ...]
|
57
|
-
def __post_init__(self):
|
58
|
-
assert isinstance(self.views, tuple) and all(isinstance(v, View) for v in self.views), "ShapeTracker must be created with a tuple of Views"
|
59
21
|
|
60
22
|
def __add__(self, st:ShapeTracker) -> ShapeTracker:
|
61
|
-
|
62
|
-
for v in st.views:
|
63
|
-
return
|
23
|
+
ret = self
|
24
|
+
for v in st.views: ret = ShapeTracker(ret.views + (v,)).simplify() # one view at a time = better simplification
|
25
|
+
return ret
|
64
26
|
|
65
27
|
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
|
66
|
-
|
67
|
-
|
28
|
+
inverted_views:List[View] = []
|
29
|
+
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
30
|
+
if (inverted:= v.invert(s)) is None: return None
|
31
|
+
inverted_views.append(inverted)
|
32
|
+
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
68
33
|
|
69
34
|
@staticmethod
|
70
35
|
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
@@ -72,16 +37,22 @@ class ShapeTracker:
|
|
72
37
|
@property
|
73
38
|
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
74
39
|
|
40
|
+
@property
|
41
|
+
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
|
42
|
+
|
75
43
|
@property
|
76
44
|
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
77
45
|
|
78
46
|
@property
|
79
|
-
def size(self) -> int: return
|
47
|
+
def size(self) -> int: return self.views[-1].size()
|
80
48
|
|
81
49
|
def real_size(self) -> int:
|
82
50
|
if 0 in self.shape: return 0
|
83
|
-
|
84
|
-
|
51
|
+
idx, valid = self.expr_idxs()
|
52
|
+
if not valid: return 0
|
53
|
+
# TODO: it's possible that the real_size is smaller condition on valid being true
|
54
|
+
ret = idx.max
|
55
|
+
if not isinstance(ret, int): ret = ret.max # might be represent by symbolic shape, take one more max for int max
|
85
56
|
assert isinstance(ret, int), f"ret must be integer, {ret=} isn't"
|
86
57
|
return ret+1
|
87
58
|
|
@@ -90,30 +61,9 @@ class ShapeTracker:
|
|
90
61
|
@property
|
91
62
|
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
92
63
|
|
93
|
-
def unbind(self) -> ShapeTracker
|
94
|
-
|
95
|
-
|
96
|
-
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
97
|
-
for v in self.views:
|
98
|
-
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
99
|
-
real_offset = 0 if 0 in real_shape else (v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0))
|
100
|
-
# first, we apply the offset
|
101
|
-
# then, we make it the correct shape
|
102
|
-
# then, we apply permutations
|
103
|
-
to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset)))
|
104
|
-
# then, we apply pre expand pads
|
105
|
-
if v.mask is not None:
|
106
|
-
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
107
|
-
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
108
|
-
if any(x != (0,0) for x in pre_expand_pads):
|
109
|
-
to_apply.append((MovementOps.PAD, pre_expand_pads))
|
110
|
-
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
111
|
-
# then, we do any expands
|
112
|
-
# NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy
|
113
|
-
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
|
114
|
-
# lastly, we apply post expand pads
|
115
|
-
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
|
116
|
-
return to_apply
|
64
|
+
def unbind(self) -> Tuple[ShapeTracker, Dict[Variable, int]]:
|
65
|
+
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
66
|
+
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
117
67
|
|
118
68
|
# NOTE: if a stride is not always valid, it will be None
|
119
69
|
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
@@ -124,7 +74,7 @@ class ShapeTracker:
|
|
124
74
|
bad_idx_vars: Set[Variable] = set()
|
125
75
|
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
126
76
|
idx_maybe, stride_maybe = (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1)
|
127
|
-
try: ret[idxs.index(idx_maybe)] = stride_maybe
|
77
|
+
try: ret[idxs.index(idx_maybe)] = cast(sint, stride_maybe)
|
128
78
|
except ValueError: bad_idx_vars = bad_idx_vars.union(idx_maybe.vars())
|
129
79
|
idx_vars, valid_vars = idx.vars(), valid.vars()
|
130
80
|
for i,tidx in enumerate(idxs):
|
@@ -134,30 +84,27 @@ class ShapeTracker:
|
|
134
84
|
|
135
85
|
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
136
86
|
|
137
|
-
def
|
138
|
-
for
|
87
|
+
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
88
|
+
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
89
|
+
idx, valid = _expr_view(self.views[-1], idxs)
|
90
|
+
for view in reversed(self.views[0:-1]):
|
139
91
|
if valid.max == 0: return NumNode(-1), valid
|
140
|
-
|
141
|
-
|
92
|
+
view = view.minify()
|
93
|
+
acc, idxs = 1, []
|
94
|
+
for d in reversed(view.shape):
|
95
|
+
idxs.append((idx//acc)%d)
|
96
|
+
acc *= d
|
97
|
+
idx, valid = _expr_view(view, idxs[::-1], valid)
|
98
|
+
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
99
|
+
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
142
100
|
return idx, valid
|
143
101
|
|
144
|
-
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None):
|
145
|
-
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
146
|
-
idx = expr_idxs(self.views[-1], tuple(idxs))
|
147
|
-
valid = expr_node_mask(self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
148
|
-
return self._expr_idx(idx, valid)
|
149
|
-
|
150
|
-
def expr_node(self, idx:Union[Node,str]='idx'):
|
151
|
-
if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1)
|
152
|
-
return self._expr_idx(expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx))
|
153
|
-
|
154
102
|
def axis_is_masked(self, axis:int) -> bool:
|
155
103
|
_, valid = self.expr_idxs()
|
156
104
|
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
157
105
|
|
158
106
|
def simplify(self) -> ShapeTracker:
|
159
|
-
if len(self.views) >= 2 and (new_view :=
|
160
|
-
if DEBUG >= 5: print(f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}")
|
107
|
+
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
161
108
|
return ShapeTracker(self.views[:-2] + (new_view,)).simplify()
|
162
109
|
return self
|
163
110
|
|
@@ -172,11 +119,3 @@ class ShapeTracker:
|
|
172
119
|
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
173
120
|
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
174
121
|
return ShapeTracker(self.views + (View.create(new_shape), ))
|
175
|
-
|
176
|
-
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
177
|
-
# TODO: if we remove movementops from lazy.py we can delete this
|
178
|
-
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
179
|
-
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
180
|
-
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
181
|
-
except ValueError: return None
|
182
|
-
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|