tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1
1
|
from typing import cast
|
2
2
|
import math, struct, sys
|
3
|
+
from tinygrad.codegen.opt import tc
|
3
4
|
from tinygrad.renderer import Renderer
|
4
|
-
from tinygrad.renderer.cstyle import
|
5
|
-
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
5
|
+
from tinygrad.renderer.cstyle import AMDRenderer
|
6
|
+
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
6
7
|
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
7
8
|
from tinygrad.helpers import prod, AMX
|
8
9
|
|
9
10
|
def ldt(dt:DType):
|
10
11
|
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
11
12
|
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
12
|
-
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
13
|
+
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
13
14
|
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
14
|
-
dtypes.float16: "half", dtypes.
|
15
|
+
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
|
15
16
|
|
16
17
|
def lconst(x, dtype:DType):
|
17
18
|
if dtype in dtypes.floats:
|
@@ -32,7 +33,7 @@ def lcast(input_type:DType, output_type:DType):
|
|
32
33
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
33
34
|
|
34
35
|
# https://github.com/corsix/amx
|
35
|
-
def
|
36
|
+
def render_wmma_amx(ctx, wmma: UOp) -> str:
|
36
37
|
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
|
37
38
|
|
38
39
|
return "\n".join([
|
@@ -44,25 +45,40 @@ def render_wmma(ctx, wmma: UOp) -> str:
|
|
44
45
|
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
|
45
46
|
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
46
47
|
|
48
|
+
def render_wmma_amd(ctx, wmma: UOp, arch: str) -> str:
|
49
|
+
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"}
|
50
|
+
# https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
|
51
|
+
if arch.split(":")[0] in {"gfx942", "gfx950"}:
|
52
|
+
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
53
|
+
f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
54
|
+
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
55
|
+
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
56
|
+
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
57
|
+
f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
|
58
|
+
if wmma.dtype.scalar() != dtypes.float else ")")
|
59
|
+
|
47
60
|
# llvm ops, lop[<dtype>][<op>]
|
48
61
|
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
49
|
-
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor",
|
50
|
-
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
|
62
|
+
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.CMPEQ: "icmp eq", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor",}
|
63
|
+
signed_lop = {**unsigned_lop, Ops.ADD: "add nsw", Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
|
51
64
|
flags = " nsz arcp contract afn"
|
52
|
-
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult",
|
65
|
+
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult",
|
66
|
+
Ops.CMPNE: f"fcmp{flags} une", Ops.CMPEQ: f"fcmp{flags} oeq", Ops.FDIV: "fdiv"+flags}
|
53
67
|
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
54
68
|
|
55
69
|
base_rewrite = PatternMatcher([
|
56
70
|
# memory load/store
|
57
71
|
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
58
72
|
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
59
|
-
(UPat(Ops.LOAD, src=(UPat.
|
73
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), name="x"),
|
74
|
+
lambda ctx,x,idx,alt,mask:
|
60
75
|
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
61
76
|
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
62
77
|
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
63
78
|
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
64
79
|
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
65
|
-
(UPat(Ops.LOAD, src=(UPat.var('idx'),),
|
80
|
+
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
|
81
|
+
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
66
82
|
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
67
83
|
|
68
84
|
# GEP/VECTORIZE/CAST for float4 support
|
@@ -73,12 +89,11 @@ base_rewrite = PatternMatcher([
|
|
73
89
|
(UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
74
90
|
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
75
91
|
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
76
|
-
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
77
|
-
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
|
78
|
-
|
79
92
|
# unary/binary/ternary ops
|
80
93
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
81
94
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
95
|
+
(UPat(Ops.TRUNC, name="x"),
|
96
|
+
lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
82
97
|
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
83
98
|
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
84
99
|
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
@@ -88,33 +103,27 @@ base_rewrite = PatternMatcher([
|
|
88
103
|
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
89
104
|
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
|
90
105
|
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
|
91
|
-
f" {ctx[x]} = phi {ldt(x.dtype)} [
|
106
|
+
f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{x.arg} ], [ {ctx[x]}phi, %loop_latch_{x.arg} ]"),
|
92
107
|
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
93
108
|
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
|
94
|
-
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[
|
109
|
+
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
|
95
110
|
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
|
96
111
|
|
97
112
|
# if
|
98
113
|
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
99
114
|
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
100
115
|
|
101
|
-
|
102
|
-
(UPat(Ops.WMMA, name="wmma"), render_wmma),
|
116
|
+
(UPat(Ops.BARRIER), lambda ctx: "")
|
103
117
|
])
|
104
118
|
|
105
|
-
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
106
|
-
u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
|
107
|
-
return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
|
108
|
-
|
109
119
|
class LLVMRenderer(Renderer):
|
110
120
|
device = "LLVM"
|
111
121
|
abi = 'win64cc' if sys.platform == 'win32' else None
|
112
122
|
supports_float4 = True
|
113
123
|
has_local = False
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
if AMX: tensor_cores = ClangRenderer.amx_tc
|
124
|
+
global_max: tuple[int, ...] | None = None
|
125
|
+
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
|
126
|
+
if AMX: tensor_cores = tc.amx
|
118
127
|
|
119
128
|
extra_matcher = PatternMatcher([
|
120
129
|
# rewrite RECIP with FDIV
|
@@ -123,25 +132,30 @@ class LLVMRenderer(Renderer):
|
|
123
132
|
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
124
133
|
# rewrite MAX to CMPLT + WHERE
|
125
134
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
126
|
-
#
|
127
|
-
(UPat(Ops.
|
135
|
+
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
|
136
|
+
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
137
|
+
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
138
|
+
# copied from cstyle.py, add float intermediate casting
|
139
|
+
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
140
|
+
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
128
141
|
])
|
129
142
|
|
130
|
-
def render(self, uops: list[UOp]) -> str:
|
143
|
+
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
|
144
|
+
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
|
145
|
+
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
146
|
+
# NOTE: CPUAllocator promises 0x20 alignment
|
147
|
+
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
|
148
|
+
sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
|
149
|
+
return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
150
|
+
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
|
131
151
|
r: dict[UOp, str] = {}
|
132
|
-
args: list[str] = []
|
152
|
+
args: list[tuple[str, DType]] = []
|
133
153
|
kernel: list[str] = []
|
134
|
-
end_lines: dict[str, None] = {}
|
135
154
|
vc = -1
|
136
155
|
|
137
|
-
|
156
|
+
local_args: list[str] = []
|
138
157
|
for u in uops:
|
139
|
-
if u.op is Ops.
|
140
|
-
vc += 1
|
141
|
-
r[u] = r[u.src[1]] = f"%assign{vc}"
|
142
|
-
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
143
|
-
acc_to_assign[u.src[0]] = u.src[1]
|
144
|
-
if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
|
158
|
+
if AMX and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
|
145
159
|
vc += 1
|
146
160
|
r[u] = f"%wmma{vc}"
|
147
161
|
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
|
@@ -150,17 +164,24 @@ class LLVMRenderer(Renderer):
|
|
150
164
|
|
151
165
|
name = "test"
|
152
166
|
for u in uops:
|
153
|
-
if u.op is Ops.
|
154
|
-
|
167
|
+
if u.op is Ops.NOOP: continue
|
168
|
+
if u.op is Ops.SINK:
|
169
|
+
if u.arg is not None: name = u.arg.function_name
|
155
170
|
continue
|
156
171
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
157
172
|
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
173
|
+
args.append((r[u], u.dtype))
|
174
|
+
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
175
|
+
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
176
|
+
assert isinstance(u.dtype, PtrDType)
|
177
|
+
if self.device == "LLVM" or u.op is Ops.DEFINE_REG:
|
178
|
+
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
|
179
|
+
else:
|
180
|
+
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
|
181
|
+
kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
|
162
182
|
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
163
|
-
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype)
|
183
|
+
elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
|
184
|
+
r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
|
164
185
|
else:
|
165
186
|
# if it's an assign target, it's already preallocated
|
166
187
|
if u not in r:
|
@@ -171,21 +192,51 @@ class LLVMRenderer(Renderer):
|
|
171
192
|
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
|
172
193
|
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
173
194
|
kernel.append(cast(str, l))
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
195
|
+
return tuple(local_args), self._render_fn(name, args, kernel, prefix)
|
196
|
+
|
197
|
+
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
|
198
|
+
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
|
199
|
+
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
|
200
|
+
class AMDLLVMRenderer(LLVMRenderer):
|
201
|
+
device = "AMD"
|
202
|
+
has_local = True
|
203
|
+
shared_max = AMDRenderer.shared_max
|
204
|
+
global_max = AMDRenderer.global_max
|
205
|
+
abi = "amdgpu_kernel"
|
206
|
+
string_rewrite = PatternMatcher([
|
207
|
+
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
|
208
|
+
(UPat(Ops.BARRIER), lambda ctx: barrier),
|
209
|
+
]) + base_rewrite
|
210
|
+
extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([
|
211
|
+
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
|
212
|
+
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
213
|
+
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
|
214
|
+
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
|
215
|
+
])
|
216
|
+
def _render_footer(self, uops: list[UOp]) -> str:
|
217
|
+
# TODO: this is copied from cstyle
|
218
|
+
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
219
|
+
attributes = ["alwaysinline", "nounwind", '"no-builtins"',
|
220
|
+
f'"amdgpu-flat-work-group-size"="1,{requiredMaxThreadsPerBlock}"', '"no-trapping-math"="true"']
|
221
|
+
return 'attributes #0 = { ' + ' '.join(attributes) + ' }'
|
222
|
+
def __init__(self, arch:str):
|
223
|
+
self.arch = arch
|
224
|
+
self.tensor_cores = AMDRenderer.get_tensor_cores(arch)
|
225
|
+
self.string_rewrite += PatternMatcher([(UPat(Ops.WMMA, name="wmma"), lambda ctx, wmma, arch=arch: render_wmma_amd(ctx, wmma, arch))])
|
226
|
+
if self.arch.split(":")[0] == "gfx1100":
|
227
|
+
self.extra_matcher += PatternMatcher([
|
228
|
+
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
|
229
|
+
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),
|
230
|
+
(UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint16.vec(16)), x.src[1].bitcast(dtypes.uint16.vec(16)),
|
231
|
+
x.src[2]), x.arg) if x.src[0].dtype == dtypes.bfloat16.vec(16) else None),
|
232
|
+
])
|
233
|
+
if self.arch.split(":")[0] == "gfx1201":
|
234
|
+
self.extra_matcher += PatternMatcher([
|
235
|
+
(UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
|
236
|
+
(x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
|
237
|
+
.bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
|
238
|
+
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
|
239
|
+
lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
|
240
|
+
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
|
241
|
+
])
|
242
|
+
def __reduce__(self): return self.__class__, (self.arch,)
|
tinygrad/renderer/ptx.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
from typing import cast, Callable
|
2
2
|
import struct
|
3
3
|
from collections import defaultdict
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.
|
4
|
+
from tinygrad.codegen.opt import tc
|
5
|
+
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
6
|
+
from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
|
6
7
|
from tinygrad.renderer import Renderer
|
7
8
|
from tinygrad.renderer.cstyle import CUDARenderer
|
8
|
-
from tinygrad.helpers import flatten, get_single_element
|
9
|
+
from tinygrad.helpers import flatten, get_single_element, prod
|
9
10
|
|
10
11
|
def render_val(x, dtype):
|
11
12
|
if dtypes.is_float(dtype):
|
@@ -18,6 +19,7 @@ asm_for_op: dict[Ops, Callable] = {
|
|
18
19
|
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
19
20
|
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
20
21
|
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
22
|
+
Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",
|
21
23
|
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
22
24
|
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
|
23
25
|
Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
@@ -25,18 +27,19 @@ asm_for_op: dict[Ops, Callable] = {
|
|
25
27
|
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};",
|
26
28
|
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
|
27
29
|
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
28
|
-
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
|
30
|
+
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
|
29
31
|
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
30
32
|
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
31
33
|
Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
|
32
34
|
f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
33
35
|
}
|
34
36
|
|
35
|
-
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
|
37
|
+
supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE, Ops.TRUNC)
|
36
38
|
doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
37
39
|
ptx_matcher = PatternMatcher([
|
38
40
|
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
39
41
|
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
42
|
+
(UPat.var('x', dtype=dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
|
40
43
|
(UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
|
41
44
|
# upcast to float32 all the ops that don't support half
|
42
45
|
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
@@ -47,14 +50,20 @@ ptx_matcher = PatternMatcher([
|
|
47
50
|
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
48
51
|
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
49
52
|
# load/store use pointer arithmetic, and the cast does nothing
|
50
|
-
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))),
|
51
|
-
|
53
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))),
|
54
|
+
lambda buf,idx: (buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize) if buf.dtype.addrspace != AddrSpace.REG else None),
|
55
|
+
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
|
56
|
+
# move mask from INDEX to the load/store to enable pointer arithmetic
|
57
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))),
|
58
|
+
lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))),
|
59
|
+
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate")), allow_any_len=True),
|
60
|
+
lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
|
52
61
|
# ptx shr and shl instructions require y to be uint
|
53
62
|
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
54
63
|
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
55
64
|
])
|
56
65
|
|
57
|
-
def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
|
66
|
+
def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort()) else 'global'
|
58
67
|
|
59
68
|
def render_wmma(ctx: "PTXRenderer", wmma: UOp):
|
60
69
|
assert ctx.wmma_r, "registry values for wmma must be populated"
|
@@ -84,7 +93,7 @@ string_rewrite = PatternMatcher([
|
|
84
93
|
f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
85
94
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
|
86
95
|
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
|
87
|
-
(UPat((Ops.CMPLT, Ops.CMPNE), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
|
96
|
+
(UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
|
88
97
|
lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
|
89
98
|
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
90
99
|
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
@@ -103,19 +112,14 @@ string_rewrite = PatternMatcher([
|
|
103
112
|
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
|
104
113
|
lambda ctx, x, loc: f"ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
105
114
|
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
106
|
-
(UPat(Ops.
|
107
|
-
|
108
|
-
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred"),), allow_any_len=True),
|
109
|
-
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
|
110
|
-
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
111
|
-
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
|
112
|
-
(UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
|
115
|
+
(UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
|
116
|
+
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
113
117
|
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
114
118
|
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
|
115
|
-
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[
|
119
|
+
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
|
116
120
|
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
117
121
|
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
118
|
-
lambda ctx, x: [f".shared .align
|
122
|
+
lambda ctx, x: [f".shared .align 16 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
|
119
123
|
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
120
124
|
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
121
125
|
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
@@ -127,12 +131,12 @@ class PTXRenderer(Renderer):
|
|
127
131
|
device = "CUDA"
|
128
132
|
suffix = "PTX"
|
129
133
|
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
130
|
-
tc_sm80 = [
|
134
|
+
tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]]
|
131
135
|
code_for_op = asm_for_op
|
132
136
|
extra_matcher = ptx_matcher
|
133
137
|
def __init__(self, arch:str, device="CUDA"):
|
134
138
|
self.device, self.arch = device, arch
|
135
|
-
self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else
|
139
|
+
self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else []
|
136
140
|
def __reduce__(self): return self.__class__, (self.arch, self.device)
|
137
141
|
|
138
142
|
# language options
|
@@ -148,11 +152,12 @@ class PTXRenderer(Renderer):
|
|
148
152
|
|
149
153
|
mem_types: dict[DType, str] = {**types, dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}
|
150
154
|
|
151
|
-
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
155
|
+
def render_kernel(self, kernel, function_name, bufs, regs, uops) -> str:
|
152
156
|
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
153
157
|
kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]))
|
158
|
+
launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
154
159
|
params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
|
155
|
-
return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
|
160
|
+
return f"{self.kernel_prefix.format(launch_bounds=launch_bounds)} {function_name} (\n\t{params}\n)\n.maxntid {launch_bounds}\n{{\n{kernel}\n}}"
|
156
161
|
|
157
162
|
def render(self, uops:list[UOp]) -> str:
|
158
163
|
kernel:list[str] = []
|
@@ -165,14 +170,15 @@ class PTXRenderer(Renderer):
|
|
165
170
|
|
166
171
|
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
|
167
172
|
nonlocal c, r
|
168
|
-
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
|
173
|
+
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype.base]}_"
|
169
174
|
c[prefix] += 1
|
170
175
|
return f"%{prefix}{c[prefix]-1}"
|
171
176
|
|
172
177
|
name = "test"
|
173
178
|
for u in uops:
|
174
|
-
if u.op is Ops.
|
175
|
-
|
179
|
+
if u.op is Ops.NOOP: continue
|
180
|
+
if u.op is Ops.SINK:
|
181
|
+
if u.arg is not None: name = u.arg.function_name
|
176
182
|
continue
|
177
183
|
if u.op is Ops.VECTORIZE:
|
178
184
|
r[u] = [cast(str,r[x]) for x in u.src]
|
@@ -183,6 +189,19 @@ class PTXRenderer(Renderer):
|
|
183
189
|
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
|
184
190
|
r[u] = r[u.src[0]]
|
185
191
|
continue
|
192
|
+
if u.op is Ops.DEFINE_REG:
|
193
|
+
r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(cast(PtrDType, u.dtype).size)]
|
194
|
+
continue
|
195
|
+
if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
|
196
|
+
if u.op is Ops.INDEX:
|
197
|
+
assert u.src[1].op == Ops.CONST, f"index on REG in ptx only supported on CONST, not {u.src[1].op}"
|
198
|
+
r[u] = r[u.src[0]][u.src[1].arg]
|
199
|
+
else:
|
200
|
+
r[u] = r[u.src[0]]
|
201
|
+
if u.op is Ops.STORE:
|
202
|
+
typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:])
|
203
|
+
kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};")
|
204
|
+
continue
|
186
205
|
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
187
206
|
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
|
188
207
|
elif u.op is Ops.LOAD:
|
@@ -191,12 +210,12 @@ class PTXRenderer(Renderer):
|
|
191
210
|
elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype))
|
192
211
|
elif u.op is Ops.WMMA:
|
193
212
|
# registers for packing/unpacking input and acc
|
194
|
-
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.
|
195
|
-
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.
|
196
|
-
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.
|
213
|
+
self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
214
|
+
[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
|
215
|
+
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
197
216
|
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
198
217
|
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
199
|
-
Ops.
|
218
|
+
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
|
200
219
|
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
201
220
|
if prefix: r[u] = ssa(prefix, u, dtype)
|
202
221
|
|
@@ -204,6 +223,5 @@ class PTXRenderer(Renderer):
|
|
204
223
|
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
205
224
|
kernel.extend([l] if isinstance(l, str) else l)
|
206
225
|
|
207
|
-
if u.op is Ops.
|
208
|
-
|
209
|
-
return self.render_kernel(kernel, name, bufs, c.items())
|
226
|
+
if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
|
227
|
+
return self.render_kernel(kernel, name, bufs, c.items(), uops)
|
tinygrad/renderer/wgsl.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
from tinygrad.dtype import DType, PtrDType, dtypes
|
2
|
-
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat
|
1
|
+
from tinygrad.dtype import DType, PtrDType, dtypes, AddrSpace
|
2
|
+
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
3
3
|
from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
|
4
4
|
from tinygrad.helpers import strip_parens
|
5
|
-
import math
|
6
5
|
|
7
6
|
def sign_extend(val:UOp, sext_am:int):
|
8
7
|
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
|
@@ -20,24 +19,25 @@ def packed_store(bidx:UOp, var:UOp):
|
|
20
19
|
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
21
20
|
div_idx = bidx.src[1]//(4//dtype.itemsize)
|
22
21
|
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
23
|
-
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx
|
22
|
+
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg)
|
24
23
|
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
|
25
24
|
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
26
25
|
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
27
26
|
|
28
|
-
def is_packed(dt:DType) -> bool:
|
27
|
+
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
28
|
+
if odt is None: odt = dt
|
29
|
+
return dt.itemsize < 4 and dt.base != dtypes.half and (not isinstance(odt, PtrDType) or odt.addrspace != AddrSpace.REG)
|
29
30
|
|
30
31
|
wgsl_matcher = PatternMatcher([
|
31
32
|
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
32
33
|
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
33
|
-
(UPat
|
34
|
-
|
35
|
-
|
36
|
-
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
|
37
|
-
|
38
|
-
(UPat
|
39
|
-
|
40
|
-
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
|
34
|
+
(UPat.load(UPat.var("b"), UPat.cvar("c"), name="l"),
|
35
|
+
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
|
36
|
+
(UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
|
37
|
+
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
|
38
|
+
lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
|
39
|
+
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
|
40
|
+
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
41
41
|
]) + extra_pm
|
42
42
|
|
43
43
|
class WGSLRenderer(CStyleLanguage):
|
@@ -54,23 +54,27 @@ class WGSLRenderer(CStyleLanguage):
|
|
54
54
|
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
|
55
55
|
|
56
56
|
string_rewrite = PatternMatcher([
|
57
|
-
(UPat(
|
58
|
-
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
59
|
-
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
60
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x:
|
61
|
-
|
62
|
-
|
57
|
+
(UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
|
58
|
+
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
59
|
+
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
60
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x:
|
61
|
+
f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
|
62
|
+
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:
|
63
|
+
f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
|
64
|
+
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
65
|
+
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
63
66
|
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
64
67
|
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
|
65
68
|
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
66
69
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
67
|
-
(UPat.load(UPat.var("b"),UPat.
|
68
|
-
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.
|
69
|
-
(UPat.
|
70
|
-
(UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
70
|
+
(UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
|
71
|
+
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
|
72
|
+
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
71
73
|
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
72
74
|
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
73
75
|
else f"{ctx[b]} = {ctx[v]};"),
|
76
|
+
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
|
77
|
+
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
|
74
78
|
# fix nan check: 'a != a -> is_nan()'
|
75
79
|
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
|
76
80
|
]) + base_rewrite
|