tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,13 @@
1
- from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast, Literal, Callable
2
- import math, functools
1
+ from typing import Dict, List, Optional, Tuple, Union, DefaultDict, cast, Literal, Callable
2
+ import os, math
3
3
  from collections import defaultdict, Counter
4
- from tinygrad.codegen.linearizer import UOps, UOp
5
4
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
6
- from tinygrad.helpers import prod, strip_parens, getenv
7
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
5
+ from tinygrad.helpers import strip_parens, getenv, prod, dedup
6
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
7
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph
8
+ from tinygrad.renderer import Renderer, TensorCore
8
9
 
9
- class CStyleLanguage(NamedTuple):
10
- size_prefix: str = "int"
11
- generic_var_prefix: str = ""
10
+ class CStyleLanguage(Renderer):
12
11
  kernel_prefix: str = ""
13
12
  buffer_prefix: str = ""
14
13
  buffer_suffix: str = ""
@@ -18,39 +17,36 @@ class CStyleLanguage(NamedTuple):
18
17
  arg_int_prefix: str = "const int"
19
18
  barrier: str = ""
20
19
  code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
21
- global_max: List[int] = []
22
- local_max: List[int] = []
23
20
  extra_args: List[str] = []
24
21
  float4: Optional[str] = None
25
- half_prekernel: Optional[str] = None
26
22
  uses_vload: bool = False
27
- external_local_bufs: bool = False
28
23
  uses_ptr_arithmetic: bool = False
29
- launch_bounds: bool = False
30
24
  type_map: Dict[DType, str] = {}
31
25
  code_for_op: Dict = {
32
- UnaryOps.NEG: lambda x,dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
26
+ UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
27
+ UnaryOps.RECIP: lambda x,dtype: f"(1/{x})",
33
28
  UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
34
- BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
35
- BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
36
- BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPEQ: lambda a,b,dtype: f"({a}=={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
37
- TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"
38
- }
29
+ BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})",
30
+ BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
31
+ BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})",
32
+ TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})"}
39
33
 
40
34
  # returns a str expression of the casted xs with the given type
41
35
  def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
42
36
  if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
43
37
  if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
44
- assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
38
+ assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
45
39
  assert self.float4 is not None, "vectorized cast is not supported on this platform"
46
40
  return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
47
41
 
48
42
  # returns a str expression of the const with the given type
49
- def render_const(self, x:Union[float,int,bool], var_dtype) -> str:
43
+ def render_const(self, x:ConstType, dtype:DType) -> str:
50
44
  if math.isnan(x): val = "NAN"
51
45
  elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
52
- else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
53
- return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val
46
+ elif dtype == dtypes.bool: val = "1" if x else "0"
47
+ elif dtype == dtypes.float: val = f"{x}f"
48
+ else: val = str(x)
49
+ return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
54
50
 
55
51
  # returns a str expression of the loaded value with the output type
56
52
  def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
@@ -58,32 +54,23 @@ class CStyleLanguage(NamedTuple):
58
54
  assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
59
55
  return f"read_imagef({buf_name}, smp, {idx})"
60
56
  if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
61
- return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
62
- if output_dtype.sz > 1:
63
- out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501
57
+ return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
58
+ if output_dtype.count > 1:
59
+ out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(buf_dtype)}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
64
60
  else:
65
61
  out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
66
-
67
62
  return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
68
63
 
