tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/renderer/__init__.py
CHANGED
@@ -1,41 +1,100 @@
|
|
1
|
-
from
|
2
|
-
import
|
3
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Optional, Callable
|
3
|
+
import functools, math
|
4
|
+
from enum import Enum, auto
|
5
|
+
from dataclasses import dataclass, field, replace
|
4
6
|
from tinygrad.helpers import to_function_name, dedup, prod
|
5
|
-
from tinygrad.ops import Ops, UOp,
|
7
|
+
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
6
8
|
from tinygrad.dtype import DType
|
7
9
|
|
10
|
+
class OptOps(Enum):
|
11
|
+
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
|
12
|
+
GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto(); SWAP = auto() # noqa: E702
|
13
|
+
def __lt__(self, x:OptOps): return self.value < x.value
|
14
|
+
|
15
|
+
@dataclass(frozen=True, order=True)
|
16
|
+
class Opt:
|
17
|
+
op: OptOps
|
18
|
+
axis: Optional[int] = None
|
19
|
+
arg: Optional[int | tuple] = None
|
20
|
+
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})"
|
21
|
+
|
8
22
|
@dataclass(frozen=True)
|
9
23
|
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
10
|
-
dims:
|
24
|
+
dims: tuple[int,int,int] # N, M, K
|
25
|
+
threads: int # number of threads that construct the warp
|
26
|
+
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
11
27
|
dtype_in: DType # dtype for A and B
|
12
28
|
dtype_out: DType # dtype for C and D
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
def
|
17
|
-
|
18
|
-
upcast_axes: Tuple[List[Tuple[int,int]], List[Tuple[int,int]], List[Tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
|
19
|
-
st1_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
|
20
|
-
st2_pattern: Optional[Tuple[Tuple[Tuple[int,int], ...], Tuple[Tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
|
21
|
-
expanded_shape: Optional[Tuple[int, ...]] = None
|
22
|
-
opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
|
29
|
+
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
30
|
+
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
|
31
|
+
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
32
|
+
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
33
|
+
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
23
34
|
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
35
|
+
def __post_init__(self):
|
36
|
+
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
37
|
+
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
|
38
|
+
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
|
39
|
+
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
40
|
+
assert 2**upcast_axes == self.elements_per_thread[2], (
|
41
|
+
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
|
42
|
+
assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
|
43
|
+
f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
|
44
|
+
|
45
|
+
@dataclass(frozen=True)
|
46
|
+
class Estimates:
|
47
|
+
# number of FLOPS used in the Kernel
|
48
|
+
ops:sint = 0
|
49
|
+
# bytes accessed in loads and stores
|
50
|
+
lds:sint = 0
|
51
|
+
# total bytes accessed, counting only once for bytes that are accessed multiple times
|
52
|
+
mem:sint = 0
|
53
|
+
def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem)
|
54
|
+
def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
|
55
|
+
@staticmethod
|
56
|
+
def from_uops(uops:list[UOp], ignore_indexing=False) -> Estimates:
|
57
|
+
flops: sint = 0
|
58
|
+
lds: sint = 0
|
59
|
+
mults: sint = 1
|
60
|
+
mult_stack: list[sint] = []
|
61
|
+
dont_count: set[UOp] = set()
|
62
|
+
if ignore_indexing:
|
63
|
+
for u in uops:
|
64
|
+
if u.op in {Ops.LOAD, Ops.STORE}:
|
65
|
+
dont_count = dont_count.union(u.src[0].toposort)
|
66
|
+
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
|
67
|
+
elif u.op is Ops.IF:
|
68
|
+
dont_count = dont_count.union(u.src[0].toposort)
|
69
|
+
for u in uops:
|
70
|
+
if u.op is Ops.RANGE:
|
71
|
+
mult_stack.append(mults)
|
72
|
+
mults *= (u.src[1] - u.src[0]).ssimplify()
|
73
|
+
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)
|
74
|
+
elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
75
|
+
elif u.op is Ops.LOAD: lds += u.dtype.itemsize * mults
|
76
|
+
elif u.op is Ops.STORE: lds += u.src[1].dtype.itemsize * mults
|
77
|
+
elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
|
78
|
+
elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
79
|
+
return Estimates(flops, lds, lds) # TODO: properly track memory, lds is always a high estimate
|
24
80
|
|
25
81
|
@dataclass
|
26
|
-
class
|
82
|
+
class ProgramSpec:
|
27
83
|
name:str
|
28
84
|
src:str
|
29
|
-
|
30
|
-
|
85
|
+
device:str
|
86
|
+
ast:UOp # save the base ast (this is method cache key)
|
87
|
+
uops:Optional[list[UOp]]=None
|
88
|
+
applied_opts:Optional[list[Opt]]=None
|
31
89
|
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
32
90
|
|
33
91
|
# filled in from uops (if we have uops)
|
34
|
-
global_size:Optional[
|
35
|
-
local_size:Optional[
|
36
|
-
vars:
|
37
|
-
globals:
|
38
|
-
outs:
|
92
|
+
global_size:Optional[list[int]]=None
|
93
|
+
local_size:Optional[list[int]]=None
|
94
|
+
vars:list[Variable]=field(default_factory=list)
|
95
|
+
globals:list[int]=field(default_factory=list)
|
96
|
+
outs:list[int]=field(default_factory=list)
|
97
|
+
ins:list[int]=field(default_factory=list)
|
39
98
|
_ran_post_init:bool=False # NOTE: this is needed if you call replace on the Program
|
40
99
|
|
41
100
|
def __post_init__(self):
|
@@ -44,7 +103,8 @@ class Program:
|
|
44
103
|
for u in self.uops:
|
45
104
|
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
46
105
|
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
47
|
-
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].
|
106
|
+
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
107
|
+
if u.op is Ops.LOAD: self.ins.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
48
108
|
if u.op is Ops.SPECIAL:
|
49
109
|
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
50
110
|
if u.arg[0][0] == 'i': self.local_size = None
|
@@ -53,19 +113,17 @@ class Program:
|
|
53
113
|
special_size[int(u.arg[0][-1])] = u.arg[1]
|
54
114
|
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
55
115
|
self.outs = sorted(dedup(self.outs))
|
116
|
+
self.ins = sorted(dedup(self.ins))
|
56
117
|
self._ran_post_init = True
|
57
118
|
|
58
|
-
@property
|
59
|
-
def op_estimate(self) -> sint: return self._ops_lds[0]
|
60
|
-
@property
|
61
|
-
def lds_estimate(self) -> sint: return self._ops_lds[1]
|
62
119
|
@functools.cached_property
|
63
|
-
def
|
120
|
+
def estimates(self) -> Estimates:
|
121
|
+
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
|
64
122
|
|
65
123
|
@functools.cached_property
|
66
124
|
def function_name(self) -> str: return to_function_name(self.name)
|
67
125
|
|
68
|
-
def launch_dims(self, var_vals:
|
126
|
+
def launch_dims(self, var_vals:dict[Variable, int]):
|
69
127
|
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
70
128
|
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
71
129
|
return global_size, local_size
|
@@ -78,12 +136,13 @@ class Renderer:
|
|
78
136
|
has_local: bool = True
|
79
137
|
has_shared: bool = True
|
80
138
|
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
81
|
-
global_max: Optional[
|
82
|
-
local_max: Optional[
|
139
|
+
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
140
|
+
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
83
141
|
shared_max: int = 32768
|
84
|
-
tensor_cores:
|
85
|
-
|
86
|
-
|
142
|
+
tensor_cores: list[TensorCore] = []
|
143
|
+
pre_matcher: Optional[PatternMatcher] = None
|
144
|
+
extra_matcher: Optional[PatternMatcher] = None
|
145
|
+
code_for_op: dict[Ops, Callable] = {}
|
87
146
|
|
88
147
|
def __reduce__(self): return self.__class__, ()
|
89
|
-
def render(self,
|
148
|
+
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
tinygrad/renderer/cstyle.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
import os, math
|
1
|
+
from typing import Optional, Union, Literal, Callable, cast
|
2
|
+
import os, math, sys
|
4
3
|
from collections import defaultdict, Counter
|
5
|
-
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
4
|
+
from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
6
5
|
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
7
6
|
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
8
7
|
from tinygrad.renderer import Renderer, TensorCore
|
8
|
+
from tinygrad.codegen.devectorizer import no_vectorized_alu
|
9
9
|
|
10
10
|
base_rewrite = PatternMatcher([
|
11
11
|
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
@@ -18,10 +18,12 @@ base_rewrite = PatternMatcher([
|
|
18
18
|
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
19
19
|
(UPat(Ops.VECTORIZE, name="x"),
|
20
20
|
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
21
|
-
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device
|
21
|
+
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
|
22
|
+
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
23
|
+
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
22
24
|
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
23
25
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
24
|
-
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.
|
26
|
+
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
|
25
27
|
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
26
28
|
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
27
29
|
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
@@ -50,21 +52,27 @@ base_rewrite = PatternMatcher([
|
|
50
52
|
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
51
53
|
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
52
54
|
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
53
|
-
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device
|
55
|
+
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
|
56
|
+
f".{'xyzwabcd'[x.arg[0]]}")),
|
57
|
+
# custom passes through with format
|
58
|
+
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
54
59
|
])
|
55
60
|
|
56
61
|
extra_pm = PatternMatcher([
|
57
62
|
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
58
63
|
(UPat(Ops.BITCAST, name="x"),
|
59
64
|
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
60
|
-
# gate any stores that aren't gated with ifs
|
61
|
-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
62
|
-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
63
65
|
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
64
66
|
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
67
|
+
# devectorize any bools
|
68
|
+
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
|
69
|
+
# CAST (from bool) can't be vectorized
|
70
|
+
(UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
|
71
|
+
# WHERE can't be vectorized
|
72
|
+
(UPat(Ops.WHERE, name="alu"), no_vectorized_alu),
|
65
73
|
])
|
66
74
|
|
67
|
-
def uops_to_dtypes(uops:
|
75
|
+
def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
68
76
|
|
69
77
|
class CStyleLanguage(Renderer):
|
70
78
|
kernel_prefix: str = ""
|
@@ -75,13 +83,13 @@ class CStyleLanguage(Renderer):
|
|
75
83
|
smem_prefix_for_cast: bool = True
|
76
84
|
arg_int_prefix: str = "const int"
|
77
85
|
barrier: str = ""
|
78
|
-
code_for_workitem:
|
79
|
-
extra_args:
|
86
|
+
code_for_workitem: dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
87
|
+
extra_args: list[str] = []
|
80
88
|
float4: Optional[str] = None
|
81
|
-
type_map:
|
89
|
+
type_map: dict[DType, str] = {}
|
82
90
|
infinity: str = "INFINITY"
|
83
91
|
nan: str = "NAN"
|
84
|
-
code_for_op:
|
92
|
+
code_for_op: dict = {
|
85
93
|
Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
|
86
94
|
Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
|
87
95
|
Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
|
@@ -93,8 +101,8 @@ class CStyleLanguage(Renderer):
|
|
93
101
|
string_rewrite = base_rewrite
|
94
102
|
extra_matcher = extra_pm
|
95
103
|
|
96
|
-
def get_kernel_modifier(self, uops:
|
97
|
-
def render_kernel(self, function_name:str, kernel:
|
104
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
|
105
|
+
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
|
98
106
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
99
107
|
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
100
108
|
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
@@ -105,24 +113,27 @@ class CStyleLanguage(Renderer):
|
|
105
113
|
|
106
114
|
def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
|
107
115
|
def render_dtype(self, dt:DType, mutable=True) -> str:
|
108
|
-
if isinstance(dt, ImageDType):
|
109
|
-
return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
116
|
+
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
110
117
|
if isinstance(dt, PtrDType):
|
111
|
-
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) +
|
112
|
-
|
113
|
-
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
118
|
+
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
|
119
|
+
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
|
120
|
+
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
114
121
|
|
115
122
|
def __getitem__(self, key): return self.r[key] # hacky helper
|
116
|
-
def render(self,
|
117
|
-
r:
|
123
|
+
def render(self, uops:list[UOp]) -> str:
|
124
|
+
r: dict[UOp, str] = {}
|
118
125
|
self.r = r
|
119
126
|
|
120
127
|
child_count = Counter(v for ru in uops for v in ru.src)
|
121
|
-
bufs:
|
128
|
+
bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
|
122
129
|
kernel = []
|
123
130
|
depth = 1
|
124
|
-
c:
|
131
|
+
c: defaultdict[str, int] = defaultdict(int)
|
132
|
+
name = "test"
|
125
133
|
for u in uops:
|
134
|
+
if u.op is Ops.NAME:
|
135
|
+
name = u.arg
|
136
|
+
continue
|
126
137
|
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
127
138
|
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
128
139
|
bufs[u] = (r[u], (u.dtype, False))
|
@@ -130,7 +141,7 @@ class CStyleLanguage(Renderer):
|
|
130
141
|
|
131
142
|
# mark buffers that we store to writable
|
132
143
|
if u.op is Ops.STORE:
|
133
|
-
for up in u.src[0].
|
144
|
+
for up in u.src[0].toposort:
|
134
145
|
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
135
146
|
|
136
147
|
# naming
|
@@ -147,8 +158,8 @@ class CStyleLanguage(Renderer):
|
|
147
158
|
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
148
159
|
|
149
160
|
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
150
|
-
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX
|
151
|
-
|
161
|
+
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
|
162
|
+
(u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
152
163
|
r[u] = l
|
153
164
|
else:
|
154
165
|
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
@@ -164,25 +175,31 @@ class CStyleLanguage(Renderer):
|
|
164
175
|
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
165
176
|
|
166
177
|
class ClangRenderer(CStyleLanguage):
|
167
|
-
device = "
|
178
|
+
device = "CPU"
|
168
179
|
float4 = "(float4)"
|
169
180
|
has_local = False
|
170
181
|
global_max = None
|
171
182
|
infinity = "__builtin_inff()"
|
172
183
|
nan = '__builtin_nanf("")'
|
184
|
+
amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
|
185
|
+
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
186
|
+
if AMX: tensor_cores = amx_tc
|
173
187
|
|
174
188
|
# language options
|
175
189
|
buffer_suffix = " restrict"
|
176
190
|
type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
|
177
191
|
code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
|
178
192
|
Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
|
193
|
+
# LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
|
194
|
+
extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
|
195
|
+
CStyleLanguage.extra_matcher
|
179
196
|
|
180
|
-
if
|
181
|
-
|
182
|
-
for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
|
183
|
-
|
197
|
+
if sys.platform == 'win32':
|
198
|
+
kernel_prefix = "__attribute__((ms_abi)) "
|
184
199
|
def render_vector_prefix(self, dt:DType) -> str:
|
185
|
-
|
200
|
+
# round (down) to power of two
|
201
|
+
alignment = 2**int(math.log2(dt.itemsize))
|
202
|
+
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
|
186
203
|
|
187
204
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
188
205
|
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
@@ -192,7 +209,10 @@ class ClangRenderer(CStyleLanguage):
|
|
192
209
|
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
193
210
|
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
194
211
|
]
|
195
|
-
|
212
|
+
# 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
|
213
|
+
# to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
|
214
|
+
# wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
|
215
|
+
prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
196
216
|
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
197
217
|
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
198
218
|
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
@@ -209,7 +229,8 @@ class OpenCLRenderer(CStyleLanguage):
|
|
209
229
|
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
|
210
230
|
float4 = "(float4)"
|
211
231
|
code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
|
212
|
-
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
|
232
|
+
type_map = { dtypes.int8: "char", dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong",
|
233
|
+
dtypes.bfloat16: "ushort" }
|
213
234
|
|
214
235
|
string_rewrite = PatternMatcher([
|
215
236
|
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
@@ -223,17 +244,17 @@ class OpenCLRenderer(CStyleLanguage):
|
|
223
244
|
]) + base_rewrite
|
224
245
|
|
225
246
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
226
|
-
if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
247
|
+
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
227
248
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
228
249
|
|
229
250
|
class IntelRenderer(OpenCLRenderer):
|
230
251
|
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
231
|
-
tensor_cores = [TensorCore(dims=(8,8,16),threads=
|
232
|
-
|
252
|
+
tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
253
|
+
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
|
233
254
|
|
234
255
|
string_rewrite = PatternMatcher([
|
235
|
-
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x
|
236
|
-
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x
|
256
|
+
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
|
257
|
+
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
|
237
258
|
]) + OpenCLRenderer.string_rewrite
|
238
259
|
|
239
260
|
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
@@ -247,9 +268,9 @@ class IntelRenderer(OpenCLRenderer):
|
|
247
268
|
class MetalRenderer(CStyleLanguage):
|
248
269
|
device = "METAL"
|
249
270
|
shared_max = 32768
|
250
|
-
tensor_cores = [TensorCore(dims=(8,8,8),threads=
|
251
|
-
|
252
|
-
|
271
|
+
tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
|
272
|
+
swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
273
|
+
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
253
274
|
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
254
275
|
|
255
276
|
# language options
|
@@ -289,18 +310,27 @@ class MetalRenderer(CStyleLanguage):
|
|
289
310
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
290
311
|
|
291
312
|
_nms = "xyzwabcdefghijkl"
|
313
|
+
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
292
314
|
|
293
315
|
class CUDARenderer(CStyleLanguage):
|
294
316
|
device = "CUDA"
|
295
317
|
global_max = (2147483647, 65535, 65535)
|
296
318
|
local_max = (1024, 1024, 64)
|
297
319
|
shared_max = 49152
|
298
|
-
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
320
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
321
|
+
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
322
|
+
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
|
323
|
+
(dtypes.half,dtypes.half)]]
|
324
|
+
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
325
|
+
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
326
|
+
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
327
|
+
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
328
|
+
|
329
|
+
tc_sm80 = tc_81616 + tc_8168_f16
|
330
|
+
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
|
331
|
+
tc_sm75 = tc_8168_f16
|
332
|
+
def __init__(self, arch:str):
|
333
|
+
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
|
304
334
|
def __reduce__(self): return self.__class__, (self.arch,)
|
305
335
|
|
306
336
|
# language options
|
@@ -333,7 +363,8 @@ class CUDARenderer(CStyleLanguage):
|
|
333
363
|
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
334
364
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
335
365
|
|
336
|
-
|
366
|
+
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
367
|
+
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
337
368
|
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
338
369
|
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
339
370
|
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
@@ -342,27 +373,34 @@ class CUDARenderer(CStyleLanguage):
|
|
342
373
|
|
343
374
|
# mma operands => {c}, {a}, {b}, {c}
|
344
375
|
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
|
345
|
-
int *a_pk = (int *)(&a), *b_pk = (int *)(&b)
|
376
|
+
int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
|
377
|
+
asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
|
346
378
|
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
|
347
379
|
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
|
348
|
-
: {", ".join([f'"+
|
380
|
+
: {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
|
349
381
|
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
|
350
382
|
return c;\n}}""")
|
351
383
|
|
352
384
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
353
385
|
|
354
|
-
def get_kernel_modifier(self, uops:
|
386
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
355
387
|
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
356
388
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
357
389
|
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
358
390
|
|
391
|
+
def cast_float_to_bf16(x: UOp) -> UOp:
|
392
|
+
assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
|
393
|
+
x = x.bitcast(dtypes.uint)
|
394
|
+
x = (-x & 0x7f800000).where(x + ((x >> 16) & 1) + 0x7fff, (x & 0xffff).where((x | 0x10000), x))
|
395
|
+
return (x >> 16).cast(dtypes.ushort).bitcast(dtypes.bfloat16)
|
396
|
+
|
359
397
|
class AMDRenderer(CStyleLanguage):
|
360
398
|
device = "AMD"
|
361
399
|
shared_max = 65536
|
362
400
|
# https://gpuopen.com/learn/wmma_on_rdna3/
|
363
|
-
tensor_cores = [TensorCore(dims=(16,16,16), threads=
|
364
|
-
|
365
|
-
for
|
401
|
+
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
402
|
+
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
403
|
+
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
|
366
404
|
|
367
405
|
# language options
|
368
406
|
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
@@ -397,8 +435,7 @@ class AMDRenderer(CStyleLanguage):
|
|
397
435
|
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
398
436
|
# bfloat16 casting
|
399
437
|
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
400
|
-
(UPat(Ops.CAST,
|
401
|
-
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
438
|
+
(UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
402
439
|
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
403
440
|
|
404
441
|
def render_vector_prefix(self, dtype:DType) -> str:
|
@@ -410,7 +447,7 @@ class AMDRenderer(CStyleLanguage):
|
|
410
447
|
prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
|
411
448
|
|
412
449
|
used_dtypes = uops_to_dtypes(uops)
|
413
|
-
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("
|
450
|
+
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
|
414
451
|
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
415
452
|
|
416
453
|
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
@@ -421,42 +458,12 @@ class AMDRenderer(CStyleLanguage):
|
|
421
458
|
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
422
459
|
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
423
460
|
|
424
|
-
def get_kernel_modifier(self, uops:
|
461
|
+
def get_kernel_modifier(self, uops:list[UOp]) -> str:
|
425
462
|
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
426
463
|
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
427
464
|
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
428
465
|
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
429
466
|
|
430
|
-
class DSPRenderer(ClangRenderer):
|
431
|
-
device = "DSP"
|
432
|
-
supports_float4 = False
|
433
|
-
buffer_suffix = " restrict __attribute__((align_value(128)))"
|
434
|
-
kernel_prefix = "__attribute__((noinline)) "
|
435
|
-
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
436
|
-
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
437
|
-
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
438
|
-
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
|
439
|
-
|
440
|
-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
441
|
-
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
442
|
-
msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
|
443
|
-
short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
|
444
|
-
'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
|
445
|
-
'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
|
446
|
-
'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
|
447
|
-
'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
|
448
|
-
'HAP_power_set((void*)handle, (void*)&req);']
|
449
|
-
msrc += ['if ((sc>>24) != 2) return 0;']
|
450
|
-
msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
|
451
|
-
msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
452
|
-
msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
453
|
-
msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
|
454
|
-
msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
|
455
|
-
msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
|
456
|
-
msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
|
457
|
-
msrc += ["return 0; }"]
|
458
|
-
return ret + '\n' + '\n'.join(msrc)
|
459
|
-
|
460
467
|
class NVRenderer(CUDARenderer): device = "NV"
|
461
468
|
class HIPRenderer(AMDRenderer): device = "HIP"
|
462
469
|
class QCOMRenderer(OpenCLRenderer): device = "QCOM"
|