tinygrad 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +248 -115
- tinygrad/codegen/lowerer.py +215 -0
- tinygrad/codegen/transcendental.py +310 -0
- tinygrad/codegen/uopgraph.py +622 -0
- tinygrad/codegen/uops.py +235 -393
- tinygrad/device.py +428 -69
- tinygrad/dtype.py +18 -4
- tinygrad/engine/graph.py +19 -32
- tinygrad/engine/jit.py +148 -70
- tinygrad/engine/realize.py +127 -51
- tinygrad/engine/schedule.py +259 -216
- tinygrad/engine/search.py +29 -22
- tinygrad/function.py +9 -0
- tinygrad/helpers.py +87 -49
- tinygrad/lazy.py +34 -35
- tinygrad/multi.py +41 -36
- tinygrad/nn/__init__.py +39 -22
- tinygrad/nn/state.py +3 -3
- tinygrad/ops.py +63 -62
- tinygrad/renderer/__init__.py +43 -21
- tinygrad/renderer/assembly.py +104 -106
- tinygrad/renderer/cstyle.py +87 -60
- tinygrad/renderer/llvmir.py +21 -30
- tinygrad/runtime/autogen/amd_gpu.py +25208 -5753
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/kfd.py +32 -0
- tinygrad/runtime/autogen/libc.py +4260 -0
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/graph/clang.py +2 -2
- tinygrad/runtime/graph/cuda.py +8 -11
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +18 -15
- tinygrad/runtime/ops_amd.py +197 -305
- tinygrad/runtime/ops_clang.py +2 -2
- tinygrad/runtime/ops_cuda.py +36 -94
- tinygrad/runtime/ops_disk.py +3 -7
- tinygrad/runtime/ops_gpu.py +4 -2
- tinygrad/runtime/ops_hip.py +70 -0
- tinygrad/runtime/ops_metal.py +38 -27
- tinygrad/runtime/ops_nv.py +283 -363
- tinygrad/runtime/ops_python.py +26 -30
- tinygrad/runtime/support/compiler_cuda.py +78 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +15 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/shape/shapetracker.py +5 -14
- tinygrad/shape/symbolic.py +4 -8
- tinygrad/shape/view.py +34 -22
- tinygrad/tensor.py +399 -97
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/METADATA +49 -48
- tinygrad-0.9.2.dist-info/RECORD +70 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/runtime/{driver → support}/__init__.py +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.9.2.dist-info}/top_level.txt +0 -0
tinygrad/renderer/cstyle.py
CHANGED
@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
|
|
4
4
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
5
5
|
from tinygrad.helpers import strip_parens, getenv, prod, dedup
|
6
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
7
|
-
from tinygrad.codegen.uops import UOps, UOp
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp
|
8
8
|
from tinygrad.renderer import Renderer, TensorCore
|
9
9
|
|
10
10
|
class CStyleLanguage(Renderer):
|
@@ -22,6 +22,8 @@ class CStyleLanguage(Renderer):
|
|
22
22
|
uses_vload: bool = False
|
23
23
|
uses_ptr_arithmetic: bool = False
|
24
24
|
type_map: Dict[DType, str] = {}
|
25
|
+
infinity: str = "INFINITY"
|
26
|
+
nan: str = "NAN"
|
25
27
|
code_for_op: Dict = {
|
26
28
|
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
27
29
|
UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
|
@@ -29,24 +31,30 @@ class CStyleLanguage(Renderer):
|
|
29
31
|
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
|
30
32
|
BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
31
33
|
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
|
34
|
+
BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})",
|
32
35
|
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
|
33
36
|
|
34
37
|
# returns a str expression of the casted xs with the given type
|
35
|
-
def render_cast(self, x:
|
36
|
-
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x
|
37
|
-
|
38
|
+
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
39
|
+
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x}))"
|
40
|
+
return f"({self.render_dtype(var_dtype)})({x})"
|
41
|
+
|
42
|
+
# returns a str expression of the vectorized xs with the given type
|
43
|
+
def render_vectorize(self, x:List[str], var_dtype:DType) -> str:
|
38
44
|
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
|
39
45
|
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
40
|
-
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
46
|
+
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}" + (f"{{{','.join(x)}}}" if self.device == "CLANG" else f"({','.join(x)})")
|
41
47
|
|
42
48
|
# returns a str expression of the const with the given type
|
43
49
|
def render_const(self, x:ConstType, dtype:DType) -> str:
|
44
|
-
|
45
|
-
|
50
|
+
assert dtype.count == 1, f"consts should be scalar, got {dtype}"
|
51
|
+
if math.isnan(x): val = self.nan
|
52
|
+
elif math.isinf(x): val = ("-" if x < 0 else "") + self.infinity
|
46
53
|
elif dtype == dtypes.bool: val = "1" if x else "0"
|
47
54
|
elif dtype == dtypes.float: val = f"{x}f"
|
55
|
+
elif dtype == dtypes.uint64: val = f"{x}ULL"
|
48
56
|
else: val = str(x)
|
49
|
-
return (self.render_cast(
|
57
|
+
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
50
58
|
|
51
59
|
# returns a str expression of the loaded value with the output type
|
52
60
|
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
@@ -56,13 +64,11 @@ class CStyleLanguage(Renderer):
|
|
56
64
|
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
57
65
|
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
58
66
|
if output_dtype.count > 1:
|
59
|
-
|
60
|
-
else
|
61
|
-
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
62
|
-
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
67
|
+
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
|
68
|
+
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
63
69
|
|
64
|
-
def get_kernel_modifier(self, uops:
|
65
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:
|
70
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
71
|
+
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
66
72
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
67
73
|
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
68
74
|
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
@@ -81,15 +87,16 @@ class CStyleLanguage(Renderer):
|
|
81
87
|
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
82
88
|
if var_dtype.count > 1:
|
83
89
|
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
|
84
|
-
return f"*(({prefix}{self.render_dtype(
|
90
|
+
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
|
85
91
|
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
86
92
|
|
87
93
|
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
88
|
-
def render_dtype(self, var_dtype:DType) -> str:
|
94
|
+
def render_dtype(self, var_dtype:DType) -> str:
|
95
|
+
return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
|
89
96
|
|
90
|
-
def render(self, name:str, uops:
|
97
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
91
98
|
kernel = []
|
92
|
-
bufs:
|
99
|
+
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
|
93
100
|
depth = 1
|
94
101
|
def kk(s): kernel.append(" "*depth+s)
|
95
102
|
|
@@ -118,6 +125,8 @@ class CStyleLanguage(Renderer):
|
|
118
125
|
kk("}")
|
119
126
|
elif uop is UOps.STORE:
|
120
127
|
assert src[0].dtype is not None and src[2].dtype is not None
|
128
|
+
# mark DEFINE_GLOBAL buf as writable
|
129
|
+
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
|
121
130
|
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
122
131
|
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
123
132
|
else:
|
@@ -128,6 +137,7 @@ class CStyleLanguage(Renderer):
|
|
128
137
|
elif uop is UOps.ALU:
|
129
138
|
# remove parens if ALU types are the same. TODO: can do more here
|
130
139
|
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
140
|
+
elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], cast(DType, v.dtype)) if v.op is UOps.CONST else r[v] for v in src]
|
131
141
|
else: operands = [r[v] for v in src]
|
132
142
|
val = self.code_for_op[args](*operands, dtype)
|
133
143
|
assert child_count[u] != 0, f"childless ALU op found {u}"
|
@@ -135,58 +145,71 @@ class CStyleLanguage(Renderer):
|
|
135
145
|
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
136
146
|
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
137
147
|
elif uop is UOps.SPECIAL:
|
138
|
-
kk(f"int {args[
|
139
|
-
r[u] = args[
|
148
|
+
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
|
149
|
+
r[u] = args[0]
|
150
|
+
elif uop is UOps.DEFINE_VAR:
|
151
|
+
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
152
|
+
seen_vars.add(args.expr)
|
153
|
+
bufs[u] = (args.expr, (dtype,False))
|
154
|
+
r[u] = args.expr
|
140
155
|
elif uop is UOps.LOAD:
|
141
156
|
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
142
157
|
# NOTE: this relies on the load not happening if it's in the unselected branch
|
143
|
-
if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[
|
158
|
+
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
|
144
159
|
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
145
160
|
elif uop is UOps.PHI:
|
146
161
|
kk(f"{r[src[0]]} = {r[src[1]]};")
|
147
162
|
r[u] = r[src[0]]
|
148
|
-
elif uop in {UOps.CAST, UOps.BITCAST}:
|
163
|
+
elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
|
164
|
+
assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
|
149
165
|
if uop is UOps.BITCAST:
|
150
|
-
assert len(src) == 1
|
151
166
|
precast = ssa('precast')
|
152
167
|
kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
|
153
|
-
val = self.render_cast(
|
154
|
-
|
155
|
-
|
168
|
+
val = self.render_cast(precast, dtype, bitcast=True)
|
169
|
+
elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
|
170
|
+
else: val = self.render_vectorize([r[x] for x in src], dtype)
|
156
171
|
if child_count[u] <= 1: r[u] = val
|
157
172
|
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
158
173
|
elif uop is UOps.DEFINE_LOCAL:
|
159
174
|
kk(self.render_local(args[0], dtype, args[1]))
|
160
175
|
r[u] = args[0]
|
161
|
-
elif uop is UOps.DEFINE_VAR:
|
162
|
-
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
|
163
|
-
seen_vars.add(args.expr)
|
164
|
-
bufs.append((args.expr, (dtype,False)))
|
165
|
-
r[u] = args.expr
|
166
176
|
elif uop is UOps.DEFINE_GLOBAL:
|
167
|
-
bufs
|
177
|
+
bufs[u] = (nm:=f"data{args}", (dtype, False))
|
168
178
|
r[u] = nm
|
169
179
|
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
170
|
-
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {
|
180
|
+
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
|
171
181
|
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
172
182
|
elif uop is UOps.GEP:
|
173
183
|
assert src[0].dtype is not None
|
174
184
|
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
175
|
-
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") +
|
176
|
-
|
185
|
+
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
|
186
|
+
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) or self.device == 'CLANG' else f".{'xyzwabcd'[args]}")
|
187
|
+
else: raise RuntimeError(f"failed to render {u}")
|
188
|
+
|
189
|
+
# NOTE: this relies on bufs dict preserving order
|
190
|
+
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
177
191
|
|
178
|
-
|
192
|
+
def _make_clang_dtype(self, dtype):
|
193
|
+
return f"typedef {self.render_dtype(dtype.scalar())} {self.render_dtype(dtype)} __attribute__((aligned({(sz:=dtype.itemsize)}),vector_size({sz})));"
|
179
194
|
|
180
195
|
class ClangRenderer(CStyleLanguage):
|
181
196
|
device = "CLANG"
|
182
|
-
|
197
|
+
float4 = "(float4)"
|
183
198
|
has_local = False
|
184
199
|
global_max = None
|
200
|
+
infinity = "__builtin_inff()"
|
201
|
+
nan = '__builtin_nanf("")'
|
185
202
|
|
186
203
|
# language options
|
187
204
|
buffer_suffix = " restrict"
|
188
205
|
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
189
|
-
code_for_op = {**CStyleLanguage().code_for_op
|
206
|
+
code_for_op = {**({k:v for k,v in CStyleLanguage().code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}),
|
207
|
+
UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrtl({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
|
208
|
+
BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
|
209
|
+
|
210
|
+
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
211
|
+
prefix = [_make_clang_dtype(self, dtype) for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype.count>1)]
|
212
|
+
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
190
213
|
|
191
214
|
class OpenCLRenderer(CStyleLanguage):
|
192
215
|
device = "GPU"
|
@@ -202,7 +225,7 @@ class OpenCLRenderer(CStyleLanguage):
|
|
202
225
|
uses_vload = True
|
203
226
|
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
|
204
227
|
def render_cast(self, x, var_dtype, bitcast=False) -> str:
|
205
|
-
return f"as_{self.render_dtype(var_dtype)}({x
|
228
|
+
return f"as_{self.render_dtype(var_dtype)}({x})" if bitcast else super().render_cast(x, var_dtype)
|
206
229
|
|
207
230
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
208
231
|
if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
|
@@ -211,7 +234,7 @@ class OpenCLRenderer(CStyleLanguage):
|
|
211
234
|
class MetalRenderer(CStyleLanguage):
|
212
235
|
device = "METAL"
|
213
236
|
shared_max = 32768
|
214
|
-
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)],
|
237
|
+
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
215
238
|
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
|
216
239
|
|
217
240
|
# language options
|
@@ -222,7 +245,7 @@ class MetalRenderer(CStyleLanguage):
|
|
222
245
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
223
246
|
float4 = "float4"
|
224
247
|
uses_ptr_arithmetic = True
|
225
|
-
code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
|
248
|
+
code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
|
226
249
|
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
227
250
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
228
251
|
type_map = {dtypes.bfloat16: "bfloat"}
|
@@ -231,10 +254,10 @@ class MetalRenderer(CStyleLanguage):
|
|
231
254
|
UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
|
232
255
|
UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
|
233
256
|
UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
|
234
|
-
UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
|
257
|
+
UnaryOps.SIN: lambda x,dtype: f"(bfloat)precise::sin({x})" if dtype == dtypes.bfloat16 else f"precise::sin({x})",}
|
235
258
|
|
236
|
-
def render_cast(self, x:
|
237
|
-
return f"as_type<{self.render_dtype(var_dtype)}>({x
|
259
|
+
def render_cast(self, x:str, var_dtype:DType, bitcast=False) -> str:
|
260
|
+
return f"as_type<{self.render_dtype(var_dtype)}>({x})" if bitcast else super().render_cast(x, var_dtype)
|
238
261
|
|
239
262
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
240
263
|
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
@@ -252,16 +275,17 @@ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dt
|
|
252
275
|
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
253
276
|
|
254
277
|
_nms = "xyzwabcdefghijkl"
|
255
|
-
def _make_cuda_dtype(
|
256
|
-
vec,
|
257
|
-
|
278
|
+
def _make_cuda_dtype(renderer:CStyleLanguage, dtype:DType) -> str:
|
279
|
+
vec, scal = renderer.render_dtype(dtype), renderer.render_dtype(dtype.scalar()),
|
280
|
+
elems, header = ', '.join(_nms[:dtype.count]), ', '.join([f"{scal} {x}" for x in _nms[:dtype.count]])
|
281
|
+
return f"struct __align__({dtype.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
258
282
|
|
259
283
|
class CUDARenderer(CStyleLanguage):
|
260
284
|
device = "CUDA"
|
261
285
|
global_max = (2147483647, 65535, 65535)
|
262
286
|
local_max = (1024, 1024, 64)
|
263
287
|
shared_max = 49152
|
264
|
-
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(
|
288
|
+
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
265
289
|
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
266
290
|
|
267
291
|
# language options
|
@@ -270,25 +294,22 @@ class CUDARenderer(CStyleLanguage):
|
|
270
294
|
smem_prefix_for_cast = False
|
271
295
|
barrier = "__syncthreads();"
|
272
296
|
float4 = "make_float4"
|
273
|
-
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
|
274
|
-
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
|
297
|
+
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
|
298
|
+
"i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+x)}+threadIdx.{chr(120+int(x))})"}
|
275
299
|
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
276
300
|
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
277
301
|
|
278
302
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
279
303
|
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
280
|
-
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
|
281
|
-
|
282
304
|
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
283
|
-
if any(uop.dtype == dtypes.half for uop in uops):
|
284
|
-
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
|
285
305
|
|
286
|
-
|
287
|
-
prefix += ["#include <
|
306
|
+
for dtype in dedup(uop.dtype for uop in uops if uop.dtype is not None and uop.dtype in (dtypes.half, dtypes.bfloat16)):
|
307
|
+
prefix += [f"#include <cuda_{'fp' if dtype == dtypes.half else 'bf'}16.h>"] + [_make_cuda_dtype(self, dtype.vec(sz)) for sz in [4, 8]]
|
288
308
|
|
289
309
|
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
310
|
+
dt_map = { dtypes.float: "f32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
290
311
|
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
291
|
-
fn, ti, to, ci, co = arg[0],
|
312
|
+
fn, ti, to, ci, co = arg[0], self.render_dtype(arg[2]), self.render_dtype(arg[3]), dt_map[arg[2]], dt_map[arg[3]]
|
292
313
|
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
293
314
|
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
294
315
|
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
@@ -296,6 +317,11 @@ return c;}}""")
|
|
296
317
|
|
297
318
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
298
319
|
|
320
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
321
|
+
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
|
322
|
+
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
323
|
+
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
324
|
+
|
299
325
|
code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
300
326
|
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
301
327
|
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
|
@@ -321,7 +347,7 @@ def _make_hip_dtype(base_type, name, cnt):
|
|
321
347
|
class AMDRenderer(CStyleLanguage):
|
322
348
|
device = "AMD"
|
323
349
|
shared_max = 65536
|
324
|
-
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)],
|
350
|
+
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
325
351
|
|
326
352
|
# language options
|
327
353
|
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
@@ -380,10 +406,11 @@ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat
|
|
380
406
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
381
407
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
382
408
|
|
383
|
-
def get_kernel_modifier(self, uops:
|
384
|
-
requiredMaxThreadsPerBlock = prod(u.arg[
|
409
|
+
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
410
|
+
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
|
385
411
|
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
386
412
|
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
387
413
|
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
388
414
|
|
389
415
|
class NVRenderer(CUDARenderer): device = "NV"
|
416
|
+
class HIPRenderer(AMDRenderer): device = "HIP"
|
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,32 +1,14 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Dict, Callable, Any, List, Optional
|
2
2
|
from llvmlite import ir
|
3
3
|
from tinygrad.dtype import DType, PtrDType, dtypes
|
4
4
|
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
5
|
-
from tinygrad.codegen.uops import UOps, UOp
|
5
|
+
from tinygrad.codegen.uops import UOps, UOp
|
6
6
|
from tinygrad.renderer import Renderer
|
7
7
|
|
8
8
|
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
9
9
|
|
10
10
|
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
11
11
|
|
12
|
-
code_for_op: Final[Dict[Op, Callable]] = {
|
13
|
-
UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
|
14
|
-
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
15
|
-
UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
16
|
-
UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
17
|
-
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
18
|
-
UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
19
|
-
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
20
|
-
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
21
|
-
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
22
|
-
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
23
|
-
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
24
|
-
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
25
|
-
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
26
|
-
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
27
|
-
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
|
28
|
-
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
29
|
-
|
30
12
|
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
31
13
|
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
32
14
|
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
|
@@ -71,8 +53,22 @@ class LLVMRenderer(Renderer):
|
|
71
53
|
has_local = False
|
72
54
|
has_shared = False
|
73
55
|
global_max = None
|
74
|
-
|
75
|
-
|
56
|
+
code_for_op: Dict[Op, Callable] = {
|
57
|
+
UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
|
58
|
+
(builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
59
|
+
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
60
|
+
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
61
|
+
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
62
|
+
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
63
|
+
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
64
|
+
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
65
|
+
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
66
|
+
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
67
|
+
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
68
|
+
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
|
69
|
+
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
70
|
+
|
71
|
+
def render(self, name:str, uops:List[UOp]) -> str:
|
76
72
|
# all llvm stuff goes into a module
|
77
73
|
module = ir.Module(name=__file__)
|
78
74
|
|
@@ -86,10 +82,6 @@ class LLVMRenderer(Renderer):
|
|
86
82
|
for a in func.args:
|
87
83
|
if a.type.is_pointer: a.add_attribute("noalias")
|
88
84
|
|
89
|
-
# add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
|
90
|
-
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
|
91
|
-
func.attributes.add('"no-nans-fp-math"="true"')
|
92
|
-
|
93
85
|
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
94
86
|
loop_blocks: List = []
|
95
87
|
reduce_phis: List = []
|
@@ -136,9 +128,9 @@ class LLVMRenderer(Renderer):
|
|
136
128
|
reduce_phis.append(u)
|
137
129
|
elif uop is UOps.LOAD:
|
138
130
|
if len(src) > 2:
|
139
|
-
aug_idx = bb[-1].select(lvars[src[
|
131
|
+
aug_idx = bb[-1].select(lvars[src[3]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
|
140
132
|
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
|
141
|
-
val = bb[-1].select(lvars[src[
|
133
|
+
val = bb[-1].select(lvars[src[3]], val, lvars[src[2]])
|
142
134
|
else:
|
143
135
|
val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
144
136
|
lvars[u] = val
|
@@ -149,10 +141,9 @@ class LLVMRenderer(Renderer):
|
|
149
141
|
while backward.op is UOps.PHI: backward = backward.src[0]
|
150
142
|
lvars[backward] = lvars[u]
|
151
143
|
elif uop is UOps.ALU:
|
152
|
-
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args
|
144
|
+
lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
153
145
|
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
154
146
|
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
155
|
-
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
156
147
|
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
157
148
|
else: raise RuntimeError(f"failed to render {uop}")
|
158
149
|
|