tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,69 +1,77 @@
|
|
1
|
-
from typing import
|
2
|
-
|
3
|
-
from tinygrad.dtype import DType, PtrDType, dtypes
|
4
|
-
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
5
|
-
from tinygrad.codegen.uops import UOps, UOp, UOpGraph
|
1
|
+
from typing import List, Dict, cast
|
2
|
+
import math, struct
|
6
3
|
from tinygrad.renderer import Renderer
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
def
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
24
|
-
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
25
|
-
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
26
|
-
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
27
|
-
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
28
|
-
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
29
|
-
|
30
|
-
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
31
|
-
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
32
|
-
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
|
33
|
-
|
34
|
-
def cast(bb, val, input_type, output_type, bitcast=False):
|
35
|
-
if input_type == output_type: return val
|
36
|
-
llvm_type = dtype_to_llvm_dtype[output_type]
|
37
|
-
if bitcast: return bb[-1].bitcast(val, llvm_type)
|
38
|
-
|
39
|
-
if input_type == dtypes.bfloat16:
|
40
|
-
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
|
41
|
-
input_type = dtypes.float32
|
42
|
-
if output_type == dtypes.bfloat16:
|
43
|
-
val = cast(bb, val, input_type, dtypes.float32)
|
44
|
-
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
|
45
|
-
|
4
|
+
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
5
|
+
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
6
|
+
|
7
|
+
def ldt(dt:DType):
|
8
|
+
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
9
|
+
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
10
|
+
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
11
|
+
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
|
12
|
+
|
13
|
+
def lconst(x, dtype:DType):
|
14
|
+
if dtype in dtypes.floats:
|
15
|
+
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
16
|
+
return truncate[dtype](x)
|
17
|
+
return int(x)
|
18
|
+
|
19
|
+
def lcast(input_type:DType, output_type:DType):
|
46
20
|
if dtypes.is_float(input_type):
|
47
|
-
if dtypes.is_float(output_type):
|
48
|
-
|
49
|
-
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
|
50
|
-
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
|
51
|
-
|
21
|
+
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
|
22
|
+
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
|
52
23
|
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
|
53
|
-
if
|
54
|
-
if dtypes.
|
55
|
-
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
|
56
|
-
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
|
57
|
-
|
24
|
+
if dtypes.is_float(output_type): return 'uitofp'
|
25
|
+
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
|
58
26
|
if dtypes.is_int(input_type):
|
59
|
-
if
|
60
|
-
if dtypes.
|
61
|
-
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
|
62
|
-
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
|
63
|
-
|
27
|
+
if dtypes.is_float(output_type): return 'sitofp'
|
28
|
+
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
64
29
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
65
30
|
|
66
|
-
|
31
|
+
# llvm ops, lop[<dtype>][<op>]
|
32
|
+
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
33
|
+
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
34
|
+
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
|
35
|
+
flags = " nsz arcp contract afn"
|
36
|
+
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
|
37
|
+
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
38
|
+
|
39
|
+
llvm_rewrite = PatternMatcher([
|
40
|
+
# memory load/store
|
41
|
+
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
42
|
+
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
43
|
+
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
|
44
|
+
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
45
|
+
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
46
|
+
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
47
|
+
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
48
|
+
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
49
|
+
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
50
|
+
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
51
|
+
|
52
|
+
# unary/binary/ternary ops
|
53
|
+
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
54
|
+
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
55
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
56
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
57
|
+
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
58
|
+
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
59
|
+
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
|
60
|
+
|
61
|
+
# range
|
62
|
+
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
63
|
+
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
|
64
|
+
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
|
65
|
+
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
|
66
|
+
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
67
|
+
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
|
68
|
+
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
69
|
+
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
|
70
|
+
|
71
|
+
# if
|
72
|
+
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
73
|
+
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
74
|
+
])
|
67
75
|
|
68
76
|
class LLVMRenderer(Renderer):
|
69
77
|
device = "LLVM"
|
@@ -72,89 +80,63 @@ class LLVMRenderer(Renderer):
|
|
72
80
|
has_shared = False
|
73
81
|
global_max = None
|
74
82
|
|
75
|
-
|
76
|
-
#
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
83
|
+
extra_matcher = PatternMatcher([
|
84
|
+
# rewrite RECIP with FDIV
|
85
|
+
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
86
|
+
# rewrite cast to bool to CMPNE 0
|
87
|
+
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
88
|
+
# *** also in cstyle ***
|
89
|
+
# gate any stores that aren't gated with ifs
|
90
|
+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
91
|
+
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
92
|
+
# rewrite MAX to CMPLT + WHERE
|
93
|
+
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
94
|
+
])
|
95
|
+
|
96
|
+
def render(self, name: str, uops: List[UOp]) -> str:
|
97
|
+
r: Dict[UOp, str] = {}
|
98
|
+
args: List[str] = []
|
99
|
+
kernel: List[str] = []
|
100
|
+
end_lines: Dict[str, None] = {}
|
101
|
+
vc = -1
|
102
|
+
|
103
|
+
# prealloc all assigns
|
104
|
+
acc_to_assign: Dict[UOp, UOp] = {}
|
105
|
+
for u in uops:
|
106
|
+
if u.op is Ops.ASSIGN:
|
107
|
+
vc += 1
|
108
|
+
r[u] = r[u.src[1]] = f"%assign{vc}"
|
109
|
+
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
110
|
+
acc_to_assign[u.src[0]] = u.src[1]
|
101
111
|
|
102
112
|
for u in uops:
|
103
|
-
|
104
|
-
if
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
elif
|
112
|
-
|
113
|
-
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
114
|
-
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
115
|
-
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
116
|
-
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
117
|
-
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
113
|
+
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
114
|
+
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
|
115
|
+
|
116
|
+
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
117
|
+
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
118
|
+
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
119
|
+
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
120
|
+
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
121
|
+
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
122
|
+
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
|
118
123
|
else:
|
119
|
-
|
120
|
-
if
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
if len(src) > 2:
|
139
|
-
aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
|
140
|
-
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
|
141
|
-
val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
|
142
|
-
else:
|
143
|
-
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
144
|
-
lvars[u] = val
|
145
|
-
elif uop is UOps.PHI:
|
146
|
-
lvars[u] = lvars[src[1]]
|
147
|
-
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
148
|
-
backward = src[0]
|
149
|
-
while backward.op is UOps.PHI: backward = backward.src[0]
|
150
|
-
lvars[backward] = lvars[u]
|
151
|
-
elif uop is UOps.ALU:
|
152
|
-
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
|
153
|
-
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
154
|
-
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
155
|
-
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
156
|
-
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
157
|
-
else: raise RuntimeError(f"failed to render {uop}")
|
158
|
-
|
159
|
-
bb[-1].ret_void()
|
160
|
-
return str(module)
|
124
|
+
# if it's an assign target, it's already preallocated
|
125
|
+
if u not in r:
|
126
|
+
vc += 1
|
127
|
+
r[u] = f"%v{vc}"
|
128
|
+
|
129
|
+
# do the rendering of the llvm ir code
|
130
|
+
if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
131
|
+
kernel.append(cast(str, l))
|
132
|
+
|
133
|
+
# generate the phi nodes for the assigns
|
134
|
+
if u.op is Ops.RANGE:
|
135
|
+
for x in acc_to_assign:
|
136
|
+
if u in x.src: # if this range is relevent for this acc
|
137
|
+
vc += 1
|
138
|
+
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
|
139
|
+
r[x] = f"%acc{vc}"
|
140
|
+
|
141
|
+
# output the function
|
142
|
+
return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
|
tinygrad/renderer/ptx.py
ADDED
@@ -0,0 +1,225 @@
|
|
1
|
+
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
|
2
|
+
import struct
|
3
|
+
from collections import defaultdict
|
4
|
+
from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
|
5
|
+
from tinygrad.dtype import dtypes, DType, PtrDType
|
6
|
+
from tinygrad.renderer import Renderer
|
7
|
+
from tinygrad.renderer.cstyle import CUDARenderer
|
8
|
+
from tinygrad.helpers import prod, flatten
|
9
|
+
|
10
|
+
def render_val(x, dtype):
|
11
|
+
if dtypes.is_float(dtype):
|
12
|
+
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
13
|
+
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
14
|
+
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
15
|
+
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
16
|
+
|
17
|
+
asm_for_op: Dict[Ops, Callable] = {
|
18
|
+
Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
19
|
+
Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
20
|
+
Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
21
|
+
Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
22
|
+
Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
23
|
+
Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
24
|
+
Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
25
|
+
Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
26
|
+
Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
27
|
+
Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
28
|
+
Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
29
|
+
Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
30
|
+
Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
31
|
+
Ops.WHERE: lambda d,a,b,c,dt,name:
|
32
|
+
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
33
|
+
}
|
34
|
+
|
35
|
+
supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
|
36
|
+
doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
|
37
|
+
ptx_matcher = PatternMatcher([
|
38
|
+
# bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
|
39
|
+
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
40
|
+
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
41
|
+
# upcast to float32 all the ops that don't support half
|
42
|
+
(UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
|
43
|
+
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
|
44
|
+
# load/store bool -> uint8
|
45
|
+
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
46
|
+
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
47
|
+
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
48
|
+
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
49
|
+
# load/store use pointer arithmetic, and the cast does nothing
|
50
|
+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
51
|
+
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
52
|
+
# ptx shr and shl instructions require y to be uint
|
53
|
+
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
54
|
+
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
55
|
+
])
|
56
|
+
|
57
|
+
def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
|
58
|
+
|
59
|
+
def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
|
60
|
+
gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
|
61
|
+
return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
|
62
|
+
if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
|
63
|
+
|
64
|
+
def render_wmma(ctx: "PTXRenderer", x: UOp):
|
65
|
+
assert ctx.wmma_r, "registry values for wmma must be populated"
|
66
|
+
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
|
67
|
+
n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
|
68
|
+
dt_map = { dtypes.half: "f16" }
|
69
|
+
_i = 0
|
70
|
+
for vv in x.src[:2]:
|
71
|
+
for i in range(0, len(ctx.r[vv]), 2):
|
72
|
+
yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
|
73
|
+
_i += 1
|
74
|
+
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
|
75
|
+
f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
|
76
|
+
f'{{{", ".join(ctx.r[x.src[2]])}}};'
|
77
|
+
|
78
|
+
def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
|
79
|
+
(a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
|
80
|
+
|
81
|
+
string_rewrite = PatternMatcher([
|
82
|
+
(UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
|
83
|
+
(UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
|
84
|
+
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
|
85
|
+
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
|
86
|
+
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
|
87
|
+
(UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
|
88
|
+
lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
|
89
|
+
(UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
|
90
|
+
(UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
|
91
|
+
(UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
|
92
|
+
lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
|
93
|
+
(UPat(Ops.CAST, name="x", dtype=dtypes.bool),
|
94
|
+
lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
|
95
|
+
(UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
|
96
|
+
lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
|
97
|
+
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
|
98
|
+
[f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
|
99
|
+
[f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
|
100
|
+
]) if alt.dtype.count > 1 else [
|
101
|
+
f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
|
102
|
+
f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
|
103
|
+
(UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
|
104
|
+
lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
|
105
|
+
if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
|
106
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
|
107
|
+
lambda ctx, x, pred: flatten([
|
108
|
+
[f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
|
109
|
+
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
110
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
|
111
|
+
lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
|
112
|
+
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
|
113
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
|
114
|
+
f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
|
115
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
|
116
|
+
f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
|
117
|
+
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
|
118
|
+
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
|
119
|
+
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
|
120
|
+
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
|
121
|
+
(UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
|
122
|
+
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
|
123
|
+
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
|
124
|
+
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
|
125
|
+
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
|
126
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"),
|
127
|
+
lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
|
128
|
+
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
129
|
+
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
130
|
+
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
131
|
+
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
|
132
|
+
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
|
133
|
+
])
|
134
|
+
|
135
|
+
class PTXRenderer(Renderer):
|
136
|
+
device = "CUDA"
|
137
|
+
suffix = "PTX"
|
138
|
+
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
139
|
+
tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
|
140
|
+
code_for_op = asm_for_op
|
141
|
+
extra_matcher = ptx_matcher
|
142
|
+
def __init__(self, arch:str, device="CUDA"):
|
143
|
+
self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
|
144
|
+
def __reduce__(self): return self.__class__, (self.arch, self.device)
|
145
|
+
|
146
|
+
# language options
|
147
|
+
kernel_prefix = """.version VERSION
|
148
|
+
.target TARGET
|
149
|
+
.address_size 64
|
150
|
+
.visible .entry"""
|
151
|
+
barrier = "bar.sync\t0;"
|
152
|
+
supports_half = supports_half
|
153
|
+
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
154
|
+
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
155
|
+
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
156
|
+
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
157
|
+
|
158
|
+
mem_types: Dict[DType, str] = types.copy()
|
159
|
+
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
|
160
|
+
|
161
|
+
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
162
|
+
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
163
|
+
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
164
|
+
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
165
|
+
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
166
|
+
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
167
|
+
"\n}")
|
168
|
+
|
169
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
170
|
+
kernel:List[str] = []
|
171
|
+
bufs = []
|
172
|
+
|
173
|
+
c: DefaultDict[str, int] = defaultdict(int)
|
174
|
+
r: Dict[UOp, Union[List[str], str]] = {}
|
175
|
+
self.r = r
|
176
|
+
self.uops = uops
|
177
|
+
|
178
|
+
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
179
|
+
nonlocal c, r
|
180
|
+
prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
|
181
|
+
c[prefix] += 1
|
182
|
+
return f"%{prefix}{c[prefix]-1}"
|
183
|
+
|
184
|
+
for u in uops:
|
185
|
+
if u.op is Ops.VECTORIZE:
|
186
|
+
r[u] = [cast(str,r[x]) for x in u.src]
|
187
|
+
continue
|
188
|
+
if u.op is Ops.GEP:
|
189
|
+
assert len(u.arg) == 1
|
190
|
+
r[u] = r[u.src[0]][u.arg[0]]
|
191
|
+
continue
|
192
|
+
if u.op in {Ops.CAST, Ops.BITCAST}:
|
193
|
+
if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
|
194
|
+
r[u] = r[u.src[0]]
|
195
|
+
continue
|
196
|
+
r[u] = ssa('cast', u, self.types[u.dtype])
|
197
|
+
elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
|
198
|
+
elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
|
199
|
+
elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
|
200
|
+
elif u.op is Ops.DEFINE_ACC:
|
201
|
+
if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
|
202
|
+
r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
|
203
|
+
r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
|
204
|
+
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
|
205
|
+
elif u.op is Ops.DEFINE_VAR:
|
206
|
+
bufs.append((u.arg[0], u.dtype))
|
207
|
+
r[u] = ssa("dat", u, self.types[u.dtype])
|
208
|
+
elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
|
209
|
+
elif u.op is Ops.LOAD:
|
210
|
+
assert u.src[0].dtype == dtypes.int64, "load isn't int64"
|
211
|
+
r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
|
212
|
+
elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
|
213
|
+
elif u.op is Ops.DEFINE_GLOBAL:
|
214
|
+
bufs.append((f"data{u.arg}", u.dtype))
|
215
|
+
r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
|
216
|
+
elif u.op is Ops.WMMA:
|
217
|
+
self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
|
218
|
+
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
219
|
+
if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
|
220
|
+
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
|
221
|
+
kernel.extend([l] if isinstance(l, str) else l)
|
222
|
+
|
223
|
+
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
224
|
+
elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
|
225
|
+
return self.render_kernel(kernel, name, bufs, c.items())
|