tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,197 +1,384 @@
1
- from typing import Dict, List, Optional, NamedTuple, Tuple, Union
2
- import math
3
- from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
1
+ from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
2
+ import os, math
3
+ from collections import defaultdict, Counter
4
+ from tinygrad.codegen.linearizer import UOps, UOp
4
5
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
5
- from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
6
- from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render
6
+ from tinygrad.helpers import strip_parens, getenv, prod
7
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
8
+ from tinygrad.codegen.uops import UOpGraph
9
+ from tinygrad.renderer import Renderer, TensorCore
7
10
 
8
- # div is different in cl than python
9
- render_cl = render_python.copy()
10
- render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})"
11
- render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
12
-
13
- class CStyleLanguage(NamedTuple):
14
- size_prefix: str = "int"
15
- generic_var_prefix: str = ""
11
+ class CStyleLanguage(Renderer):
16
12
  kernel_prefix: str = ""
17
13
  buffer_prefix: str = ""
18
14
  buffer_suffix: str = ""
15
+ smem_align: str = ""
19
16
  smem_prefix: str = ""
20
- arg_int_prefix: str = ""
17
+ smem_prefix_for_cast: bool = True
18
+ arg_int_prefix: str = "const int"
21
19
  barrier: str = ""
22
- gid: List[str] = []
23
- lid: List[str] = []
24
- global_max: List[int] = []
25
- local_max: List[int] = []
20
+ code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
26
21
  extra_args: List[str] = []
27
22
  float4: Optional[str] = None
28
- half_prekernel: Optional[str] = None
29
23
  uses_vload: bool = False
30
- external_local_bufs: bool = False
31
24
  uses_ptr_arithmetic: bool = False
32
- launch_bounds: bool = False
25
+ type_map: Dict[DType, str] = {}
33
26
  code_for_op: Dict = {
34
- UnaryOps.EXP2: lambda x: f"exp2({x})",
35
- UnaryOps.LOG2: lambda x: f"log2({x})",
36
- UnaryOps.SIN: lambda x: f"sin({x})",
37
- UnaryOps.SQRT: lambda x: f"sqrt({x})",
38
- BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
39
- BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
40
- BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
41
- BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
42
- TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
43
- }
27
+ UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
28
+ UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
29
+ BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
30
+ BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
+ BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
32
+ TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
44
33
 
45
34
  # returns a str expression of the casted xs with the given type
46
- def render_cast(self, x:List[str], var_dtype:DType) -> str:
47
- assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
48
- assert self.float4 is not None, "cast is not supported on this platform"
49
- if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
50
- if var_dtype == dtypes._float2: return f"{self.float4.replace('float4', 'float2')}({','.join(x)})"
51
- raise NotImplementedError(f"no cast for {var_dtype}")
35
+ def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
36
+ if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
37
+ if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
38
+ assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
39
+ assert self.float4 is not None, "vectorized cast is not supported on this platform"
40
+ return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
52
41
 
53
42
  # returns a str expression of the const with the given type
54
- def render_const(self, x:Union[float,int], var_dtype) -> str:
43
+ def render_const(self, x:ConstType, dtype:DType) -> str:
55
44
  if math.isnan(x): val = "NAN"
56
45
  elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
57
- else: val = f"{x}f" if dtypes.is_float(var_dtype) and isinstance(x, float) else f"{int(x)}"
58
- return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
46
+ elif dtype == dtypes.bool: val = "1" if x else "0"
47
+ elif dtype == dtypes.float: val = f"{x}f"
48
+ else: val = str(x)
49
+ return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
59
50
 
60
51
  # returns a str expression of the loaded value with the output type
61
52
  def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
62
53
  if isinstance(buf_dtype, ImageDType):
