tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/renderer/__init__.py
CHANGED
@@ -1,41 +1,85 @@
|
|
1
|
-
from
|
2
|
-
import
|
3
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, Callable
|
3
|
+
import functools, math
|
4
|
+
from dataclasses import dataclass, field, replace
|
4
5
|
from tinygrad.helpers import to_function_name, dedup, prod
|
5
|
-
from tinygrad.ops import Ops, UOp,
|
6
|
+
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
6
7
|
from tinygrad.dtype import DType
|
7
8
|
|
8
9
|
@dataclass(frozen=True)
|
9
10
|
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
10
|
-
dims:
|
11
|
+
dims: tuple[int,int,int] # N, M, K
|
12
|
+
threads: int # number of threads that construct the warp
|
13
|
+
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
11
14
|
dtype_in: DType # dtype for A and B
|
12
15
|
dtype_out: DType # dtype for C and D
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
def
|
17
|
-
|
18
|
-
upcast_axes: Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
|
19
|
-
st1_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
|
20
|
-
st2_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
|
21
|
-
expanded_shape: Optional[Tuple[int, ...]] = None
|
22
|
-
opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
|
16
|
+
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
17
|
+
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
|
18
|
+
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
19
|
+
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
20
|
+
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
23
21
|
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
22
|
+
def __post_init__(self):
|
23
|
+
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
24
|
+
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
|
25
|
+
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
|
26
|
+
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
27
|
+
assert 2**upcast_axes == self.elements_per_thread[2], (
|
28
|
+
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
|
29
|
+
assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
|
30
|
+
f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
|
31
|
+
|
32
|
+
@dataclass(frozen=True)
|
33
|
+
class Estimates:
|
34
|
+
# number of FLOPS used in the Kernel
|
35
|
+
ops:sint = 0
|
36
|
+
# bytes accessed in loads and stores
|
37
|
+
lds:sint = 0
|
38
|
+
# total bytes accessed, counting only once for bytes that are accessed multiple times
|
39
|
+
mem:sint = 0
|
40
|
+
def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem)
|
41
|
+
def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
|
42
|
+
@staticmethod
|
43
|
+
def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
|
44
|
+
flops: sint = 0
|
45
|
+
lds: sint = 0
|
46
|
+
mults: sint = 1
|
47
|
+
mult_stack: list[sint] = []
|
48
|
+
dont_count: set[UOp] = set()
|
49
|
+
if ignore_indexing:
|
50
|
+
for u in uops:
|
51
|
+
if u.op in {Ops.LOAD, Ops.STORE}:
|
52
|
+
dont_count = dont_count.union(u.src[0].toposort)
|
53
|
+
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
|
54
|
+
elif u.op is Ops.IF:
|
55
|
+
dont_count = dont_count.union(u.src[0].toposort)
|
56
|
+
for u in uops:
|
57
|
+
if u.op is Ops.RANGE:
|
58
|
+
mult_stack.append(mults)
|
59
|
+
mults *= (u.src[1] - u.src[0]).ssimplify()
|
60
|
+
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
61
|
+
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
62
|
+
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
|
63
|
+
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
|
64
|
+
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
65
|
+
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
66
|
+
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
24
67
|
|
25
68
|
@dataclass
|
26
|
-
class
|
69
|
+
class ProgramSpec:
|
27
70
|
name:str
|
28
71
|
src:str
|
29
|
-
|
30
|
-
uops:Optional[
|
72
|
+
device:str
|
73
|
+
uops:Optional[list[UOp]]=None
|
31
74
|
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
32
75
|
|
33
76
|
# filled in from uops (if we have uops)
|
34
|
-
global_size:Optional[
|
35
|
-
local_size:Optional[
|
36
|
-
vars:
|
37
|
-
globals:
|
38
|
-
outs:
|
77
|
+
global_size:Optional[list[int]]=None
|
78
|
+
local_size:Optional[list[int]]=None
|
79
|
+
vars:list[Variable]=field(default_factory=list)
|
80
|
+
globals:list[int]=field(default_factory=list)
|
81
|
+
outs:list[int]=field(default_factory=list)
|
82
|
+
ins:list[int]=field(default_factory=list)
|
39
83
|
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
|
40
84
|
|
41
85
|
def __post_init__(self):
|
@@ -44,7 +88,8 @@ class Program:
|
|
44
88
|
for u in self.uops:
|
45
89
|
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
46
90
|
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
47
|
-
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].
|
91
|
+
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
92
|
+
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
48
93
|
if u.op is Ops.SPECIAL:
|
49
94
|
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
50
95
|
if u.arg[0][0] == 'i': self.local_size = None
|
@@ -53,19 +98,17 @@ class Program:
|
|
53
98
|
special_size[int(u.arg[0][-1])] = u.arg[1]
|
54
99
|
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
55
100
|
self.outs = sorted(dedup(self.outs))
|
101
|
+
self.ins = sorted(dedup(self.ins))
|
56
102
|
self._ran_post_init = True
|
57
103
|
|
58
|
-
@property
|
59
|
-
def op_estimate(self) -> sint: return self._ops_lds[0]
|
60
|
-
@property
|
61
|
-
def lds_estimate(self) -> sint: return self._ops_lds[1]
|
62
104
|
@functools.cached_property
|
63
|
-
def
|
105
|
+
def estimates(self) -> Estimates:
|
106
|
+
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
|
64
107
|
|
65
108
|
@functools.cached_property
|
66
109
|
def function_name(self) -> str: return to_function_name(self.name)
|
67
110
|
|
68
|
-
def launch_dims(self, var_vals:
|
111
|
+
def launch_dims(self, var_vals:dict[Variable, int]):
|
69
112
|
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
70
113
|
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
71
114
|
return global_size, local_size
|
@@ -78,12 +121,12 @@ class Renderer:
|
|
78
121
|
has_local: bool = True
|
79
122
|
has_shared: bool = True
|
80
123
|
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
81
|
-
global_max: Optional[
|
82
|
-
local_max: Optional[
|
124
|
+
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
125
|
+
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
83
126
|
shared_max: int = 32768
|
84
|
-
tensor_cores:
|
85
|
-
extra_matcher:
|
86
|
-
code_for_op:
|
127
|
+
tensor_cores: list[TensorCore] = []
|
128
|
+
extra_matcher: Optional[PatternMatcher] = None
|
129
|
+
code_for_op: dict[Ops, Callable] = {}
|
87
130
|
|
88
131
|
def __reduce__(self): return self.__class__, ()
|
89
|
-
def render(self, name:str, uops:
|
132
|
+
def render(self, name:str, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
import os, math
|
1
|
+
from typing import Optional, Union, Literal, Callable, cast
|
2
|
+
import os, math, sys
|
4
3
|
from collections import defaultdict, Counter
|
5
|
-
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
4
|
+
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
6
5
|
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
7
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
8
7
|
from tinygrad.renderer import Renderer, TensorCore
|
@@ -21,7 +20,7 @@ base_rewrite = PatternMatcher([
|
|
21
20
|
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
|
22
21
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
23
22
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
24
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.
|
23
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
25
24
|
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
26
25
|
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
27
26
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
@@ -57,14 +56,11 @@ extra_pm = PatternMatcher([
|
|
57
56
|
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
58
57
|
(UPat(Ops.BITCAST, name="x"),
|
59
58
|
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
60
|
-
# gate any stores that aren't gated with ifs
|
61
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
62
|
-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
63
59
|
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
64
60
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
65
61
|
])
|
66
62
|
|
67
|
-
def uops_to_dtypes(uops:
|
63
|
+
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
68
64
|
|
69
65
|
class CStyleLanguage(Renderer):
|
70
66
|
kernel_prefix: str = ""
|
@@ -75,13 +71,13 @@ class CStyleLanguage(Renderer):
|
|
75
71
|
smem_prefix_for_cast: bool = True
|
76
72
|
arg_int_prefix: str = "const int"
|
77
73
|
barrier: str = ""
|
78
|
-
code_for_workitem:
|
79
|
-
extra_args:
|
74
|
+
code_for_workitem: dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
75
|
+
extra_args: list[str] = []
|
80
76
|
float4: Optional[str] = None
|
81
|
-
type_map:
|
77
|
+
type_map: dict[DType, str] = {}
|
82
78
|
infinity: str = "INFINITY"
|
83
79
|
nan: str = "NAN"
|
84
|
-
code_for_op:
|
80
|
+
code_for_op: dict = {
|
85
81
|
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
86
82
|
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
87
83
|
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
@@ -93,8 +89,8 @@ class CStyleLanguage(Renderer):
|
|
93
89
|
string_rewrite = base_rewrite
|
94
90
|
extra_matcher = extra_pm
|
95
91
|
|
96
|
-
def get_kernel_modifier(self, uops:
|
97
|
-
def render_kernel(self, function_name:str, kernel:
|
92
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
|
93
|
+
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
98
94
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
99
95
|
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
100
96
|
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
@@ -105,23 +101,21 @@ class CStyleLanguage(Renderer):
|
|
105
101
|
|
106
102
|
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
|
107
103
|
def render_dtype(self, dt:DType, mutable=True) -> str:
|
108
|
-
if isinstance(dt, ImageDType):
|
109
|
-
return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
104
|
+
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
110
105
|
if isinstance(dt, PtrDType):
|
111
|
-
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) +
|
112
|
-
self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
|
106
|
+
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
|
113
107
|
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
|
114
108
|
|
115
109
|
def __getitem__(self, key): return self.r[key] # hacky helper
|
116
|
-
def render(self, name:str, uops:
|
117
|
-
r:
|
110
|
+
def render(self, name:str, uops:list[UOp]) -> str:
|
111
|
+
r: dict[UOp, str] = {}
|
118
112
|
self.r = r
|
119
113
|
|
120
114
|
child_count = Counter(v for ru in uops for v in ru.src)
|
121
|
-
bufs:
|
115
|
+
bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
|
122
116
|
kernel = []
|
123
117
|
depth = 1
|
124
|
-
c:
|
118
|
+
c: defaultdict[str, int] = defaultdict(int)
|
125
119
|
for u in uops:
|
126
120
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
127
121
|
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
@@ -130,7 +124,7 @@ class CStyleLanguage(Renderer):
|
|
130
124
|
|
131
125
|
# mark buffers that we store to writable
|
132
126
|
if u.op is Ops.STORE:
|
133
|
-
for up in u.src[0].
|
127
|
+
for up in u.src[0].toposort:
|
134
128
|
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
135
129
|
|
136
130
|
# naming
|
@@ -147,8 +141,8 @@ class CStyleLanguage(Renderer):
|
|
147
141
|
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
148
142
|
|
149
143
|
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
150
|
-
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or
|
151
|
-
|
144
|
+
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or \
|
145
|
+
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
152
146
|
r[u] = l
|
153
147
|
else:
|
154
148
|
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
@@ -176,11 +170,16 @@ class ClangRenderer(CStyleLanguage):
|
|
176
170
|
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
177
171
|
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
178
172
|
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
173
|
+
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
174
|
+
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
|
175
|
+
CStyleLanguage.extra_matcher
|
179
176
|
|
180
177
|
if AMX:
|
181
|
-
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=
|
182
|
-
|
183
|
-
|
178
|
+
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
179
|
+
swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
|
180
|
+
for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
181
|
+
if sys.platform == 'win32':
|
182
|
+
kernel_prefix = "__attribute__((ms_abi)) "
|
184
183
|
def render_vector_prefix(self, dt:DType) -> str:
|
185
184
|
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
|
186
185
|
|
@@ -192,7 +191,10 @@ class ClangRenderer(CStyleLanguage):
|
|
192
191
|
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
193
192
|
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
194
193
|
]
|
195
|
-
|
194
|
+
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
195
|
+
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
|
196
|
+
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
197
|
+
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
196
198
|
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
197
199
|
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
198
200
|
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
@@ -209,7 +211,8 @@ class OpenCLRenderer(CStyleLanguage):
|
|
209
211
|
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
210
212
|
float4 = "(float4)"
|
211
213
|
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
212
|
-
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
|
214
|
+
type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
|
215
|
+
dtypes.bfloat16: "ushort" }
|
213
216
|
|
214
217
|
string_rewrite = PatternMatcher([
|
215
218
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
@@ -223,17 +226,17 @@ class OpenCLRenderer(CStyleLanguage):
|
|
223
226
|
]) + base_rewrite
|
224
227
|
|
225
228
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
226
|
-
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
229
|
+
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
227
230
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
228
231
|
|
229
232
|
class IntelRenderer(OpenCLRenderer):
|
230
233
|
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
231
|
-
tensor_cores = [TensorCore(dims=(8,8,16),threads=
|
232
|
-
|
234
|
+
tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
235
|
+
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
|
233
236
|
|
234
237
|
string_rewrite = PatternMatcher([
|
235
|
-
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x
|
236
|
-
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x
|
238
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
|
239
|
+
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
|
237
240
|
]) + OpenCLRenderer.string_rewrite
|
238
241
|
|
239
242
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
@@ -247,9 +250,9 @@ class IntelRenderer(OpenCLRenderer):
|
|
247
250
|
class MetalRenderer(CStyleLanguage):
|
248
251
|
device = "METAL"
|
249
252
|
shared_max = 32768
|
250
|
-
tensor_cores = [TensorCore(dims=(8,8,8),threads=
|
251
|
-
|
252
|
-
|
253
|
+
tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
|
254
|
+
swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
255
|
+
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
253
256
|
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
254
257
|
|
255
258
|
# language options
|
@@ -289,18 +292,26 @@ class MetalRenderer(CStyleLanguage):
|
|
289
292
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
290
293
|
|
291
294
|
_nms = "xyzwabcdefghijkl"
|
295
|
+
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
292
296
|
|
293
297
|
class CUDARenderer(CStyleLanguage):
|
294
298
|
device = "CUDA"
|
295
299
|
global_max = (2147483647, 65535, 65535)
|
296
300
|
local_max = (1024, 1024, 64)
|
297
301
|
shared_max = 49152
|
298
|
-
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
302
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
303
|
+
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
|
304
|
+
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
|
305
|
+
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
306
|
+
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
|
307
|
+
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
308
|
+
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
309
|
+
|
310
|
+
tc_sm80 = tc_81616 + tc_8168_f16
|
311
|
+
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
|
312
|
+
tc_sm75 = tc_8168_f16
|
313
|
+
def __init__(self, arch:str):
|
314
|
+
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
|
304
315
|
def __reduce__(self): return self.__class__, (self.arch,)
|
305
316
|
|
306
317
|
# language options
|
@@ -333,7 +344,7 @@ class CUDARenderer(CStyleLanguage):
|
|
333
344
|
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
334
345
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
335
346
|
|
336
|
-
dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
347
|
+
dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
337
348
|
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
338
349
|
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
339
350
|
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
@@ -351,18 +362,24 @@ class CUDARenderer(CStyleLanguage):
|
|
351
362
|
|
352
363
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
353
364
|
|
354
|
-
def get_kernel_modifier(self, uops:
|
365
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
355
366
|
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
356
367
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
357
368
|
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
358
369
|
|
370
|
+
def cast_float_to_bf16(x: UOp) -> UOp:
|
371
|
+
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
372
|
+
x = x.bitcast(dtypes.uint)
|
373
|
+
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
374
|
+
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
375
|
+
|
359
376
|
class AMDRenderer(CStyleLanguage):
|
360
377
|
device = "AMD"
|
361
378
|
shared_max = 65536
|
362
379
|
# https://gpuopen.com/learn/wmma_on_rdna3/
|
363
|
-
tensor_cores = [TensorCore(dims=(16,16,16), threads=
|
364
|
-
|
365
|
-
for
|
380
|
+
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
381
|
+
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
382
|
+
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
|
366
383
|
|
367
384
|
# language options
|
368
385
|
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
@@ -397,8 +414,7 @@ class AMDRenderer(CStyleLanguage):
|
|
397
414
|
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
398
415
|
# bfloat16 casting
|
399
416
|
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
400
|
-
(UPat(Ops.CAST,
|
401
|
-
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
417
|
+
(UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
402
418
|
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
403
419
|
|
404
420
|
def render_vector_prefix(self, dtype:DType) -> str:
|
@@ -410,7 +426,7 @@ class AMDRenderer(CStyleLanguage):
|
|
410
426
|
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
411
427
|
|
412
428
|
used_dtypes = uops_to_dtypes(uops)
|
413
|
-
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("
|
429
|
+
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
414
430
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
415
431
|
|
416
432
|
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
@@ -421,42 +437,12 @@ class AMDRenderer(CStyleLanguage):
|
|
421
437
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
422
438
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
423
439
|
|
424
|
-
def get_kernel_modifier(self, uops:
|
440
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
425
441
|
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
426
442
|
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
427
443
|
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
428
444
|
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
429
445
|
|
430
|
-
class DSPRenderer(ClangRenderer):
|
431
|
-
device = "DSP"
|
432
|
-
supports_float4 = False
|
433
|
-
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
434
|
-
kernel_prefix = "__attribute__((noinline)) "
|
435
|
-
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
436
|
-
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
437
|
-
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
438
|
-
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
|
439
|
-
|
440
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
441
|
-
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
442
|
-
msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
|
443
|
-
short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
|
444
|
-
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
445
|
-
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
446
|
-
'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
447
|
-
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
448
|
-
'HAP_power_set((void*)handle, (void*)&req);']
|
449
|
-
msrc += ['if ((sc>>24) != 2) return 0;']
|
450
|
-
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
451
|
-
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
452
|
-
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
453
|
-
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
454
|
-
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
455
|
-
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
456
|
-
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
457
|
-
msrc += ["return 0; }"]
|
458
|
-
return ret + '\n' + '\n'.join(msrc)
|
459
|
-
|
460
446
|
class NVRenderer(CUDARenderer): device = "NV"
|
461
447
|
class HIPRenderer(AMDRenderer): device = "HIP"
|
462
448
|
class QCOMRenderer(OpenCLRenderer): device = "QCOM"
|
tinygrad/renderer/llvmir.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast
|
2
2
|
import math, struct
|
3
3
|
from tinygrad.renderer import Renderer
|
4
4
|
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
@@ -60,19 +60,23 @@ llvm_rewrite = PatternMatcher([
|
|
60
60
|
|
61
61
|
# range
|
62
62
|
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
63
|
-
f" br label %loop_entry_{x.arg
|
64
|
-
f" br label %loop_body_{x.arg
|
65
|
-
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg
|
63
|
+
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
|
64
|
+
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
|
65
|
+
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
|
66
66
|
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
67
|
-
f" br label %loop_latch_{x.src[0].arg
|
67
|
+
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
|
68
68
|
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
69
|
-
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg
|
69
|
+
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
|
70
70
|
|
71
71
|
# if
|
72
72
|
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
73
73
|
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
74
74
|
])
|
75
75
|
|
76
|
+
def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
|
77
|
+
u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
|
78
|
+
return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
|
79
|
+
|
76
80
|
class LLVMRenderer(Renderer):
|
77
81
|
device = "LLVM"
|
78
82
|
supports_float4 = False
|
@@ -85,23 +89,24 @@ class LLVMRenderer(Renderer):
|
|
85
89
|
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
86
90
|
# rewrite cast to bool to CMPNE 0
|
87
91
|
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
88
|
-
# *** also in cstyle ***
|
89
|
-
# gate any stores that aren't gated with ifs
|
90
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
91
|
-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
92
92
|
# rewrite MAX to CMPLT + WHERE
|
93
93
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
94
|
+
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
|
95
|
+
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
|
94
96
|
])
|
95
97
|
|
96
|
-
def
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
98
|
+
def __init__(self, abi:str|None=None):
|
99
|
+
self.abi = abi
|
100
|
+
|
101
|
+
def render(self, name: str, uops: list[UOp]) -> str:
|
102
|
+
r: dict[UOp, str] = {}
|
103
|
+
args: list[str] = []
|
104
|
+
kernel: list[str] = []
|
105
|
+
end_lines: dict[str, None] = {}
|
101
106
|
vc = -1
|
102
107
|
|
103
108
|
# prealloc all assigns
|
104
|
-
acc_to_assign:
|
109
|
+
acc_to_assign: dict[UOp, UOp] = {}
|
105
110
|
for u in uops:
|
106
111
|
if u.op is Ops.ASSIGN:
|
107
112
|
vc += 1
|
@@ -133,10 +138,17 @@ class LLVMRenderer(Renderer):
|
|
133
138
|
# generate the phi nodes for the assigns
|
134
139
|
if u.op is Ops.RANGE:
|
135
140
|
for x in acc_to_assign:
|
136
|
-
if u in x.src: # if this range is
|
141
|
+
if u in x.src: # if this range is relevant for this acc
|
137
142
|
vc += 1
|
138
|
-
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg
|
143
|
+
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
|
139
144
|
r[x] = f"%acc{vc}"
|
140
145
|
|
141
|
-
# output the function
|
142
|
-
return f
|
146
|
+
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
|
147
|
+
return f'''\
|
148
|
+
define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
|
149
|
+
{chr(10).join(kernel)}
|
150
|
+
ret void
|
151
|
+
}}
|
152
|
+
{chr(10).join(end_lines.keys())}
|
153
|
+
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
|
154
|
+
'''
|