tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +78 -90
- tinygrad/codegen/linearizer.py +237 -169
- tinygrad/codegen/uops.py +278 -242
- tinygrad/device.py +147 -10
- tinygrad/dtype.py +7 -7
- tinygrad/engine/graph.py +16 -16
- tinygrad/engine/jit.py +39 -36
- tinygrad/engine/realize.py +6 -5
- tinygrad/engine/schedule.py +15 -7
- tinygrad/engine/search.py +6 -3
- tinygrad/function.py +17 -23
- tinygrad/helpers.py +77 -8
- tinygrad/lazy.py +26 -26
- tinygrad/multi.py +13 -9
- tinygrad/nn/__init__.py +1 -1
- tinygrad/nn/datasets.py +2 -1
- tinygrad/nn/state.py +3 -4
- tinygrad/ops.py +49 -16
- tinygrad/renderer/__init__.py +8 -4
- tinygrad/renderer/assembly.py +93 -100
- tinygrad/renderer/cstyle.py +47 -42
- tinygrad/renderer/llvmir.py +30 -30
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +11504 -1
- tinygrad/runtime/autogen/comgr.py +36 -10
- tinygrad/runtime/autogen/hsa.py +146 -14
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/nv_gpu.py +269 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +20 -11
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +3 -2
- tinygrad/runtime/graph/cuda.py +2 -2
- tinygrad/runtime/graph/hcq.py +122 -78
- tinygrad/runtime/ops_amd.py +302 -316
- tinygrad/runtime/ops_cuda.py +3 -3
- tinygrad/runtime/ops_disk.py +70 -5
- tinygrad/runtime/ops_gpu.py +2 -2
- tinygrad/runtime/ops_metal.py +5 -6
- tinygrad/runtime/ops_npy.py +1 -1
- tinygrad/runtime/ops_nv.py +161 -166
- tinygrad/runtime/ops_python.py +20 -16
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +5 -2
- tinygrad/shape/symbolic.py +1 -3
- tinygrad/shape/view.py +34 -19
- tinygrad/tensor.py +219 -135
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/runtime/driver/hsa.py +0 -143
- tinygrad/runtime/graph/hsa.py +0 -171
- tinygrad/runtime/ops_hsa.py +0 -278
- tinygrad-0.9.0.dist-info/RECORD +0 -60
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/renderer/assembly.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
2
|
-
import struct
|
2
|
+
import struct, math
|
3
3
|
from collections import defaultdict
|
4
4
|
from tinygrad.helpers import DEBUG
|
5
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
6
5
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
7
6
|
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
8
|
-
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
|
9
8
|
from tinygrad.renderer import Renderer, TensorCore
|
10
9
|
|
11
10
|
def render_val(x, dtype):
|
@@ -18,8 +17,8 @@ def render_val(x, dtype):
|
|
18
17
|
class PTXRenderer(Renderer):
|
19
18
|
device = "CUDA"
|
20
19
|
suffix = "PTX"
|
21
|
-
global_max =
|
22
|
-
local_max =
|
20
|
+
global_max = (2147483647, 65535, 65535)
|
21
|
+
local_max = (1024, 1024, 64)
|
23
22
|
shared_max = 49152
|
24
23
|
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
25
24
|
def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
@@ -30,29 +29,28 @@ class PTXRenderer(Renderer):
|
|
30
29
|
.address_size 64
|
31
30
|
.visible .entry"""
|
32
31
|
barrier = "bar.sync\t0;"
|
33
|
-
has_pred = True
|
34
|
-
load_global = True
|
35
|
-
label_prefix = "$"
|
36
32
|
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
37
33
|
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
38
34
|
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
39
35
|
asm_for_op: Dict[Op, Callable] = {
|
40
|
-
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"
|
36
|
+
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
|
37
|
+
else f"neg.{name} {d}, {a};",
|
38
|
+
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
41
39
|
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
42
40
|
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
41
|
+
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
43
42
|
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
44
|
-
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
45
43
|
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
46
44
|
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
47
|
-
BinaryOps.
|
45
|
+
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
48
46
|
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
49
47
|
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
50
|
-
BinaryOps.
|
48
|
+
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
51
49
|
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
52
50
|
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
53
51
|
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
54
52
|
}
|
55
|
-
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.
|
53
|
+
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
56
54
|
TernaryOps.WHERE]
|
57
55
|
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
58
56
|
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
@@ -74,7 +72,7 @@ class PTXRenderer(Renderer):
|
|
74
72
|
|
75
73
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
76
74
|
|
77
|
-
def render_bra(self, b1, pred=None
|
75
|
+
def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
78
76
|
|
79
77
|
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
80
78
|
assert dtype != dtypes.bool
|
@@ -118,14 +116,6 @@ class PTXRenderer(Renderer):
|
|
118
116
|
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
119
117
|
return f"%{prefix}{c[prefix]-1}"
|
120
118
|
|
121
|
-
c_label: DefaultDict[str, int] = defaultdict(int)
|
122
|
-
r_label: Dict[UOp, str] = {}
|
123
|
-
def ssa_label(prefix:str, u:UOp):
|
124
|
-
nonlocal c_label, r_label
|
125
|
-
c_label[prefix] += 1
|
126
|
-
r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
|
127
|
-
return r_label[u]
|
128
|
-
|
129
119
|
def const(x:ConstType, dtype:DType, mov=False):
|
130
120
|
if mov or dtype in self.const_requires_mov:
|
131
121
|
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
@@ -140,42 +130,42 @@ class PTXRenderer(Renderer):
|
|
140
130
|
return ret
|
141
131
|
|
142
132
|
for u in uops:
|
143
|
-
uop,dtype,
|
133
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
144
134
|
if uop is UOps.IF:
|
145
|
-
assert
|
146
|
-
kk(*self.render_bra(
|
135
|
+
assert src[0].dtype is not None
|
136
|
+
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
|
147
137
|
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
148
138
|
elif uop is UOps.ENDRANGE:
|
149
|
-
kk(self.asm_for_op[BinaryOps.ADD](r[
|
150
|
-
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[
|
151
|
-
kk(*self.render_bra(
|
139
|
+
kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
140
|
+
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
141
|
+
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
152
142
|
elif uop is UOps.ENDIF:
|
153
|
-
kk(f"{
|
143
|
+
kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
|
154
144
|
elif uop is UOps.STORE:
|
155
|
-
assert
|
156
|
-
assert
|
157
|
-
assert
|
158
|
-
mem_type = '.shared' if
|
159
|
-
if
|
160
|
-
kk((f"@{r[
|
161
|
-
f"st{mem_type}.v{
|
145
|
+
assert src[0].dtype is not None and src[2].dtype is not None
|
146
|
+
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
147
|
+
assert src[1].op is UOps.CONST, f"store isn't const {u}"
|
148
|
+
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
149
|
+
if src[2].dtype.count > 1:
|
150
|
+
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
151
|
+
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
162
152
|
else:
|
163
|
-
kk(*self.render_store(r[
|
153
|
+
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
164
154
|
else:
|
165
155
|
assert dtype is not None, f"None dtype for uop {uop}"
|
166
|
-
if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[
|
156
|
+
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
167
157
|
elif uop is UOps.ALU:
|
168
|
-
assert
|
169
|
-
if args is BinaryOps.CMPLT or args is BinaryOps.
|
158
|
+
assert src[0].dtype is not None
|
159
|
+
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
|
170
160
|
# pass in the other dtype here
|
171
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in
|
161
|
+
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
|
172
162
|
else:
|
173
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in
|
163
|
+
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
|
174
164
|
elif uop is UOps.DEFINE_ACC:
|
175
165
|
if dtype.count > 1:
|
176
166
|
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
177
|
-
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(
|
178
|
-
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(
|
167
|
+
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
168
|
+
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
179
169
|
elif uop is UOps.SPECIAL:
|
180
170
|
assert args[1][0] != "i", "idx not supported"
|
181
171
|
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
@@ -184,30 +174,30 @@ class PTXRenderer(Renderer):
|
|
184
174
|
elif uop is UOps.CONST:
|
185
175
|
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
186
176
|
else: r[u] = const(args, dtype, mov=True)
|
187
|
-
elif uop is UOps.GEP: r[u] = r[
|
177
|
+
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
188
178
|
elif uop is UOps.LOAD:
|
189
|
-
assert
|
190
|
-
assert
|
191
|
-
mem_type = '.shared' if
|
179
|
+
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
180
|
+
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
181
|
+
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
192
182
|
if dtype.count > 1:
|
193
183
|
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
194
|
-
if(len(
|
184
|
+
if(len(src)>3):
|
195
185
|
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
196
|
-
kk((f"@{r[
|
197
|
-
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[
|
186
|
+
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
|
187
|
+
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
198
188
|
else:
|
199
|
-
kk(*self.render_load(r[
|
200
|
-
alt=r[
|
189
|
+
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
|
190
|
+
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
|
201
191
|
elif uop is UOps.PHI:
|
202
192
|
if dtype.count > 1:
|
203
|
-
for x0, x1 in zip(r[
|
193
|
+
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
204
194
|
else:
|
205
|
-
kk(f"mov.b{self.types[dtype][1:]} {r[
|
206
|
-
r[u] = r[
|
195
|
+
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
196
|
+
r[u] = r[src[0]]
|
207
197
|
elif uop in {UOps.CAST, UOps.BITCAST}:
|
208
|
-
assert
|
209
|
-
if dtype.count>1: r[u] = [r[x] for x in
|
210
|
-
else: _cast(r[
|
198
|
+
assert src[0].dtype is not None
|
199
|
+
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
|
200
|
+
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
211
201
|
elif uop is UOps.DEFINE_LOCAL:
|
212
202
|
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
213
203
|
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
@@ -215,62 +205,65 @@ class PTXRenderer(Renderer):
|
|
215
205
|
elif uop is UOps.DEFINE_VAR:
|
216
206
|
bufs.append((args.expr, dtype))
|
217
207
|
r[u] = f"%{args.expr}"
|
218
|
-
|
208
|
+
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
219
209
|
elif uop is UOps.DEFINE_GLOBAL:
|
220
210
|
bufs.append((nm:=f"data{args[0]}", dtype))
|
221
211
|
r[u] = f"%{nm}"
|
222
|
-
if
|
223
|
-
|
224
|
-
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
212
|
+
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
213
|
+
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
225
214
|
elif uop is UOps.WMMA:
|
226
215
|
wmma = []
|
227
|
-
for vv in
|
216
|
+
for vv in src[:2]:
|
228
217
|
for i in range(0, len(r[vv]), 2):
|
229
218
|
wmma.append(ssa("wmma", dtype="b32"))
|
230
219
|
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
231
220
|
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
232
221
|
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
233
|
-
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[
|
222
|
+
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
|
234
223
|
else: raise NotImplementedError(f"no code for {uop}")
|
235
224
|
|
236
225
|
return self.render_kernel(kernel, name, bufs, c.items())
|
237
226
|
|
238
227
|
ptx_matcher = PatternMatcher([
|
239
|
-
(
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
lambda x: UOp(
|
228
|
+
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
229
|
+
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
230
|
+
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
231
|
+
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
232
|
+
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
233
|
+
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
234
|
+
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
235
|
+
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
236
|
+
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
237
|
+
(UPat(UOps.ALU, BinaryOps.ADD,
|
238
|
+
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
239
|
+
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
240
|
+
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
241
|
+
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
248
242
|
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
249
|
-
(
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
243
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
244
|
+
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
245
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
246
|
+
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
247
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
248
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
249
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
250
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
251
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
252
|
+
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
260
253
|
# ptr_ar (load/store)
|
261
|
-
({
|
262
|
-
|
263
|
-
lambda root, alu, const: UOp(root.
|
264
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.
|
265
|
-
UOp.const(const.dtype, root.
|
266
|
-
({
|
267
|
-
|
268
|
-
lambda root, const: UOp(root.
|
269
|
-
UOp.const(dtypes.int64, const.arg * root.
|
270
|
-
)+root.
|
271
|
-
({
|
272
|
-
|
273
|
-
lambda root, alu: UOp(root.
|
274
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.
|
275
|
-
UOp.const(dtypes.int64, 0))+root.
|
254
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
255
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
256
|
+
lambda root, alu, const: UOp(root.op, root.dtype,
|
257
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
258
|
+
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
259
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
260
|
+
UPat(UOps.CONST, name="const"))),
|
261
|
+
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
262
|
+
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
263
|
+
)+root.src[2:])),
|
264
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
265
|
+
UPat(name="alu"))), # no const here
|
266
|
+
lambda root, alu: UOp(root.op, root.dtype,
|
267
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
268
|
+
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
276
269
|
])
|
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
|
2
2
|
import os, math
|
3
3
|
from collections import defaultdict, Counter
|
4
|
-
from tinygrad.codegen.linearizer import UOps, UOp
|
5
4
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
6
|
-
from tinygrad.helpers import strip_parens, getenv, prod
|
5
|
+
from tinygrad.helpers import strip_parens, getenv, prod, dedup
|
7
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
8
|
-
from tinygrad.codegen.uops import UOpGraph
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
9
8
|
from tinygrad.renderer import Renderer, TensorCore
|
10
9
|
|
11
10
|
class CStyleLanguage(Renderer):
|
@@ -25,10 +24,11 @@ class CStyleLanguage(Renderer):
|
|
25
24
|
type_map: Dict[DType, str] = {}
|
26
25
|
code_for_op: Dict = {
|
27
26
|
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
27
|
+
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
28
28
|
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
29
|
-
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.
|
30
|
-
BinaryOps.
|
31
|
-
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.
|
29
|
+
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
30
|
+
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
31
|
+
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
32
32
|
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
33
33
|
|
34
34
|
# returns a str expression of the casted xs with the given type
|
@@ -103,31 +103,32 @@ class CStyleLanguage(Renderer):
|
|
103
103
|
c[prefix] += 1
|
104
104
|
return ret
|
105
105
|
|
106
|
-
child_count = Counter(v for ru in uops for v in ru.
|
106
|
+
child_count = Counter(v for ru in uops for v in ru.src)
|
107
107
|
|
108
|
+
seen_vars = set()
|
108
109
|
for u in uops:
|
109
|
-
uop,dtype,
|
110
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
110
111
|
# these four uops don't have output dtypes
|
111
112
|
if uop is UOps.IF:
|
112
|
-
kk(f"if ({r[
|
113
|
+
kk(f"if ({r[src[0]]}) {{")
|
113
114
|
depth += 1
|
114
115
|
elif uop is UOps.BARRIER: kk(self.barrier)
|
115
116
|
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
|
116
117
|
depth -= 1
|
117
118
|
kk("}")
|
118
119
|
elif uop is UOps.STORE:
|
119
|
-
assert
|
120
|
-
rendered_store = self.render_store(r[
|
121
|
-
kk(f"if ({r[
|
120
|
+
assert src[0].dtype is not None and src[2].dtype is not None
|
121
|
+
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
122
|
+
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
122
123
|
else:
|
123
124
|
assert dtype is not None, f"None dtype for uop {uop}"
|
124
125
|
if uop is UOps.RANGE:
|
125
|
-
kk(f"for (int {(expr := ssa('ridx',u))} = {r[
|
126
|
+
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
|
126
127
|
depth += 1
|
127
128
|
elif uop is UOps.ALU:
|
128
129
|
# remove parens if ALU types are the same. TODO: can do more here
|
129
|
-
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in
|
130
|
-
else: operands = [r[v] for v in
|
130
|
+
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
131
|
+
else: operands = [r[v] for v in src]
|
131
132
|
val = self.code_for_op[args](*operands, dtype)
|
132
133
|
assert child_count[u] != 0, f"childless ALU op found {u}"
|
133
134
|
# TODO: fix index rendering issue. fix clang nested max macro issue
|
@@ -137,39 +138,41 @@ class CStyleLanguage(Renderer):
|
|
137
138
|
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
138
139
|
r[u] = args[1]
|
139
140
|
elif uop is UOps.LOAD:
|
140
|
-
val = self.render_load(dtype, r[
|
141
|
+
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
141
142
|
# NOTE: this relies on the load not happening if it's in the unselected branch
|
142
|
-
if len(
|
143
|
+
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
|
143
144
|
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
144
145
|
elif uop is UOps.PHI:
|
145
|
-
kk(f"{r[
|
146
|
-
r[u] = r[
|
146
|
+
kk(f"{r[src[0]]} = {r[src[1]]};")
|
147
|
+
r[u] = r[src[0]]
|
147
148
|
elif uop in {UOps.CAST, UOps.BITCAST}:
|
148
149
|
if uop is UOps.BITCAST:
|
149
|
-
assert len(
|
150
|
+
assert len(src) == 1
|
150
151
|
precast = ssa('precast')
|
151
|
-
kk(f"{self.render_dtype(cast(DType,
|
152
|
+
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
152
153
|
val = self.render_cast([precast], dtype, bitcast=True)
|
153
154
|
else:
|
154
|
-
val = self.render_cast([r[x] for x in
|
155
|
+
val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
|
155
156
|
if child_count[u] <= 1: r[u] = val
|
156
157
|
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
157
158
|
elif uop is UOps.DEFINE_LOCAL:
|
158
159
|
kk(self.render_local(args[0], dtype, args[1]))
|
159
160
|
r[u] = args[0]
|
160
161
|
elif uop is UOps.DEFINE_VAR:
|
162
|
+
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
163
|
+
seen_vars.add(args.expr)
|
161
164
|
bufs.append((args.expr, (dtype,False)))
|
162
165
|
r[u] = args.expr
|
163
166
|
elif uop is UOps.DEFINE_GLOBAL:
|
164
167
|
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
165
168
|
r[u] = nm
|
166
|
-
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[
|
167
|
-
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(
|
169
|
+
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
170
|
+
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
|
168
171
|
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
169
172
|
elif uop is UOps.GEP:
|
170
|
-
assert
|
171
|
-
from_ssa =
|
172
|
-
r[u] = (r[
|
173
|
+
assert src[0].dtype is not None
|
174
|
+
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
175
|
+
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
173
176
|
else: raise RuntimeError(f"failed to render {uop}")
|
174
177
|
|
175
178
|
return self.render_kernel(name, kernel, bufs, uops)
|
@@ -178,6 +181,7 @@ class ClangRenderer(CStyleLanguage):
|
|
178
181
|
device = "CLANG"
|
179
182
|
supports_float4 = False
|
180
183
|
has_local = False
|
184
|
+
global_max = None
|
181
185
|
|
182
186
|
# language options
|
183
187
|
buffer_suffix = " restrict"
|
@@ -219,6 +223,7 @@ class MetalRenderer(CStyleLanguage):
|
|
219
223
|
float4 = "float4"
|
220
224
|
uses_ptr_arithmetic = True
|
221
225
|
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
226
|
+
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
222
227
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
223
228
|
type_map = {dtypes.bfloat16: "bfloat"}
|
224
229
|
code_for_op = {**CStyleLanguage().code_for_op,
|
@@ -232,14 +237,15 @@ class MetalRenderer(CStyleLanguage):
|
|
232
237
|
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
233
238
|
|
234
239
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
235
|
-
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.
|
240
|
+
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
236
241
|
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
237
242
|
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
238
243
|
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
239
244
|
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
240
245
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
241
246
|
|
242
|
-
code_for_op_half = {
|
247
|
+
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
|
248
|
+
BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
243
249
|
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
244
250
|
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
245
251
|
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
@@ -252,8 +258,8 @@ def _make_cuda_dtype(base_type, name, cnt):
|
|
252
258
|
|
253
259
|
class CUDARenderer(CStyleLanguage):
|
254
260
|
device = "CUDA"
|
255
|
-
global_max =
|
256
|
-
local_max =
|
261
|
+
global_max = (2147483647, 65535, 65535)
|
262
|
+
local_max = (1024, 1024, 64)
|
257
263
|
shared_max = 49152
|
258
264
|
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
259
265
|
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
@@ -281,7 +287,7 @@ class CUDARenderer(CStyleLanguage):
|
|
281
287
|
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
282
288
|
|
283
289
|
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
284
|
-
for arg in
|
290
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
285
291
|
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
286
292
|
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
287
293
|
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
@@ -312,8 +318,8 @@ def _make_hip_dtype(base_type, name, cnt):
|
|
312
318
|
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
313
319
|
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
314
320
|
|
315
|
-
class
|
316
|
-
device = "
|
321
|
+
class AMDRenderer(CStyleLanguage):
|
322
|
+
device = "AMD"
|
317
323
|
shared_max = 65536
|
318
324
|
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
319
325
|
|
@@ -346,18 +352,18 @@ f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt},
|
|
346
352
|
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
347
353
|
struct hip_bfloat16 {
|
348
354
|
unsigned short data;
|
349
|
-
__attribute__((device)) hip_bfloat16(float val) {
|
355
|
+
inline __attribute__((device)) hip_bfloat16(float val) {
|
350
356
|
union { float fp32; unsigned int u32; } u = {val};
|
351
357
|
if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
|
352
358
|
data = (u.u32 >> 16);
|
353
359
|
}
|
354
|
-
__attribute__((device)) operator float() const {
|
360
|
+
inline __attribute__((device)) operator float() const {
|
355
361
|
unsigned int uval = data << 16;
|
356
362
|
return *reinterpret_cast<float*>(&uval);
|
357
363
|
}
|
358
364
|
};
|
359
|
-
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
360
|
-
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
365
|
+
static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
366
|
+
static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
361
367
|
""")
|
362
368
|
|
363
369
|
if any(uop.dtype == dtypes.half for uop in uops):
|
@@ -366,19 +372,18 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
|
366
372
|
|
367
373
|
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
368
374
|
|
369
|
-
for arg in
|
375
|
+
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
370
376
|
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
371
|
-
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
377
|
+
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
372
378
|
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
373
379
|
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
374
380
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
375
381
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
376
382
|
|
377
383
|
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
378
|
-
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.
|
384
|
+
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
|
379
385
|
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
380
386
|
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
381
387
|
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
382
388
|
|
383
389
|
class NVRenderer(CUDARenderer): device = "NV"
|
384
|
-
class AMDRenderer(HIPRenderer): device = "AMD"
|