63
- assert output_dtype == dtypes._float4, "images must be float4"
64
- return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
65
- if self.uses_vload and buf_dtype == dtypes.float16:
66
- return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl, strip_parens=True)})"
67
- if output_dtype.sz > 1:
68
- return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
69
- return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}]"
70
-
71
- def render_local(self, name:str, size:int):
72
- return self.smem_prefix + f"float {name}[{size}];"
73
-
74
- def render_for(self, expr: str, _min:int, _max:Union[int,str]) -> str:
75
- return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
76
-
77
- def render_conditional(self, cond: str, x:str, y:str) -> str:
78
- return f"({cond})?({x}):{y}"
79
-
80
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
81
- tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else ""
82
- buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
83
- self.arg_int_prefix if dtype == dtypes._arg_int32 else
84
- ("const " if i > 0 else "")+self.buffer_prefix+dtype.name+"*"+self.buffer_suffix) for i,(name,dtype) in enumerate(bufs)]
85
- prg = ''.join([f"{self.kernel_prefix} void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
54
+ assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
55
+ return f"read_imagef({buf_name}, smp, {idx})"
56
+ if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
57
+ return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
58
+ if output_dtype.count > 1:
59
+ out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
60
+ else:
61
+ out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
62
+ return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
63
+
64
+ def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
65
+ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
66
+ tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
67
+ buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
68
+ ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
69
+ self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
70
+ prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
86
71
  [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
87
72
  [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
88
- if self.half_prekernel and any(dtype == dtypes.float16 for _,dtype in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
89
-
90
- return prg, global_size[::-1], local_size[::-1]
73
+ return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
91
74
 
92
75
  # returns a str statement that does the store
93
- def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
76
+ def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
94
77
  if isinstance(buf_dtype, ImageDType):
95
- assert var_dtype == dtypes._float4, "images must be float4"
96
- return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
97
- if self.uses_vload and buf_dtype == dtypes.float16:
98
- return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl, strip_parens=True)});"
99
- if var_dtype.sz > 1:
100
- return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
101
- return f"*({buf_name}+{idx.render(render_cl, strip_parens=True)}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
102
-
103
- def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
104
- # for M1 tensor core stuff, support > 3 dims
105
- if i >= 2 and len(args[0]) > len(xid):
106
- # do this on the x dim for warps
107
- if len(local_size) == 2: local_size.append(1)
108
- local_size[-1] *= var.max+1
109
- lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
110
- lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
111
- assert lidx.max == var.max and lidx.min == var.min
112
- return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
113
- local_size.append(var.max+1)
114
- return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
115
-
116
- def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
117
- global_size: List[int] = []
118
- local_size: List[int] = []
119
- kernel,prekernel = [],[]
120
- pend_close = None
121
- bufs = []
122
- depth = 0
123
- def kk(s): kernel.append(" "*depth+s)
124
-
125
- for uop,newvar,vin,args in uops:
126
- if uop == UOps.LOOP:
127
- for i,var in enumerate(args[0]):
128
- if args[1] == "global" and lang.gid:
129
- kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
130
- elif args[1] == "local" and lang.lid:
131
- kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
132
- else:
133
- if getenv("NOUNROLL") and not isinstance(var, NumNode): kk("#pragma unroll(1)") # prevent loop unrolling
134
- kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, sym_render(var.max)))
135
- depth += 1
136
- elif uop == UOps.BARRIER:
137
- kk(lang.barrier)
138
- elif uop == UOps.ENDLOOP:
139
- if args[1] == "local" and lang.lid:
140
- # TODO: this is a bit of a hack. the local loop isn't real on the GPU
141
- kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
142
- pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
143
- else:
144
- if args[1] == "global" and pend_close:
145
- depth -= 1
146
- kk(pend_close)
147
- pend_close = None
78
+ assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
79
+ return f"write_imagef({buf_name}, {idx}, {var_name});"
80
+ if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
81
+ return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
82
+ if var_dtype.count > 1:
83
+ prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
84
+ return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
85
+ return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
86
+
87
+ def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
88
+ def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
89
+
90
+ def render(self, name:str, uops:UOpGraph) -> str:
91
+ kernel = []
92
+ bufs: List[Tuple[str, Tuple[DType, bool]]] = []
93
+ depth = 1
94
+ def kk(s): kernel.append(" "*depth+s)
95
+
96
+ c: DefaultDict[str, int] = defaultdict(int)
97
+ r: Dict[UOp, str] = {}
98
+
99
+ def ssa(prefix:str, u:Optional[UOp]=None):
100
+ nonlocal c, r
101
+ ret = f"{prefix}{c[prefix]}"
102
+ if u is not None: r[u] = ret
103
+ c[prefix] += 1
104
+ return ret
105
+
106
+ child_count = Counter(v for ru in uops for v in ru.vin)
107
+
108
+ for u in uops:
109
+ uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
110
+ # these four uops don't have output dtypes
111
+ if uop is UOps.IF:
112
+ kk(f"if ({r[vin[0]]}) {{")
113
+ depth += 1
114
+ elif uop is UOps.BARRIER: kk(self.barrier)
115
+ elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
148
116
  depth -= 1
