tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,71 @@
1
- from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
1
+ from __future__ import annotations
2
+ from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
2
3
  import os, math
3
4
  from collections import defaultdict, Counter
4
- from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
5
- from tinygrad.helpers import strip_parens, getenv, prod, dedup
6
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
7
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph
5
+ from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
6
+ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
7
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
8
8
  from tinygrad.renderer import Renderer, TensorCore
9
9
 
10
+ base_rewrite = PatternMatcher([
11
+ (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
12
+ (UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
13
+ (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
14
+ (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
15
+ (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
16
+ # r method accesses
17
+ (UPat(Ops.RANGE, name="x"),
18
+ lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
19
+ (UPat(Ops.VECTORIZE, name="x"),
20
+ lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
21
+ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
22
+ (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
23
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
24
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
25
+ (UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
26
+ (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
27
+ (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
28
+ # const
29
+ (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
30
+ (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"),
31
+ (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
32
+ (UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
33
+ (UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
34
+ (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
35
+ (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
36
+ (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
37
+ # consts are rendered to larger type and casted
38
+ (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
39
+ (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
40
+ (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
41
+ # default const render
42
+ (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
43
+ # new load/store
44
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
45
+ lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
46
+ (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
47
+ (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
48
+ (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
49
+ # alu/gep
50
+ (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
51
+ *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
52
+ (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
53
+ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
54
+ ])
55
+
56
+ extra_pm = PatternMatcher([
57
+ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
58
+ (UPat(Ops.BITCAST, name="x"),
59
+ lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
60
+ # gate any stores that aren't gated with ifs
61
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
62
+ lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
63
+ # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
64
+ (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
65
+ ])
66
+
67
+ def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
68
+
10
69
  class CStyleLanguage(Renderer):
11
70
  kernel_prefix: str = ""
12
71
  buffer_prefix: str = ""
@@ -19,174 +78,125 @@ class CStyleLanguage(Renderer):
19
78
  code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
20
79
  extra_args: List[str] = []
21
80
  float4: Optional[str] = None
22
- uses_vload: bool = False
23
- uses_ptr_arithmetic: bool = False
24
81
  type_map: Dict[DType, str] = {}
82
+ infinity: str = "INFINITY"
83
+ nan: str = "NAN"
25
84
  code_for_op: Dict = {
26
- UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
27
- UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
28
- UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
29
- BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
30
- BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
- BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
32
- TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
33
-
34
- # returns a str expression of the casted xs with the given type
35
- def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
36
- if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
37
- if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
38
- assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
39
- assert self.float4 is not None, "vectorized cast is not supported on this platform"
40
- return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
41
-
42
- # returns a str expression of the const with the given type
43
- def render_const(self, x:ConstType, dtype:DType) -> str:
44
- if math.isnan(x): val = "NAN"
45
- elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
46
- elif dtype == dtypes.bool: val = "1" if x else "0"
47
- elif dtype == dtypes.float: val = f"{x}f"
48
- else: val = str(x)
49
- return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
50
-
51
- # returns a str expression of the loaded value with the output type
52
- def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
53
- if isinstance(buf_dtype, ImageDType):
54
- assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
55
- return f"read_imagef({buf_name}, smp, {idx})"
56
- if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
57
- return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
58
- if output_dtype.count > 1:
59
- out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
60
- else:
61
- out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
62
- return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
63
-
64
- def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
65
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
85
+ Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
86
+ Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
87
+ Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
88
+ Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
89
+ Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
90
+ Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
91
+ Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
92
+
93
+ string_rewrite = base_rewrite
94
+ extra_matcher = extra_pm
95
+
96
+ def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
97
+ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
66
98
  tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
67
- buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
68
- ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
99
+ buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
69
100
  self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
70
101
  prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
71
102
  [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
72
103
  [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
73
104
  return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
74
105
 
75
- # returns a str statement that does the store
76
- def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
77
- if isinstance(buf_dtype, ImageDType):
78
- assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
79
- return f"write_imagef({buf_name}, {idx}, {var_name});"
80
- if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
81
- return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
82
- if var_dtype.count > 1:
83
- prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
84
- return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
85
- return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
86
-
87
- def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
88
- def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
89
-
90
- def render(self, name:str, uops:UOpGraph) -> str:
91
- kernel = []
92
- bufs: List[Tuple[str, Tuple[DType, bool]]] = []
93
- depth = 1
94
- def kk(s): kernel.append(" "*depth+s)
95
-
96
- c: DefaultDict[str, int] = defaultdict(int)
106
+ def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})"
107
+ def render_dtype(self, dt:DType, mutable=True) -> str:
108
+ if isinstance(dt, ImageDType):
109
+ return f"{'write_only' if mutable else 'read_only'} image2d_t"
110
+ if isinstance(dt, PtrDType):
111
+ return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + \
112
+ self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
113
+ return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
114
+
115
+ def __getitem__(self, key): return self.r[key] # hacky helper
116
+ def render(self, name:str, uops:List[UOp]) -> str:
97
117
  r: Dict[UOp, str] = {}
98
-
99
- def ssa(prefix:str, u:Optional[UOp]=None):
100
- nonlocal c, r
101
- ret = f"{prefix}{c[prefix]}"
102
- if u is not None: r[u] = ret
103
- c[prefix] += 1
104
- return ret
118
+ self.r = r
105
119
 
106
120
  child_count = Counter(v for ru in uops for v in ru.src)
107
-
108
- seen_vars = set()
121
+ bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
122
+ kernel = []
123
+ depth = 1
124
+ c: DefaultDict[str, int] = defaultdict(int)
109
125
  for u in uops:
110
- uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
111
- # these four uops don't have output dtypes
112
- if uop is UOps.IF:
113
- kk(f"if ({r[src[0]]}) {{")
114
- depth += 1
115
- elif uop is UOps.BARRIER: kk(self.barrier)
116
- elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
117
- depth -= 1
118
- kk("}")
119
- elif uop is UOps.STORE:
120
- assert src[0].dtype is not None and src[2].dtype is not None
121
- rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
122
- kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
126
+ if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
127
+ r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
128
+ bufs[u] = (r[u], (u.dtype, False))
129
+ continue
130
+
131
+ # mark buffers that we store to writable
132
+ if u.op is Ops.STORE:
133
+ for up in u.src[0].sparents:
134
+ if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
135
+
136
+ # naming
137
+ prefix = None
138
+ if u.op is Ops.SPECIAL:
139
+ r[u] = u.arg[0]
140
+ else:
141
+ prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
142
+ Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
143
+ Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
144
+ r[u] = f"{prefix}{c[prefix]}"
145
+
146
+ l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
147
+ assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
148
+
149
+ if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
150
+ if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST}
151
+ and child_count[u] == 1 and not getenv("EXPAND_SSA")):
152
+ r[u] = l
123
153
  else:
124
- assert dtype is not None, f"None dtype for uop {uop}"
125
- if uop is UOps.RANGE:
126
- kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
127
- depth += 1
128
- elif uop is UOps.ALU:
129
- # remove parens if ALU types are the same. TODO: can do more here
130
- if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
131
- else: operands = [r[v] for v in src]
132
- val = self.code_for_op[args](*operands, dtype)
133
- assert child_count[u] != 0, f"childless ALU op found {u}"
134
- # TODO: fix index rendering issue. fix clang nested max macro issue
135
- if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
136
- else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
137
- elif uop is UOps.SPECIAL:
138
- kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
139
- r[u] = args[1]
140
- elif uop is UOps.LOAD:
141
- val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
142
- # NOTE: this relies on the load not happening if it's in the unselected branch
143
- if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
144
- kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
145
- elif uop is UOps.PHI:
146
- kk(f"{r[src[0]]} = {r[src[1]]};")
147
- r[u] = r[src[0]]
148
- elif uop in {UOps.CAST, UOps.BITCAST}:
149
- if uop is UOps.BITCAST:
150
- assert len(src) == 1
151
- precast = ssa('precast')
152
- kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
153
- val = self.render_cast([precast], dtype, bitcast=True)
154
- else:
155
- val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
156
- if child_count[u] <= 1: r[u] = val
157
- else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
158
- elif uop is UOps.DEFINE_LOCAL:
159
- kk(self.render_local(args[0], dtype, args[1]))
160
- r[u] = args[0]
161
- elif uop is UOps.DEFINE_VAR:
162
- assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
163
- seen_vars.add(args.expr)
164
- bufs.append((args.expr, (dtype,False)))
165
- r[u] = args.expr
166
- elif uop is UOps.DEFINE_GLOBAL:
167
- bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
168
- r[u] = nm
169
- elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
170
- elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
171
- elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
172
- elif uop is UOps.GEP:
173
- assert src[0].dtype is not None
174
- from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
175
- r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
176
- else: raise RuntimeError(f"failed to render {uop}")
177
-
178
- return self.render_kernel(name, kernel, bufs, uops)
154
+ if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
155
+ if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
156
+ else:
157
+ l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
158
+ kernel.append(" "*depth + l)
159
+ if prefix: c[prefix] += 1 # if it was used, increment
160
+ if u.op in {Ops.IF, Ops.RANGE}: depth += 1
161
+ del self.r
162
+
163
+ # NOTE: this relies on bufs dict preserving order
164
+ return self.render_kernel(name, kernel, list(bufs.values()), uops)
179
165
 
180
166
  class ClangRenderer(CStyleLanguage):
181
167
  device = "CLANG"
182
- supports_float4 = False
168
+ float4 = "(float4)"
183
169
  has_local = False
184
170
  global_max = None
171
+ infinity = "__builtin_inff()"
172
+ nan = '__builtin_nanf("")'
185
173
 
186
174
  # language options
187
175
  buffer_suffix = " restrict"
188
176
  type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
189
- code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
177
+ code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
178
+ Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
179
+
180
+ if AMX:
181
+ tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
182
+ for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
183
+
184
+ def render_vector_prefix(self, dt:DType) -> str:
185
+ return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
186
+
187
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
188
+ prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
189
+ # https://github.com/corsix/amx
190
+ for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
191
+ prefix += [
192
+ '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
193
+ '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
194
+ ]
195
+ prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
196
+ AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
197
+ AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
198
+ for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
199
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
190
200
 
191
201
  class OpenCLRenderer(CStyleLanguage):
192
202
  device = "GPU"
@@ -199,20 +209,48 @@ class OpenCLRenderer(CStyleLanguage):
199
209
  barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
200
210
  float4 = "(float4)"
201
211
  code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
202
- uses_vload = True
203
- type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
204
- def render_cast(self, x, var_dtype, bitcast=False) -> str:
205
- return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
212
+ type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
213
+
214
+ string_rewrite = PatternMatcher([
215
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
216
+ # load/store image (OpenCL)
217
+ (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
218
+ lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
219
+ (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
220
+ lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
221
+ (UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
222
+ lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
223
+ ]) + base_rewrite
206
224
 
207
225
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
208
- if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
226
+ if any(uop.dtype == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
209
227
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
210
228
 
229
+ class IntelRenderer(OpenCLRenderer):
230
+ device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
231
+ tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
232
+ st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
233
+
234
+ string_rewrite = PatternMatcher([
235
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
236
+ (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
237
+ ]) + OpenCLRenderer.string_rewrite
238
+
239
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
240
+ prefix = []
241
+ for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
242
+ dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
243
+ prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
244
+ return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
245
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
246
+
211
247
  class MetalRenderer(CStyleLanguage):
212
248
  device = "METAL"
213
249
  shared_max = 32768
214
- tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
215
- def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
250
+ tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
251
+ st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
252
+ dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
253
+ def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
216
254
 
217
255
  # language options
218
256
  kernel_prefix = "kernel "
@@ -221,48 +259,49 @@ class MetalRenderer(CStyleLanguage):
221
259
  arg_int_prefix = "constant int&"
222
260
  barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
223
261
  float4 = "float4"
224
- uses_ptr_arithmetic = True
225
- code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
262
+ code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
226
263
  # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
227
264
  extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
228
265
  type_map = {dtypes.bfloat16: "bfloat"}
229
- code_for_op = {**CStyleLanguage().code_for_op,
230
- BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
231
- UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
232
- UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
233
- UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
234
- UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
235
266
 
236
- def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
237
- return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
267
+ # precise::sin
268
+ code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"}
269
+
270
+ # upcast to float32 all the ops that don't support bfloat16
271
+ extra_matcher = PatternMatcher([
272
+ # NOTE: this is copied from PTX
273
+ (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
274
+ lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
275
+ ]) + extra_pm
276
+
277
+ string_rewrite = PatternMatcher([
278
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
279
+ ]) + base_rewrite
238
280
 
239
281
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
240
- prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
241
- for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
242
- simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
243
- b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
244
- return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
282
+ prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
283
+ for arg in wmma_args: prefix.append(
284
+ f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
285
+ simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
286
+ mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
287
+ mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
288
+ simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
245
289
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
246
290
 
247
- code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
248
- BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
249
- UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
250
- UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
251
- UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
252
- UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
253
-
254
291
  _nms = "xyzwabcdefghijkl"
255
- def _make_cuda_dtype(base_type, name, cnt):
256
- vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
257
- return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
258
292
 
259
293
  class CUDARenderer(CStyleLanguage):
260
294
  device = "CUDA"
261
295
  global_max = (2147483647, 65535, 65535)
262
296
  local_max = (1024, 1024, 64)
263
297
  shared_max = 49152
264
- tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
265
- def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
298
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
299
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do, expanded_shape=(2,2,2,2,2,2),
300
+ st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
301
+ st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
302
+ upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
303
+ def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
304
+ def __reduce__(self): return self.__class__, (self.arch,)
266
305
 
267
306
  # language options
268
307
  kernel_prefix = "extern \"C\" __global__ "
@@ -270,109 +309,111 @@ class CUDARenderer(CStyleLanguage):
270
309
  smem_prefix_for_cast = False
271
310
  barrier = "__syncthreads();"
272
311
  float4 = "make_float4"
273
- code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
274
- "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
275
- code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
312
+ code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
313
+ "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
314
+ code_for_op = { **CStyleLanguage.code_for_op,
315
+ Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
316
+ Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
317
+ Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
318
+ Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
319
+ Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
276
320
  type_map = {dtypes.bfloat16: "nv_bfloat16"}
277
321
 
322
+ def render_vector_prefix(self, dt:DType) -> str:
323
+ vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
324
+ elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
325
+ return f"struct __align__({dt.itemsize}) {vec} {{ {scal} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
326
+
278
327
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
279
328
  # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
280
- dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
281
-
282
329
  prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
283
- if any(uop.dtype == dtypes.half for uop in uops):
284
- prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
285
330
 
286
- if any(uop.dtype == dtypes.bfloat16 for uop in uops):
287
- prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
288
-
289
- # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
290
- for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
291
- fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
292
- prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
293
- asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
294
- : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
295
- return c;}}""")
331
+ used_dtypes = uops_to_dtypes(uops)
332
+ if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
333
+ if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
334
+ prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
335
+
336
+ dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
337
+ for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
338
+ upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
339
+ wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
340
+ n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
341
+ operands = [f"%{i}" for i in range(sum(n_operands))]
342
+
343
+ # mma operands => {c}, {a}, {b}, {c}
344
+ prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
345
+ int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
346
+ "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
347
+ "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
348
+ : {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
349
+ : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
350
+ return c;\n}}""")
296
351
 
297
352
  return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
298
353
 
299
- code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
300
- UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
301
- UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
302
- UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
303
- # TODO: MAX with int uses fmax_f32?
304
- BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
305
-
306
- def _make_hip_code_for_op():
307
- def wrapper(key, func):
308
- def cast_bf16(*args):
309
- if args[-1] == dtypes.bfloat16:
310
- operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
311
- return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
312
- return func(*args)
313
- return cast_bf16
314
- return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
315
-
316
- def _make_hip_dtype(base_type, name, cnt):
317
- elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
318
- return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
319
- f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
354
+ def get_kernel_modifier(self, uops:List[UOp]) -> str:
355
+ maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
356
+ # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
357
+ return f"__launch_bounds__({maxThreadsPerBlock}) "
320
358
 
321
359
  class AMDRenderer(CStyleLanguage):
322
360
  device = "AMD"
323
361
  shared_max = 65536
324
- tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
362
+ # https://gpuopen.com/learn/wmma_on_rdna3/
363
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do, reduce_axes=[(0,16)], opts_seq=("LC","UP"),
364
+ upcast_axes = ([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,2),(0,2),(1,1),(0,1)),((1,0),(0,0))), expanded_shape=(16,2,4))
365
+ for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]]
325
366
 
326
367
  # language options
327
- kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
328
- extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
329
- extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
330
- extern "C" {\n""" + "".join([
331
- f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
332
- __attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
333
- __attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
334
- __attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
335
- __attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
336
- '}\nextern "C" __attribute__((global))'
368
+ ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
369
+ ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
370
+ for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
371
+ for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
372
+
373
+ kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
374
+ kernel_prefix += '\nextern "C" __attribute__((global))'
337
375
  code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
338
376
  "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
339
- code_for_op = _make_hip_code_for_op()
377
+ code_for_op = { **CStyleLanguage.code_for_op,
378
+ Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
379
+ Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
380
+ Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
381
+ Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
340
382
  smem_prefix = "__attribute__((shared))"
341
383
  barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
342
384
  '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
343
385
  float4 = "make_float4"
344
- uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
345
386
  type_map = {dtypes.bfloat16: "hip_bfloat16"}
387
+ extra_matcher = PatternMatcher([
388
+ # cast bfloat16 alus to float
389
+ (UPat(Ops.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
390
+ lambda b,x,y: UOp(Ops.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
391
+ (UPat(GroupOp.ALU, dtype=dtypes.bfloat16, name="x"),
392
+ lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
393
+ (UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
394
+ lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
395
+ # add float intermediate casting for bfloat16
396
+ (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
397
+ (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
398
+ # bfloat16 casting
399
+ (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
400
+ (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
401
+ lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
402
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
403
+
404
+ def render_vector_prefix(self, dtype:DType) -> str:
405
+ vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
406
+ return f"typedef {scal} {vec} __attribute__((ext_vector_type({dtype.count})));\nstatic inline __attribute__((device)) "+ \
407
+ f"{vec} make_{vec}({', '.join([f'{scal} {x}' for x in _nms[:dtype.count]])}) {{ return {{ {', '.join(_nms[:dtype.count])} }}; }}"
346
408
 
347
409
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
348
- prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
349
- vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
350
-
351
- # TODO: add BF16 vec dts
352
- if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
353
- struct hip_bfloat16 {
354
- unsigned short data;
355
- inline __attribute__((device)) hip_bfloat16(float val) {
356
- union { float fp32; unsigned int u32; } u = {val};
357
- if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
358
- data = (u.u32 >> 16);
359
- }
360
- inline __attribute__((device)) operator float() const {
361
- unsigned int uval = data << 16;
362
- return *reinterpret_cast<float*>(&uval);
363
- }
364
- };
365
- static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
366
- static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
367
- """)
368
-
369
- if any(uop.dtype == dtypes.half for uop in uops):
370
- prefix.append("#define half _Float16")
371
- vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
372
-
373
- prefix += [_make_hip_dtype(*x) for x in vec_dts]
374
-
375
- for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
410
+ prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
411
+
412
+ used_dtypes = uops_to_dtypes(uops)
413
+ if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
414
+ prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
415
+
416
+ for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
376
417
  if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
377
418
  else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
378
419
  half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
@@ -380,10 +421,42 @@ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat
380
421
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
381
422
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
382
423
 
383
- def get_kernel_modifier(self, uops:UOpGraph) -> str:
384
- requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
424
+ def get_kernel_modifier(self, uops:List[UOp]) -> str:
425
+ requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
385
426
  # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
386
427
  # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
387
428
  return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
388
429
 
430
+ class DSPRenderer(ClangRenderer):
431
+ device = "DSP"
432
+ supports_float4 = False
433
+ buffer_suffix = " restrict __attribute__((align_value(128)))"
434
+ kernel_prefix = "__attribute__((noinline)) "
435
+ type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
436
+ code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
437
+ Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
438
+ Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
439
+
440
+ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
441
+ ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
442
+ msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
443
+ short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
444
+ 'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;',
445
+ 'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);',
446
+ 'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {',
447
+ 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};',
448
+ 'HAP_power_set((void*)handle, (void*)&req);']
449
+ msrc += ['if ((sc>>24) != 2) return 0;']
450
+ msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)]
451
+ msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
452
+ msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
453
+ msrc += ["unsigned long long start = HAP_perf_get_time_us();"]
454
+ msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"]
455
+ msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"]
456
+ msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)]
457
+ msrc += ["return 0; }"]
458
+ return ret + '\n' + '\n'.join(msrc)
459
+
389
460
  class NVRenderer(CUDARenderer): device = "NV"
461
+ class HIPRenderer(AMDRenderer): device = "HIP"
462
+ class QCOMRenderer(OpenCLRenderer): device = "QCOM"