tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,31 @@
1
- from typing import Optional, Union, Literal, Callable, cast
1
+ from typing import Literal, Callable, cast
2
2
  import os, math, sys
3
3
  from collections import defaultdict, Counter
4
- from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
4
+ from tinygrad.codegen.opt import tc
5
+ from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
5
6
  from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
6
- from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
7
- from tinygrad.renderer import Renderer, TensorCore
7
+ from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate
8
+ from tinygrad.renderer import Renderer
8
9
  from tinygrad.codegen.devectorizer import no_vectorized_alu
9
10
 
10
11
  base_rewrite = PatternMatcher([
11
- (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
12
- (UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
12
+ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
13
13
  (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
14
14
  (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
15
15
  (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
16
16
  # r method accesses
17
17
  (UPat(Ops.RANGE, name="x"),
18
- lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
18
+ lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
19
19
  (UPat(Ops.VECTORIZE, name="x"),
20
20
  lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
21
- (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
21
+ f"{ctx.float4_style[0]}{','.join([ctx[y] for y in x.src])}{ctx.float4_style[1]}"),
22
22
  (UPat(Ops.CAST, name="x"), lambda ctx,x:
23
23
  f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
24
24
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
25
25
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
26
26
  (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
27
27
  (UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
28
- (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
28
+ (UPat(Ops.PRECAST, name="x"), lambda ctx,x: ctx[x.src[0]]),
29
29
  (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
30
30
  # const
31
31
  (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"),
@@ -33,39 +33,38 @@ base_rewrite = PatternMatcher([
33
33
  (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None),
34
34
  (UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
35
35
  (UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
36
- (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
37
- (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
36
+ (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}ull"),
37
+ (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
38
38
  (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
39
39
  # consts are rendered to larger type and casted
40
40
  (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
41
41
  (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
42
- (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"),
42
+ (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
43
43
  # default const render
44
44
  (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
45
45
  # new load/store
46
- (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
46
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
47
47
  lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
48
- (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
49
- (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
48
+ (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True),
49
+ lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
50
+ (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"),
50
51
  (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
51
52
  # alu/gep
53
+ # TODO: look for left-associative
52
54
  (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
53
- *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
55
+ *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR, Ops.OR, Ops.AND} else ctx[v] for v in x.src]), x.dtype)),
54
56
  (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
55
- (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
56
- f".{'xyzwabcd'[x.arg[0]]}")),
57
+ (f"[{x.arg[0]}]" if x.src[0].dtype.count > ctx.gep_arr_threshold else f".{'xyzwabcd'[x.arg[0]]}")),
57
58
  # custom passes through with format
58
- (UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
59
+ (UPat((Ops.CUSTOM, Ops.CUSTOMI), name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
59
60
  ])
60
61
 
61
62
  extra_pm = PatternMatcher([
62
- # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
63
- (UPat(Ops.BITCAST, name="x"),
64
- lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
65
- # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
66
- (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
63
+ # insert a PRECAST before BITCAST to force it to be rendered. not needed on all backends?
64
+ (UPat(Ops.BITCAST, name="x"), lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.PRECAST, x.src[0].dtype, x.src),))
65
+ if x.src[0].op not in {Ops.PRECAST, Ops.LOAD, Ops.CUSTOM} else None),
67
66
  # devectorize any bools
68
- (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
67
+ (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.INDEX), dtype=dtypes.bool, name="alu"), no_vectorized_alu),
69
68
  # CAST (from bool) can't be vectorized
70
69
  (UPat(Ops.CAST, src=(UPat(dtype=dtypes.bool),), name="alu"), no_vectorized_alu),
71
70
  # WHERE can't be vectorized
@@ -74,8 +73,12 @@ extra_pm = PatternMatcher([
74
73
 
75
74
  def uops_to_dtypes(uops:list[UOp]) -> list[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
76
75
 
76
+ # (name, dims, dtype_in, dtype_out, device, threads, upcast_axes, reduce_axes)
77
+ def wmma_args(uops:list[UOp]):
78
+ return dedup((uop.arg[0], uop.arg[1], uop.src[0].dtype.scalar(), uop.dtype.scalar(), *(uop.arg[4:8])) for uop in uops if uop.op is Ops.WMMA)
79
+
77
80
  class CStyleLanguage(Renderer):
78
- kernel_prefix: str = ""
81
+ kernel_typedef: str = "void"
79
82
  buffer_prefix: str = ""
80
83
  buffer_suffix: str = ""
81
84
  smem_align: str = ""
@@ -83,30 +86,33 @@ class CStyleLanguage(Renderer):
83
86
  smem_prefix_for_cast: bool = True
84
87
  arg_int_prefix: str = "const int"
85
88
  barrier: str = ""
86
- code_for_workitem: dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
89
+ code_for_workitem: dict[Literal["g", "l", "i"], Callable] = {}
87
90
  extra_args: list[str] = []
88
- float4: Optional[str] = None
91
+ float4: str|None = None
92
+ float4_style: tuple[str, str] = ('(', ')')
93
+ gep_arr_threshold: int = 4
89
94
  type_map: dict[DType, str] = {}
90
95
  infinity: str = "INFINITY"
91
96
  nan: str = "NAN"
92
97
  code_for_op: dict = {
93
98
  Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}",
94
99
  Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})",
100
+ Ops.TRUNC: lambda x,dtype: f"trunc({x})",
95
101
  Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})",
96
102
  Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
97
103
  Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
98
104
  Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
99
- Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" }
105
+ Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
100
106
 
101
107
  string_rewrite = base_rewrite
102
108
  extra_matcher = extra_pm
103
109
 
104
- def get_kernel_modifier(self, uops:list[UOp]) -> str: return ""
105
110
  def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
106
111
  tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
107
112
  buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
108
113
  self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
109
- prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
114
+ launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
115
+ prg = ''.join([f"{self.kernel_typedef.format(launch_bounds=launch_bounds)} {function_name}(",] +
110
116
  [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
111
117
  [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
112
118
  return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
@@ -115,12 +121,15 @@ class CStyleLanguage(Renderer):
115
121
  def render_dtype(self, dt:DType, mutable=True) -> str:
116
122
  if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
117
123
  if isinstance(dt, PtrDType):
118
- return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
124
+ prefix = ""
125
+ if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
126
+ if dt.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
127
+ return prefix + self.render_dtype(dt.base) + "*"
119
128
  if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
120
129
  return self.type_map.get(scalar:=dt.scalar(), scalar.name)
121
130
 
122
131
  def __getitem__(self, key): return self.r[key] # hacky helper
123
- def render(self, uops:list[UOp]) -> str:
132
+ def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
124
133
  r: dict[UOp, str] = {}
125
134
  self.r = r
126
135
 
@@ -131,98 +140,107 @@ class CStyleLanguage(Renderer):
131
140
  c: defaultdict[str, int] = defaultdict(int)
132
141
  name = "test"
133
142
  for u in uops:
134
- if u.op is Ops.NAME:
135
- name = u.arg
143
+ if u.op is Ops.NOOP: continue
144
+ if u.op is Ops.SINK:
145
+ if u.arg is not None: name = u.arg.function_name
136
146
  continue
137
147
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
138
- r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
148
+ r[u] = (f"data{u.arg}_{sz}" if (sz:=cast(PtrDType, u.dtype).size) > 0 else f"data{u.arg}") if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
139
149
  bufs[u] = (r[u], (u.dtype, False))
140
150
  continue
141
151
 
142
152
  # mark buffers that we store to writable
143
153
  if u.op is Ops.STORE:
144
- for up in u.src[0].toposort:
154
+ for up in u.src[0].toposort():
145
155
  if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
146
156
 
147
157
  # naming
148
158
  prefix = None
149
- if u.op is Ops.SPECIAL:
150
- r[u] = u.arg[0]
159
+ if u.op is Ops.SPECIAL: r[u] = u.arg[0]
160
+ elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg}"
151
161
  else:
152
- prefix = {Ops.RANGE: "ridx", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
153
- Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
154
- Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
162
+ prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
163
+ Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
164
+ Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
155
165
  r[u] = f"{prefix}{c[prefix]}"
156
166
 
157
167
  l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
158
168
  assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
159
169
 
160
170
  if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
161
- if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOM} or \
162
- (u.op in {Ops.VECTORIZE, *GroupOp.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")):
171
+ if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
172
+ (u.op is Ops.LOAD and cast(PtrDType, u.src[0].dtype).addrspace == AddrSpace.REG) or \
173
+ (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
174
+ (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
163
175
  r[u] = l
164
176
  else:
165
- if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
166
- if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
167
- else:
168
- l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
177
+ if u.op in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} or u.dtype == dtypes.void: pass
178
+ else: l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
169
179
  kernel.append(" "*depth + l)
170
180
  if prefix: c[prefix] += 1 # if it was used, increment
171
181
  if u.op in {Ops.IF, Ops.RANGE}: depth += 1
172
182
  del self.r
173
183
 
174
184
  # NOTE: this relies on bufs dict preserving order
175
- return self.render_kernel(name, kernel, list(bufs.values()), uops)
185
+ return (name, kernel, list(bufs.values()))
186
+ def render(self, uops:list[UOp]) -> str: return self.render_kernel(*self._render(uops), uops)
176
187
 
177
188
  class ClangRenderer(CStyleLanguage):
178
189
  device = "CPU"
179
190
  float4 = "(float4)"
191
+ float4_style = ('{', '}')
192
+ gep_arr_threshold = 0
180
193
  has_local = False
181
194
  global_max = None
182
195
  infinity = "__builtin_inff()"
183
196
  nan = '__builtin_nanf("")'
184
- amx_tc = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt, swizzle=(None,((),(4,5,6,7,0,1,2,3))),
185
- opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
186
- if AMX: tensor_cores = amx_tc
197
+ if AMX: tensor_cores = tc.amx
187
198
 
188
199
  # language options
189
200
  buffer_suffix = " restrict"
190
201
  type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"}
191
- code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}),
192
- Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"}
202
+ code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2, Ops.TRUNC]}),
203
+ Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})",
204
+ Ops.TRUNC: lambda x,dtype: f"__builtin_trunc({x})" if dtype == dtypes.float64 else f"__builtin_truncf({x})"}
193
205
  # LLVM legalizes double => half cast on systems that don't support it natively (like x86 cpus without AVX512-FP16) into a compiler-rt libcall.
194
- extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16))]) + \
195
- CStyleLanguage.extra_matcher
206
+ extra_matcher = PatternMatcher([(UPat.var("x", dtypes.float64).cast(dtypes.float16), lambda x: x.cast(dtypes.float32).cast(dtypes.float16)),
207
+ (UPat((Ops.SQRT, Ops.TRUNC), name="alu"), no_vectorized_alu),]) + CStyleLanguage.extra_matcher
196
208
 
197
209
  if sys.platform == 'win32':
198
- kernel_prefix = "__attribute__((ms_abi)) "
210
+ kernel_typedef = "__attribute__((ms_abi)) void"
199
211
  def render_vector_prefix(self, dt:DType) -> str:
200
- # round (down) to power of two
201
- alignment = 2**int(math.log2(dt.itemsize))
212
+ # round (down) to power of two (this is actually the default clang behavior)
213
+ alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
202
214
  return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
203
215
 
204
- def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
216
+ def _render_defines(self, uops) -> list[str]:
205
217
  prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
206
218
  # https://github.com/corsix/amx
207
- for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
219
+ for name, (N, M, _), dtype_in, _, _, _, _, _ in wmma_args(uops):
208
220
  prefix += [
209
221
  '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
210
222
  '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
211
223
  ]
212
224
  # 'static' in C roughly means that function symbol isn't exported. LLVM puts those symbols at the end of object file which allows Clang JIT
213
- # to just jump at the start of a shellcode whithout having to deal with symbols or trampolines at all. This is better than having to inline
225
+ # to just jump at the start of a shellcode without having to deal with symbols or trampolines at all. This is better than having to inline
214
226
  # wmma function every time it is called or wasting complexity on a symbol parsing and a memory page on trampoline.
215
227
  prefix += [f"""static {(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
216
228
  AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
217
229
  AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
218
230
  for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
219
- return super().render_kernel(function_name, kernel, bufs, uops, prefix)
231
+ return prefix
232
+ def _render_body(self, function_name, kernel, bufs, uops, pref=None) -> str: return super().render_kernel(function_name, kernel, bufs, uops, pref)
233
+ def _render_entry(self, function_name:str, bufs:list[tuple[str,tuple[DType,bool]]]) -> str: return ""
234
+
235
+ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
236
+ defines = '\n'.join(self._render_defines(uops))
237
+ return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
220
238
 
221
239
  class OpenCLRenderer(CStyleLanguage):
222
240
  device = "GPU"
223
241
 
224
242
  # language options
225
- kernel_prefix = "__kernel "
243
+ kernel_typedef = "__kernel void"
226
244
  buffer_prefix = "__global "
227
245
  smem_align = "__attribute__ ((aligned (16))) "
228
246
  smem_prefix = "__local "
@@ -235,7 +253,7 @@ class OpenCLRenderer(CStyleLanguage):
235
253
  string_rewrite = PatternMatcher([
236
254
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
237
255
  # load/store image (OpenCL)
238
- (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
256
+ (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
239
257
  lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
240
258
  (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
241
259
  lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
@@ -248,35 +266,31 @@ class OpenCLRenderer(CStyleLanguage):
248
266
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
249
267
 
250
268
  class IntelRenderer(OpenCLRenderer):
251
- device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
252
- tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
253
- opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
269
+ device, suffix, kernel_typedef = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
270
+ tensor_cores = tc.intel
254
271
 
255
272
  string_rewrite = PatternMatcher([
256
- (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
257
- (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
273
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float),)), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
274
+ (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16),)), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
258
275
  ]) + OpenCLRenderer.string_rewrite
259
276
 
260
277
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
261
278
  prefix = []
262
- for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
263
- dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
264
- prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
279
+ for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops):
280
+ dt_in = ("ushort", "bf16") if dtype_in == dtypes.bfloat16 else (dtype_in.name, "f16")
281
+ prefix.append(f"""{dtype_out.name}8 __{name}({dt_in[0]}16 a, {dt_in[0]}16 b, {dtype_out.name}8 c) {{
265
282
  return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
266
283
  return super().render_kernel(function_name, kernel, bufs, uops, prefix or None)
267
284
 
268
285
  class MetalRenderer(CStyleLanguage):
269
286
  device = "METAL"
270
287
  shared_max = 32768
271
- tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
272
- swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
273
- (dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
274
- def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
288
+ def __init__(self): self.tensor_cores = tc.metal if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
275
289
 
276
290
  # language options
277
- kernel_prefix = "kernel "
291
+ kernel_typedef = "kernel void"
278
292
  buffer_prefix = "device "
279
- smem_prefix = "threadgroup "
293
+ smem_prefix = "threadgroup __attribute__((aligned(16))) "
280
294
  arg_int_prefix = "constant int&"
281
295
  barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
282
296
  float4 = "float4"
@@ -300,45 +314,35 @@ class MetalRenderer(CStyleLanguage):
300
314
  ]) + base_rewrite
301
315
 
302
316
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
303
- prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
304
- for arg in wmma_args: prefix.append(
305
- f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{
306
- simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c;
317
+ prefix = ["#include <metal_stdlib>","using namespace metal;"]
318
+ for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): prefix.append(
319
+ f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
320
+ simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
307
321
  mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
308
322
  mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1];
309
- simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
323
+ simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dstr_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""")
310
324
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
311
325
 
312
326
  _nms = "xyzwabcdefghijkl"
313
- cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
314
327
 
315
328
  class CUDARenderer(CStyleLanguage):
316
329
  device = "CUDA"
317
330
  global_max = (2147483647, 65535, 65535)
318
331
  local_max = (1024, 1024, 64)
319
332
  shared_max = 49152
320
- # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
321
- tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
322
- swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
323
- (dtypes.half,dtypes.half)]]
324
- tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
325
- swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
326
- tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
327
- swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
328
-
329
- tc_sm80 = tc_81616 + tc_8168_f16
330
- if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
331
- tc_sm75 = tc_8168_f16
333
+
332
334
  def __init__(self, arch:str):
333
- self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
335
+ self.tensor_cores, self.arch = tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [], arch
334
336
  def __reduce__(self): return self.__class__, (self.arch,)
335
337
 
336
338
  # language options
337
- kernel_prefix = "extern \"C\" __global__ "
338
- smem_prefix = "__shared__ "
339
+ # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
340
+ kernel_typedef = "extern \"C\" __global__ void __launch_bounds__({launch_bounds})"
341
+ smem_prefix = "__shared__ __align__(16) "
339
342
  smem_prefix_for_cast = False
340
343
  barrier = "__syncthreads();"
341
344
  float4 = "make_float4"
345
+ gep_arr_threshold = 8
342
346
  code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}",
343
347
  "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"}
344
348
  code_for_op = { **CStyleLanguage.code_for_op,
@@ -365,7 +369,7 @@ class CUDARenderer(CStyleLanguage):
365
369
 
366
370
  dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
367
371
  dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
368
- for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
372
+ for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
369
373
  upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
370
374
  wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
371
375
  n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
@@ -383,11 +387,6 @@ class CUDARenderer(CStyleLanguage):
383
387
 
384
388
  return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
385
389
 
386
- def get_kernel_modifier(self, uops:list[UOp]) -> str:
387
- maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
388
- # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
389
- return f"__launch_bounds__({maxThreadsPerBlock}) "
390
-
391
390
  def cast_float_to_bf16(x: UOp) -> UOp:
392
391
  assert x.dtype == dtypes.float, "cast float -> bf16 must start with float"
393
392
  x = x.bitcast(dtypes.uint)
@@ -397,27 +396,40 @@ def cast_float_to_bf16(x: UOp) -> UOp:
397
396
  class AMDRenderer(CStyleLanguage):
398
397
  device = "AMD"
399
398
  shared_max = 65536
400
- # https://gpuopen.com/learn/wmma_on_rdna3/
401
- tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
402
- opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
403
- for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
399
+ # NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
400
+ global_max = (2147483647, 65535, 65535)
401
+
402
+ @staticmethod
403
+ def get_tensor_cores(arch):
404
+ return {"gfx942": tc.amd_cdna, "gfx950": tc.amd_cdna, "gfx1200": tc.amd_rdna4, "gfx1201": tc.amd_rdna4}.get(arch.split(":")[0], tc.amd_rdna3)
405
+ def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900, gfx1201 => RX 9700
406
+ self.arch = arch
407
+ self.tensor_cores = self.get_tensor_cores(arch)
408
+ if self.tensor_cores == tc.amd_cdna:
409
+ self.string_rewrite = PatternMatcher([
410
+ (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]}, 0, 0, 0)")]) + base_rewrite
411
+ def __reduce__(self): return self.__class__, (self.arch,)
404
412
 
405
413
  # language options
406
414
  ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
407
415
  ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
408
416
  for dt, n in [(dtype.name, dtype.itemsize * 8) for dtype in [dtypes.float, dtypes.double, dtypes.half]]
409
- for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", "")]]
417
+ for name, atr in [("fmax", "const"), ("exp2", "pure"), ("log2", "pure"), ("sqrt", "const"), ("sin", ""), ("trunc", "")]]
410
418
 
411
- kernel_prefix = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
412
- kernel_prefix += '\nextern "C" __attribute__((global))'
419
+ kernel_typedef = "\n".join(f'extern "C" __attribute__((device{f", {atr}" if atr else ""})) {dto} {meth}({dti});' for meth,dti,dto,atr in ockl+ocml)
420
+ # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
421
+ # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
422
+ kernel_typedef += '\nextern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, {launch_bounds})))'
413
423
  code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
414
424
  "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
415
425
  code_for_op = { **CStyleLanguage.code_for_op,
426
+ Ops.TRUNC: lambda x,dtype: f"__ocml_trunc_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
416
427
  Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
417
428
  Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
418
429
  Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})",
419
430
  Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" }
420
- smem_prefix = "__attribute__((shared))"
431
+ smem_prefix = "__attribute__((shared, aligned(16)))"
432
+ smem_prefix_for_cast: bool = False
421
433
  barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
422
434
  '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
423
435
  float4 = "make_float4"
@@ -431,12 +443,15 @@ class AMDRenderer(CStyleLanguage):
431
443
  (UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
432
444
  lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
433
445
  # add float intermediate casting for bfloat16
434
- (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
435
- (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
446
+ (UPat(Ops.CAST, name="x", src=(UPat.var("y", dtypes.bfloat16),)),
447
+ lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
448
+ (UPat(Ops.CAST, dtypes.bfloat16, (UPat.var("x"),)),
449
+ lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
436
450
  # bfloat16 casting
437
451
  (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
438
- (UPat(Ops.CAST, dtypes.float, UPat.var("x", dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
439
- (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
452
+ (UPat(Ops.CAST, dtypes.float, (UPat.var("x", dtypes.bfloat16),)),
453
+ lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
454
+ (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var("x", dtype=dtypes.float),)), cast_float_to_bf16)]) + extra_pm
440
455
 
441
456
  def render_vector_prefix(self, dtype:DType) -> str:
442
457
  vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
@@ -445,25 +460,25 @@ class AMDRenderer(CStyleLanguage):
445
460
 
446
461
  def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
447
462
  prefix = ["#define INFINITY (__builtin_inff())","#define NAN (__builtin_nanf(\"\"))","typedef long unsigned int size_t;","#define half _Float16"]
448
-
463
+ type_map = { dtypes.bfloat16: "bf16", dtypes.float: "f32", dtypes.half: "f16" }
449
464
  used_dtypes = uops_to_dtypes(uops)
450
465
  if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("typedef unsigned short hip_bfloat16;")
451
466
  prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
452
467
 
453
- for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
454
- if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
455
- else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
468
+ for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
469
+ if self.tensor_cores == tc.amd_cdna:
470
+ prefix.append(f"#define __{name} __builtin_amdgcn_mfma_f32_16x16x16{'f16' if dtype_in == dtypes.half else 'bf16_1k'}")
471
+ # #define __WMMA_16_16_16_half_half __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12
472
+ elif self.tensor_cores == tc.amd_rdna4:
473
+ prefix.append(f"#define __{name} __builtin_amdgcn_wmma_{type_map[dtype_out]}_16x16x16_{type_map[dtype_in]}_w32_gfx12")
474
+ elif dtype_out == dtypes.float:
475
+ prefix.append(f"#define __{name} __builtin_amdgcn_wmma_f32_16x16x16_{'f16' if dtype_in == dtypes.half else 'bf16'}_w32")
476
+ else: prefix.append(f"static inline __attribute__((device)) half8 __{name}"+"""(half16 a, half16 b, half8 c) {
456
477
  half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
457
478
  c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
458
479
  for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
459
480
  return super().render_kernel(function_name, kernel, bufs, uops, prefix)
460
481
 
461
- def get_kernel_modifier(self, uops:list[UOp]) -> str:
462
- requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
463
- # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
464
- # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
465
- return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
466
-
467
482
  class NVRenderer(CUDARenderer): device = "NV"
468
483
  class HIPRenderer(AMDRenderer): device = "HIP"
469
484
  class QCOMRenderer(OpenCLRenderer): device = "QCOM"