149
- kk("}"*len(args[0]) + f" /* {args[1]} */")
150
- elif uop == UOps.WMMA:
151
- if args == "METAL":
152
- # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
153
- kk("{ simdgroup_float8x8 a,b,c;")
154
- kk(f"a.thread_elements()[0] = {vin[0].render()}; a.thread_elements()[1] = {vin[1].render()};")
155
- kk(f"b.thread_elements()[0] = {vin[2].render()}; b.thread_elements()[1] = {vin[3].render()};")
156
- kk(f"c.thread_elements()[0] = {vin[4].render()}; c.thread_elements()[1] = {vin[5].render()};")
157
- kk("simdgroup_multiply_accumulate(c, a, b, c);")
158
- kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
159
- elif args == "HIP":
160
- kk("{")
161
- kk(f"half16 a_frag = {{ {','.join(['(half)'+x.render() for x in vin[8:8+16]])} }};")
162
- kk(f"half16 b_frag = {{ {','.join(['(half)'+x.render() for x in vin[8+16:8+32]])} }};")
163
- kk(f"float8 c_frag = {{ {','.join([x.render() for x in vin[:8]])} }};")
164
- kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
165
- for i in range(8): kk(f"{vin[i].render()} = c_frag[{i}];")
166
117
  kk("}")
118
+ elif uop is UOps.STORE:
119
+ assert vin[0].dtype is not None and vin[2].dtype is not None
120
+ rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
121
+ kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
167
122
  else:
168
- raise NotImplementedError(f"WMMA not implemented for {args}")
169
- elif uop == UOps.ALU:
170
- assert newvar is not None
171
- kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
172
- elif uop == UOps.LOAD:
173
- assert newvar is not None and isinstance(args, (MemOp, ConstOp))
174
- # valids are handled here
175
- if isinstance(args, ConstOp):
176
- val = lang.render_const(args.value, newvar.dtype)
177
- else:
178
- val = lang.render_load(newvar.dtype, args.name, args.memory_dtype, args.idx, args.local)
179
- if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(args.invalid_value, newvar.dtype))
180
- kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
181
- elif uop == UOps.STORE:
182
- assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
183
- # TODO: instead of dtypes.float, a base type
184
- kk(lang.render_store(args.name, args.memory_dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, args.local))
185
- elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
186
- kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
187
- elif uop == UOps.DEFINE_LOCAL:
188
- if lang.external_local_bufs:
189
- prekernel.append(lang.render_local(args[0], args[1]))
190
- else:
191
- kk(lang.render_local(args[0], args[1]))
192
- elif uop == UOps.DEFINE_GLOBAL:
193
- bufs.append(args)
194
- else:
195
- raise RuntimeError(f"failed to render {uop}")
123
+ assert dtype is not None, f"None dtype for uop {uop}"
124
+ if uop is UOps.RANGE:
125
+ kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
126
+ depth += 1
127
+ elif uop is UOps.ALU:
128
+ # remove parens if ALU types are the same. TODO: can do more here
129
+ if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in vin]
130
+ else: operands = [r[v] for v in vin]
131
+ val = self.code_for_op[args](*operands, dtype)
132
+ assert child_count[u] != 0, f"childless ALU op found {u}"
133
+ # TODO: fix index rendering issue. fix clang nested max macro issue
134
+ if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
135
+ else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
136
+ elif uop is UOps.SPECIAL:
137
+ kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
138
+ r[u] = args[1]
139
+ elif uop is UOps.LOAD:
140
+ val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
141
+ # NOTE: this relies on the load not happening if it's in the unselected branch
142
+ if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
143
+ kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
144
+ elif uop is UOps.PHI:
145
+ kk(f"{r[vin[0]]} = {r[vin[1]]};")
146
+ r[u] = r[vin[0]]
147
+ elif uop in {UOps.CAST, UOps.BITCAST}:
148
+ if uop is UOps.BITCAST:
149
+ assert len(vin) == 1
150
+ precast = ssa('precast')
151
+ kk(f"{self.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
152
+ val = self.render_cast([precast], dtype, bitcast=True)
153
+ else:
154
+ val = self.render_cast([r[x] for x in vin], dtype, bitcast=False)
155
+ if child_count[u] <= 1: r[u] = val
156
+ else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
157
+ elif uop is UOps.DEFINE_LOCAL:
158
+ kk(self.render_local(args[0], dtype, args[1]))
159
+ r[u] = args[0]
160
+ elif uop is UOps.DEFINE_VAR:
161
+ bufs.append((args.expr, (dtype,False)))
162
+ r[u] = args.expr
163
+ elif uop is UOps.DEFINE_GLOBAL:
164
+ bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
165
+ r[u] = nm
166
+ elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
167
+ elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
168
+ elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
169
+ elif uop is UOps.GEP:
170
+ assert vin[0].dtype is not None
171
+ from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
172
+ r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
173
+ else: raise RuntimeError(f"failed to render {uop}")
174
+
175
+ return self.render_kernel(name, kernel, bufs, uops)
176
+
177
+ class ClangRenderer(CStyleLanguage):
178
+ device = "CLANG"
179
+ supports_float4 = False
180
+ has_local = False
181
+
182
+ # language options
183
+ buffer_suffix = " restrict"
184
+ type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
185
+ code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
186
+
187
+ class OpenCLRenderer(CStyleLanguage):
188
+ device = "GPU"
189
+
190
+ # language options
191
+ kernel_prefix = "__kernel "
192
+ buffer_prefix = "__global "
193
+ smem_align = "__attribute__ ((aligned (16))) "
194
+ smem_prefix = "__local "
195
+ barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
196
+ float4 = "(float4)"
197
+ code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
198
+ uses_vload = True
199
+ type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
200
+ def render_cast(self, x, var_dtype, bitcast=False) -> str:
201
+ return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
202
+
203
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
204
+ if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
205
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
206
+
207
+ class MetalRenderer(CStyleLanguage):
208
+ device = "METAL"
209
+ shared_max = 32768
210
+ tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
211
+ def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
212
+
213
+ # language options
214
+ kernel_prefix = "kernel "
215
+ buffer_prefix = "device "
216
+ smem_prefix = "threadgroup "
217
+ arg_int_prefix = "constant int&"
218
+ barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
219
+ float4 = "float4"
220
+ uses_ptr_arithmetic = True
221
+ code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
222
+ extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
223
+ type_map = {dtypes.bfloat16: "bfloat"}
224
+ code_for_op = {**CStyleLanguage().code_for_op,
225
+ BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
226
+ UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
227
+ UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
228
+ UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
229
+ UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
230
+
231
+ def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
232
+ return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
233
+
234
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
235
+ prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
236
+ for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
237
+ simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
238
+ b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
239
+ return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
240
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
241
+
242
+ code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
243
+ UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
244
+ UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
245
+ UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
246
+ UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
247
+
248
+ _nms = "xyzwabcdefghijkl"
249
+ def _make_cuda_dtype(base_type, name, cnt):
250
+ vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
251
+ return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
252
+
253
+ class CUDARenderer(CStyleLanguage):
254
+ device = "CUDA"
255
+ global_max = [65535, 65535, 2147483647]
256
+ local_max = [64, 1024, 1024]
257
+ shared_max = 49152
258
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
259
+ def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
260
+
261
+ # language options
262
+ kernel_prefix = "extern \"C\" __global__ "
263
+ smem_prefix = "__shared__ "
264
+ smem_prefix_for_cast = False
265
+ barrier = "__syncthreads();"
266
+ float4 = "make_float4"
267
+ code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
268
+ "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
269
+ code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
270
+ type_map = {dtypes.bfloat16: "nv_bfloat16"}
271
+
272
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
273
+ # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
274
+ dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
275
+
276
+ prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
277
+ if any(uop.dtype == dtypes.half for uop in uops):
278
+ prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
279
+
280
+ if any(uop.dtype == dtypes.bfloat16 for uop in uops):
281
+ prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
282
+
283
+ # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
284
+ for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
285
+ fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
286
+ prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
287
+ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
288
+ : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
289
+ return c;}}""")
290
+
291
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
292
+
293
+ code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
294
+ UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
295
+ UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
296
+ UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
297
+ # TODO: MAX with int uses fmax_f32?
298
+ BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
299
+
300
+ def _make_hip_code_for_op():
301
+ def wrapper(key, func):
302
+ def cast_bf16(*args):
303
+ if args[-1] == dtypes.bfloat16:
304
+ operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
305
+ return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
306
+ return func(*args)
307
+ return cast_bf16
308
+ return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
309
+
310
+ def _make_hip_dtype(base_type, name, cnt):
311
+ elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
312
+ return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
313
+ f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
314
+
315
+ class HIPRenderer(CStyleLanguage):
316
+ device = "HSA"
317
+ shared_max = 65536
318
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
319
+
320
+ # language options
321
+ kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
322
+ extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
323
+ extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
324
+ extern "C" {\n""" + "".join([
325
+ f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
326
+ __attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
327
+ __attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
328
+ __attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
329
+ __attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
330
+ '}\nextern "C" __attribute__((global))'
331
+ code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
332
+ "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
333
+ code_for_op = _make_hip_code_for_op()
334
+ smem_prefix = "__attribute__((shared))"
335
+ barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
336
+ '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
337
+ float4 = "make_float4"
338
+ uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
339
+ type_map = {dtypes.bfloat16: "hip_bfloat16"}
340
+
341
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
342
+ prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
343
+ vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
344
+
345
+ # TODO: add BF16 vec dts
346
+ if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
347
+ struct hip_bfloat16 {
348
+ unsigned short data;
349
+ __attribute__((device)) hip_bfloat16(float val) {
350
+ union { float fp32; unsigned int u32; } u = {val};
351
+ if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
352
+ data = (u.u32 >> 16);
353
+ }
354
+ __attribute__((device)) operator float() const {
355
+ unsigned int uval = data << 16;
356
+ return *reinterpret_cast<float*>(&uval);
357
+ }
358
+ };
359
+ static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
360
+ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
361
+ """)
362
+
363
+ if any(uop.dtype == dtypes.half for uop in uops):
364
+ prefix.append("#define half _Float16")
365
+ vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
366
+
367
+ prefix += [_make_hip_dtype(*x) for x in vec_dts]
368
+
369
+ for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
370
+ if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
371
+ else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
372
+ half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
373
+ c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
374
+ for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
375
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
376
+
377
+ def get_kernel_modifier(self, uops:UOpGraph) -> str:
378
+ requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
379
+ # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
380
+ # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
381
+ return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
196
382
 
197
- return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)
383
+ class NVRenderer(CUDARenderer): device = "NV"
384
+ class AMDRenderer(HIPRenderer): device = "AMD"