tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_python.py
CHANGED
@@ -4,21 +4,21 @@
|
|
4
4
|
# this is the (living) definition of uops
|
5
5
|
from typing import Tuple, List, Optional, Any, Dict
|
6
6
|
import pickle, base64, itertools, time, struct
|
7
|
-
from tinygrad.dtype import DType, dtypes, ImageDType
|
7
|
+
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
8
8
|
from tinygrad.helpers import all_same, getenv, flatten
|
9
9
|
from tinygrad.device import Compiled, Compiler, Allocator
|
10
|
-
from tinygrad.
|
11
|
-
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
|
10
|
+
from tinygrad.ops import exec_alu, Ops, UOp, GroupOp
|
12
11
|
from tinygrad.renderer import Renderer
|
13
|
-
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
12
|
+
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
14
13
|
|
15
14
|
def _load(m, i):
|
15
|
+
if i is None: return 0.0
|
16
16
|
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
17
17
|
return m[i]
|
18
18
|
|
19
19
|
def load(inp, j=0):
|
20
|
-
if len(inp) ==
|
21
|
-
return [_load(m, x+j) for m,x in
|
20
|
+
if len(inp) == 3: return [_load(m, x+j if x is not None else None) if gate else default for (m,x),default,gate in zip(*inp)]
|
21
|
+
return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
|
22
22
|
|
23
23
|
def _store(m, i, v):
|
24
24
|
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
@@ -26,7 +26,7 @@ def _store(m, i, v):
|
|
26
26
|
|
27
27
|
class PythonProgram:
|
28
28
|
def __init__(self, name:str, lib:bytes):
|
29
|
-
self.uops: List[Tuple[
|
29
|
+
self.uops: List[Tuple[Ops, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
30
30
|
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
31
31
|
st = time.perf_counter()
|
32
32
|
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
@@ -40,58 +40,59 @@ class PythonProgram:
|
|
40
40
|
loop_ends: Dict[int, int] = {}
|
41
41
|
while i < len(self.uops):
|
42
42
|
uop, dtype, idp, arg = self.uops[i]
|
43
|
-
void_ops = {
|
44
|
-
if uop is
|
43
|
+
void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
|
44
|
+
if uop is Ops.DEFINE_ACC: idp = [idp[0]]
|
45
45
|
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
46
46
|
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
47
47
|
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
48
|
-
if uop is
|
49
|
-
if len(inp) ==
|
50
|
-
if
|
51
|
-
|
52
|
-
|
53
|
-
for j,val in enumerate(inp[2]):
|
54
|
-
for m,ox,oy,v,g in zip(inp[0], inp[1][0], inp[1][1], val, inp[3]):
|
55
|
-
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
56
|
-
if g: _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
57
|
-
elif dtp[2].count > 1:
|
58
|
-
for j,val in enumerate(inp[2]):
|
59
|
-
for m,o,v,g in zip(inp[0], inp[1], val, inp[3]):
|
48
|
+
if uop is Ops.STORE:
|
49
|
+
if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
|
50
|
+
if dtp[1].count > 1:
|
51
|
+
for j,val in enumerate(inp[1]):
|
52
|
+
for (m,o),v,g in zip(inp[0], val, inp[2]):
|
60
53
|
if g: _store(m, o+j, v)
|
61
54
|
else:
|
62
|
-
for m,o,v,g in zip(*inp):
|
55
|
+
for (m,o),v,g in zip(*inp):
|
63
56
|
if g: _store(m, o, v)
|
64
57
|
i += 1
|
65
58
|
continue
|
66
|
-
if uop is
|
59
|
+
if uop is Ops.ENDRANGE:
|
67
60
|
loop_ends[idp[0]] = i
|
68
61
|
i = idp[0]
|
69
62
|
continue
|
70
|
-
if uop in (
|
63
|
+
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
|
71
64
|
# in the python emulator, the warp is always in sync
|
72
65
|
i += 1
|
73
66
|
continue
|
74
67
|
assert dtype is not None, f"{uop} is missing a dtype"
|
75
68
|
dl[i] = dtype
|
76
|
-
if uop is
|
69
|
+
if uop is Ops.DEFINE_GLOBAL:
|
77
70
|
assert dtype.fmt is not None
|
78
71
|
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
79
|
-
elif uop is
|
72
|
+
elif uop is Ops.DEFINE_LOCAL:
|
80
73
|
assert dtype.fmt is not None
|
81
74
|
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
82
75
|
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
83
|
-
elif uop is
|
76
|
+
elif uop is Ops.DEFINE_VAR:
|
84
77
|
ul[i] = [pvals.pop(0)] * warp_size
|
85
|
-
elif uop is
|
86
|
-
if arg[
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
78
|
+
elif uop is Ops.SPECIAL:
|
79
|
+
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
|
80
|
+
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
|
81
|
+
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
|
82
|
+
elif uop is Ops.DEFINE_ACC:
|
83
|
+
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
84
|
+
elif uop is Ops.INDEX:
|
85
|
+
ret = []
|
86
|
+
if isinstance(dtp[0], ImageDType):
|
87
|
+
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
88
|
+
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
|
89
|
+
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
|
90
|
+
else:
|
91
|
+
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
|
92
|
+
ul[i] = ret
|
93
|
+
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
94
|
+
ul[i] = inp[0]
|
95
|
+
elif uop is Ops.RANGE:
|
95
96
|
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
96
97
|
else:
|
97
98
|
for j in range(len(ul[i])):
|
@@ -100,45 +101,29 @@ class PythonProgram:
|
|
100
101
|
del ul[i]
|
101
102
|
i = loop_ends[i] + 1
|
102
103
|
continue
|
103
|
-
elif uop
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
|
113
|
-
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
|
114
|
-
elif dtypes.is_float(dtype):
|
115
|
-
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
116
|
-
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
117
|
-
elif uop is UOps.LOAD:
|
118
|
-
if isinstance(dtp[0], ImageDType):
|
119
|
-
assert dtype.count == 4
|
120
|
-
ul[i] = []
|
121
|
-
for j in range(dtype.count):
|
122
|
-
ret = []
|
123
|
-
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
124
|
-
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
125
|
-
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
126
|
-
ul[i].append(ret)
|
127
|
-
elif dtype.count > 1:
|
128
|
-
ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
104
|
+
elif uop is Ops.VECTORIZE: ul[i] = inp
|
105
|
+
elif uop in {Ops.CAST, Ops.BITCAST}:
|
106
|
+
assert dtp[0].fmt and dtype.fmt
|
107
|
+
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
108
|
+
if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
109
|
+
else: ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
110
|
+
elif uop is Ops.LOAD:
|
111
|
+
if dtype.count > 1:
|
112
|
+
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
129
113
|
else:
|
130
114
|
ul[i] = load(inp)
|
131
|
-
elif uop is
|
115
|
+
elif uop is Ops.ASSIGN:
|
132
116
|
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
133
117
|
ul[i] = inp[0]
|
134
|
-
elif uop is
|
135
|
-
|
136
|
-
|
118
|
+
elif uop is Ops.GEP:
|
119
|
+
assert len(arg) == 1
|
120
|
+
ul[i] = inp[0][arg[0]]
|
121
|
+
elif uop is Ops.WMMA:
|
137
122
|
# here are the models for the WMMA instruction on the different hardware
|
138
123
|
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
139
|
-
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread"
|
140
|
-
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread"
|
141
|
-
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread"
|
124
|
+
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
|
125
|
+
assert len(inp[1]) == NUM_B, f"B must have {NUM_B} elements per thread, it has {len(inp[1])}"
|
126
|
+
assert len(inp[2]) == NUM_C, f"C must have {NUM_C} elements per thread, it has {len(inp[2])}"
|
142
127
|
assert len(flatten(inp[0])) == NUM_A * warp_size, f"WMMA must have {NUM_A * warp_size} total elements for A in WMMA"
|
143
128
|
assert len(flatten(inp[1])) == NUM_B * warp_size, f"WMMA must have {NUM_B * warp_size} total elements for B in WMMA"
|
144
129
|
assert len(flatten(inp[2])) == NUM_C * warp_size, f"WMMA must have {NUM_C * warp_size} total elements for C in WMMA"
|
@@ -152,13 +137,13 @@ class PythonProgram:
|
|
152
137
|
return out
|
153
138
|
|
154
139
|
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
155
|
-
if arg[
|
140
|
+
if arg[4] == "METAL":
|
156
141
|
# A (2 elements on 32 threads): row major
|
157
142
|
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
158
143
|
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
159
144
|
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
160
145
|
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
161
|
-
elif arg[
|
146
|
+
elif arg[4] == "AMD":
|
162
147
|
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
163
148
|
def a_elem(x, i, j, goff):
|
164
149
|
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
@@ -167,7 +152,7 @@ class PythonProgram:
|
|
167
152
|
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
|
168
153
|
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
169
154
|
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
170
|
-
elif arg[
|
155
|
+
elif arg[4] == "CUDA":
|
171
156
|
# A (8 elements on 32 threads)
|
172
157
|
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
173
158
|
# B (4 elements on 32 threads)
|
@@ -175,11 +160,23 @@ class PythonProgram:
|
|
175
160
|
# (i, j), C, D (4 elements on 32 threads)
|
176
161
|
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
|
177
162
|
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
163
|
+
elif arg[4] == "INTEL":
|
164
|
+
# A (16 elements on 8 threads)
|
165
|
+
def a_elem(x, i, j, goff): return x[i%2+j*2][goff+i//2]
|
166
|
+
# B (16 elements on 8 threads)
|
167
|
+
def b_elem(x, i, j, goff): return x[j][goff+i]
|
168
|
+
# C, D (8 elements on 8 threads)
|
169
|
+
def c_map(lane, elem): return (lane, elem)
|
170
|
+
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
171
|
+
elif arg[4] == "CLANG":
|
172
|
+
def elem(x, i, j, _): return x[i+j][0]
|
173
|
+
def c_map(_, elem): return (elem%16, elem//16)
|
174
|
+
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
178
175
|
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
179
|
-
elif uop
|
180
|
-
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {
|
181
|
-
assert all_same([dtype] + dtp) or
|
182
|
-
ul[i] = [exec_alu(
|
176
|
+
elif uop in GroupOp.ALU:
|
177
|
+
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
|
178
|
+
assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
|
179
|
+
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
|
183
180
|
assert i in ul, (uop, dtype, idp, arg)
|
184
181
|
i += 1
|
185
182
|
return time.perf_counter() - st
|
@@ -190,9 +187,11 @@ class PythonRenderer(Renderer):
|
|
190
187
|
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
191
188
|
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
192
189
|
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
190
|
+
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
|
191
|
+
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
|
193
192
|
|
194
|
-
def render(self, name:str, uops:
|
195
|
-
lops = [(u.op, u.dtype, [uops.
|
193
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
194
|
+
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
|
196
195
|
return base64.b64encode(pickle.dumps(lops)).decode()
|
197
196
|
|
198
197
|
class PythonCompiler(Compiler):
|