69
- def render_local(self, name:str, dtype:DType, size:int):
70
- return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
71
-
72
- def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
73
- return f"for ({self.generic_var_prefix if self.generic_var_prefix else 'int'} {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
74
-
75
- def render_if(self, cond: str): return f"if ({cond}) {{"
76
-
77
- def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str:
78
- tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,dtype in bufs) else "" # noqa: E501
79
- buftypes = [(name,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith('image') else
80
- ("const " if i > 0 else "")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
81
- self.arg_int_prefix if dtype == dtypes.int else None) for i,(name,dtype) in enumerate(bufs)]
82
- prg = ''.join([f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",] +
64
+ def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
65
+ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
66
+ tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
67
+ buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
68
+ ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
69
+ self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
70
+ prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
83
71
  [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
84
72
  [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
85
- if self.half_prekernel and any(dtype in [dtypes.float16, dtypes.bfloat16] for _,dtype in bufs): prg = ''.join((self.half_prekernel, "\n", prg))
86
- return prg
73
+ return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
87
74
 
88
75
  # returns a str statement that does the store
89
76
  def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
@@ -91,201 +78,312 @@ class CStyleLanguage(NamedTuple):
91
78
  assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
92
79
  return f"write_imagef({buf_name}, {idx}, {var_name});"
93
80
  if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
94
- return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
95
- if var_dtype.sz > 1:
96
- return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501
81
+ return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
82
+ if var_dtype.count > 1:
83
+ prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
84
+ return f"*(({prefix}{self.render_dtype(buf_dtype)}{var_dtype.count}*)({buf_name}+{idx})) = {var_name};"
97
85
  return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
98
86
 
99
- def render_dtype(self, var_dtype:DType) -> str: return self.type_map[var_dtype] if var_dtype in self.type_map else var_dtype.name
100
-
101
- def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
102
- local_size: List[int] = []
103
- kernel,prekernel,bufs = [],[],[]
104
- #pend_close = None
105
- depth = 1
106
- def kk(s): kernel.append(" "*depth+s)
107
-
108
- c: DefaultDict[str, int] = defaultdict(int)
109
- r: Dict[UOp, str] = {}
110
- def ssa(u, prefix="t"):
111
- nonlocal c, r
112
- ret = f"{prefix}{c[prefix]}"
113
- if u is not None: r[u] = ret
114
- c[prefix] += 1
115
- return ret
116
-
117
- child_count = Counter(v for ru in uops for v in ru.vin)
118
-
119
- for u in uops:
120
- uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
121
- # these four uops don't have output dtypes
122
- if uop == UOps.IF:
123
- kk(lang.render_if(r[vin[0]]))
124
- depth += 1
125
- elif uop == UOps.BARRIER:
126
- kk(lang.barrier)
127
- elif uop == UOps.END:
128
- depth -= 1
129
- kk("}")
130
- elif uop == UOps.STORE:
131
- assert vin[0].dtype is not None and vin[2].dtype is not None
132
- if len(vin) > 3: kk(lang.render_if(r[vin[3]]))
133
- kk(lang.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL))
134
- if len(vin) > 3: kk("}")
135
- else:
136
- assert dtype is not None, f"None dtype for uop {uop}"
137
- if uop == UOps.LOOP:
138
- kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
87
+ def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
88
+ def render_dtype(self, var_dtype:DType) -> str: return self.type_map.get(var_dtype, var_dtype.name)
89
+
90
+ def render(self, name:str, uops:UOpGraph) -> str:
91
+ kernel = []
92
+ bufs: List[Tuple[str, Tuple[DType, bool]]] = []
93
+ depth = 1
94
+ def kk(s): kernel.append(" "*depth+s)
95
+
96
+ c: DefaultDict[str, int] = defaultdict(int)
97
+ r: Dict[UOp, str] = {}
98
+
99
+ def ssa(prefix:str, u:Optional[UOp]=None):
100
+ nonlocal c, r
101
+ ret = f"{prefix}{c[prefix]}"
102
+ if u is not None: r[u] = ret
103
+ c[prefix] += 1
104
+ return ret
105
+
106
+ child_count = Counter(v for ru in uops for v in ru.src)
107
+
108
+ seen_vars = set()
109
+ for u in uops:
110
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
111
+ # these four uops don't have output dtypes
112
+ if uop is UOps.IF:
113
+ kk(f"if ({r[src[0]]}) {{")
139
114
  depth += 1
140
- elif uop == UOps.WMMA:
141
- if args[0] == "METAL":
142
- assert dtype == dtypes.float.vec(2), "output dtype of METAL TC is _float2"
143
- # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
144
- output = ssa(u, 'wmma')
145
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};")
146
- kk("{ simdgroup_float8x8 a,b,c;")
147
- kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
148
- kk(f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};")
149
- kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
150
- kk("simdgroup_multiply_accumulate(c, a, b, c);")
151
- kk(f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}")
152
- elif args[0] == "HIP":
153
- assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
154
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501
155
- else:
156
- raise NotImplementedError(f"WMMA not implemented for {args}")
157
- elif uop == UOps.ALU:
158
- # remove parens if ALU types are the same. TODO: can do more here
159
- if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}:
160
- val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
161
- else:
162
- val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
163
- assert child_count[u] != 0, f"childless ALU op found {u}"
164
- # TODO: fix index rendering issue. fix clang nested max macro issue
165
- if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
166
- r[u] = val
167
- else:
168
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
169
- elif uop == UOps.DEFINE_ACC:
170
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
171
- elif uop == UOps.SPECIAL:
172
- kk(f"{lang.size_prefix} {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
173
- if args[1].startswith("l"): local_size.append(args[2])
174
- r[u] = args[1]
175
- elif uop == UOps.CONST:
176
- r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
177
- elif uop == UOps.LOAD:
178
- val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL)
179
- # NOTE: this relies on the load not happening if it's in the unselected branch
180
- if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
181
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else lang.render_dtype(dtype)} {ssa(u,'val')} = {val};")
182
- elif uop == UOps.PHI:
183
- kk(f"{r[vin[0]]} = {r[vin[1]]};")
184
- r[u] = r[vin[0]]
185
- elif uop == UOps.CAST:
186
- if isinstance(args, tuple) and args[1]: # bitcast
187
- assert len(vin) == 1
188
- precast = ssa(None,'precast')
189
- kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
190
- val = lang.render_cast([precast], dtype, bitcast=True)
191
- else:
192
- val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
193
- if child_count[u] <= 1: r[u] = val
194
- else: kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};")
195
- elif uop == UOps.DEFINE_LOCAL:
196
- if lang.external_local_bufs:
197
- prekernel.append(lang.render_local(args[0], dtype, args[1]))
198
- else:
199
- kk(lang.render_local(args[0], dtype, args[1]))
200
- r[u] = args[0]
201
- elif uop == UOps.DEFINE_GLOBAL:
202
- bufs.append((args, dtype))
203
- r[u] = args
204
- elif uop == UOps.GEP:
205
- if cast(DType, vin[0].dtype).sz > 4:
206
- r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
207
- else:
208
- r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
115
+ elif uop is UOps.BARRIER: kk(self.barrier)
116
+ elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
117
+ depth -= 1
118
+ kk("}")
119
+ elif uop is UOps.STORE:
120
+ assert src[0].dtype is not None and src[2].dtype is not None
121
+ rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
122
+ kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
209
123
  else:
