tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/renderer/assembly.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
2
2
|
import struct, math
|
3
3
|
from collections import defaultdict
|
4
|
-
from tinygrad.helpers import DEBUG
|
5
4
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
6
5
|
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp,
|
6
|
+
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
|
8
7
|
from tinygrad.renderer import Renderer, TensorCore
|
9
8
|
|
10
9
|
def render_val(x, dtype):
|
@@ -14,14 +13,85 @@ def render_val(x, dtype):
|
|
14
13
|
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
15
14
|
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
16
15
|
|
16
|
+
asm_for_op: Dict[Op, Callable] = {
|
17
|
+
UnaryOps.NEG: lambda d,a,dt,name:
|
18
|
+
f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) else f"neg.{name} {d}, {a};",
|
19
|
+
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
20
|
+
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
21
|
+
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
22
|
+
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
23
|
+
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
24
|
+
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
25
|
+
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
26
|
+
BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
|
27
|
+
BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
|
28
|
+
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
29
|
+
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
30
|
+
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
31
|
+
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
32
|
+
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
33
|
+
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
34
|
+
}
|
35
|
+
|
36
|
+
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
37
|
+
shiftable_consts = set([2**i for i in range(64)])
|
38
|
+
ptx_matcher = PatternMatcher([
|
39
|
+
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
40
|
+
src=[UPat(UOps.CONST, name="const"), UPat(name="mul")]),
|
41
|
+
lambda root, mul, const: UOp(UOps.ALU, root.dtype,
|
42
|
+
(mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL) if const.arg in shiftable_consts else None),
|
43
|
+
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
44
|
+
src=[UPat(UOps.CONST, name="const"), UPat(name="div")]),
|
45
|
+
lambda root, div, const: UOp(UOps.ALU, root.dtype,
|
46
|
+
(div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR) if const.arg in shiftable_consts else None),
|
47
|
+
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
48
|
+
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
49
|
+
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
50
|
+
(UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
51
|
+
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
52
|
+
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
53
|
+
lambda x: (UOp(x.op, dtypes.float32, tuple([vv.cast(dtypes.float32) for vv in x.src]), x.arg).cast(dtypes.half)))
|
54
|
+
for op in asm_for_op.keys() if op not in supports_half],
|
55
|
+
(UPat(UOps.ALU, name="x", dtype=dtypes.bool, arg=BinaryOps.MAX),
|
56
|
+
lambda x: UOp(UOps.ALU, dtypes.uint8, tuple(s.cast(dtypes.uint8) for s in x.src), x.arg).cast(dtypes.bool)),
|
57
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
58
|
+
lambda root,x,y,z,k: UOp(root.op, dtypes.uint8, (x,y,z.cast(dtypes.uint8),k)).cast(dtypes.bool)),
|
59
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
60
|
+
lambda root: UOp(root.op, dtypes.uint8, root.src, root.arg).cast(dtypes.bool)),
|
61
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
62
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
63
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
64
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (z.cast(dtypes.uint8),), root.arg)),
|
65
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
66
|
+
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (g.cast(dtypes.uint8),), root.arg)),
|
67
|
+
# ptr_ar (load/store)
|
68
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
69
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
70
|
+
lambda root, alu, const: UOp(root.op, root.dtype,
|
71
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
72
|
+
const.const(root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
73
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
74
|
+
UPat(UOps.CONST, name="const"))),
|
75
|
+
lambda root, const: UOp(root.op, root.dtype,
|
76
|
+
(root.src[0].cast(dtypes.int64),
|
77
|
+
UOp.const(dtypes.int64, const.arg*root.src[0].dtype.itemsize),)+root.src[2:])),
|
78
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_any_len=True, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
79
|
+
UPat(name="alu"))), # no const here
|
80
|
+
lambda root, alu: UOp(root.op, root.dtype,
|
81
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
82
|
+
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
83
|
+
])
|
84
|
+
|
17
85
|
class PTXRenderer(Renderer):
|
18
86
|
device = "CUDA"
|
19
87
|
suffix = "PTX"
|
20
88
|
global_max = (2147483647, 65535, 65535)
|
21
89
|
local_max = (1024, 1024, 64)
|
22
90
|
shared_max = 49152
|
23
|
-
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(
|
24
|
-
|
91
|
+
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
92
|
+
code_for_op = asm_for_op
|
93
|
+
extra_matcher = ptx_matcher
|
94
|
+
def __init__(self, arch:str, device="CUDA"): self.device, self.tensor_cores = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
25
95
|
|
26
96
|
# language options
|
27
97
|
kernel_prefix = """.version VERSION
|
@@ -29,29 +99,7 @@ class PTXRenderer(Renderer):
|
|
29
99
|
.address_size 64
|
30
100
|
.visible .entry"""
|
31
101
|
barrier = "bar.sync\t0;"
|
32
|
-
|
33
|
-
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
34
|
-
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
35
|
-
asm_for_op: Dict[Op, Callable] = {
|
36
|
-
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
|
37
|
-
else f"neg.{name} {d}, {a};",
|
38
|
-
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
39
|
-
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
40
|
-
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
41
|
-
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
42
|
-
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
43
|
-
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
44
|
-
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
45
|
-
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
46
|
-
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
47
|
-
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
48
|
-
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
49
|
-
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
50
|
-
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
51
|
-
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
52
|
-
}
|
53
|
-
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
54
|
-
TernaryOps.WHERE]
|
102
|
+
supports_half = supports_half
|
55
103
|
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
56
104
|
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
57
105
|
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
@@ -98,13 +146,10 @@ class PTXRenderer(Renderer):
|
|
98
146
|
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
99
147
|
"\n}")
|
100
148
|
|
101
|
-
def render(self, name:str, uops:
|
149
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
102
150
|
kernel:List[str] = []
|
103
151
|
bufs = []
|
104
152
|
|
105
|
-
uops.linearize(ptx_matcher)
|
106
|
-
if DEBUG >= 4: uops.print()
|
107
|
-
|
108
153
|
def kk(*s: str): kernel.append("\n".join(s))
|
109
154
|
|
110
155
|
c: DefaultDict[str, int] = defaultdict(int)
|
@@ -133,14 +178,14 @@ class PTXRenderer(Renderer):
|
|
133
178
|
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
134
179
|
if uop is UOps.IF:
|
135
180
|
assert src[0].dtype is not None
|
136
|
-
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{
|
181
|
+
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
|
137
182
|
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
138
183
|
elif uop is UOps.ENDRANGE:
|
139
|
-
kk(self.
|
140
|
-
self.
|
184
|
+
kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
185
|
+
self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
141
186
|
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
142
187
|
elif uop is UOps.ENDIF:
|
143
|
-
kk(f"IF_{r[src[0].src[0]][1:]}_{
|
188
|
+
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
|
144
189
|
elif uop is UOps.STORE:
|
145
190
|
assert src[0].dtype is not None and src[2].dtype is not None
|
146
191
|
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
@@ -156,58 +201,54 @@ class PTXRenderer(Renderer):
|
|
156
201
|
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
157
202
|
elif uop is UOps.ALU:
|
158
203
|
assert src[0].dtype is not None
|
159
|
-
if args
|
160
|
-
|
161
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
|
162
|
-
else:
|
163
|
-
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
|
204
|
+
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
|
205
|
+
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
164
206
|
elif uop is UOps.DEFINE_ACC:
|
165
207
|
if dtype.count > 1:
|
166
208
|
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
167
|
-
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
168
|
-
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
209
|
+
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
|
210
|
+
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
169
211
|
elif uop is UOps.SPECIAL:
|
170
|
-
assert args[
|
171
|
-
kk(f"mov.u32 %{args[
|
172
|
-
r[u] = "%" + args[
|
173
|
-
kernel = [f".reg .u32 %{args[
|
174
|
-
elif uop is UOps.
|
175
|
-
|
176
|
-
|
212
|
+
assert args[0][0] != "i", "idx not supported"
|
213
|
+
kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
|
214
|
+
r[u] = "%" + args[0]
|
215
|
+
kernel = [f".reg .u32 %{args[0]};"] + kernel
|
216
|
+
elif uop is UOps.DEFINE_VAR:
|
217
|
+
bufs.append((args.expr, dtype))
|
218
|
+
r[u] = f"%{args.expr}"
|
219
|
+
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
220
|
+
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
|
177
221
|
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
178
222
|
elif uop is UOps.LOAD:
|
179
223
|
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
180
224
|
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
181
225
|
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
226
|
+
has_gate = len(src) > 3 and src[3].op is UOps.ALU
|
182
227
|
if dtype.count > 1:
|
183
228
|
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
184
|
-
if
|
229
|
+
if has_gate:
|
185
230
|
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
186
|
-
kk((f"@{r[src[
|
231
|
+
kk((f"@{r[src[3]]}"if has_gate else "")
|
187
232
|
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
188
233
|
else:
|
189
|
-
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[
|
190
|
-
alt=r[src[
|
234
|
+
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[3]] if has_gate else None,
|
235
|
+
alt=r[src[2]] if has_gate else None, ss=mem_type, offset=src[1].arg))
|
191
236
|
elif uop is UOps.PHI:
|
192
237
|
if dtype.count > 1:
|
193
238
|
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
194
|
-
else:
|
195
|
-
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
239
|
+
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
|
196
240
|
r[u] = r[src[0]]
|
241
|
+
# NOTE: casting to str is fine because you can't vectorize a vectorize
|
242
|
+
elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
|
197
243
|
elif uop in {UOps.CAST, UOps.BITCAST}:
|
198
|
-
assert src[0].dtype is not None
|
199
|
-
|
200
|
-
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
244
|
+
assert src[0].dtype is not None and dtype.count == 1
|
245
|
+
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
201
246
|
elif uop is UOps.DEFINE_LOCAL:
|
202
247
|
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
203
248
|
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
204
249
|
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
205
|
-
elif uop is UOps.DEFINE_VAR:
|
206
|
-
bufs.append((args.expr, dtype))
|
207
|
-
r[u] = f"%{args.expr}"
|
208
|
-
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
209
250
|
elif uop is UOps.DEFINE_GLOBAL:
|
210
|
-
bufs.append((nm:=f"data{args
|
251
|
+
bufs.append((nm:=f"data{args}", dtype))
|
211
252
|
r[u] = f"%{nm}"
|
212
253
|
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
213
254
|
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
@@ -224,46 +265,3 @@ class PTXRenderer(Renderer):
|
|
224
265
|
|
225
266
|
return self.render_kernel(kernel, name, bufs, c.items())
|
226
267
|
|
227
|
-
ptx_matcher = PatternMatcher([
|
228
|
-
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
229
|
-
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
230
|
-
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
231
|
-
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
232
|
-
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
233
|
-
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
234
|
-
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
235
|
-
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
236
|
-
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
237
|
-
(UPat(UOps.ALU, BinaryOps.ADD,
|
238
|
-
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
239
|
-
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
240
|
-
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
241
|
-
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
242
|
-
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
243
|
-
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
244
|
-
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
245
|
-
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
246
|
-
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
247
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
248
|
-
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
249
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
250
|
-
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
251
|
-
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
252
|
-
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
253
|
-
# ptr_ar (load/store)
|
254
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
255
|
-
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
256
|
-
lambda root, alu, const: UOp(root.op, root.dtype,
|
257
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
258
|
-
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
259
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
260
|
-
UPat(UOps.CONST, name="const"))),
|
261
|
-
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
262
|
-
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
263
|
-
)+root.src[2:])),
|
264
|
-
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
265
|
-
UPat(name="alu"))), # no const here
|
266
|
-
lambda root, alu: UOp(root.op, root.dtype,
|
267
|
-
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
268
|
-
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
269
|
-
])
|