tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,31 +1,31 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Literal, Callable, cast
|
2
2
|
import os, math, sys
|
3
3
|
from collections import defaultdict, Counter
|
4
|
-
from tinygrad.
|
4
|
+
from tinygrad.codegen.opt import tc
|
5
|
+
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
5
6
|
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
6
|
-
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
7
|
-
from tinygrad.renderer import Renderer
|
7
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
|
8
|
+
from tinygrad.renderer import Renderer
|
8
9
|
from tinygrad.codegen.devectorizer import no_vectorized_alu
|
9
10
|
|
10
11
|
base_rewrite = PatternMatcher([
|
11
|
-
(UPat(Ops.
|
12
|
-
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
12
|
+
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
13
13
|
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
14
14
|
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
15
15
|
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
16
16
|
# r method accesses
|
17
17
|
(UPat(Ops.RANGE, name="x"),
|
18
|
-
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} =
|
18
|
+
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
|
19
19
|
(UPat(Ops.VECTORIZE, name="x"),
|
20
20
|
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
21
|
-
|
21
|
+
f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
|
22
22
|
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
23
23
|
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
24
24
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
25
25
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
26
26
|
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
27
27
|
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
28
|
-
(UPat(Ops.
|
28
|
+
(UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
29
29
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
30
30
|
# const
|
31
31
|
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
|
@@ -33,39 +33,38 @@ base_rewrite = PatternMatcher([
|
|
33
33
|
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
|
34
34
|
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
35
35
|
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
36
|
-
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
37
|
-
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
36
|
+
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}ull"),
|
37
|
+
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
|
38
38
|
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
39
39
|
# consts are rendered to larger type and casted
|
40
40
|
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
41
41
|
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
|
42
|
-
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
|
42
|
+
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
43
43
|
# default const render
|
44
44
|
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
45
45
|
# new load/store
|
46
|
-
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
46
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
|
47
47
|
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
48
|
-
(UPat(Ops.LOAD, src=(UPat.
|
49
|
-
|
48
|
+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True),
|
49
|
+
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
50
|
+
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"),
|
50
51
|
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
51
52
|
# alu/gep
|
53
|
+
# TODO: look for left-associative
|
52
54
|
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
53
|
-
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
55
|
+
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
|
54
56
|
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
55
|
-
(f"[{x.arg[0]}]" if x.src[0].dtype.count >
|
56
|
-
f".{'xyzwabcd'[x.arg[0]]}")),
|
57
|
+
(f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
|
57
58
|
# custom passes through with format
|
58
|
-
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
59
|
+
(UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
59
60
|
])
|
60
61
|
|
61
62
|
extra_pm = PatternMatcher([
|
62
|
-
# insert a
|
63
|
-
(UPat(Ops.BITCAST, name="x"),
|
64
|
-
|
65
|
-
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
66
|
-
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
63
|
+
# insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
|
64
|
+
(UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
|
65
|
+
if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
|
67
66
|
# devectorize any bools
|
68
|
-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.
|
67
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
69
68
|
# CAST (from bool) can't be vectorized
|
70
69
|
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
|
71
70
|
# WHERE can't be vectorized
|
@@ -74,8 +73,12 @@ extra_pm = PatternMatcher([
|
|
74
73
|
|
75
74
|
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
76
75
|
|
76
|
+
# (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
|
77
|
+
def wmma_args(uops:list[UOp]):
|
78
|
+
return dedup((uop.arg[0], uop.arg[1], uop.src[0].dtype.scalar(), uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
|
79
|
+
|
77
80
|
class CStyleLanguage(Renderer):
|
78
|
-
|
81
|
+
kernel_typedef: str = "void"
|
79
82
|
buffer_prefix: str = ""
|
80
83
|
buffer_suffix: str = ""
|
81
84
|
smem_align: str = ""
|
@@ -83,30 +86,33 @@ class CStyleLanguage(Renderer):
|
|
83
86
|
smem_prefix_for_cast: bool = True
|
84
87
|
arg_int_prefix: str = "const int"
|
85
88
|
barrier: str = ""
|
86
|
-
code_for_workitem: dict[
|
89
|
+
code_for_workitem: dict[Literal["g", "l", "i"], Callable] = {}
|
87
90
|
extra_args: list[str] = []
|
88
|
-
float4:
|
91
|
+
float4: str|None = None
|
92
|
+
float4_style: tuple[str, str] = ('(', ')')
|
93
|
+
gep_arr_threshold: int = 4
|
89
94
|
type_map: dict[DType, str] = {}
|
90
95
|
infinity: str = "INFINITY"
|
91
96
|
nan: str = "NAN"
|
92
97
|
code_for_op: dict = {
|
93
98
|
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
94
99
|
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
100
|
+
Ops.TRUNC: lambda x,dtype: f"trunc({x})",
|
95
101
|
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
96
102
|
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
97
103
|
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
98
104
|
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
99
|
-
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
|
105
|
+
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
|
100
106
|
|
101
107
|
string_rewrite = base_rewrite
|
102
108
|
extra_matcher = extra_pm
|
103
109
|
|
104
|
-
def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
|
105
110
|
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
106
111
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
107
112
|
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
108
113
|
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
109
|
-
|
114
|
+
launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
115
|
+
prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
|
110
116
|
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
111
117
|
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
112
118
|
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
@@ -115,12 +121,15 @@ class CStyleLanguage(Renderer):
|
|
115
121
|
def render_dtype(self, dt:DType, mutable=True) -> str:
|
116
122
|
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
117
123
|
if isinstance(dt, PtrDType):
|
118
|
-
|
124
|
+
prefix = ""
|
125
|
+
if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
|
126
|
+
if dt.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
|
127
|
+
return prefix + self.render_dtype(dt.base) + "*"
|
119
128
|
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
|
120
129
|
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
121
130
|
|
122
131
|
def __getitem__(self, key): return self.r[key] # hacky helper
|
123
|
-
def
|
132
|
+
def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
|
124
133
|
r: dict[UOp, str] = {}
|
125
134
|
self.r = r
|
126
135
|
|
@@ -131,98 +140,107 @@ class CStyleLanguage(Renderer):
|
|
131
140
|
c: defaultdict[str, int] = defaultdict(int)
|
132
141
|
name = "test"
|
133
142
|
for u in uops:
|
134
|
-
if u.op is Ops.
|
135
|
-
|
143
|
+
if u.op is Ops.NOOP: continue
|
144
|
+
if u.op is Ops.SINK:
|
145
|
+
if u.arg is not None: name = u.arg.function_name
|
136
146
|
continue
|
137
147
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
138
|
-
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
148
|
+
r[u] = (f"data{u.arg}_{sz}" if (sz:=cast(PtrDType, u.dtype).size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
139
149
|
bufs[u] = (r[u], (u.dtype, False))
|
140
150
|
continue
|
141
151
|
|
142
152
|
# mark buffers that we store to writable
|
143
153
|
if u.op is Ops.STORE:
|
144
|
-
for up in u.src[0].toposort:
|
154
|
+
for up in u.src[0].toposort():
|
145
155
|
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
146
156
|
|
147
157
|
# naming
|
148
158
|
prefix = None
|
149
|
-
if u.op is Ops.SPECIAL:
|
150
|
-
|
159
|
+
if u.op is Ops.SPECIAL: r[u] = u.arg[0]
|
160
|
+
elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg}"
|
151
161
|
else:
|
152
|
-
prefix = {Ops.
|
153
|
-
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.
|
154
|
-
Ops.INDEX: "bidx", Ops.
|
162
|
+
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
163
|
+
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
164
|
+
Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
155
165
|
r[u] = f"{prefix}{c[prefix]}"
|
156
166
|
|
157
167
|
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
158
168
|
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
159
169
|
|
160
170
|
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
161
|
-
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.
|
162
|
-
(u.op
|
171
|
+
if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
|
172
|
+
(u.op is Ops.LOAD and cast(PtrDType, u.src[0].dtype).addrspace == AddrSpace.REG) or \
|
173
|
+
(u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
|
174
|
+
(u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
|
163
175
|
r[u] = l
|
164
176
|
else:
|
165
|
-
if u.op in {Ops.RANGE, Ops.
|
166
|
-
|
167
|
-
else:
|
168
|
-
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
177
|
+
if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: pass
|
178
|
+
else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
169
179
|
kernel.append(" "*depth + l)
|
170
180
|
if prefix: c[prefix] += 1 # if it was used, increment
|
171
181
|
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
172
182
|
del self.r
|
173
183
|
|
174
184
|
# NOTE: this relies on bufs dict preserving order
|
175
|
-
return
|
185
|
+
return (name, kernel, list(bufs.values()))
|
186
|
+
def render(self, uops:list[UOp]) -> str: return self.render_kernel(*self._render(uops), uops)
|
176
187
|
|
177
188
|
class ClangRenderer(CStyleLanguage):
|
178
189
|
device = "CPU"
|
179
190
|
float4 = "(float4)"
|
191
|
+
float4_style = ('{', '}')
|
192
|
+
gep_arr_threshold = 0
|
180
193
|
has_local = False
|
181
194
|
global_max = None
|
182
195
|
infinity = "__builtin_inff()"
|
183
196
|
nan = '__builtin_nanf("")'
|
184
|
-
|
185
|
-
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
186
|
-
if AMX: tensor_cores = amx_tc
|
197
|
+
if AMX: tensor_cores = tc.amx
|
187
198
|
|
188
199
|
# language options
|
189
200
|
buffer_suffix = " restrict"
|
190
201
|
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
191
|
-
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
192
|
-
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"
|
202
|
+
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC]}),
|
203
|
+
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
|
204
|
+
Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})"}
|
193
205
|
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
194
|
-
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))
|
195
|
-
CStyleLanguage.extra_matcher
|
206
|
+
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
|
207
|
+
(UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
|
196
208
|
|
197
209
|
if sys.platform == 'win32':
|
198
|
-
|
210
|
+
kernel_typedef = "__attribute__((ms_abi)) void"
|
199
211
|
def render_vector_prefix(self, dt:DType) -> str:
|
200
|
-
# round (down) to power of two
|
201
|
-
alignment = 2**int(math.log2(dt.itemsize))
|
212
|
+
# round (down) to power of two (this is actually the default clang behavior)
|
213
|
+
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
|
202
214
|
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
|
203
215
|
|
204
|
-
def
|
216
|
+
def _render_defines(self, uops) -> list[str]:
|
205
217
|
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
206
218
|
# https://github.com/corsix/amx
|
207
|
-
for name, (N, M, _), dtype_in, _, _, _, _, _ in
|
219
|
+
for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops):
|
208
220
|
prefix += [
|
209
221
|
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
210
222
|
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
211
223
|
]
|
212
224
|
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
213
|
-
# to just jump at the start of a shellcode
|
225
|
+
# to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
|
214
226
|
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
215
227
|
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
216
228
|
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
217
229
|
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
218
230
|
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
219
|
-
return
|
231
|
+
return prefix
|
232
|
+
def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
|
233
|
+
def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
|
234
|
+
|
235
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
236
|
+
defines = '\n'.join(self._render_defines(uops))
|
237
|
+
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
|
220
238
|
|
221
239
|
class OpenCLRenderer(CStyleLanguage):
|
222
240
|
device = "GPU"
|
223
241
|
|
224
242
|
# language options
|
225
|
-
|
243
|
+
kernel_typedef = "__kernel void"
|
226
244
|
buffer_prefix = "__global "
|
227
245
|
smem_align = "__attribute__ ((aligned (16))) "
|
228
246
|
smem_prefix = "__local "
|
@@ -235,7 +253,7 @@ class OpenCLRenderer(CStyleLanguage):
|
|
235
253
|
string_rewrite = PatternMatcher([
|
236
254
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
237
255
|
# load/store image (OpenCL)
|
238
|
-
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))
|
256
|
+
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
|
239
257
|
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
240
258
|
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
241
259
|
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
@@ -248,35 +266,31 @@ class OpenCLRenderer(CStyleLanguage):
|
|
248
266
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
249
267
|
|
250
268
|
class IntelRenderer(OpenCLRenderer):
|
251
|
-
device, suffix,
|
252
|
-
tensor_cores =
|
253
|
-
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
|
269
|
+
device, suffix, kernel_typedef = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
270
|
+
tensor_cores = tc.intel
|
254
271
|
|
255
272
|
string_rewrite = PatternMatcher([
|
256
|
-
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
|
257
|
-
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
|
273
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float),)), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
|
274
|
+
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16),)), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
|
258
275
|
]) + OpenCLRenderer.string_rewrite
|
259
276
|
|
260
277
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
261
278
|
prefix = []
|
262
|
-
for
|
263
|
-
dt_in = ("ushort", "bf16") if
|
264
|
-
prefix.append(f"""{
|
279
|
+
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops):
|
280
|
+
dt_in = ("ushort", "bf16") if dtype_in == dtypes.bfloat16 else (dtype_in.name, "f16")
|
281
|
+
prefix.append(f"""{dtype_out.name}8 __{name}({dt_in[0]}16 a, {dt_in[0]}16 b, {dtype_out.name}8 c) {{
|
265
282
|
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
|
266
283
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
|
267
284
|
|
268
285
|
class MetalRenderer(CStyleLanguage):
|
269
286
|
device = "METAL"
|
270
287
|
shared_max = 32768
|
271
|
-
tensor_cores =
|
272
|
-
swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
273
|
-
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
274
|
-
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
288
|
+
def __init__(self): self.tensor_cores = tc.metal if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
275
289
|
|
276
290
|
# language options
|
277
|
-
|
291
|
+
kernel_typedef = "kernel void"
|
278
292
|
buffer_prefix = "device "
|
279
|
-
smem_prefix = "threadgroup "
|
293
|
+
smem_prefix = "threadgroup __attribute__((aligned(16))) "
|
280
294
|
arg_int_prefix = "constant int&"
|
281
295
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
282
296
|
float4 = "float4"
|
@@ -300,45 +314,35 @@ class MetalRenderer(CStyleLanguage):
|
|
300
314
|
]) + base_rewrite
|
301
315
|
|
302
316
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
303
|
-
prefix
|
304
|
-
for
|
305
|
-
f"""{(
|
306
|
-
simdgroup_{self.render_dtype(
|
317
|
+
prefix = ["#include <metal_stdlib>","using namespace metal;"]
|
318
|
+
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): prefix.append(
|
319
|
+
f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
|
320
|
+
simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
|
307
321
|
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
|
308
322
|
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
|
309
|
-
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {
|
323
|
+
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
310
324
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
311
325
|
|
312
326
|
_nms = "xyzwabcdefghijkl"
|
313
|
-
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
314
327
|
|
315
328
|
class CUDARenderer(CStyleLanguage):
|
316
329
|
device = "CUDA"
|
317
330
|
global_max = (2147483647, 65535, 65535)
|
318
331
|
local_max = (1024, 1024, 64)
|
319
332
|
shared_max = 49152
|
320
|
-
|
321
|
-
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
322
|
-
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
|
323
|
-
(dtypes.half,dtypes.half)]]
|
324
|
-
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
325
|
-
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
326
|
-
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
327
|
-
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
328
|
-
|
329
|
-
tc_sm80 = tc_81616 + tc_8168_f16
|
330
|
-
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
|
331
|
-
tc_sm75 = tc_8168_f16
|
333
|
+
|
332
334
|
def __init__(self, arch:str):
|
333
|
-
self.tensor_cores, self.arch =
|
335
|
+
self.tensor_cores, self.arch = tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [], arch
|
334
336
|
def __reduce__(self): return self.__class__, (self.arch,)
|
335
337
|
|
336
338
|
# language options
|
337
|
-
|
338
|
-
|
339
|
+
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
340
|
+
kernel_typedef = "extern \"C\" __global__ void __launch_bounds__({launch_bounds})"
|
341
|
+
smem_prefix = "__shared__ __align__(16) "
|
339
342
|
smem_prefix_for_cast = False
|
340
343
|
barrier = "__syncthreads();"
|
341
344
|
float4 = "make_float4"
|
345
|
+
gep_arr_threshold = 8
|
342
346
|
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
343
347
|
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
|
344
348
|
code_for_op = { **CStyleLanguage.code_for_op,
|
@@ -365,7 +369,7 @@ class CUDARenderer(CStyleLanguage):
|
|
365
369
|
|
366
370
|
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
367
371
|
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
368
|
-
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in
|
372
|
+
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
|
369
373
|
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
370
374
|
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
371
375
|
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
|
@@ -383,11 +387,6 @@ class CUDARenderer(CStyleLanguage):
|
|
383
387
|
|
384
388
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
385
389
|
|
386
|
-
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
387
|
-
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
388
|
-
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
389
|
-
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
390
|
-
|
391
390
|
def cast_float_to_bf16(x: UOp) -> UOp:
|
392
391
|
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
393
392
|
x = x.bitcast(dtypes.uint)
|
@@ -397,27 +396,40 @@ def cast_float_to_bf16(x: UOp) -> UOp:
|
|
397
396
|
class AMDRenderer(CStyleLanguage):
|
398
397
|
device = "AMD"
|
399
398
|
shared_max = 65536
|
400
|
-
#
|
401
|
-
|
402
|
-
|
403
|
-
|
399
|
+
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
|
400
|
+
global_max = (2147483647, 65535, 65535)
|
401
|
+
|
402
|
+
@staticmethod
|
403
|
+
def get_tensor_cores(arch):
|
404
|
+
return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3)
|
405
|
+
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
|
406
|
+
self.arch = arch
|
407
|
+
self.tensor_cores = self.get_tensor_cores(arch)
|
408
|
+
if self.tensor_cores == tc.amd_cdna:
|
409
|
+
self.string_rewrite = PatternMatcher([
|
410
|
+
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)")]) + base_rewrite
|
411
|
+
def __reduce__(self): return self.__class__, (self.arch,)
|
404
412
|
|
405
413
|
# language options
|
406
414
|
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
407
415
|
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
408
416
|
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
|
409
|
-
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
|
417
|
+
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", ""), ("trunc", "")]]
|
410
418
|
|
411
|
-
|
412
|
-
|
419
|
+
kernel_typedef = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
420
|
+
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
421
|
+
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
422
|
+
kernel_typedef += '\nextern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {launch_bounds})))'
|
413
423
|
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
414
424
|
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
415
425
|
code_for_op = { **CStyleLanguage.code_for_op,
|
426
|
+
Ops.TRUNC: lambda x,dtype: f"__ocml_trunc_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
416
427
|
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
417
428
|
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
418
429
|
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
419
430
|
Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
|
420
|
-
smem_prefix = "__attribute__((shared))"
|
431
|
+
smem_prefix = "__attribute__((shared, aligned(16)))"
|
432
|
+
smem_prefix_for_cast: bool = False
|
421
433
|
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
422
434
|
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
423
435
|
float4 = "make_float4"
|
@@ -431,12 +443,15 @@ class AMDRenderer(CStyleLanguage):
|
|
431
443
|
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
432
444
|
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
433
445
|
# add float intermediate casting for bfloat16
|
434
|
-
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)
|
435
|
-
|
446
|
+
(UPat(Ops.CAST, name="x", src=(UPat.var("y", dtypes.bfloat16),)),
|
447
|
+
lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
448
|
+
(UPat(Ops.CAST, dtypes.bfloat16, (UPat.var("x"),)),
|
449
|
+
lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
436
450
|
# bfloat16 casting
|
437
451
|
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
438
|
-
(UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)
|
439
|
-
|
452
|
+
(UPat(Ops.CAST, dtypes.float, (UPat.var("x", dtypes.bfloat16),)),
|
453
|
+
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
454
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16)]) + extra_pm
|
440
455
|
|
441
456
|
def render_vector_prefix(self, dtype:DType) -> str:
|
442
457
|
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
@@ -445,25 +460,25 @@ class AMDRenderer(CStyleLanguage):
|
|
445
460
|
|
446
461
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
447
462
|
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
448
|
-
|
463
|
+
type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16" }
|
449
464
|
used_dtypes = uops_to_dtypes(uops)
|
450
465
|
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
451
466
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
452
467
|
|
453
|
-
for
|
454
|
-
if
|
455
|
-
|
468
|
+
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
469
|
+
if self.tensor_cores == tc.amd_cdna:
|
470
|
+
prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if dtype_in == dtypes.half else 'bf16_1k'}")
|
471
|
+
# #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
|
472
|
+
elif self.tensor_cores == tc.amd_rdna4:
|
473
|
+
prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")
|
474
|
+
elif dtype_out == dtypes.float:
|
475
|
+
prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32")
|
476
|
+
else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) {
|
456
477
|
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
457
478
|
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
458
479
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
459
480
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
460
481
|
|
461
|
-
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
462
|
-
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
463
|
-
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
464
|
-
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
465
|
-
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
466
|
-
|
467
482
|
class NVRenderer(CUDARenderer): device = "NV"
|
468
483
|
class HIPRenderer(AMDRenderer): device = "HIP"
|
469
484
|
class QCOMRenderer(OpenCLRenderer): device = "QCOM"
|