tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
- from typing import List, Dict, cast
2
- import math, struct
1
+ from typing import cast
2
+ import math, struct, sys
3
3
  from tinygrad.renderer import Renderer
4
+ from tinygrad.renderer.cstyle import ClangRenderer
4
5
  from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
5
6
  from tinygrad.dtype import dtypes, DType, PtrDType, truncate
7
+ from tinygrad.helpers import prod, AMX
6
8
 
7
9
  def ldt(dt:DType):
10
+ if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
8
11
  if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
9
12
  return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
10
13
  dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
@@ -20,7 +23,7 @@ def lcast(input_type:DType, output_type:DType):
20
23
  if dtypes.is_float(input_type):
21
24
  if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
22
25
  if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
23
- if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
26
+ if dtypes.is_unsigned(input_type) or dtypes.is_bool(input_type):
24
27
  if dtypes.is_float(output_type): return 'uitofp'
25
28
  if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
26
29
  if dtypes.is_int(input_type):
@@ -28,6 +31,19 @@ def lcast(input_type:DType, output_type:DType):
28
31
  if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
29
32
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
30
33
 
34
+ # https://github.com/corsix/amx
35
+ def render_wmma(ctx, wmma: UOp) -> str:
36
+ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
37
+
38
+ return "\n".join([
39
+ *[f' store {ldt(src.dtype)} {ctx[src]}, {ldt(src.dtype.ptr())} {ctx[wmma]}_amx{i}, align {src.dtype.itemsize}' for i,src in enumerate(wmma.src)],
40
+ f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 0})", "~{{memory}}"() #0; AMX set', # set
41
+ *[f' {ctx[wmma]}_ld{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(4,f"{ctx[wmma]}_ld{i}")} ldz' for i in range(16)], # ldz
42
+ f' {AMX(0, f"{ctx[wmma]}_ptr_amx1")} ldx\n {AMX(1, f"{ctx[wmma]}_ptr_amx0")} ldy\n {AMX(12, 0)} fma32', # ldx ldy fma
43
+ *[f' {ctx[wmma]}_st{i} = add i64 {ctx[wmma]}_ptr_amx2, {i*4<<56 | i*64}\n {AMX(5,f"{ctx[wmma]}_st{i}")} stz' for i in range(16)], # stz
44
+ f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
45
+ f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
46
+
31
47
  # llvm ops, lop[<dtype>][<op>]
32
48
  unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
33
49
  Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
@@ -36,7 +52,7 @@ flags = " nsz arcp contract afn"
36
52
  float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
37
53
  lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
38
54
 
