tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/renderer/cstyle.py
CHANGED
@@ -5,6 +5,7 @@ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
|
5
5
|
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
6
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
7
7
|
from tinygrad.renderer import Renderer, TensorCore
|
8
|
+
from tinygrad.codegen.devectorizer import no_vectorized_alu
|
8
9
|
|
9
10
|
base_rewrite = PatternMatcher([
|
10
11
|
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
@@ -17,7 +18,9 @@ base_rewrite = PatternMatcher([
|
|
17
18
|
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
18
19
|
(UPat(Ops.VECTORIZE, name="x"),
|
19
20
|
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
20
|
-
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device
|
21
|
+
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
|
22
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
23
|
+
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
21
24
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
22
25
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
23
26
|
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
@@ -49,7 +52,10 @@ base_rewrite = PatternMatcher([
|
|
49
52
|
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
50
53
|
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
51
54
|
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
52
|
-
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device
|
55
|
+
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
|
56
|
+
f".{'xyzwabcd'[x.arg[0]]}")),
|
57
|
+
# custom passes through with format
|
58
|
+
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
53
59
|
])
|
54
60
|
|
55
61
|
extra_pm = PatternMatcher([
|
@@ -58,6 +64,12 @@ extra_pm = PatternMatcher([
|
|
58
64
|
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
59
65
|
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
60
66
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
67
|
+
# devectorize any bools
|
68
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
69
|
+
# CAST (from bool) can't be vectorized
|
70
|
+
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
|
71
|
+
# WHERE can't be vectorized
|
72
|
+
(UPat(Ops.WHERE, name="alu"), no_vectorized_alu),
|
61
73
|
])
|
62
74
|
|
63
75
|
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
@@ -104,10 +116,11 @@ class CStyleLanguage(Renderer):
|
|
104
116
|
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
105
117
|
if isinstance(dt, PtrDType):
|
106
118
|
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
|
107
|
-
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
119
|
+
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
|
120
|
+
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
108
121
|
|
109
122
|
def __getitem__(self, key): return self.r[key] # hacky helper
|
110
|
-
def render(self,
|
123
|
+
def render(self, uops:list[UOp]) -> str:
|
111
124
|
r: dict[UOp, str] = {}
|
112
125
|
self.r = r
|
113
126
|
|
@@ -116,7 +129,11 @@ class CStyleLanguage(Renderer):
|
|
116
129
|
kernel = []
|
117
130
|
depth = 1
|
118
131
|
c: defaultdict[str, int] = defaultdict(int)
|
132
|
+
name = "test"
|
119
133
|
for u in uops:
|
134
|
+
if u.op is Ops.NAME:
|
135
|
+
name = u.arg
|
136
|
+
continue
|
120
137
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
121
138
|
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
122
139
|
bufs[u] = (r[u], (u.dtype, False))
|
@@ -141,7 +158,7 @@ class CStyleLanguage(Renderer):
|
|
141
158
|
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
142
159
|
|
143
160
|
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
144
|
-
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
|
161
|
+
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
|
145
162
|
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
146
163
|
r[u] = l
|
147
164
|
else:
|
@@ -158,12 +175,15 @@ class CStyleLanguage(Renderer):
|
|
158
175
|
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
159
176
|
|
160
177
|
class ClangRenderer(CStyleLanguage):
|
161
|
-
device = "
|
178
|
+
device = "CPU"
|
162
179
|
float4 = "(float4)"
|
163
180
|
has_local = False
|
164
181
|
global_max = None
|
165
182
|
infinity = "__builtin_inff()"
|
166
183
|
nan = '__builtin_nanf("")'
|
184
|
+
amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
|
185
|
+
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
186
|
+
if AMX: tensor_cores = amx_tc
|
167
187
|
|
168
188
|
# language options
|
169
189
|
buffer_suffix = " restrict"
|
@@ -174,14 +194,12 @@ class ClangRenderer(CStyleLanguage):
|
|
174
194
|
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
|
175
195
|
CStyleLanguage.extra_matcher
|
176
196
|
|
177
|
-
if AMX:
|
178
|
-
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
179
|
-
swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
|
180
|
-
for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
181
197
|
if sys.platform == 'win32':
|
182
198
|
kernel_prefix = "__attribute__((ms_abi)) "
|
183
199
|
def render_vector_prefix(self, dt:DType) -> str:
|
184
|
-
|
200
|
+
# round (down) to power of two
|
201
|
+
alignment = 2**int(math.log2(dt.itemsize))
|
202
|
+
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
|
185
203
|
|
186
204
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
187
205
|
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
@@ -300,10 +318,11 @@ class CUDARenderer(CStyleLanguage):
|
|
300
318
|
local_max = (1024, 1024, 64)
|
301
319
|
shared_max = 49152
|
302
320
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
303
|
-
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
|
304
|
-
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)
|
305
|
-
|
306
|
-
|
321
|
+
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
322
|
+
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
|
323
|
+
(dtypes.half,dtypes.half)]]
|
324
|
+
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
325
|
+
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
307
326
|
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
308
327
|
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
309
328
|
|
@@ -344,7 +363,8 @@ class CUDARenderer(CStyleLanguage):
|
|
344
363
|
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
345
364
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
346
365
|
|
347
|
-
|
366
|
+
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
367
|
+
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
348
368
|
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
349
369
|
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
350
370
|
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
@@ -353,10 +373,11 @@ class CUDARenderer(CStyleLanguage):
|
|
353
373
|
|
354
374
|
# mma operands => {c}, {a}, {b}, {c}
|
355
375
|
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
|
356
|
-
int *a_pk = (int *)(&a), *b_pk = (int *)(&b)
|
376
|
+
int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
|
377
|
+
asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
|
357
378
|
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
|
358
379
|
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
|
359
|
-
: {", ".join([f'"+
|
380
|
+
: {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
|
360
381
|
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
|
361
382
|
return c;\n}}""")
|
362
383
|
|
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1
1
|
from typing import cast
|
2
|
-
import math, struct
|
2
|
+
import math, struct, sys
|
3
3
|
from tinygrad.renderer import Renderer
|
4
|
+
from tinygrad.renderer.cstyle import ClangRenderer
|
4
5
|
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
5
6
|
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
7
|
+
from tinygrad.helpers import prod, AMX
|
6
8
|
|
7
9
|
def ldt(dt:DType):
|
10
|
+
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
8
11
|
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
9
12
|
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
10
13
|
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
@@ -20,7 +23,7 @@ def lcast(input_type:DType, output_type:DType):
|
|
20
23
|
if dtypes.is_float(input_type):
|
21
24
|
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
|
22
25
|
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
|
23
|
-
if dtypes.is_unsigned(input_type) or
|
26
|
+
if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type):
|
24
27
|
if dtypes.is_float(output_type): return 'uitofp'
|
25
28
|
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
|
26
29
|
if dtypes.is_int(input_type):
|
@@ -28,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
|
|
28
31
|
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
29
32
|
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
30
33
|
|
34
|
+
# https://github.com/corsix/amx
|
35
|
+
def render_wmma(ctx, wmma: UOp) -> str:
|
36
|
+
def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
|
37
|
+
|
38
|
+
return "\n".join([
|
39
|
+
*[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
|
40
|
+
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
|
41
|
+
*[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
|
42
|
+
f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
|
43
|
+
*[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
|
44
|
+
f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
|
45
|
+
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
46
|
+
|
31
47
|
# llvm ops, lop[<dtype>][<op>]
|
32
48
|
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
33
49
|
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
@@ -36,7 +52,7 @@ flags = " nsz arcp contract afn"
|
|
36
52
|
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
|
37
53
|
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
38
54
|
|
39
|
-
|
55
|
+
base_rewrite = PatternMatcher([
|
40
56
|
# memory load/store
|
41
57
|
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
42
58
|
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
@@ -49,12 +65,22 @@ llvm_rewrite = PatternMatcher([
|
|
49
65
|
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
50
66
|
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
51
67
|
|
68
|
+
# GEP/VECTORIZE/CAST for float4 support
|
69
|
+
(UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
|
70
|
+
(UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y:
|
71
|
+
f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
|
72
|
+
f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
|
73
|
+
(UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
|
74
|
+
f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
|
75
|
+
f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
|
76
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
77
|
+
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
|
78
|
+
|
52
79
|
# unary/binary/ternary ops
|
53
|
-
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
54
|
-
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
55
80
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
56
81
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
57
|
-
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
82
|
+
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
83
|
+
f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
58
84
|
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
59
85
|
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
|
60
86
|
|
@@ -71,6 +97,9 @@ llvm_rewrite = PatternMatcher([
|
|
71
97
|
# if
|
72
98
|
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
73
99
|
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
100
|
+
|
101
|
+
# wmma
|
102
|
+
(UPat(Ops.WMMA, name="wmma"), render_wmma),
|
74
103
|
])
|
75
104
|
|
76
105
|
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
@@ -79,10 +108,13 @@ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
|
79
108
|
|
80
109
|
class LLVMRenderer(Renderer):
|
81
110
|
device = "LLVM"
|
82
|
-
|
111
|
+
abi = 'win64cc' if sys.platform == 'win32' else None
|
112
|
+
supports_float4 = True
|
83
113
|
has_local = False
|
84
114
|
has_shared = False
|
85
115
|
global_max = None
|
116
|
+
string_rewrite = base_rewrite
|
117
|
+
if AMX: tensor_cores = ClangRenderer.amx_tc
|
86
118
|
|
87
119
|
extra_matcher = PatternMatcher([
|
88
120
|
# rewrite RECIP with FDIV
|
@@ -95,32 +127,36 @@ class LLVMRenderer(Renderer):
|
|
95
127
|
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
|
96
128
|
])
|
97
129
|
|
98
|
-
def
|
99
|
-
self.abi = abi
|
100
|
-
|
101
|
-
def render(self, name: str, uops: list[UOp]) -> str:
|
130
|
+
def render(self, uops: list[UOp]) -> str:
|
102
131
|
r: dict[UOp, str] = {}
|
103
132
|
args: list[str] = []
|
104
133
|
kernel: list[str] = []
|
105
134
|
end_lines: dict[str, None] = {}
|
106
135
|
vc = -1
|
107
136
|
|
108
|
-
# prealloc all assigns
|
109
137
|
acc_to_assign: dict[UOp, UOp] = {}
|
110
138
|
for u in uops:
|
111
|
-
if u.op is Ops.ASSIGN:
|
139
|
+
if u.op is Ops.ASSIGN: # prealloc all assigns
|
112
140
|
vc += 1
|
113
141
|
r[u] = r[u.src[1]] = f"%assign{vc}"
|
114
142
|
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
115
143
|
acc_to_assign[u.src[0]] = u.src[1]
|
144
|
+
if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
|
145
|
+
vc += 1
|
146
|
+
r[u] = f"%wmma{vc}"
|
147
|
+
for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
|
148
|
+
kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
|
149
|
+
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
|
116
150
|
|
151
|
+
name = "test"
|
117
152
|
for u in uops:
|
118
|
-
|
119
|
-
|
120
|
-
|
153
|
+
if u.op is Ops.NAME:
|
154
|
+
name = u.arg
|
155
|
+
continue
|
121
156
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
122
157
|
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
123
|
-
|
158
|
+
# NOTE: MallocAllocator promises 0x20 alignment
|
159
|
+
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
124
160
|
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
125
161
|
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
126
162
|
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
@@ -132,7 +168,8 @@ class LLVMRenderer(Renderer):
|
|
132
168
|
r[u] = f"%v{vc}"
|
133
169
|
|
134
170
|
# do the rendering of the llvm ir code
|
135
|
-
if (l:=
|
171
|
+
if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
|
172
|
+
raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
136
173
|
kernel.append(cast(str, l))
|
137
174
|
|
138
175
|
# generate the phi nodes for the assigns
|
tinygrad/renderer/ptx.py
CHANGED
@@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
|
|
65
65
|
if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
|
66
66
|
else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"
|
67
67
|
|
68
|
-
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"}
|
68
|
+
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
|
69
69
|
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
|
70
70
|
f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'
|
71
71
|
|
@@ -154,7 +154,7 @@ class PTXRenderer(Renderer):
|
|
154
154
|
params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
|
155
155
|
return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
|
156
156
|
|
157
|
-
def render(self,
|
157
|
+
def render(self, uops:list[UOp]) -> str:
|
158
158
|
kernel:list[str] = []
|
159
159
|
bufs = []
|
160
160
|
|
@@ -169,7 +169,11 @@ class PTXRenderer(Renderer):
|
|
169
169
|
c[prefix] += 1
|
170
170
|
return f"%{prefix}{c[prefix]-1}"
|
171
171
|
|
172
|
+
name = "test"
|
172
173
|
for u in uops:
|
174
|
+
if u.op is Ops.NAME:
|
175
|
+
name = u.arg
|
176
|
+
continue
|
173
177
|
if u.op is Ops.VECTORIZE:
|
174
178
|
r[u] = [cast(str,r[x]) for x in u.src]
|
175
179
|
continue
|
tinygrad/renderer/wgsl.py
CHANGED
@@ -25,16 +25,19 @@ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
|
25
25
|
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
26
26
|
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
27
27
|
|
28
|
+
def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
|
29
|
+
|
28
30
|
wgsl_matcher = PatternMatcher([
|
29
31
|
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
30
32
|
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
|
31
|
-
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if l.dtype
|
33
|
+
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if is_packed(l.dtype) else None),
|
32
34
|
(UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
|
33
|
-
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if l.dtype
|
34
|
-
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if var.dtype
|
35
|
+
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
|
36
|
+
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
|
35
37
|
# TODO: why is this needed, and only for this MUL order
|
36
38
|
(UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
|
37
39
|
lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
|
40
|
+
(UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
|
38
41
|
]) + extra_pm
|
39
42
|
|
40
43
|
class WGSLRenderer(CStyleLanguage):
|
@@ -48,38 +51,43 @@ class WGSLRenderer(CStyleLanguage):
|
|
48
51
|
code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"}
|
49
52
|
nan = "nan()"
|
50
53
|
type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
|
51
|
-
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool" }
|
54
|
+
dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
|
52
55
|
|
53
56
|
string_rewrite = PatternMatcher([
|
54
57
|
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"),
|
55
58
|
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
|
56
59
|
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
57
60
|
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
|
58
|
-
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<
|
61
|
+
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]" \
|
62
|
+
if x.src[0].dtype in [dtypes.short, dtypes.ushort, dtypes.uint32] else None),
|
63
|
+
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
64
|
+
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
|
65
|
+
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
66
|
+
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
59
67
|
(UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
|
60
68
|
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
|
61
69
|
(UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
62
70
|
(UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
|
63
71
|
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
64
|
-
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype
|
72
|
+
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
|
65
73
|
else f"{ctx[b]} = {ctx[v]};"),
|
66
74
|
# fix nan check: 'a != a -> is_nan()'
|
67
|
-
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"
|
75
|
+
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
|
68
76
|
]) + base_rewrite
|
69
77
|
|
70
78
|
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
71
79
|
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
72
|
-
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt
|
73
|
-
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if dt
|
80
|
+
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
|
81
|
+
def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
|
74
82
|
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
75
83
|
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
76
84
|
if not local_size: local_size = [1]
|
77
85
|
bind_it = iter(range(len(bufs)))
|
78
86
|
external_local_bufs = [line.lstrip() for line in kernel if "var<workgroup>" in line]
|
79
87
|
kernel[:] = [line for line in kernel if "var<workgroup>" not in line]
|
80
|
-
prg = "
|
81
|
-
|
82
|
-
prg += "
|
88
|
+
prg = "enable f16;\n" if any(uop.dtype.base == dtypes.half for uop in uops) else ""
|
89
|
+
prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
|
90
|
+
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
83
91
|
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
84
92
|
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
85
93
|
f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
|