tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,12 +1,71 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
2
3
|
import os, math
|
3
4
|
from collections import defaultdict, Counter
|
4
|
-
from tinygrad.ops import
|
5
|
-
from tinygrad.helpers import strip_parens, getenv, prod, dedup
|
6
|
-
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
5
|
+
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
6
|
+
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
7
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
8
8
|
from tinygrad.renderer import Renderer, TensorCore
|
9
9
|
|
10
|
+
base_rewrite = PatternMatcher([
|
11
|
+
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
12
|
+
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
13
|
+
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
14
|
+
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
15
|
+
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
16
|
+
# r method accesses
|
17
|
+
(UPat(Ops.RANGE, name="x"),
|
18
|
+
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
19
|
+
(UPat(Ops.VECTORIZE, name="x"),
|
20
|
+
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
21
|
+
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
|
22
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
23
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
24
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
|
25
|
+
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
26
|
+
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
27
|
+
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
28
|
+
# const
|
29
|
+
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
|
30
|
+
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"),
|
31
|
+
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
|
32
|
+
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
33
|
+
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
34
|
+
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
35
|
+
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
36
|
+
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
37
|
+
# consts are rendered to larger type and casted
|
38
|
+
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
39
|
+
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
|
40
|
+
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
|
41
|
+
# default const render
|
42
|
+
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
43
|
+
# new load/store
|
44
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
45
|
+
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
|
46
|
+
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
47
|
+
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
48
|
+
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
49
|
+
# alu/gep
|
50
|
+
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
51
|
+
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
52
|
+
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
53
|
+
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
54
|
+
])
|
55
|
+
|
56
|
+
extra_pm = PatternMatcher([
|
57
|
+
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
58
|
+
(UPat(Ops.BITCAST, name="x"),
|
59
|
+
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
60
|
+
# gate any stores that aren't gated with ifs
|
61
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
62
|
+
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
63
|
+
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
64
|
+
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
65
|
+
])
|
66
|
+
|
67
|
+
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
68
|
+
|
10
69
|
class CStyleLanguage(Renderer):
|
11
70
|
kernel_prefix: str = ""
|
12
71
|
buffer_prefix: str = ""
|
@@ -19,174 +78,125 @@ class CStyleLanguage(Renderer):
|
|
19
78
|
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
20
79
|
extra_args: List[str] = []
|
21
80
|
float4: Optional[str] = None
|
22
|
-
uses_vload: bool = False
|
23
|
-
uses_ptr_arithmetic: bool = False
|
24
81
|
type_map: Dict[DType, str] = {}
|
82
|
+
infinity: str = "INFINITY"
|
83
|
+
nan: str = "NAN"
|
25
84
|
code_for_op: Dict = {
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
40
|
-
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
41
|
-
|
42
|
-
# returns a str expression of the const with the given type
|
43
|
-
def render_const(self, x:ConstType, dtype:DType) -> str:
|
44
|
-
if math.isnan(x): val = "NAN"
|
45
|
-
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
46
|
-
elif dtype == dtypes.bool: val = "1" if x else "0"
|
47
|
-
elif dtype == dtypes.float: val = f"{x}f"
|
48
|
-
else: val = str(x)
|
49
|
-
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
50
|
-
|
51
|
-
# returns a str expression of the loaded value with the output type
|
52
|
-
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
53
|
-
if isinstance(buf_dtype, ImageDType):
|
54
|
-
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
55
|
-
return f"read_imagef({buf_name}, smp, {idx})"
|
56
|
-
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
57
|
-
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
58
|
-
if output_dtype.count > 1:
|
59
|
-
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
|
60
|
-
else:
|
61
|
-
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
62
|
-
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
63
|
-
|
64
|
-
def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
|
65
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
|
85
|
+
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
86
|
+
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
87
|
+
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
88
|
+
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
89
|
+
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
90
|
+
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
91
|
+
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
|
92
|
+
|
93
|
+
string_rewrite = base_rewrite
|
94
|
+
extra_matcher = extra_pm
|
95
|
+
|
96
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
97
|
+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
66
98
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
67
|
-
buftypes = [(name,
|
68
|
-
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
99
|
+
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
69
100
|
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
70
101
|
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
71
102
|
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
72
103
|
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
73
104
|
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
74
105
|
|
75
|
-
|
76
|
-
def
|
77
|
-
if isinstance(
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
if
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
88
|
-
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
|
89
|
-
|
90
|
-
def render(self, name:str, uops:UOpGraph) -> str:
|
91
|
-
kernel = []
|
92
|
-
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
93
|
-
depth = 1
|
94
|
-
def kk(s): kernel.append(" "*depth+s)
|
95
|
-
|
96
|
-
c: DefaultDict[str, int] = defaultdict(int)
|
106
|
+
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
|
107
|
+
def render_dtype(self, dt:DType, mutable=True) -> str:
|
108
|
+
if isinstance(dt, ImageDType):
|
109
|
+
return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
110
|
+
if isinstance(dt, PtrDType):
|
111
|
+
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + \
|
112
|
+
self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
|
113
|
+
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
|
114
|
+
|
115
|
+
def __getitem__(self, key): return self.r[key] # hacky helper
|
116
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
97
117
|
r: Dict[UOp, str] = {}
|
98
|
-
|
99
|
-
def ssa(prefix:str, u:Optional[UOp]=None):
|
100
|
-
nonlocal c, r
|
101
|
-
ret = f"{prefix}{c[prefix]}"
|
102
|
-
if u is not None: r[u] = ret
|
103
|
-
c[prefix] += 1
|
104
|
-
return ret
|
118
|
+
self.r = r
|
105
119
|
|
106
120
|
child_count = Counter(v for ru in uops for v in ru.src)
|
107
|
-
|
108
|
-
|
121
|
+
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
|
122
|
+
kernel = []
|
123
|
+
depth = 1
|
124
|
+
c: DefaultDict[str, int] = defaultdict(int)
|
109
125
|
for u in uops:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
126
|
+
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
127
|
+
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
128
|
+
bufs[u] = (r[u], (u.dtype, False))
|
129
|
+
continue
|
130
|
+
|
131
|
+
# mark buffers that we store to writable
|
132
|
+
if u.op is Ops.STORE:
|
133
|
+
for up in u.src[0].sparents:
|
134
|
+
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
135
|
+
|
136
|
+
# naming
|
137
|
+
prefix = None
|
138
|
+
if u.op is Ops.SPECIAL:
|
139
|
+
r[u] = u.arg[0]
|
140
|
+
else:
|
141
|
+
prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
142
|
+
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
|
143
|
+
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
|
144
|
+
r[u] = f"{prefix}{c[prefix]}"
|
145
|
+
|
146
|
+
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
147
|
+
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
148
|
+
|
149
|
+
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
150
|
+
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
|
151
|
+
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
152
|
+
r[u] = l
|
123
153
|
else:
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
136
|
-
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
137
|
-
elif uop is UOps.SPECIAL:
|
138
|
-
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
139
|
-
r[u] = args[1]
|
140
|
-
elif uop is UOps.LOAD:
|
141
|
-
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
142
|
-
# NOTE: this relies on the load not happening if it's in the unselected branch
|
143
|
-
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
144
|
-
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
145
|
-
elif uop is UOps.PHI:
|
146
|
-
kk(f"{r[src[0]]} = {r[src[1]]};")
|
147
|
-
r[u] = r[src[0]]
|
148
|
-
elif uop in {UOps.CAST, UOps.BITCAST}:
|
149
|
-
if uop is UOps.BITCAST:
|
150
|
-
assert len(src) == 1
|
151
|
-
precast = ssa('precast')
|
152
|
-
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
153
|
-
val = self.render_cast([precast], dtype, bitcast=True)
|
154
|
-
else:
|
155
|
-
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
|
156
|
-
if child_count[u] <= 1: r[u] = val
|
157
|
-
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
158
|
-
elif uop is UOps.DEFINE_LOCAL:
|
159
|
-
kk(self.render_local(args[0], dtype, args[1]))
|
160
|
-
r[u] = args[0]
|
161
|
-
elif uop is UOps.DEFINE_VAR:
|
162
|
-
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
163
|
-
seen_vars.add(args.expr)
|
164
|
-
bufs.append((args.expr, (dtype,False)))
|
165
|
-
r[u] = args.expr
|
166
|
-
elif uop is UOps.DEFINE_GLOBAL:
|
167
|
-
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
168
|
-
r[u] = nm
|
169
|
-
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
170
|
-
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
|
171
|
-
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
172
|
-
elif uop is UOps.GEP:
|
173
|
-
assert src[0].dtype is not None
|
174
|
-
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
175
|
-
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
176
|
-
else: raise RuntimeError(f"failed to render {uop}")
|
177
|
-
|
178
|
-
return self.render_kernel(name, kernel, bufs, uops)
|
154
|
+
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
155
|
+
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
156
|
+
else:
|
157
|
+
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
158
|
+
kernel.append(" "*depth + l)
|
159
|
+
if prefix: c[prefix] += 1 # if it was used, increment
|
160
|
+
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
161
|
+
del self.r
|
162
|
+
|
163
|
+
# NOTE: this relies on bufs dict preserving order
|
164
|
+
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
179
165
|
|
180
166
|
class ClangRenderer(CStyleLanguage):
|
181
167
|
device = "CLANG"
|
182
|
-
|
168
|
+
float4 = "(float4)"
|
183
169
|
has_local = False
|
184
170
|
global_max = None
|
171
|
+
infinity = "__builtin_inff()"
|
172
|
+
nan = '__builtin_nanf("")'
|
185
173
|
|
186
174
|
# language options
|
187
175
|
buffer_suffix = " restrict"
|
188
176
|
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
189
|
-
code_for_op = {**CStyleLanguage().
|
177
|
+
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
178
|
+
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
179
|
+
|
180
|
+
if AMX:
|
181
|
+
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
|
182
|
+
for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
|
183
|
+
|
184
|
+
def render_vector_prefix(self, dt:DType) -> str:
|
185
|
+
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
|
186
|
+
|
187
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
188
|
+
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
189
|
+
# https://github.com/corsix/amx
|
190
|
+
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
191
|
+
prefix += [
|
192
|
+
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
193
|
+
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
194
|
+
]
|
195
|
+
prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
196
|
+
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
197
|
+
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
198
|
+
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
199
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
190
200
|
|
191
201
|
class OpenCLRenderer(CStyleLanguage):
|
192
202
|
device = "GPU"
|
@@ -199,20 +209,48 @@ class OpenCLRenderer(CStyleLanguage):
|
|
199
209
|
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
200
210
|
float4 = "(float4)"
|
201
211
|
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
212
|
+
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
213
|
+
|
214
|
+
string_rewrite = PatternMatcher([
|
215
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
216
|
+
# load/store image (OpenCL)
|
217
|
+
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
218
|
+
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
219
|
+
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
220
|
+
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
221
|
+
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
222
|
+
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
223
|
+
]) + base_rewrite
|
206
224
|
|
207
225
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
208
|
-
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
226
|
+
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
209
227
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
210
228
|
|
229
|
+
class IntelRenderer(OpenCLRenderer):
|
230
|
+
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
231
|
+
tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
|
232
|
+
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
233
|
+
|
234
|
+
string_rewrite = PatternMatcher([
|
235
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
236
|
+
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
237
|
+
]) + OpenCLRenderer.string_rewrite
|
238
|
+
|
239
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
240
|
+
prefix = []
|
241
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
242
|
+
dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
|
243
|
+
prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
|
244
|
+
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
|
245
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
|
246
|
+
|
211
247
|
class MetalRenderer(CStyleLanguage):
|
212
248
|
device = "METAL"
|
213
249
|
shared_max = 32768
|
214
|
-
tensor_cores = [TensorCore(dims=(8,8,8),
|
215
|
-
|
250
|
+
tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
|
251
|
+
st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
|
252
|
+
dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
|
253
|
+
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
216
254
|
|
217
255
|
# language options
|
218
256
|
kernel_prefix = "kernel "
|
@@ -221,48 +259,49 @@ class MetalRenderer(CStyleLanguage):
|
|
221
259
|
arg_int_prefix = "constant int&"
|
222
260
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
223
261
|
float4 = "float4"
|
224
|
-
|
225
|
-
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
262
|
+
code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
|
226
263
|
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
227
264
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
228
265
|
type_map = {dtypes.bfloat16: "bfloat"}
|
229
|
-
code_for_op = {**CStyleLanguage().code_for_op,
|
230
|
-
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
|
231
|
-
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
232
|
-
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
233
|
-
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
234
|
-
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
|
235
266
|
|
236
|
-
|
237
|
-
|
267
|
+
# precise::sin
|
268
|
+
code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"}
|
269
|
+
|
270
|
+
# upcast to float32 all the ops that don't support bfloat16
|
271
|
+
extra_matcher = PatternMatcher([
|
272
|
+
# NOTE: this is copied from PTX
|
273
|
+
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
274
|
+
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
275
|
+
]) + extra_pm
|
276
|
+
|
277
|
+
string_rewrite = PatternMatcher([
|
278
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
279
|
+
]) + base_rewrite
|
238
280
|
|
239
281
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
240
|
-
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is
|
241
|
-
for arg in wmma_args: prefix.append(
|
242
|
-
|
243
|
-
|
244
|
-
|
282
|
+
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
|
283
|
+
for arg in wmma_args: prefix.append(
|
284
|
+
f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
|
285
|
+
simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
|
286
|
+
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
|
287
|
+
mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
|
288
|
+
simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
|
245
289
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
246
290
|
|
247
|
-
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
|
248
|
-
BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
249
|
-
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
250
|
-
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
251
|
-
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
252
|
-
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
253
|
-
|
254
291
|
_nms = "xyzwabcdefghijkl"
|
255
|
-
def _make_cuda_dtype(base_type, name, cnt):
|
256
|
-
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
257
|
-
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
258
292
|
|
259
293
|
class CUDARenderer(CStyleLanguage):
|
260
294
|
device = "CUDA"
|
261
295
|
global_max = (2147483647, 65535, 65535)
|
262
296
|
local_max = (1024, 1024, 64)
|
263
297
|
shared_max = 49152
|
264
|
-
|
265
|
-
|
298
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
299
|
+
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do, expanded_shape=(2,2,2,2,2,2),
|
300
|
+
st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
|
301
|
+
st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
|
302
|
+
upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
|
303
|
+
def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
|
304
|
+
def __reduce__(self): return self.__class__, (self.arch,)
|
266
305
|
|
267
306
|
# language options
|
268
307
|
kernel_prefix = "extern \"C\" __global__ "
|
@@ -270,109 +309,111 @@ class CUDARenderer(CStyleLanguage):
|
|
270
309
|
smem_prefix_for_cast = False
|
271
310
|
barrier = "__syncthreads();"
|
272
311
|
float4 = "make_float4"
|
273
|
-
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
|
274
|
-
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
|
275
|
-
code_for_op = {**CStyleLanguage
|
312
|
+
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
313
|
+
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
|
314
|
+
code_for_op = { **CStyleLanguage.code_for_op,
|
315
|
+
Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
316
|
+
Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
317
|
+
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
318
|
+
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
319
|
+
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
276
320
|
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
277
321
|
|
322
|
+
def render_vector_prefix(self, dt:DType) -> str:
|
323
|
+
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
|
324
|
+
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
|
325
|
+
return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
326
|
+
|
278
327
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
279
328
|
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
280
|
-
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
|
281
|
-
|
282
329
|
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
283
|
-
if any(uop.dtype == dtypes.half for uop in uops):
|
284
|
-
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
|
285
330
|
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
331
|
+
used_dtypes = uops_to_dtypes(uops)
|
332
|
+
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
|
333
|
+
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
334
|
+
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
335
|
+
|
336
|
+
dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
337
|
+
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
338
|
+
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
339
|
+
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
340
|
+
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
|
341
|
+
operands = [f"%{i}" for i in range(sum(n_operands))]
|
342
|
+
|
343
|
+
# mma operands => {c}, {a}, {b}, {c}
|
344
|
+
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
|
345
|
+
int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
|
346
|
+
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
|
347
|
+
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
|
348
|
+
: {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
|
349
|
+
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
|
350
|
+
return c;\n}}""")
|
296
351
|
|
297
352
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
298
353
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
# TODO: MAX with int uses fmax_f32?
|
304
|
-
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
|
305
|
-
|
306
|
-
def _make_hip_code_for_op():
|
307
|
-
def wrapper(key, func):
|
308
|
-
def cast_bf16(*args):
|
309
|
-
if args[-1] == dtypes.bfloat16:
|
310
|
-
operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
|
311
|
-
return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
|
312
|
-
return func(*args)
|
313
|
-
return cast_bf16
|
314
|
-
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
|
315
|
-
|
316
|
-
def _make_hip_dtype(base_type, name, cnt):
|
317
|
-
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
318
|
-
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
319
|
-
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
354
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
355
|
+
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
356
|
+
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
357
|
+
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
320
358
|
|
321
359
|
class AMDRenderer(CStyleLanguage):
|
322
360
|
device = "AMD"
|
323
361
|
shared_max = 65536
|
324
|
-
|
362
|
+
# https://gpuopen.com/learn/wmma_on_rdna3/
|
363
|
+
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do, reduce_axes=[(0,16)], opts_seq=("LC","UP"),
|
364
|
+
upcast_axes = ([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,2),(0,2),(1,1),(0,1)),((1,0),(0,0))), expanded_shape=(16,2,4))
|
365
|
+
for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]]
|
325
366
|
|
326
367
|
# language options
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
__attribute__((device
|
333
|
-
|
334
|
-
__attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
|
335
|
-
__attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
|
336
|
-
'}\nextern "C" __attribute__((global))'
|
368
|
+
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
369
|
+
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
370
|
+
for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
|
371
|
+
for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
|
372
|
+
|
373
|
+
kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
|
374
|
+
kernel_prefix += '\nextern "C" __attribute__((global))'
|
337
375
|
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
338
376
|
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
339
|
-
code_for_op =
|
377
|
+
code_for_op = { **CStyleLanguage.code_for_op,
|
378
|
+
Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
379
|
+
Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
380
|
+
Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
381
|
+
Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
|
340
382
|
smem_prefix = "__attribute__((shared))"
|
341
383
|
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
342
384
|
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
343
385
|
float4 = "make_float4"
|
344
|
-
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
|
345
386
|
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
387
|
+
extra_matcher = PatternMatcher([
|
388
|
+
# cast bfloat16 alus to float
|
389
|
+
(UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
390
|
+
lambda b,x,y: UOp(Ops.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
391
|
+
(UPat(GroupOp.ALU, dtype=dtypes.bfloat16, name="x"),
|
392
|
+
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
393
|
+
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
394
|
+
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
395
|
+
# add float intermediate casting for bfloat16
|
396
|
+
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
397
|
+
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
398
|
+
# bfloat16 casting
|
399
|
+
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
400
|
+
(UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
401
|
+
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
402
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
403
|
+
|
404
|
+
def render_vector_prefix(self, dtype:DType) -> str:
|
405
|
+
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
406
|
+
return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \
|
407
|
+
f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
|
346
408
|
|
347
409
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
348
|
-
prefix = ["#define INFINITY (__builtin_inff())",
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
inline __attribute__((device)) hip_bfloat16(float val) {
|
356
|
-
union { float fp32; unsigned int u32; } u = {val};
|
357
|
-
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
|
358
|
-
data = (u.u32 >> 16);
|
359
|
-
}
|
360
|
-
inline __attribute__((device)) operator float() const {
|
361
|
-
unsigned int uval = data << 16;
|
362
|
-
return *reinterpret_cast<float*>(&uval);
|
363
|
-
}
|
364
|
-
};
|
365
|
-
static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
366
|
-
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
367
|
-
""")
|
368
|
-
|
369
|
-
if any(uop.dtype == dtypes.half for uop in uops):
|
370
|
-
prefix.append("#define half _Float16")
|
371
|
-
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
|
372
|
-
|
373
|
-
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
374
|
-
|
375
|
-
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
410
|
+
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
411
|
+
|
412
|
+
used_dtypes = uops_to_dtypes(uops)
|
413
|
+
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
|
414
|
+
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
415
|
+
|
416
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
376
417
|
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
377
418
|
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
378
419
|
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
@@ -380,10 +421,42 @@ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat
|
|
380
421
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
381
422
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
382
423
|
|
383
|
-
def get_kernel_modifier(self, uops:
|
384
|
-
requiredMaxThreadsPerBlock = prod(u.arg[
|
424
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
425
|
+
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
385
426
|
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
386
427
|
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
387
428
|
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
388
429
|
|
430
|
+
class DSPRenderer(ClangRenderer):
|
431
|
+
device = "DSP"
|
432
|
+
supports_float4 = False
|
433
|
+
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
434
|
+
kernel_prefix = "__attribute__((noinline)) "
|
435
|
+
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
436
|
+
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
437
|
+
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
438
|
+
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
|
439
|
+
|
440
|
+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
441
|
+
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
442
|
+
msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
|
443
|
+
short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
|
444
|
+
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
445
|
+
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
446
|
+
'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
447
|
+
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
448
|
+
'HAP_power_set((void*)handle, (void*)&req);']
|
449
|
+
msrc += ['if ((sc>>24) != 2) return 0;']
|
450
|
+
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
451
|
+
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
452
|
+
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
453
|
+
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
454
|
+
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
455
|
+
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
456
|
+
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
457
|
+
msrc += ["return 0; }"]
|
458
|
+
return ret + '\n' + '\n'.join(msrc)
|
459
|
+
|
389
460
|
class NVRenderer(CUDARenderer): device = "NV"
|
461
|
+
class HIPRenderer(AMDRenderer): device = "HIP"
|
462
|
+
class QCOMRenderer(OpenCLRenderer): device = "QCOM"
|