39
- llvm_rewrite = PatternMatcher([
55
+ base_rewrite = PatternMatcher([
40
56
  # memory load/store
41
57
  (UPat(Ops.INDEX, name="x"), lambda ctx,x:
42
58
  f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
@@ -49,73 +65,98 @@ llvm_rewrite = PatternMatcher([
49
65
  (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
50
66
  (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
51
67
 
68
+ # GEP/VECTORIZE/CAST for float4 support
69
+ (UPat(Ops.GEP, name="x"), lambda ctx,x: f" {ctx[x]} = extractelement {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i32 {x.arg[0]}"),
70
+ (UPat(Ops.VECTORIZE, src=UPat.var('y'), name="x"), lambda ctx,x,y:
71
+ f" {ctx[x]}_z = insertelement <1 x {ldt(y.dtype)}> poison, {ldt(y.dtype)} {ctx[y]}, i32 0\n"
72
+ f" {ctx[x]} = shufflevector <1 x {ldt(y.dtype)}> {ctx[x]}_z, <1 x {ldt(y.dtype)}> poison, <{x.dtype.count} x i32> zeroinitializer"),
73
+ (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
74
+ f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
75
+ f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
76
+ (UPat(Ops.CAST, name="x"), lambda ctx,x:
77
+ f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
78
+
52
79
  # unary/binary/ternary ops
53
- (UPat(Ops.SQRT, name="x"), lambda ctx,x:
54
- f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
55
80
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
56
81
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
57
- (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
82
+ (UPat(GroupOp.Binary, name="x"), lambda ctx,x:
83
+ f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
58
84
  (UPat(Ops.WHERE, name="x"), lambda ctx,x:
59
85
  f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
60
86
 
61
87
  # range
62
88
  (UPat(Ops.RANGE, name="x"), lambda ctx,x:
63
- f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
64
- f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
65
- f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
89
+ f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
90
+ f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
91
+ f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
66
92
  (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
67
- f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
93
+ f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
68
94
  f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
69
- f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
95
+ f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
70
96
 
71
97
  # if
72
98
  (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
73
99
  (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
100
+
101
+ # wmma
102
+ (UPat(Ops.WMMA, name="wmma"), render_wmma),
74
103
  ])
75
104
 
105
+ def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
106
+ u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
107
+ return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
108
+
76
109
  class LLVMRenderer(Renderer):
77
110
  device = "LLVM"
78
- supports_float4 = False
111
+ abi = 'win64cc' if sys.platform == 'win32' else None
112
+ supports_float4 = True
79
113
  has_local = False
80
114
  has_shared = False
81
115
  global_max = None
116
+ string_rewrite = base_rewrite
117
+ if AMX: tensor_cores = ClangRenderer.amx_tc
82
118
 
83
119
  extra_matcher = PatternMatcher([
84
120
  # rewrite RECIP with FDIV
85
121
  (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
86
122
  # rewrite cast to bool to CMPNE 0
87
123
  (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
88
- # *** also in cstyle ***
89
- # gate any stores that aren't gated with ifs
90
- (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
91
- lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
92
124
  # rewrite MAX to CMPLT + WHERE
93
125
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
126
+ # rewrite bf16 CAST(LOAD) to CAST(BITCAST)
127
+ (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
94
128
  ])
95
129
 
96
- def render(self, name: str, uops: List[UOp]) -> str:
97
- r: Dict[UOp, str] = {}
98
- args: List[str] = []
99
- kernel: List[str] = []
100
- end_lines: Dict[str, None] = {}
130
+ def render(self, uops: list[UOp]) -> str:
131
+ r: dict[UOp, str] = {}
132
+ args: list[str] = []
133
+ kernel: list[str] = []
134
+ end_lines: dict[str, None] = {}
101
135
  vc = -1
102
136
 
103
- # prealloc all assigns
104
- acc_to_assign: Dict[UOp, UOp] = {}
137
+ acc_to_assign: dict[UOp, UOp] = {}
105
138
  for u in uops:
106
- if u.op is Ops.ASSIGN:
139
+ if u.op is Ops.ASSIGN: # prealloc all assigns
107
140
  vc += 1
108
141
  r[u] = r[u.src[1]] = f"%assign{vc}"
109
142
  assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
110
143
  acc_to_assign[u.src[0]] = u.src[1]
144
+ if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
145
+ vc += 1
146
+ r[u] = f"%wmma{vc}"
147
+ for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
148
+ kernel += [f" {r[u]}_amx{i} = alloca {ldt(dtype)}, align {dtype.itemsize}",
149
+ f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
111
150
 
151
+ name = "test"
112
152
  for u in uops:
113
- # hack for defining sqrt function (TODO: can we get a transcendental for this?)
114
- if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
115
-
153
+ if u.op is Ops.NAME:
154
+ name = u.arg
155
+ continue
116
156
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
117
157
  r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
118
- args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
158
+ # NOTE: MallocAllocator promises 0x20 alignment
159
+ args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
119
160
  elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
120
161
  elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
121
162
  elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
@@ -127,16 +168,24 @@ class LLVMRenderer(Renderer):
127
168
  r[u] = f"%v{vc}"
128
169
 
129
170
  # do the rendering of the llvm ir code
130
- if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
171
+ if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
172
+ raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
131
173
  kernel.append(cast(str, l))
132
174
 
133
175
  # generate the phi nodes for the assigns
134
176
  if u.op is Ops.RANGE:
135
177
  for x in acc_to_assign:
136
- if u in x.src: # if this range is relevent for this acc
178
+ if u in x.src: # if this range is relevant for this acc
137
179
  vc += 1
138
- kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
180
+ kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
139
181
  r[x] = f"%acc{vc}"
140
182
 
141
- # output the function
142
- return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
183
+ # output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
184
+ return f'''\
185
+ define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
186
+ {chr(10).join(kernel)}
187
+ ret void
188
+ }}
189
+ {chr(10).join(end_lines.keys())}
190
+ attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
191
+ '''
tinygrad/renderer/ptx.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
1
+ from typing import cast, Callable
2
2
  import struct
3
3
  from collections import defaultdict
4
4
  from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
5
5
  from tinygrad.dtype import dtypes, DType, PtrDType
6
6
  from tinygrad.renderer import Renderer
7
7
  from tinygrad.renderer.cstyle import CUDARenderer
8
- from tinygrad.helpers import prod, flatten
8
+ from tinygrad.helpers import flatten, get_single_element
9
9
 
10
10
  def render_val(x, dtype):
11
11
  if dtypes.is_float(dtype):
@@ -14,30 +14,30 @@ def render_val(x, dtype):
14
14
  return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
15
15
  return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
16
16
 
17
- asm_for_op: Dict[Ops, Callable] = {
17
+ asm_for_op: dict[Ops, Callable] = {
18
18
  Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
19
19
  Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
20
20
  Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
21
21
  Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
22
- Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
23
- Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
24
- Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
25
- Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
26
- Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
27
- Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
28
- Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
22
+ Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
23
+ Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
24
+ Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if dt == dtypes.bool else f"xor.b{name[1:]} {d}, {a}, {b};",
25
+ Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};",
26
+ Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
27
+ Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
28
+ Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
29
29
  Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
30
30
  Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
31
- Ops.WHERE: lambda d,a,b,c,dt,name:
32
- f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
31
+ Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
32
+ f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
33
33
  }
34
34
 
35
- supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
36
- doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
35
+ supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
36
+ doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
37
37
  ptx_matcher = PatternMatcher([
38
38
  # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
39
39
  (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
40
- (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
40
+ (UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
41
41
  # upcast to float32 all the ops that don't support half
42
42
  (UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
43
43
  lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
@@ -54,46 +54,46 @@ ptx_matcher = PatternMatcher([
54
54
  (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
55
55
  ])
56
56
 
57
- def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
57
+ def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
58
58
 
59
- def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
60
- gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
61
- return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
62
- if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
63
-
64
- def render_wmma(ctx: "PTXRenderer", x: UOp):
59
+ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
65
60
  assert ctx.wmma_r, "registry values for wmma must be populated"
66
- _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
67
- n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
68
- dt_map = { dtypes.half: "f16" }
69
- _i = 0
70
- for vv in x.src[:2]:
71
- for i in range(0, len(ctx.r[vv]), 2):
72
- yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
73
- _i += 1
74
- yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
75
- f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
76
- f'{{{", ".join(ctx.r[x.src[2]])}}};'
61
+ (N, M, K), dtype_in, dtype_out = wmma.arg[1], wmma.arg[2], wmma.arg[3]
62
+
63
+ for src, regs in zip(wmma.src, ctx.wmma_r):
64
+ for i, reg in enumerate(regs): # pack input and acc registers
65
+ if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
66
+ else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"
67
+
68
+ dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
69
+ yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
70
+ f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'
71
+
72
+ for i, reg in enumerate(ctx.wmma_r[2]): # unpack acc registers
73
+ if (elems_per_reg := 4 // dtype_out.itemsize) == 1: yield f"mov.b32 {ctx.r[wmma][i]}, {reg};"
74
+ else: yield f"mov.b32 {{{', '.join(ctx.r[wmma][i * elems_per_reg : (i+1) * elems_per_reg])}}}, {reg};"
77
75
 
78
76
  def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
79
77
  (a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
80
78
 
81
79
  string_rewrite = PatternMatcher([
82
- (UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
83
- (UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
84
- (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
80
+ (UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
81
+ (UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
82
+ (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \
83
+ f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
84
+ f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
85
85
  (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
86
86
  (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
87
- (UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
88
- lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
87
+ (UPat((Ops.CMPLT, Ops.CMPNE), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
88
+ lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
89
89
  (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
90
- (UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
91
- (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
90
+ (UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
91
+ (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"),)),
92
92
  lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
93
- (UPat(Ops.CAST, name="x", dtype=dtypes.bool),
94
- lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
95
- (UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
96
- lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
93
+ (UPat(Ops.CAST, name="x", dtype=dtypes.bool, src=(UPat.var("a"),)),
94
+ lambda ctx, x, a: f"setp.ne.b{ctx.types[a.dtype][1:]} {ctx.r[x]}, {ctx.r[a]}, {render_val(0, a.dtype)};"),
95
+ (UPat(Ops.CAST, name="x", src=(UPat.var("a"),)),
96
+ lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[a.dtype]} {ctx.r[x]}, {ctx.r[a]};"),
97
97
  (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
98
98
  [f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
99
99
  [f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
@@ -101,20 +101,11 @@ string_rewrite = PatternMatcher([
101
101
  f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
102
102
  f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
103
103
  (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
104
- lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
104
+ lambda ctx, x, loc: f"ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
105
105
  if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
106
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
107
- lambda ctx, x, pred: flatten([
108
- [f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
109
- f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
110
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
111
- lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
112
- f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
113
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
114
- f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
115
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
106
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred", dtype=dtypes.bool),), allow_any_len=True), lambda ctx, x, pred: [
116
107
  f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
117
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
108
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred"),), allow_any_len=True),
118
109
  lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
119
110
  (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
120
111
  (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
@@ -124,7 +115,7 @@ string_rewrite = PatternMatcher([
124
115
  ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
125
116
  f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
126
117
  (UPat(Ops.DEFINE_LOCAL, name="x"),
127
- lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
118
+ lambda ctx, x: [f".shared .align 4 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
128
119
  (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
129
120
  (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
130
121
  (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
@@ -136,11 +127,12 @@ class PTXRenderer(Renderer):
136
127
  device = "CUDA"
137
128
  suffix = "PTX"
138
129
  global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
139
- tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
130
+ tc_sm80 = [tc for tc in CUDARenderer.tc_sm80 if tc.dtype_in in [dtypes.half, dtypes.float]]
140
131
  code_for_op = asm_for_op
141
132
  extra_matcher = ptx_matcher
142
133
  def __init__(self, arch:str, device="CUDA"):
143
- self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
134
+ self.device, self.arch = device, arch
135
+ self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else []
144
136
  def __reduce__(self): return self.__class__, (self.arch, self.device)
145
137
 
146
138
  # language options
@@ -149,75 +141,67 @@ class PTXRenderer(Renderer):
149
141
  .address_size 64
150
142
  .visible .entry"""
151
143
  barrier = "bar.sync\t0;"
152
- supports_half = supports_half
153
144
  # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
154
- types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
145
+ types: dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
155
146
  dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
156
147
  dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
157
148
 
158
- mem_types: Dict[DType, str] = types.copy()
159
- mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
149
+ mem_types: dict[DType, str] = {**types, dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}
160
150
 
161
151
  def render_kernel(self, kernel, function_name, bufs, regs) -> str:
162
- kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
163
152
  def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
164
- return (f"{self.kernel_prefix} {function_name}(\n\t" +
165
- ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
166
- '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
167
- "\n}")
153
+ kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]))
154
+ params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
155
+ return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
168
156
 
169
- def render(self, name:str, uops:List[UOp]) -> str:
170
- kernel:List[str] = []
157
+ def render(self, uops:list[UOp]) -> str:
158
+ kernel:list[str] = []
171
159
  bufs = []
172
160
 
173
- c: DefaultDict[str, int] = defaultdict(int)
174
- r: Dict[UOp, Union[List[str], str]] = {}
161
+ c: defaultdict[str, int] = defaultdict(int)
162
+ r: dict[UOp, list[str]|str] = {}
175
163
  self.r = r
176
164
  self.uops = uops
177
165
 
178
- def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
166
+ def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
179
167
  nonlocal c, r
180
168
  prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
181
169
  c[prefix] += 1
182
170
  return f"%{prefix}{c[prefix]-1}"
183
171
 
172
+ name = "test"
184
173
  for u in uops:
174
+ if u.op is Ops.NAME:
175
+ name = u.arg
176
+ continue
185
177
  if u.op is Ops.VECTORIZE:
186
178
  r[u] = [cast(str,r[x]) for x in u.src]
187
179
  continue
188
180
  if u.op is Ops.GEP:
189
- assert len(u.arg) == 1
190
- r[u] = r[u.src[0]][u.arg[0]]
181
+ r[u] = r[u.src[0]][get_single_element(u.arg)]
182
+ continue
183
+ if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
184
+ r[u] = r[u.src[0]]
191
185
  continue
192
- if u.op in {Ops.CAST, Ops.BITCAST}:
193
- if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
194
- r[u] = r[u.src[0]]
195
- continue
196
- r[u] = ssa('cast', u, self.types[u.dtype])
197
- elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
198
- elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
199
- elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
200
- elif u.op is Ops.DEFINE_ACC:
201
- if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
202
- r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
203
- r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
204
- elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
205
- elif u.op is Ops.DEFINE_VAR:
206
- bufs.append((u.arg[0], u.dtype))
207
- r[u] = ssa("dat", u, self.types[u.dtype])
208
- elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
186
+ if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
187
+ elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
209
188
  elif u.op is Ops.LOAD:
210
189
  assert u.src[0].dtype == dtypes.int64, "load isn't int64"
211
190
  r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
212
- elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
213
- elif u.op is Ops.DEFINE_GLOBAL:
214
- bufs.append((f"data{u.arg}", u.dtype))
215
- r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
191
+ elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype))
216
192
  elif u.op is Ops.WMMA:
217
- self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
193
+ # registers for packing/unpacking input and acc
194
+ self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.arg[2].itemsize)],
195
+ [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.arg[2].itemsize)],
196
+ [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.arg[3].itemsize)]]
218
197
  r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
219
- if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
220
- raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
198
+ prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
199
+ Ops.DEFINE_ACC: ("acc", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
200
+ Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
201
+ if prefix: r[u] = ssa(prefix, u, dtype)
202
+
203
+ if (l:=cast(str|list[str], string_rewrite.rewrite(u, ctx=self))) is None:
204
+ raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
221
205
  kernel.extend([l] if isinstance(l, str) else l)
222
206
 
223
207
  if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
@@ -0,0 +1,95 @@
1
+ from tinygrad.dtype import DType, PtrDType, dtypes
2
+ from tinygrad.ops import UOp, Ops, PatternMatcher, UPat
3
+ from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
4
+ from tinygrad.helpers import strip_parens
5
+ import math
6
+
7
+ def sign_extend(val:UOp, sext_am:int):
8
+ return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
9
+ | val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
10
+
11
+ # store for char: buf[idx/4] <- (var << (idx%4)*8))
12
+ def packed_store(bidx:UOp, var:UOp):
13
+ shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
14
+ new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
15
+ mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
16
+ buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), dtype=dtypes.uint32)
17
+ return UOp.store(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], bidx.src[1]//(4//var.dtype.itemsize))), ((buf & mask) | new_v.cast(dtypes.uint32)))
18
+
19
+ # load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
20
+ def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
21
+ div_idx = bidx.src[1]//(4//dtype.itemsize)
22
+ shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
23
+ if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=dtypes.uint32, arg=root.arg)
24
+ else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
25
+ val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
26
+ return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
27
+
28
+ def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
29
+
30
+ wgsl_matcher = PatternMatcher([
31
+ (UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
32
+ lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
33
+ (UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if is_packed(l.dtype) else None),
34
+ (UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
35
+ lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
36
+ (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
37
+ # TODO: why is this needed, and only for this MUL order
38
+ (UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
39
+ lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
40
+ (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
41
+ ]) + extra_pm
42
+
43
+ class WGSLRenderer(CStyleLanguage):
44
+ device = "WEBGPU"
45
+ global_max = (65535, 65535, 65535)
46
+ local_max = (256, 256, 64)
47
+ code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[int(x)]})", "l": lambda x: f"i32(lindex.{'xyz'[int(x)]})"}
48
+ extra_matcher = wgsl_matcher
49
+ supports_float4 = False
50
+ barrier = "workgroupBarrier();"
51
+ code_for_op = {**CStyleLanguage.code_for_op, Ops.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a})"}
52
+ nan = "nan()"
53
+ type_map = { dtypes.float: "f32", dtypes.uchar: "u32", dtypes.ushort: "u32", dtypes.short: "i32",
54
+ dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
55
+
56
+ string_rewrite = PatternMatcher([
57
+ (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"),
58
+ (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
59
+ if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
60
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
61
+ (UPat(Ops.BITCAST, dtype=dtypes.half, name="x"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]" \
62
+ if x.src[0].dtype in [dtypes.short, dtypes.ushort, dtypes.uint32] else None),
63
+ (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
64
+ (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
65
+ if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
66
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
67
+ (UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
68
+ (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
69
+ (UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
70
+ (UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
71
+ # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
72
+ f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
73
+ else f"{ctx[b]} = {ctx[v]};"),
74
+ # fix nan check: 'a != a -> is_nan()'
75
+ (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
76
+ ]) + base_rewrite
77
+
78
+ def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
79
+ def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
80
+ def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if is_packed(dt) else x
81
+ def buf_map(self, dt:DType) -> str: return "atomic<u32>" if is_packed(dt) else self.type_map[dt.base]
82
+ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
83
+ local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
84
+ if not local_size: local_size = [1]
85
+ bind_it = iter(range(len(bufs)))
86
+ external_local_bufs = [line.lstrip() for line in kernel if "var<workgroup>" in line]
87
+ kernel[:] = [line for line in kernel if "var<workgroup>" not in line]
88
+ prg = "enable f16;\n" if any(uop.dtype.base == dtypes.half for uop in uops) else ""
89
+ prg += "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
90
+ prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
91
+ prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
92
+ f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
93
+ f"{name}:{f'array<{self.buf_map(dtype.base)}>' if isinstance(dtype,PtrDType) else self.buf_map(dtype)};" for name,(dtype,_) in bufs])
94
+ prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
95
+ return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"