210
- raise RuntimeError(f"failed to render {uop}")
124
+ assert dtype is not None, f"None dtype for uop {uop}"
125
+ if uop is UOps.RANGE:
126
+ kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
127
+ depth += 1
128
+ elif uop is UOps.ALU:
129
+ # remove parens if ALU types are the same. TODO: can do more here
130
+ if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
131
+ else: operands = [r[v] for v in src]
132
+ val = self.code_for_op[args](*operands, dtype)
133
+ assert child_count[u] != 0, f"childless ALU op found {u}"
134
+ # TODO: fix index rendering issue. fix clang nested max macro issue
135
+ if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
136
+ else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
137
+ elif uop is UOps.SPECIAL:
138
+ kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
139
+ r[u] = args[1]
140
+ elif uop is UOps.LOAD:
141
+ val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
142
+ # NOTE: this relies on the load not happening if it's in the unselected branch
143
+ if len(src) > 3: val = self.code_for_op[TernaryOps.WHERE](r[src[2]], val, r[src[3]], dtype)
144
+ kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
145
+ elif uop is UOps.PHI:
146
+ kk(f"{r[src[0]]} = {r[src[1]]};")
147
+ r[u] = r[src[0]]
148
+ elif uop in {UOps.CAST, UOps.BITCAST}:
149
+ if uop is UOps.BITCAST:
150
+ assert len(src) == 1
151
+ precast = ssa('precast')
152
+ kk(f"{self.render_dtype(cast(DType, src[0].dtype))} {precast} = {r[src[0]]};")
153
+ val = self.render_cast([precast], dtype, bitcast=True)
154
+ else:
155
+ val = self.render_cast([r[x] for x in src], dtype, bitcast=False)
156
+ if child_count[u] <= 1: r[u] = val
157
+ else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
158
+ elif uop is UOps.DEFINE_LOCAL:
159
+ kk(self.render_local(args[0], dtype, args[1]))
160
+ r[u] = args[0]
161
+ elif uop is UOps.DEFINE_VAR:
162
+ assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
163
+ seen_vars.add(args.expr)
164
+ bufs.append((args.expr, (dtype,False)))
165
+ r[u] = args.expr
166
+ elif uop is UOps.DEFINE_GLOBAL:
167
+ bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
168
+ r[u] = nm
169
+ elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
170
+ elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
171
+ elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
172
+ elif uop is UOps.GEP:
173
+ assert src[0].dtype is not None
174
+ from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
175
+ r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
176
+ else: raise RuntimeError(f"failed to render {uop}")
177
+
178
+ return self.render_kernel(name, kernel, bufs, uops)
211
179
 
