tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,14 +1,13 @@
|
|
1
|
-
from typing import Dict, List, Optional,
|
2
|
-
import
|
1
|
+
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
|
2
|
+
import os, math
|
3
3
|
from collections import defaultdict, Counter
|
4
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
5
4
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
6
|
-
from tinygrad.helpers import
|
7
|
-
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
5
|
+
from tinygrad.helpers import strip_parens, getenv, prod, dedup
|
6
|
+
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
8
|
+
from tinygrad.renderer import Renderer, TensorCore
|
8
9
|
|
9
|
-
class CStyleLanguage(
|
10
|
-
size_prefix: str = "int"
|
11
|
-
generic_var_prefix: str = ""
|
10
|
+
class CStyleLanguage(Renderer):
|
12
11
|
kernel_prefix: str = ""
|
13
12
|
buffer_prefix: str = ""
|
14
13
|
buffer_suffix: str = ""
|
@@ -18,39 +17,36 @@ class CStyleLanguage(NamedTuple):
|
|
18
17
|
arg_int_prefix: str = "const int"
|
19
18
|
barrier: str = ""
|
20
19
|
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
21
|
-
global_max: List[int] = []
|
22
|
-
local_max: List[int] = []
|
23
20
|
extra_args: List[str] = []
|
24
21
|
float4: Optional[str] = None
|
25
|
-
half_prekernel: Optional[str] = None
|
26
22
|
uses_vload: bool = False
|
27
|
-
external_local_bufs: bool = False
|
28
23
|
uses_ptr_arithmetic: bool = False
|
29
|
-
launch_bounds: bool = False
|
30
24
|
type_map: Dict[DType, str] = {}
|
31
25
|
code_for_op: Dict = {
|
32
|
-
UnaryOps.NEG: lambda x,dtype: f"(
|
26
|
+
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
27
|
+
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
33
28
|
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
34
|
-
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.
|
35
|
-
BinaryOps.
|
36
|
-
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.
|
37
|
-
TernaryOps.
|
38
|
-
}
|
29
|
+
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
30
|
+
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
31
|
+
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
32
|
+
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
39
33
|
|
40
34
|
# returns a str expression of the casted xs with the given type
|
41
35
|
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
42
36
|
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
|
43
37
|
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
|
44
|
-
assert len(x) == var_dtype.
|
38
|
+
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
|
45
39
|
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
46
40
|
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
47
41
|
|
48
42
|
# returns a str expression of the const with the given type
|
49
|
-
def render_const(self, x:
|
43
|
+
def render_const(self, x:ConstType, dtype:DType) -> str:
|
50
44
|
if math.isnan(x): val = "NAN"
|
51
45
|
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
52
|
-
|
53
|
-
|
46
|
+
elif dtype == dtypes.bool: val = "1" if x else "0"
|
47
|
+
elif dtype == dtypes.float: val = f"{x}f"
|
48
|
+
else: val = str(x)
|
49
|
+
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
54
50
|
|
55
51
|
# returns a str expression of the loaded value with the output type
|
56
52
|
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
@@ -58,32 +54,23 @@ class CStyleLanguage(NamedTuple):
|
|
58
54
|
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
59
55
|
return f"read_imagef({buf_name}, smp, {idx})"
|
60
56
|
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
61
|
-
return f"vload_half{'' if output_dtype.
|
62
|
-
if output_dtype.
|
63
|
-
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype
|
57
|
+
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
58
|
+
if output_dtype.count > 1:
|
59
|
+
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
|
64
60
|
else:
|
65
61
|
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
66
|
-
|
67
62
|
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
68
63
|
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
|
78
|
-
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
|
79
|
-
buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
|
80
|
-
("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
81
|
-
self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
|
82
|
-
prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
|
64
|
+
def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
|
65
|
+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
|
66
|
+
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
67
|
+
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
68
|
+
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
69
|
+
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
70
|
+
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
83
71
|
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
84
72
|
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
85
|
-
|
86
|
-
return prg
|
73
|
+
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
87
74
|
|
88
75
|
# returns a str statement that does the store
|
89
76
|
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
@@ -91,201 +78,312 @@ class CStyleLanguage(NamedTuple):
|
|
91
78
|
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
|
92
79
|
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
93
80
|
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
94
|
-
return f"vstore_half{'' if var_dtype.
|
95
|
-
if var_dtype.
|
96
|
-
|
81
|
+
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
82
|
+
if var_dtype.count > 1:
|
83
|
+
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
|
84
|
+
return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
|
97
85
|
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
98
86
|
|
99
|
-
def
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
kk(lang.barrier)
|
127
|
-
elif uop == UOps.END:
|
128
|
-
depth -= 1
|
129
|
-
kk("}")
|
130
|
-
elif uop == UOps.STORE:
|
131
|
-
assert vin[0].dtype is not None and vin[2].dtype is not None
|
132
|
-
if len(vin) > 3: kk(lang.render_if(r[vin[3]]))
|
133
|
-
kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
|
134
|
-
if len(vin) > 3: kk("}")
|
135
|
-
else:
|
136
|
-
assert dtype is not None, f"None dtype for uop {uop}"
|
137
|
-
if uop == UOps.LOOP:
|
138
|
-
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
|
87
|
+
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
88
|
+
def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
|
89
|
+
|
90
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
91
|
+
kernel = []
|
92
|
+
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
93
|
+
depth = 1
|
94
|
+
def kk(s): kernel.append(" "*depth+s)
|
95
|
+
|
96
|
+
c: DefaultDict[str, int] = defaultdict(int)
|
97
|
+
r: Dict[UOp, str] = {}
|
98
|
+
|
99
|
+
def ssa(prefix:str, u:Optional[UOp]=None):
|
100
|
+
nonlocal c, r
|
101
|
+
ret = f"{prefix}{c[prefix]}"
|
102
|
+
if u is not None: r[u] = ret
|
103
|
+
c[prefix] += 1
|
104
|
+
return ret
|
105
|
+
|
106
|
+
child_count = Counter(v for ru in uops for v in ru.src)
|
107
|
+
|
108
|
+
seen_vars = set()
|
109
|
+
for u in uops:
|
110
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
111
|
+
# these four uops don't have output dtypes
|
112
|
+
if uop is UOps.IF:
|
113
|
+
kk(f"if ({r[src[0]]}) {{")
|
139
114
|
depth += 1
|
140
|
-
elif uop
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
|
149
|
-
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
150
|
-
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
151
|
-
kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
|
152
|
-
elif args[0] == "HIP":
|
153
|
-
assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
|
154
|
-
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501
|
155
|
-
else:
|
156
|
-
raise NotImplementedError(f"WMMA not implemented for {args}")
|
157
|
-
elif uop == UOps.ALU:
|
158
|
-
# remove parens if ALU types are the same. TODO: can do more here
|
159
|
-
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}:
|
160
|
-
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
|
161
|
-
else:
|
162
|
-
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
163
|
-
assert child_count[u] != 0, f"childless ALU op found {u}"
|
164
|
-
# TODO: fix index rendering issue. fix clang nested max macro issue
|
165
|
-
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
|
166
|
-
r[u] = val
|
167
|
-
else:
|
168
|
-
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
|
169
|
-
elif uop == UOps.DEFINE_ACC:
|
170
|
-
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
171
|
-
elif uop == UOps.SPECIAL:
|
172
|
-
kk(f"{lang.size_prefix} {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
173
|
-
if args[1].startswith("l"): local_size.append(args[2])
|
174
|
-
r[u] = args[1]
|
175
|
-
elif uop == UOps.CONST:
|
176
|
-
r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
177
|
-
elif uop == UOps.LOAD:
|
178
|
-
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
|
179
|
-
# NOTE: this relies on the load not happening if it's in the unselected branch
|
180
|
-
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
181
|
-
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else lang.render_dtype(dtype)} {ssa(u,'val')} = {val};")
|
182
|
-
elif uop == UOps.PHI:
|
183
|
-
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
184
|
-
r[u] = r[vin[0]]
|
185
|
-
elif uop == UOps.CAST:
|
186
|
-
if isinstance(args, tuple) and args[1]: # bitcast
|
187
|
-
assert len(vin) == 1
|
188
|
-
precast = ssa(None,'precast')
|
189
|
-
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
190
|
-
val = lang.render_cast([precast], dtype, bitcast=True)
|
191
|
-
else:
|
192
|
-
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
193
|
-
if child_count[u] <= 1: r[u] = val
|
194
|
-
else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
|
195
|
-
elif uop == UOps.DEFINE_LOCAL:
|
196
|
-
if lang.external_local_bufs:
|
197
|
-
prekernel.append(lang.render_local(args[0], dtype, args[1]))
|
198
|
-
else:
|
199
|
-
kk(lang.render_local(args[0], dtype, args[1]))
|
200
|
-
r[u] = args[0]
|
201
|
-
elif uop == UOps.DEFINE_GLOBAL:
|
202
|
-
bufs.append((args, dtype))
|
203
|
-
r[u] = args
|
204
|
-
elif uop == UOps.GEP:
|
205
|
-
if cast(DType, vin[0].dtype).sz > 4:
|
206
|
-
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
|
207
|
-
else:
|
208
|
-
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
|
115
|
+
elif uop is UOps.BARRIER: kk(self.barrier)
|
116
|
+
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
|
117
|
+
depth -= 1
|
118
|
+
kk("}")
|
119
|
+
elif uop is UOps.STORE:
|
120
|
+
assert src[0].dtype is not None and src[2].dtype is not None
|
121
|
+
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
122
|
+
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
209
123
|
else:
|
210
|
-
|
124
|
+
assert dtype is not None, f"None dtype for uop {uop}"
|
125
|
+
if uop is UOps.RANGE:
|
126
|
+
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
|
127
|
+
depth += 1
|
128
|
+
elif uop is UOps.ALU:
|
129
|
+
# remove parens if ALU types are the same. TODO: can do more here
|
130
|
+
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
131
|
+
else: operands = [r[v] for v in src]
|
132
|
+
val = self.code_for_op[args](*operands, dtype)
|
133
|
+
assert child_count[u] != 0, f"childless ALU op found {u}"
|
134
|
+
# TODO: fix index rendering issue. fix clang nested max macro issue
|
135
|
+
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
136
|
+
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
137
|
+
elif uop is UOps.SPECIAL:
|
138
|
+
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
139
|
+
r[u] = args[1]
|
140
|
+
elif uop is UOps.LOAD:
|
141
|
+
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
142
|
+
# NOTE: this relies on the load not happening if it's in the unselected branch
|
143
|
+
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
144
|
+
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
145
|
+
elif uop is UOps.PHI:
|
146
|
+
kk(f"{r[src[0]]} = {r[src[1]]};")
|
147
|
+
r[u] = r[src[0]]
|
148
|
+
elif uop in {UOps.CAST, UOps.BITCAST}:
|
149
|
+
if uop is UOps.BITCAST:
|
150
|
+
assert len(src) == 1
|
151
|
+
precast = ssa('precast')
|
152
|
+
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
153
|
+
val = self.render_cast([precast], dtype, bitcast=True)
|
154
|
+
else:
|
155
|
+
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
|
156
|
+
if child_count[u] <= 1: r[u] = val
|
157
|
+
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
158
|
+
elif uop is UOps.DEFINE_LOCAL:
|
159
|
+
kk(self.render_local(args[0], dtype, args[1]))
|
160
|
+
r[u] = args[0]
|
161
|
+
elif uop is UOps.DEFINE_VAR:
|
162
|
+
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
163
|
+
seen_vars.add(args.expr)
|
164
|
+
bufs.append((args.expr, (dtype,False)))
|
165
|
+
r[u] = args.expr
|
166
|
+
elif uop is UOps.DEFINE_GLOBAL:
|
167
|
+
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
168
|
+
r[u] = nm
|
169
|
+
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
170
|
+
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
|
171
|
+
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
172
|
+
elif uop is UOps.GEP:
|
173
|
+
assert src[0].dtype is not None
|
174
|
+
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
175
|
+
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
176
|
+
else: raise RuntimeError(f"failed to render {uop}")
|
177
|
+
|
178
|
+
return self.render_kernel(name, kernel, bufs, uops)
|
211
179
|
|
212
|
-
|
180
|
+
class ClangRenderer(CStyleLanguage):
|
181
|
+
device = "CLANG"
|
182
|
+
supports_float4 = False
|
183
|
+
has_local = False
|
184
|
+
global_max = None
|
213
185
|
|
214
|
-
|
186
|
+
# language options
|
187
|
+
buffer_suffix = " restrict"
|
188
|
+
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
189
|
+
code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
190
|
+
|
191
|
+
class OpenCLRenderer(CStyleLanguage):
|
192
|
+
device = "GPU"
|
193
|
+
|
194
|
+
# language options
|
215
195
|
kernel_prefix = "__kernel "
|
216
196
|
buffer_prefix = "__global "
|
217
197
|
smem_align = "__attribute__ ((aligned (16))) "
|
218
198
|
smem_prefix = "__local "
|
219
|
-
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
|
220
199
|
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
221
200
|
float4 = "(float4)"
|
222
|
-
code_for_workitem ={
|
201
|
+
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
223
202
|
uses_vload = True
|
224
|
-
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
225
|
-
code_for_op = {**CStyleLanguage().code_for_op,
|
226
|
-
TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"}
|
227
203
|
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
228
204
|
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
229
|
-
return f"as_{self.
|
230
|
-
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
205
|
+
return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
231
206
|
|
232
|
-
|
233
|
-
|
207
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
208
|
+
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
209
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
210
|
+
|
211
|
+
class MetalRenderer(CStyleLanguage):
|
212
|
+
device = "METAL"
|
213
|
+
shared_max = 32768
|
214
|
+
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
215
|
+
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
|
216
|
+
|
217
|
+
# language options
|
218
|
+
kernel_prefix = "kernel "
|
234
219
|
buffer_prefix = "device "
|
235
220
|
smem_prefix = "threadgroup "
|
236
221
|
arg_int_prefix = "constant int&"
|
237
222
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
238
223
|
float4 = "float4"
|
239
|
-
uses_ptr_arithmetic=True
|
224
|
+
uses_ptr_arithmetic = True
|
240
225
|
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
226
|
+
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
241
227
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
228
|
+
type_map = {dtypes.bfloat16: "bfloat"}
|
229
|
+
code_for_op = {**CStyleLanguage().code_for_op,
|
230
|
+
BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
|
231
|
+
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
232
|
+
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
233
|
+
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
234
|
+
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
|
235
|
+
|
242
236
|
def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
|
243
|
-
return f"as_type<{var_dtype
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
237
|
+
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
238
|
+
|
239
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
240
|
+
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
241
|
+
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
242
|
+
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
243
|
+
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
244
|
+
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
245
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
246
|
+
|
247
|
+
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
|
248
|
+
BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
249
|
+
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
250
|
+
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
251
|
+
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
252
|
+
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
253
|
+
|
254
|
+
_nms = "xyzwabcdefghijkl"
|
255
|
+
def _make_cuda_dtype(base_type, name, cnt):
|
256
|
+
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
257
|
+
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
258
|
+
|
259
|
+
class CUDARenderer(CStyleLanguage):
|
260
|
+
device = "CUDA"
|
261
|
+
global_max = (2147483647, 65535, 65535)
|
262
|
+
local_max = (1024, 1024, 64)
|
263
|
+
shared_max = 49152
|
264
|
+
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
265
|
+
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
266
|
+
|
267
|
+
# language options
|
268
|
+
kernel_prefix = "extern \"C\" __global__ "
|
256
269
|
smem_prefix = "__shared__ "
|
257
270
|
smem_prefix_for_cast = False
|
258
271
|
barrier = "__syncthreads();"
|
259
272
|
float4 = "make_float4"
|
260
|
-
code_for_workitem = {
|
261
|
-
|
262
|
-
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"
|
263
|
-
}
|
273
|
+
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
|
274
|
+
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
|
264
275
|
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
265
|
-
half_prekernel ="#include <cuda_fp16.h>\n"+"#include <cuda_bf16.h>\n"+"""
|
266
|
-
struct half4 { half x, y, z, w; };
|
267
|
-
__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }
|
268
|
-
"""
|
269
276
|
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
277
|
+
|
278
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
279
|
+
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
280
|
+
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
|
281
|
+
|
282
|
+
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
283
|
+
if any(uop.dtype == dtypes.half for uop in uops):
|
284
|
+
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
|
285
|
+
|
286
|
+
if any(uop.dtype == dtypes.bfloat16 for uop in uops):
|
287
|
+
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
288
|
+
|
289
|
+
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
290
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
291
|
+
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
292
|
+
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
293
|
+
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
294
|
+
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
295
|
+
return c;}}""")
|
296
|
+
|
297
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
298
|
+
|
299
|
+
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
300
|
+
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
301
|
+
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
302
|
+
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
303
|
+
# TODO: MAX with int uses fmax_f32?
|
304
|
+
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
|
305
|
+
|
306
|
+
def _make_hip_code_for_op():
|
307
|
+
def wrapper(key, func):
|
308
|
+
def cast_bf16(*args):
|
309
|
+
if args[-1] == dtypes.bfloat16:
|
310
|
+
operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
|
311
|
+
return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
|
312
|
+
return func(*args)
|
313
|
+
return cast_bf16
|
314
|
+
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
|
315
|
+
|
316
|
+
def _make_hip_dtype(base_type, name, cnt):
|
317
|
+
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
318
|
+
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
319
|
+
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
320
|
+
|
321
|
+
class AMDRenderer(CStyleLanguage):
|
322
|
+
device = "AMD"
|
323
|
+
shared_max = 65536
|
324
|
+
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
325
|
+
|
326
|
+
# language options
|
327
|
+
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
328
|
+
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
329
|
+
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
330
|
+
extern "C" {\n""" + "".join([
|
331
|
+
f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
|
332
|
+
__attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
|
333
|
+
__attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
|
334
|
+
__attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
|
335
|
+
__attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
|
336
|
+
'}\nextern "C" __attribute__((global))'
|
337
|
+
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
338
|
+
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
339
|
+
code_for_op = _make_hip_code_for_op()
|
340
|
+
smem_prefix = "__attribute__((shared))"
|
341
|
+
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
342
|
+
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
343
|
+
float4 = "make_float4"
|
344
|
+
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
|
290
345
|
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
291
|
-
|
346
|
+
|
347
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
348
|
+
prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
|
349
|
+
vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
|
350
|
+
|
351
|
+
# TODO: add BF16 vec dts
|
352
|
+
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
353
|
+
struct hip_bfloat16 {
|
354
|
+
unsigned short data;
|
355
|
+
inline __attribute__((device)) hip_bfloat16(float val) {
|
356
|
+
union { float fp32; unsigned int u32; } u = {val};
|
357
|
+
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
|
358
|
+
data = (u.u32 >> 16);
|
359
|
+
}
|
360
|
+
inline __attribute__((device)) operator float() const {
|
361
|
+
unsigned int uval = data << 16;
|
362
|
+
return *reinterpret_cast<float*>(&uval);
|
363
|
+
}
|
364
|
+
};
|
365
|
+
static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
366
|
+
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
367
|
+
""")
|
368
|
+
|
369
|
+
if any(uop.dtype == dtypes.half for uop in uops):
|
370
|
+
prefix.append("#define half _Float16")
|
371
|
+
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
|
372
|
+
|
373
|
+
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
374
|
+
|
375
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
376
|
+
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
377
|
+
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
378
|
+
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
379
|
+
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
380
|
+
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
381
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
382
|
+
|
383
|
+
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
384
|
+
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
|
385
|
+
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
386
|
+
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
387
|
+
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
388
|
+
|
389
|
+
class NVRenderer(CUDARenderer): device = "NV"
|