tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_python.py
CHANGED
@@ -2,14 +2,14 @@
|
|
2
2
|
# a python uops emulator
|
3
3
|
# works to test the tensor cores, and all the uops in general
|
4
4
|
# this is the (living) definition of uops
|
5
|
-
from typing import
|
5
|
+
from typing import Any, TYPE_CHECKING
|
6
6
|
import pickle, base64, itertools, time, struct, sys
|
7
7
|
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
8
8
|
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
|
9
9
|
from tinygrad.device import Compiled, Compiler, Allocator
|
10
|
-
from tinygrad.
|
10
|
+
from tinygrad.codegen.opt import tc
|
11
|
+
from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp
|
11
12
|
from tinygrad.renderer import Renderer
|
12
|
-
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
13
13
|
|
14
14
|
def _load(m, i):
|
15
15
|
if i is None: return 0.0
|
@@ -17,8 +17,8 @@ def _load(m, i):
|
|
17
17
|
return m[i]
|
18
18
|
|
19
19
|
def load(inp, j=0):
|
20
|
-
if len(inp) ==
|
21
|
-
return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
|
20
|
+
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
|
21
|
+
return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]]
|
22
22
|
|
23
23
|
def _store(m, i, v):
|
24
24
|
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
|
@@ -26,7 +26,7 @@ def _store(m, i, v):
|
|
26
26
|
|
27
27
|
class PythonProgram:
|
28
28
|
def __init__(self, name:str, lib:bytes):
|
29
|
-
self.uops: list[tuple[Ops,
|
29
|
+
self.uops: list[tuple[Ops, DType|None, list[int], Any]] = pickle.loads(lib)
|
30
30
|
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
31
31
|
st = time.perf_counter()
|
32
32
|
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
@@ -40,79 +40,74 @@ class PythonProgram:
|
|
40
40
|
loop_ends: dict[int, int] = {}
|
41
41
|
while i < len(self.uops):
|
42
42
|
uop, dtype, idp, arg = self.uops[i]
|
43
|
-
void_ops = {Ops.
|
44
|
-
if uop is Ops.DEFINE_ACC: idp = [idp[0]]
|
43
|
+
void_ops = {Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP, Ops.STORE}
|
45
44
|
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
46
45
|
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
47
46
|
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
48
|
-
if uop is Ops.STORE:
|
49
|
-
if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
|
50
|
-
if dtp[1].count > 1:
|
51
|
-
for j,val in enumerate(inp[1]):
|
52
|
-
for (m,o),v,g in zip(inp[0], val, inp[2]):
|
53
|
-
if g: _store(m, o+j, v)
|
54
|
-
else:
|
55
|
-
for (m,o),v,g in zip(*inp):
|
56
|
-
if g: _store(m, o, v)
|
57
|
-
i += 1
|
58
|
-
continue
|
59
47
|
if uop is Ops.ENDRANGE:
|
60
48
|
loop_ends[idp[0]] = i
|
61
49
|
i = idp[0]
|
62
50
|
continue
|
63
|
-
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.
|
51
|
+
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.SINK, Ops.NOOP):
|
64
52
|
# in the python emulator, the warp is always in sync
|
65
53
|
i += 1
|
66
54
|
continue
|
67
55
|
assert dtype is not None, f"{uop} is missing a dtype"
|
68
56
|
dl[i] = dtype
|
69
|
-
if uop
|
70
|
-
|
57
|
+
if uop is Ops.STORE:
|
58
|
+
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
|
59
|
+
for (m,o,g),v in zip(inp[0], val):
|
60
|
+
if g: _store(m, o+j, v)
|
61
|
+
i += 1
|
62
|
+
continue
|
63
|
+
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
64
|
+
assert isinstance(dtype, PtrDType), dtype
|
65
|
+
if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported")
|
71
66
|
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
|
72
|
-
|
73
|
-
|
67
|
+
if uop is Ops.DEFINE_REG:
|
68
|
+
# REGs are per thread
|
69
|
+
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)]
|
70
|
+
else:
|
71
|
+
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
|
72
|
+
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
74
73
|
elif uop is Ops.DEFINE_VAR:
|
75
74
|
ul[i] = [pvals.pop(0)] * warp_size
|
76
75
|
elif uop is Ops.SPECIAL:
|
77
76
|
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
|
78
77
|
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
|
79
78
|
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
|
80
|
-
elif uop is Ops.DEFINE_ACC:
|
81
|
-
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
82
79
|
elif uop is Ops.INDEX:
|
83
|
-
ret = []
|
80
|
+
ret:list = []
|
84
81
|
if isinstance(dtp[0], ImageDType):
|
85
82
|
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
86
83
|
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
|
87
84
|
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
|
88
85
|
else:
|
89
86
|
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
|
90
|
-
ul[i] = ret
|
87
|
+
ul[i] = [(m,o,g) for (m,o),g in zip(ret, inp[2] if len(inp) == 3 else [True]*len(ret))] # set the gate last
|
91
88
|
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
92
89
|
ul[i] = inp[0]
|
93
90
|
elif uop is Ops.RANGE:
|
94
|
-
if i not in ul: ul[i] = [
|
91
|
+
if i not in ul: ul[i] = [0] * warp_size
|
95
92
|
else:
|
96
93
|
for j in range(len(ul[i])):
|
97
94
|
ul[i][j] += 1
|
98
|
-
if ul[i][0] == inp[
|
95
|
+
if ul[i][0] == inp[0][0]:
|
99
96
|
del ul[i]
|
100
97
|
i = loop_ends[i] + 1
|
101
98
|
continue
|
102
99
|
elif uop is Ops.VECTORIZE: ul[i] = inp
|
103
|
-
elif uop
|
100
|
+
elif uop is Ops.BITCAST:
|
104
101
|
assert dtp[0].fmt and dtype.fmt
|
105
102
|
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
106
|
-
|
107
|
-
|
103
|
+
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
104
|
+
elif uop is Ops.CAST:
|
105
|
+
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
|
108
106
|
elif uop is Ops.LOAD:
|
109
107
|
if dtype.count > 1:
|
110
108
|
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
111
109
|
else:
|
112
110
|
ul[i] = load(inp)
|
113
|
-
elif uop is Ops.ASSIGN:
|
114
|
-
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
115
|
-
ul[i] = inp[0]
|
116
111
|
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
|
117
112
|
elif uop is Ops.WMMA:
|
118
113
|
# here are the models for the WMMA instruction on the different hardware
|
@@ -129,14 +124,27 @@ class PythonProgram:
|
|
129
124
|
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
130
125
|
return out
|
131
126
|
|
127
|
+
first_src_dtype = self.uops[idp[0]][1]
|
128
|
+
assert isinstance(first_src_dtype, DType) # mypy
|
129
|
+
dims, dtype_in, device, threads = arg[1], first_src_dtype.scalar(), arg[4], arg[5]
|
132
130
|
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
133
|
-
if
|
131
|
+
if device == "METAL":
|
134
132
|
# A (2 elements on 32 threads): row major
|
135
133
|
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
136
134
|
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
137
135
|
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
138
136
|
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
139
|
-
elif
|
137
|
+
elif device == "AMD" and threads == 64:
|
138
|
+
def a_elem(x, k, row, goff): return x[k%4][goff + (k//4)*16 + row]
|
139
|
+
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
140
|
+
def c_map(lane, elem): return (lane%16, (lane//16)*4 + elem)
|
141
|
+
ul[i] = wmma_helper(64, 16, 4, 4, 4, a_elem, b_elem, c_map)
|
142
|
+
elif device == "AMD" and len(inp[0]) == 8: # RDNA4
|
143
|
+
def a_elem(x, k, row, goff): return x[k - [0, 4, 4, 8][k//4]][goff + row + [0, 16, 0, 16][k//4]]
|
144
|
+
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff)
|
145
|
+
def c_map(lane, elem): return (lane%16, (lane//16)*8 + elem)
|
146
|
+
ul[i] = wmma_helper(32, 16, 8, 8, 8, a_elem, b_elem, c_map)
|
147
|
+
elif device == "AMD":
|
140
148
|
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
141
149
|
def a_elem(x, k, row, goff):
|
142
150
|
assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
|
@@ -145,27 +153,27 @@ class PythonProgram:
|
|
145
153
|
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
146
154
|
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
147
155
|
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
148
|
-
elif
|
156
|
+
elif device == "CUDA":
|
149
157
|
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
150
158
|
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
151
159
|
|
152
|
-
if
|
160
|
+
if dims == (8,16,16):
|
153
161
|
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
154
162
|
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
155
163
|
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
156
164
|
|
157
|
-
elif
|
165
|
+
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
158
166
|
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
159
167
|
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
160
168
|
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
161
169
|
|
162
|
-
elif
|
170
|
+
elif dims == (8,16,8) and dtype_in == dtypes.float:
|
163
171
|
def a_elem(x, k, row, goff): return x[(k//4)*2 + row//8][goff + k%4 + (row%8)*4]
|
164
172
|
def b_elem(x, col, k, goff): return x[k//4][goff + k%4 + col*4]
|
165
173
|
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
166
174
|
|
167
175
|
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
168
|
-
elif
|
176
|
+
elif device == "INTEL":
|
169
177
|
# A (16 elements on 8 threads)
|
170
178
|
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
171
179
|
# B (16 elements on 8 threads)
|
@@ -173,7 +181,7 @@ class PythonProgram:
|
|
173
181
|
# C, D (8 elements on 8 threads)
|
174
182
|
def c_map(lane, elem): return (lane, elem)
|
175
183
|
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
176
|
-
elif
|
184
|
+
elif device == "CPU":
|
177
185
|
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
178
186
|
def c_map(_, elem): return (elem%16, elem//16)
|
179
187
|
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
@@ -189,12 +197,14 @@ class PythonProgram:
|
|
189
197
|
class PythonRenderer(Renderer):
|
190
198
|
device = "PYTHON"
|
191
199
|
def __init__(self):
|
192
|
-
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL",
|
193
|
-
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD",
|
194
|
-
if getenv("
|
195
|
-
if getenv("
|
196
|
-
if getenv("
|
197
|
-
if getenv("
|
200
|
+
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
|
201
|
+
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
|
202
|
+
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna
|
203
|
+
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4
|
204
|
+
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
|
205
|
+
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
|
206
|
+
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
|
207
|
+
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx
|
198
208
|
|
199
209
|
def render(self, uops:list[UOp]) -> str:
|
200
210
|
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
|
@@ -203,10 +213,10 @@ class PythonRenderer(Renderer):
|
|
203
213
|
class PythonCompiler(Compiler):
|
204
214
|
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
205
215
|
|
206
|
-
class PythonAllocator(Allocator):
|
216
|
+
class PythonAllocator(Allocator['PythonDevice']):
|
207
217
|
def _alloc(self, size, options): return memoryview(bytearray(size))
|
208
218
|
def _copyin(self, dest, src:memoryview): dest[:] = src
|
209
219
|
def _copyout(self, dest:memoryview, src): dest[:] = src
|
210
220
|
|
211
221
|
class PythonDevice(Compiled):
|
212
|
-
def __init__(self, device:str): super().__init__(device, PythonAllocator(), PythonRenderer(), PythonCompiler(), PythonProgram)
|
222
|
+
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)
|
tinygrad/runtime/ops_qcom.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, ctypes, functools, mmap, struct, array, math, sys
|
2
|
+
import os, ctypes, functools, mmap, struct, array, math, sys, weakref
|
3
3
|
assert sys.platform != 'win32'
|
4
4
|
from types import SimpleNamespace
|
5
5
|
from typing import Any, cast
|
6
6
|
from tinygrad.device import BufferSpec
|
7
7
|
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
|
8
|
-
from tinygrad.runtime.support.hcq import
|
8
|
+
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface
|
9
9
|
from tinygrad.runtime.autogen import kgsl, adreno
|
10
10
|
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
|
11
11
|
from tinygrad.renderer.cstyle import QCOMRenderer
|
@@ -37,17 +37,12 @@ class QCOMCompiler(CLCompiler):
|
|
37
37
|
def disassemble(self, lib:bytes): fromimport('extra.disassemblers.adreno', 'disasm')(lib)
|
38
38
|
|
39
39
|
class QCOMSignal(HCQSignal):
|
40
|
-
def __init__(self,
|
41
|
-
super().__init__(QCOMDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=19.2)
|
42
|
-
|
43
|
-
def __del__(self):
|
44
|
-
if isinstance(self.base_addr, int): QCOMDevice.signals_pool.append(self.base_addr)
|
40
|
+
def __init__(self, *args, **kwargs): super().__init__(*args, **{**kwargs, 'timestamp_divider': 19.2})
|
45
41
|
|
46
42
|
def _sleep(self, time_spent_waiting_ms:int):
|
47
|
-
# Sleep only for
|
48
|
-
if self.
|
49
|
-
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.
|
50
|
-
timestamp=self.timeline_for_device.last_cmd, timeout=0xffffffff)
|
43
|
+
# Sleep only for timeline signals. Do it immediately to free cpu.
|
44
|
+
if self.is_timeline and self.owner is not None:
|
45
|
+
kgsl.IOCTL_KGSL_DEVICE_WAITTIMESTAMP_CTXTID(self.owner.fd, context_id=self.owner.ctx, timestamp=self.owner.last_cmd, timeout=0xffffffff)
|
51
46
|
|
52
47
|
class QCOMComputeQueue(HWQueue):
|
53
48
|
def __del__(self):
|
@@ -135,7 +130,7 @@ class QCOMComputeQueue(HWQueue):
|
|
135
130
|
|
136
131
|
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
137
132
|
state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4),
|
138
|
-
*data64_le(args_state.
|
133
|
+
*data64_le(args_state.buf.va_addr))
|
139
134
|
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
|
140
135
|
state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128),
|
141
136
|
*data64_le(prg.lib_gpu.va_addr))
|
@@ -148,21 +143,21 @@ class QCOMComputeQueue(HWQueue):
|
|
148
143
|
if args_state.prg.samp_cnt > 0:
|
149
144
|
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
|
150
145
|
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
|
151
|
-
*data64_le(args_state.
|
152
|
-
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.
|
146
|
+
*data64_le(args_state.buf.va_addr + args_state.prg.samp_off))
|
147
|
+
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.buf.va_addr + args_state.prg.samp_off))
|
153
148
|
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev.border_color_buf.va_addr))
|
154
149
|
|
155
150
|
if args_state.prg.tex_cnt > 0:
|
156
151
|
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
157
152
|
state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)),
|
158
|
-
*data64_le(args_state.
|
159
|
-
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.
|
153
|
+
*data64_le(args_state.buf.va_addr + args_state.prg.tex_off))
|
154
|
+
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.buf.va_addr + args_state.prg.tex_off))
|
160
155
|
|
161
156
|
if args_state.prg.ibo_cnt > 0:
|
162
157
|
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT,
|
163
158
|
state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt),
|
164
|
-
*data64_le(args_state.
|
165
|
-
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.
|
159
|
+
*data64_le(args_state.buf.va_addr + args_state.prg.ibo_off))
|
160
|
+
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.buf.va_addr + args_state.prg.ibo_off))
|
166
161
|
|
167
162
|
self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
|
168
163
|
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
|
@@ -171,24 +166,24 @@ class QCOMComputeQueue(HWQueue):
|
|
171
166
|
return self
|
172
167
|
|
173
168
|
class QCOMArgsState(HCQArgsState):
|
174
|
-
def __init__(self,
|
175
|
-
super().__init__(
|
169
|
+
def __init__(self, buf:HCQBuffer, prg:QCOMProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()):
|
170
|
+
super().__init__(buf, prg, bufs, vals=vals)
|
176
171
|
|
177
172
|
if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}')
|
178
173
|
|
179
|
-
self.buf_info, self.args_info
|
174
|
+
self.buf_info, self.args_info = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):]
|
180
175
|
|
181
|
-
ctypes.memset(self.
|
182
|
-
for cnst_val,
|
176
|
+
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
|
177
|
+
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
|
183
178
|
|
184
|
-
if prg.samp_cnt > 0: to_mv(self.
|
179
|
+
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
185
180
|
for i, b in enumerate(bufs):
|
186
181
|
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
|
187
182
|
obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo
|
188
|
-
to_mv(self.
|
189
|
-
self.
|
183
|
+
to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
|
184
|
+
self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16))
|
190
185
|
|
191
|
-
for i, v in enumerate(vals): self.
|
186
|
+
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=self.args_info[i].offset)
|
192
187
|
|
193
188
|
class QCOMProgram(HCQProgram):
|
194
189
|
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
@@ -196,7 +191,7 @@ class QCOMProgram(HCQProgram):
|
|
196
191
|
self.name, self.lib = name, lib
|
197
192
|
self._parse_lib()
|
198
193
|
|
199
|
-
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size,
|
194
|
+
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, buf_spec:=BufferSpec(cpu_access=True, nolru=True))
|
200
195
|
to_mv(cast(int, self.lib_gpu.va_addr), self.image_size)[:] = self.image
|
201
196
|
|
202
197
|
self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
|
@@ -208,6 +203,7 @@ class QCOMProgram(HCQProgram):
|
|
208
203
|
|
209
204
|
kernargs_alloc_size = round_up(2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10, 0x100)
|
210
205
|
super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
|
206
|
+
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
|
211
207
|
|
212
208
|
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
213
209
|
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
|
@@ -265,9 +261,6 @@ class QCOMProgram(HCQProgram):
|
|
265
261
|
reg_desc_off = _read_lib(0x34)
|
266
262
|
self.fregs, self.hregs = _read_lib(reg_desc_off + 0x14), _read_lib(reg_desc_off + 0x18)
|
267
263
|
|
268
|
-
def __del__(self):
|
269
|
-
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
|
270
|
-
|
271
264
|
class QCOMTextureInfo:
|
272
265
|
def __init__(self, pitch:int, real_stride:int, desc:list[int], ibo:list[int]):
|
273
266
|
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
|
@@ -285,7 +278,7 @@ class QCOMAllocator(HCQAllocatorBase):
|
|
285
278
|
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
286
279
|
size = pitch * imgh
|
287
280
|
|
288
|
-
buf = HCQBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
281
|
+
buf = HCQBuffer(options.external_ptr, size, owner=self.dev) if options.external_ptr else self.dev._gpu_alloc(size)
|
289
282
|
|
290
283
|
if options.image is not None:
|
291
284
|
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
|
@@ -320,16 +313,12 @@ class QCOMAllocator(HCQAllocatorBase):
|
|
320
313
|
self.dev._gpu_free(opaque)
|
321
314
|
|
322
315
|
class QCOMDevice(HCQCompiled):
|
323
|
-
signals_page: Any = None
|
324
|
-
signals_pool: list[int] = []
|
325
316
|
gpu_id: int = 0
|
326
317
|
dummy_addr: int = 0
|
327
318
|
|
328
319
|
def __init__(self, device:str=""):
|
329
|
-
self.fd =
|
320
|
+
self.fd = FileIOInterface('/dev/kgsl-3d0', os.O_RDWR)
|
330
321
|
QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr)
|
331
|
-
QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
|
332
|
-
QCOMDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, self.signals_page.size, 16)]
|
333
322
|
|
334
323
|
flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \
|
335
324
|
| kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)
|
@@ -363,11 +352,11 @@ class QCOMDevice(HCQCompiled):
|
|
363
352
|
va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000)
|
364
353
|
|
365
354
|
if fill_zeroes: ctypes.memset(va_addr, 0, size)
|
366
|
-
return HCQBuffer(va_addr=va_addr, size=size, meta=alloc)
|
355
|
+
return HCQBuffer(va_addr=va_addr, size=size, meta=alloc, view=MMIOInterface(va_addr, size, fmt='B'), owner=self)
|
367
356
|
|
368
357
|
def _gpu_free(self, mem:HCQBuffer):
|
369
358
|
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta.id)
|
370
|
-
|
359
|
+
FileIOInterface.munmap(mem.va_addr, mem.meta.mmapsize)
|
371
360
|
|
372
361
|
def _ensure_stack_size(self, sz):
|
373
362
|
if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)
|