212
- return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel)
180
+ class ClangRenderer(CStyleLanguage):
181
+ device = "CLANG"
182
+ supports_float4 = False
183
+ has_local = False
184
+ global_max = None
213
185
 
214
- class OpenCLLanguage(CStyleLanguage):
186
+ # language options
187
+ buffer_suffix = " restrict"
188
+ type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
189
+ code_for_op = {**CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a,b,dtype: f"(({a}>{b})?{a}:{b})"}
190
+
191
+ class OpenCLRenderer(CStyleLanguage):
192
+ device = "GPU"
193
+
194
+ # language options
215
195
  kernel_prefix = "__kernel "
216
196
  buffer_prefix = "__global "
217
197
  smem_align = "__attribute__ ((aligned (16))) "
218
198
  smem_prefix = "__local "
219
- half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
220
199
  barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
221
200
  float4 = "(float4)"
222
- code_for_workitem ={ "g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})" }
201
+ code_for_workitem = {"g": lambda x: f"get_group_id({x})", "l": lambda x: f"get_local_id({x})", "i": lambda x: f"get_global_id({x})"}
223
202
  uses_vload = True
224
- # NOTE: mad is used so the loads aren't reordered into the math on 845
225
- code_for_op = {**CStyleLanguage().code_for_op,
226
- TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})" if dtypes.is_float(dtype) else f"(({a}*{b})+{c})"}
227
203
  type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong" }
228
204
  def render_cast(self, x, var_dtype, bitcast=False) -> str:
229
- return f"as_{self.type_map.get(var_dtype) or var_dtype.name}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
230
- OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
205
+ return f"as_{self.render_dtype(var_dtype)}({x[0]})" if bitcast else super().render_cast(x, var_dtype)
231
206
 
232
- class MetalLanguage(CStyleLanguage):
233
- kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel "
207
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
208
+ if any(uop.dtype == dtypes.half for uop in uops): prefix = ["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"]
209
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
210
+
211
+ class MetalRenderer(CStyleLanguage):
212
+ device = "METAL"
213
+ shared_max = 32768
214
+ tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
215
+ def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
216
+
217
+ # language options
218
+ kernel_prefix = "kernel "
234
219
  buffer_prefix = "device "
235
220
  smem_prefix = "threadgroup "
236
221
  arg_int_prefix = "constant int&"
237
222
  barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
238
223
  float4 = "float4"
239
- uses_ptr_arithmetic=True
224
+ uses_ptr_arithmetic = True
240
225
  code_for_workitem = {"g": lambda x: f"gid.{chr(120+x)}", "l": lambda x: f"lid.{chr(120+x)}"}
