tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1
|
-
from typing import
|
2
|
-
import math, struct
|
1
|
+
from typing import cast
|
2
|
+
import math, struct, sys
|
3
3
|
from tinygrad.renderer import Renderer
|
4
|
+
from tinygrad.renderer.cstyle import ClangRenderer
|
4
5
|
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
5
6
|
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
7
|
+
from tinygrad.helpers import prod, AMX
|
6
8
|
|
7
9
|
def ldt(dt:DType):
|
10
|
+
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
8
11
|
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
9
12
|
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
10
13
|
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
@@ -20,7 +23,7 @@ def lcast(input_type:DType, output_type:DType):
|
|
20
23
|
if dtypes.is_float(input_type):
|
21
24
|
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
|
22
25
|
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
|
23
|
-
if dtypes.is_unsigned(input_type) or
|
26
|
+
if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type):
|
24
27
|
if dtypes.is_float(output_type): return 'uitofp'
|
25
28
|
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
|
26
29
|
if dtypes.is_int(input_type):
|
@@ -28,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
|
|
28
31
|
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
29
32
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
30
33
|
|
34
|
+
# https://github.com/corsix/amx
|
35
|
+
def render_wmma(ctx, wmma: UOp) -> str:
|
36
|
+
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
|
37
|
+
|
38
|
+
return "\n".join([
|
39
|
+
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
|
40
|
+
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
|
41
|
+
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
|
42
|
+
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
|
43
|
+
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
|
44
|
+
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
|
45
|
+
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
46
|
+
|
31
47
|
# llvm ops, lop[<dtype>][<op>]
|
32
48
|
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
33
49
|
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
@@ -36,7 +52,7 @@ flags = " nsz arcp contract afn"
|
|
36
52
|
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
|
37
53
|
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
38
54
|
|
39
|
-
|
55
|
+
base_rewrite = PatternMatcher([
|
40
56
|
# memory load/store
|
41
57
|
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
42
58
|
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
@@ -49,73 +65,98 @@ llvm_rewrite = PatternMatcher([
|
|
49
65
|
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
50
66
|
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
51
67
|
|
68
|
+
# GEP/VECTORIZE/CAST for float4 support
|
69
|
+
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
70
|
+
(UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y:
|
71
|
+
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
|
72
|
+
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
|
73
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
74
|
+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
75
|
+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
76
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
77
|
+
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
|
78
|
+
|
52
79
|
# unary/binary/ternary ops
|
53
|
-
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
54
|
-
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
55
80
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
56
81
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
57
|
-
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
82
|
+
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
83
|
+
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
58
84
|
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
59
85
|
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
|
60
86
|
|
61
87
|
# range
|
62
88
|
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
63
|
-
f" br label %loop_entry_{x.arg
|
64
|
-
f" br label %loop_body_{x.arg
|
65
|
-
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg
|
89
|
+
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
|
90
|
+
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
|
91
|
+
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
|
66
92
|
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
67
|
-
f" br label %loop_latch_{x.src[0].arg
|
93
|
+
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
|
68
94
|
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
69
|
-
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg
|
95
|
+
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
|
70
96
|
|
71
97
|
# if
|
72
98
|
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
73
99
|
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
100
|
+
|
101
|
+
# wmma
|
102
|
+
(UPat(Ops.WMMA, name="wmma"), render_wmma),
|
74
103
|
])
|
75
104
|
|
105
|
+
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
106
|
+
u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
|
107
|
+
return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
|
108
|
+
|
76
109
|
class LLVMRenderer(Renderer):
|
77
110
|
device = "LLVM"
|
78
|
-
|
111
|
+
abi = 'win64cc' if sys.platform == 'win32' else None
|
112
|
+
supports_float4 = True
|
79
113
|
has_local = False
|
80
114
|
has_shared = False
|
81
115
|
global_max = None
|
116
|
+
string_rewrite = base_rewrite
|
117
|
+
if AMX: tensor_cores = ClangRenderer.amx_tc
|
82
118
|
|
83
119
|
extra_matcher = PatternMatcher([
|
84
120
|
# rewrite RECIP with FDIV
|
85
121
|
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
86
122
|
# rewrite cast to bool to CMPNE 0
|
87
123
|
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
88
|
-
# *** also in cstyle ***
|
89
|
-
# gate any stores that aren't gated with ifs
|
90
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
91
|
-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
92
124
|
# rewrite MAX to CMPLT + WHERE
|
93
125
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
126
|
+
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
|
127
|
+
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
|
94
128
|
])
|
95
129
|
|
96
|
-
def render(self,
|
97
|
-
r:
|
98
|
-
args:
|
99
|
-
kernel:
|
100
|
-
end_lines:
|
130
|
+
def render(self, uops: list[UOp]) -> str:
|
131
|
+
r: dict[UOp, str] = {}
|
132
|
+
args: list[str] = []
|
133
|
+
kernel: list[str] = []
|
134
|
+
end_lines: dict[str, None] = {}
|
101
135
|
vc = -1
|
102
136
|
|
103
|
-
|
104
|
-
acc_to_assign: Dict[UOp, UOp] = {}
|
137
|
+
acc_to_assign: dict[UOp, UOp] = {}
|
105
138
|
for u in uops:
|
106
|
-
if u.op is Ops.ASSIGN:
|
139
|
+
if u.op is Ops.ASSIGN: # prealloc all assigns
|
107
140
|
vc += 1
|
108
141
|
r[u] = r[u.src[1]] = f"%assign{vc}"
|
109
142
|
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
110
143
|
acc_to_assign[u.src[0]] = u.src[1]
|
144
|
+
if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
|
145
|
+
vc += 1
|
146
|
+
r[u] = f"%wmma{vc}"
|
147
|
+
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
|
148
|
+
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
|
149
|
+
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
|
111
150
|
|
151
|
+
name = "test"
|
112
152
|
for u in uops:
|
113
|
-
|
114
|
-
|
115
|
-
|
153
|
+
if u.op is Ops.NAME:
|
154
|
+
name = u.arg
|
155
|
+
continue
|
116
156
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
117
157
|
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
118
|
-
|
158
|
+
# NOTE: MallocAllocator promises 0x20 alignment
|
159
|
+
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
119
160
|
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
120
161
|
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
121
162
|
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
@@ -127,16 +168,24 @@ class LLVMRenderer(Renderer):
|
|
127
168
|
r[u] = f"%v{vc}"
|
128
169
|
|
129
170
|
# do the rendering of the llvm ir code
|
130
|
-
if (l:=
|
171
|
+
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
|
172
|
+
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
131
173
|
kernel.append(cast(str, l))
|
132
174
|
|
133
175
|
# generate the phi nodes for the assigns
|
134
176
|
if u.op is Ops.RANGE:
|
135
177
|
for x in acc_to_assign:
|
136
|
-
if u in x.src: # if this range is
|
178
|
+
if u in x.src: # if this range is relevant for this acc
|
137
179
|
vc += 1
|
138
|
-
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg
|
180
|
+
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
|
139
181
|
r[x] = f"%acc{vc}"
|
140
182
|
|
141
|
-
# output the function
|
142
|
-
return f
|
183
|
+
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
|
184
|
+
return f'''\
|
185
|
+
define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
|
186
|
+
{chr(10).join(kernel)}
|
187
|
+
ret void
|
188
|
+
}}
|
189
|
+
{chr(10).join(end_lines.keys())}
|
190
|
+
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
|
191
|
+
'''
|
tinygrad/renderer/ptx.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast, Callable
|
2
2
|
import struct
|
3
3
|
from collections import defaultdict
|
4
4
|
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
5
5
|
from tinygrad.dtype import dtypes, DType, PtrDType
|
6
6
|
from tinygrad.renderer import Renderer
|
7
7
|
from tinygrad.renderer.cstyle import CUDARenderer
|
8
|
-
from tinygrad.helpers import
|
8
|
+
from tinygrad.helpers import flatten, get_single_element
|
9
9
|
|
10
10
|
def render_val(x, dtype):
|
11
11
|
if dtypes.is_float(dtype):
|
@@ -14,30 +14,30 @@ def render_val(x, dtype):
|
|
14
14
|
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
15
15
|
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
16
16
|
|
17
|
-
asm_for_op:
|
17
|
+
asm_for_op: dict[Ops, Callable] = {
|
18
18
|
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
19
19
|
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
20
20
|
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
21
21
|
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
22
|
-
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if
|
23
|
-
Ops.MUL: lambda d,a,b,dt,name:
|
24
|
-
Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if
|
25
|
-
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if
|
26
|
-
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if
|
27
|
-
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
28
|
-
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
|
22
|
+
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
|
23
|
+
Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
24
|
+
Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if dt == dtypes.bool else f"xor.b{name[1:]} {d}, {a}, {b};",
|
25
|
+
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};",
|
26
|
+
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
|
27
|
+
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
28
|
+
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
|
29
29
|
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
30
30
|
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
31
|
-
Ops.WHERE: lambda d,a,b,c,dt,name:
|
32
|
-
f"
|
31
|
+
Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
|
32
|
+
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
33
33
|
}
|
34
34
|
|
35
|
-
supports_half
|
36
|
-
doesnt_support_half:
|
35
|
+
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
|
36
|
+
doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
37
37
|
ptx_matcher = PatternMatcher([
|
38
38
|
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
39
39
|
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
40
|
-
(UPat.var('x', dtype=dtypes.bool)
|
40
|
+
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
41
41
|
# upcast to float32 all the ops that don't support half
|
42
42
|
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
43
43
|
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
|
@@ -54,46 +54,46 @@ ptx_matcher = PatternMatcher([
|
|
54
54
|
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
55
55
|
])
|
56
56
|
|
57
|
-
def mem_type(x: UOp): return 'shared' if
|
57
|
+
def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
|
58
58
|
|
59
|
-
def
|
60
|
-
gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
|
61
|
-
return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
|
62
|
-
if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
|
63
|
-
|
64
|
-
def render_wmma(ctx: "PTXRenderer", x: UOp):
|
59
|
+
def render_wmma(ctx: "PTXRenderer", wmma: UOp):
|
65
60
|
assert ctx.wmma_r, "registry values for wmma must be populated"
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.
|
75
|
-
|
76
|
-
|
61
|
+
(N, M, K), dtype_in, dtype_out = wmma.arg[1], wmma.arg[2], wmma.arg[3]
|
62
|
+
|
63
|
+
for src, regs in zip(wmma.src, ctx.wmma_r):
|
64
|
+
for i, reg in enumerate(regs): # pack input and acc registers
|
65
|
+
if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
|
66
|
+
else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"
|
67
|
+
|
68
|
+
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
|
69
|
+
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
|
70
|
+
f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'
|
71
|
+
|
72
|
+
for i, reg in enumerate(ctx.wmma_r[2]): # unpack acc registers
|
73
|
+
if (elems_per_reg := 4 // dtype_out.itemsize) == 1: yield f"mov.b32 {ctx.r[wmma][i]}, {reg};"
|
74
|
+
else: yield f"mov.b32 {{{', '.join(ctx.r[wmma][i * elems_per_reg : (i+1) * elems_per_reg])}}}, {reg};"
|
77
75
|
|
78
76
|
def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
|
79
77
|
(a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
|
80
78
|
|
81
79
|
string_rewrite = PatternMatcher([
|
82
|
-
(UPat(
|
83
|
-
(UPat(
|
84
|
-
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")
|
80
|
+
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
|
81
|
+
(UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
|
82
|
+
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \
|
83
|
+
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
84
|
+
f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
85
85
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
|
86
86
|
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
|
87
|
-
(UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
|
88
|
-
|
87
|
+
(UPat((Ops.CMPLT, Ops.CMPNE), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
|
88
|
+
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
|
89
89
|
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
90
|
-
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
91
|
-
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
|
90
|
+
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
91
|
+
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"),)),
|
92
92
|
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
|
93
|
-
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
|
94
|
-
lambda ctx, x: f"setp.ne.b{ctx.types[
|
95
|
-
(UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
|
96
|
-
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[
|
93
|
+
(UPat(Ops.CAST, name="x", dtype=dtypes.bool, src=(UPat.var("a"),)),
|
94
|
+
lambda ctx, x, a: f"setp.ne.b{ctx.types[a.dtype][1:]} {ctx.r[x]}, {ctx.r[a]}, {render_val(0, a.dtype)};"),
|
95
|
+
(UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
|
96
|
+
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
|
97
97
|
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
|
98
98
|
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
99
99
|
[f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
@@ -101,20 +101,11 @@ string_rewrite = PatternMatcher([
|
|
101
101
|
f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
102
102
|
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
103
103
|
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
|
104
|
-
lambda ctx, x, loc: f"
|
104
|
+
lambda ctx, x, loc: f"ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
105
105
|
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
106
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(
|
107
|
-
lambda ctx, x, pred: flatten([
|
108
|
-
[f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
|
109
|
-
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
110
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
|
111
|
-
lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
|
112
|
-
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
113
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
|
114
|
-
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
|
115
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
|
106
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred", dtype=dtypes.bool),), allow_any_len=True), lambda ctx, x, pred: [
|
116
107
|
f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
|
117
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(
|
108
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred"),), allow_any_len=True),
|
118
109
|
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
|
119
110
|
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
120
111
|
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
|
@@ -124,7 +115,7 @@ string_rewrite = PatternMatcher([
|
|
124
115
|
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
|
125
116
|
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
126
117
|
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
127
|
-
lambda ctx, x: [f".shared .align 4 .b8 {x.arg
|
118
|
+
lambda ctx, x: [f".shared .align 4 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
|
128
119
|
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
129
120
|
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
130
121
|
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
@@ -136,11 +127,12 @@ class PTXRenderer(Renderer):
|
|
136
127
|
device = "CUDA"
|
137
128
|
suffix = "PTX"
|
138
129
|
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
139
|
-
|
130
|
+
tc_sm80 = [tc for tc in CUDARenderer.tc_sm80 if tc.dtype_in in [dtypes.half, dtypes.float]]
|
140
131
|
code_for_op = asm_for_op
|
141
132
|
extra_matcher = ptx_matcher
|
142
133
|
def __init__(self, arch:str, device="CUDA"):
|
143
|
-
self.device, self.
|
134
|
+
self.device, self.arch = device, arch
|
135
|
+
self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else []
|
144
136
|
def __reduce__(self): return self.__class__, (self.arch, self.device)
|
145
137
|
|
146
138
|
# language options
|
@@ -149,75 +141,67 @@ class PTXRenderer(Renderer):
|
|
149
141
|
.address_size 64
|
150
142
|
.visible .entry"""
|
151
143
|
barrier = "bar.sync\t0;"
|
152
|
-
supports_half = supports_half
|
153
144
|
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
154
|
-
types:
|
145
|
+
types: dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
155
146
|
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
156
147
|
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
157
148
|
|
158
|
-
mem_types:
|
159
|
-
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
|
149
|
+
mem_types: dict[DType, str] = {**types, dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}
|
160
150
|
|
161
151
|
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
162
|
-
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
163
152
|
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
"\n}")
|
153
|
+
kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]))
|
154
|
+
params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
|
155
|
+
return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
|
168
156
|
|
169
|
-
def render(self,
|
170
|
-
kernel:
|
157
|
+
def render(self, uops:list[UOp]) -> str:
|
158
|
+
kernel:list[str] = []
|
171
159
|
bufs = []
|
172
160
|
|
173
|
-
c:
|
174
|
-
r:
|
161
|
+
c: defaultdict[str, int] = defaultdict(int)
|
162
|
+
r: dict[UOp, list[str]|str] = {}
|
175
163
|
self.r = r
|
176
164
|
self.uops = uops
|
177
165
|
|
178
|
-
def ssa(prefix:str, u:
|
166
|
+
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
|
179
167
|
nonlocal c, r
|
180
168
|
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
|
181
169
|
c[prefix] += 1
|
182
170
|
return f"%{prefix}{c[prefix]-1}"
|
183
171
|
|
172
|
+
name = "test"
|
184
173
|
for u in uops:
|
174
|
+
if u.op is Ops.NAME:
|
175
|
+
name = u.arg
|
176
|
+
continue
|
185
177
|
if u.op is Ops.VECTORIZE:
|
186
178
|
r[u] = [cast(str,r[x]) for x in u.src]
|
187
179
|
continue
|
188
180
|
if u.op is Ops.GEP:
|
189
|
-
|
190
|
-
|
181
|
+
r[u] = r[u.src[0]][get_single_element(u.arg)]
|
182
|
+
continue
|
183
|
+
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
184
|
+
r[u] = r[u.src[0]]
|
191
185
|
continue
|
192
|
-
if u.op
|
193
|
-
|
194
|
-
r[u] = r[u.src[0]]
|
195
|
-
continue
|
196
|
-
r[u] = ssa('cast', u, self.types[u.dtype])
|
197
|
-
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
|
198
|
-
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
|
199
|
-
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
|
200
|
-
elif u.op is Ops.DEFINE_ACC:
|
201
|
-
if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
|
202
|
-
r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
|
203
|
-
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
|
204
|
-
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
205
|
-
elif u.op is Ops.DEFINE_VAR:
|
206
|
-
bufs.append((u.arg[0], u.dtype))
|
207
|
-
r[u] = ssa("dat", u, self.types[u.dtype])
|
208
|
-
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
|
186
|
+
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
187
|
+
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
|
209
188
|
elif u.op is Ops.LOAD:
|
210
189
|
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
211
190
|
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
212
|
-
elif u.op is Ops.
|
213
|
-
elif u.op is Ops.DEFINE_GLOBAL:
|
214
|
-
bufs.append((f"data{u.arg}", u.dtype))
|
215
|
-
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
191
|
+
elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype))
|
216
192
|
elif u.op is Ops.WMMA:
|
217
|
-
|
193
|
+
# registers for packing/unpacking input and acc
|
194
|
+
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.arg[2].itemsize)],
|
195
|
+
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.arg[2].itemsize)],
|
196
|
+
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.arg[3].itemsize)]]
|
218
197
|
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
219
|
-
|
220
|
-
|
198
|
+
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
199
|
+
Ops.DEFINE_ACC: ("acc", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
|
200
|
+
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
201
|
+
if prefix: r[u] = ssa(prefix, u, dtype)
|
202
|
+
|
203
|
+
if (l:=cast(str|list[str], string_rewrite.rewrite(u, ctx=self))) is None:
|
204
|
+
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
221
205
|
kernel.extend([l] if isinstance(l, str) else l)
|
222
206
|
|
223
207
|
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from tinygrad.dtype import DType, PtrDType, dtypes
|
2
|
+
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat
|
3
|
+
from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
|
4
|
+
from tinygrad.helpers import strip_parens
|
5
|
+
import math
|
6
|
+
|
7
|
+
def sign_extend(val:UOp, sext_am:int):
|
8
|
+
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
|
9
|
+
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
|
10
|
+
|
11
|
+
# store for char: buf[idx/4] <- (var << (idx%4)*8))
|
12
|
+
def packed_store(bidx:UOp, var:UOp):
|
13
|
+
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
|
14
|
+
new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
|
15
|
+
mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
16
|
+
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=dtypes.uint32)
|
17
|
+
return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(dtypes.uint32)))
|
18
|
+
|
19
|
+
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
20
|
+
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
21
|
+
div_idx = bidx.src[1]//(4//dtype.itemsize)
|
22
|
+
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
23
|
+
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=dtypes.uint32, arg=root.arg)
|
24
|
+
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
|
25
|
+
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
26
|
+
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
27
|
+
|
28
|
+
def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
|
29
|
+
|
30
|
+
wgsl_matcher = PatternMatcher([
|
31
|
+
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
32
|
+
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
33
|
+
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if is_packed(l.dtype) else None),
|
34
|
+
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
|
35
|
+
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
|
36
|
+
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
|
37
|
+
# TODO: why is this needed, and only for this MUL order
|
38
|
+
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
|
39
|
+
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
|
40
|
+
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
|
41
|
+
]) + extra_pm
|
42
|
+
|
43
|
+
class WGSLRenderer(CStyleLanguage):
|
44
|
+
device = "WEBGPU"
|
45
|
+
global_max = (65535, 65535, 65535)
|
46
|
+
local_max = (256, 256, 64)
|
47
|
+
code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[int(x)]})", "l": lambda x: f"i32(lindex.{'xyz'[int(x)]})"}
|
48
|
+
extra_matcher = wgsl_matcher
|
49
|
+
supports_float4 = False
|
50
|
+
barrier = "workgroupBarrier();"
|
51
|
+
code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"}
|
52
|
+
nan = "nan()"
|
53
|
+
type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
|
54
|
+
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
|
55
|
+
|
56
|
+
string_rewrite = PatternMatcher([
|
57
|
+
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"),
|
58
|
+
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
|
59
|
+
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
60
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
|
61
|
+
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]" \
|
62
|
+
if x.src[0].dtype in [dtypes.short, dtypes.ushort, dtypes.uint32] else None),
|
63
|
+
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
64
|
+
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
|
65
|
+
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
66
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
67
|
+
(UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
|
68
|
+
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
|
69
|
+
(UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
70
|
+
(UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
71
|
+
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
72
|
+
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
73
|
+
else f"{ctx[b]} = {ctx[v]};"),
|
74
|
+
# fix nan check: 'a != a -> is_nan()'
|
75
|
+
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
|
76
|
+
]) + base_rewrite
|
77
|
+
|
78
|
+
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
79
|
+
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
80
|
+
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
81
|
+
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
82
|
+
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
83
|
+
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
84
|
+
if not local_size: local_size = [1]
|
85
|
+
bind_it = iter(range(len(bufs)))
|
86
|
+
external_local_bufs = [line.lstrip() for line in kernel if "var<workgroup>" in line]
|
87
|
+
kernel[:] = [line for line in kernel if "var<workgroup>" not in line]
|
88
|
+
prg = "enable f16;\n" if any(uop.dtype.base == dtypes.half for uop in uops) else ""
|
89
|
+
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
90
|
+
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
91
|
+
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
92
|
+
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
93
|
+
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
|
94
|
+
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
95
|
+
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|