226
+ # uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
241
227
  extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
228
+ type_map = {dtypes.bfloat16: "bfloat"}
229
+ code_for_op = {**CStyleLanguage().code_for_op,
230
+ BinaryOps.MAX: lambda a,b,dtype: f"(bfloat)max((float){a},(float){b})" if dtype == dtypes.bfloat16 else f"max({a},{b})",
231
+ UnaryOps.SQRT: lambda x,dtype: f"(bfloat)sqrt({x})" if dtype == dtypes.bfloat16 else f"sqrt({x})",
232
+ UnaryOps.EXP2: lambda x,dtype: f"(bfloat)exp2({x})" if dtype == dtypes.bfloat16 else f"exp2({x})",
233
+ UnaryOps.LOG2: lambda x,dtype: f"(bfloat)log2({x})" if dtype == dtypes.bfloat16 else f"log2({x})",
234
+ UnaryOps.SIN: lambda x,dtype: f"(bfloat)sin({x})" if dtype == dtypes.bfloat16 else f"sin({x})",}
235
+
242
236
  def render_cast(self, x: List[str], var_dtype: DType, bitcast=False) -> str:
243
- return f"as_type<{var_dtype.name}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
244
- MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
245
-
246
- code_for_op_half = {
247
- BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})",
248
- UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})" if dtype != dtypes.half else f"hsqrt({x})",
249
- UnaryOps.SIN: lambda x,dtype: f"sin({x})" if dtype != dtypes.half else f"hsin({x})",
250
- UnaryOps.LOG2: lambda x,dtype: f"log2({x})" if dtype != dtypes.half else f"hlog2({x})",
251
- UnaryOps.EXP2: lambda x,dtype: f"exp2({x})" if dtype != dtypes.half else f"hexp2({x})",
252
- }
253
-
254
- class CUDALanguage(CStyleLanguage):
255
- kernel_prefix = "#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern \"C\" __global__ "
237
+ return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
238
+
239
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
240
+ prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
241
+ for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
242
+ simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
243
+ b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
244
+ return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
245
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
246
+
247
+ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
248
+ BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
249
+ UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
250
+ UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
251
+ UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
252
+ UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
253
+
254
+ _nms = "xyzwabcdefghijkl"
255
+ def _make_cuda_dtype(base_type, name, cnt):
256
+ vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
257
+ return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
258
+
259
+ class CUDARenderer(CStyleLanguage):
260
+ device = "CUDA"
261
+ global_max = (2147483647, 65535, 65535)
262
+ local_max = (1024, 1024, 64)
263
+ shared_max = 49152
264
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
265
+ def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
266
+
267
+ # language options
268
+ kernel_prefix = "extern \"C\" __global__ "
256
269
  smem_prefix = "__shared__ "
257
270
  smem_prefix_for_cast = False
258
271
  barrier = "__syncthreads();"
259
272
  float4 = "make_float4"
260
- code_for_workitem = {
261
- "g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
262
- "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"
263
- }
273
+ code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
274
+ "i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
264
275
  code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
265
- half_prekernel ="#include <cuda_fp16.h>\n"+"#include <cuda_bf16.h>\n"+"""
266
- struct half4 { half x, y, z, w; };
267
- __device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }
268
- """
269
276
  type_map = {dtypes.bfloat16: "nv_bfloat16"}
270
- CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
271
-
272
- class HIPLanguage(CUDALanguage):
273
- kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
274
- typedef float float8 __attribute__((ext_vector_type(8)));
275
- __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
276
- extern "C" __global__
277
- """
278
- launch_bounds = True
279
- uses_ptr_arithmetic=True
280
- half_prekernel = "#include <hip/hip_fp16.h>\n" + """
281
- typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4;
282
- __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }
283
- typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8;
284
- __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }
285
- typedef _Float16 half16 __attribute__((ext_vector_type(16)));
286
- __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d,
287
- half e, half f, half g, half h, half i, half j, half k, half l) {
288
- return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
289
- """
277
+
278
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
279
+ # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
280
+ dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
281
+
282
+ prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
283
+ if any(uop.dtype == dtypes.half for uop in uops):
284
+ prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
285
+
286
+ if any(uop.dtype == dtypes.bfloat16 for uop in uops):
287
+ prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
288
+
289
+ # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
290
+ for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
291
+ fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
292
+ prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
293
+ asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
294
+ : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
295
+ return c;}}""")
296
+
297
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
298
+
299
+ code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
300
+ UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
301
+ UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
302
+ UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
303
+ # TODO: MAX with int uses fmax_f32?
304
+ BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
305
+
306
+ def _make_hip_code_for_op():
307
+ def wrapper(key, func):
308
+ def cast_bf16(*args):
309
+ if args[-1] == dtypes.bfloat16:
310
+ operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
311
+ return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
312
+ return func(*args)
313
+ return cast_bf16
314
+ return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
315
+
316
+ def _make_hip_dtype(base_type, name, cnt):
317
+ elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
318
+ return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
319
+ f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
320
+
321
+ class AMDRenderer(CStyleLanguage):
322
+ device = "AMD"
323
+ shared_max = 65536
324
+ tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
325
+
326
+ # language options
327
+ kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
328
+ extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
329
+ extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
330
+ extern "C" {\n""" + "".join([
331
+ f""" __attribute__((device)) __attribute__((const)) {dt} __ocml_fmax_f{n}({dt}, {dt});
332
+ __attribute__((device)) __attribute__((pure)) {dt} __ocml_exp2_f{n}({dt});
333
+ __attribute__((device)) __attribute__((pure)) {dt} __ocml_log2_f{n}({dt});
334
+ __attribute__((device)) __attribute__((const)) {dt} __ocml_sqrt_f{n}({dt});
335
+ __attribute__((device)) {dt} __ocml_sin_f{n}({dt});\n""" for dt,n in [("float",32), ("double",64), ("_Float16",16)]]) +\
336
+ '}\nextern "C" __attribute__((global))'
337
+ code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
338
+ "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
339
+ code_for_op = _make_hip_code_for_op()
340
+ smem_prefix = "__attribute__((shared))"
341
+ barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
342
+ '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
343
+ float4 = "make_float4"
344
+ uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
290
345
  type_map = {dtypes.bfloat16: "hip_bfloat16"}
291
- HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
346
+
347
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
348
+ prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
349
+ vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
350
+
351
+ # TODO: add BF16 vec dts
352
+ if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
353
+ struct hip_bfloat16 {
354
+ unsigned short data;
355
+ inline __attribute__((device)) hip_bfloat16(float val) {
356
+ union { float fp32; unsigned int u32; } u = {val};
357
+ if (~u.u32 & 0x7f800000) { u.u32 += 0x7fff + ((u.u32 >> 16) & 1); } else if (u.u32 & 0xffff) { u.u32 |= 0x10000; }
358
+ data = (u.u32 >> 16);
359
+ }
360
+ inline __attribute__((device)) operator float() const {
361
+ unsigned int uval = data << 16;
362
+ return *reinterpret_cast<float*>(&uval);
363
+ }
364
+ };
365
+ static inline __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
366
+ static inline __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
367
+ """)
368
+
369
+ if any(uop.dtype == dtypes.half for uop in uops):
370
+ prefix.append("#define half _Float16")
371
+ vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
372
+
373
+ prefix += [_make_hip_dtype(*x) for x in vec_dts]
374
+
375
+ for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
376
+ if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
377
+ else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
378
+ half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
379
+ c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
380
+ for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
381
+ return super().render_kernel(function_name, kernel, bufs, uops, prefix)
382
+
383
+ def get_kernel_modifier(self, uops:UOpGraph) -> str:
384
+ requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
385
+ # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
386
+ # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
387
+ return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
388
+
389
+ class NVRenderer(CUDARenderer): device = "NV"