tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,18 @@
1
1
  from typing import cast
2
2
  import math, struct, sys
3
+ from tinygrad.codegen.opt import tc
3
4
  from tinygrad.renderer import Renderer
4
- from tinygrad.renderer.cstyle import ClangRenderer
5
- from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
5
+ from tinygrad.renderer.cstyle import AMDRenderer
6
+ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
6
7
  from tinygrad.dtype import dtypes, DType, PtrDType, truncate
7
8
  from tinygrad.helpers import prod, AMX
8
9
 
9
10
  def ldt(dt:DType):
10
11
  if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
11
12
  if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
12
- return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
13
+ return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
13
14
  dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
14
- dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
15
+ dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
15
16
 
16
17
  def lconst(x, dtype:DType):
17
18
  if dtype in dtypes.floats:
@@ -32,7 +33,7 @@ def lcast(input_type:DType, output_type:DType):
32
33
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
33
34
 
34
35
  # https://github.com/corsix/amx
35
- def render_wmma(ctx, wmma: UOp) -> str:
36
+ def render_wmma_amx(ctx, wmma: UOp) -> str:
36
37
  def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1-((0$1>>4)*6))", "i,r,~{{memory}}"(i32 {op}, i64 {gpr}) #0; AMX'
37
38
 
38
39
  return "\n".join([
@@ -44,25 +45,40 @@ def render_wmma(ctx, wmma: UOp) -> str:
44
45
  f' call void asm sideeffect "nop\\0Anop\\0Anop\\0A.word ({0x201000 + (17 << 5) + 1})", "~{{memory}}"() #0; AMX clr', # clr
45
46
  f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
46
47
 
48
+ def render_wmma_amd(ctx, wmma: UOp, arch: str) -> str:
49
+ dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.bfloat16: "bf16", dtypes.ushort: "bf16"}
50
+ # https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
51
+ if arch.split(":")[0] in {"gfx942", "gfx950"}:
52
+ return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
53
+ f".16x16x16{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
54
+ # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
55
+ # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
56
+ return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
57
+ f"{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + (", i1 false)" \
58
+ if wmma.dtype.scalar() != dtypes.float else ")")
59
+
47
60
  # llvm ops, lop[<dtype>][<op>]
48
61
  unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
49
- Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
50
- signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
62
+ Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.CMPEQ: "icmp eq", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor",}
63
+ signed_lop = {**unsigned_lop, Ops.ADD: "add nsw", Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
51
64
  flags = " nsz arcp contract afn"
52
- float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
65
+ float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult",
66
+ Ops.CMPNE: f"fcmp{flags} une", Ops.CMPEQ: f"fcmp{flags} oeq", Ops.FDIV: "fdiv"+flags}
53
67
  lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
54
68
 
55
69
  base_rewrite = PatternMatcher([
56
70
  # memory load/store
57
71
  (UPat(Ops.INDEX, name="x"), lambda ctx,x:
58
72
  f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
59
- (UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
73
+ (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("mask"))).or_casted("idx"), UPat.var("alt")), name="x"),
74
+ lambda ctx,x,idx,alt,mask:
60
75
  f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
61
76
  f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
62
77
  f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
63
78
  f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
64
79
  f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
65
- (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
80
+ (UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
81
+ lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
66
82
  (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
67
83
 
68
84
  # GEP/VECTORIZE/CAST for float4 support
@@ -73,12 +89,11 @@ base_rewrite = PatternMatcher([
73
89
  (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: "\n".join([(f" {ctx[x]}_{i}" if i+1 != len(x.src) else f" {ctx[x]}")+
74
90
  f" = insertelement {ldt(x.dtype)} "+(f"{ctx[x]}_{i-1}" if i != 0 else "poison")+
75
91
  f", {ldt(u.dtype)} {ctx[u]}, i32 {i}" for i,u in enumerate(x.src)])),
76
- (UPat(Ops.CAST, name="x"), lambda ctx,x:
77
- f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
78
-
79
92
  # unary/binary/ternary ops
80
93
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
81
94
  (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
95
+ (UPat(Ops.TRUNC, name="x"),
96
+ lambda ctx,x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.trunc.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
82
97
  (UPat(GroupOp.Binary, name="x"), lambda ctx,x:
83
98
  f" {ctx[x]} = {lop[x.src[0].dtype.scalar()][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
84
99
  (UPat(Ops.WHERE, name="x"), lambda ctx,x:
@@ -88,33 +103,27 @@ base_rewrite = PatternMatcher([
88
103
  (UPat(Ops.RANGE, name="x"), lambda ctx,x:
89
104
  f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
90
105
  f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
91
- f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
106
+ f" {ctx[x]} = phi {ldt(x.dtype)} [ 0, %loop_entry_{x.arg} ], [ {ctx[x]}phi, %loop_latch_{x.arg} ]"),
92
107
  (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
93
108
  f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
94
- f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
109
+ f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
95
110
  f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
96
111
 
97
112
  # if
98
113
  (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
99
114
  (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
100
115
 
101
- # wmma
102
- (UPat(Ops.WMMA, name="wmma"), render_wmma),
116
+ (UPat(Ops.BARRIER), lambda ctx: "")
103
117
  ])
104
118
 
105
- def llvm_bf16_cast(buf:UOp, idx:UOp, root:UOp):
106
- u16_buf = buf.replace(dtype=dtypes.ushort.ptr(size=cast(PtrDType,buf.dtype).size))
107
- return UOp.load(UOp.index(u16_buf, idx), dtype=dtypes.ushort).cast(dtypes.uint).mul(1<<16).bitcast(dtypes.float32).cast(root.dtype)
108
-
109
119
  class LLVMRenderer(Renderer):
110
120
  device = "LLVM"
111
121
  abi = 'win64cc' if sys.platform == 'win32' else None
112
122
  supports_float4 = True
113
123
  has_local = False
114
- has_shared = False
115
- global_max = None
116
- string_rewrite = base_rewrite
117
- if AMX: tensor_cores = ClangRenderer.amx_tc
124
+ global_max: tuple[int, ...] | None = None
125
+ string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
126
+ if AMX: tensor_cores = tc.amx
118
127
 
119
128
  extra_matcher = PatternMatcher([
120
129
  # rewrite RECIP with FDIV
@@ -123,25 +132,30 @@ class LLVMRenderer(Renderer):
123
132
  (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
124
133
  # rewrite MAX to CMPLT + WHERE
125
134
  (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
126
- # rewrite bf16 CAST(LOAD) to CAST(BITCAST)
127
- (UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
135
+ # copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
136
+ (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
137
+ lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
138
+ # copied from cstyle.py, add float intermediate casting
139
+ (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
140
+ (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
128
141
  ])
129
142
 
130
- def render(self, uops: list[UOp]) -> str:
143
+ def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
144
+ def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
145
+ def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
146
+ # NOTE: CPUAllocator promises 0x20 alignment
147
+ sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
148
+ sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
149
+ return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
150
+ def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
131
151
  r: dict[UOp, str] = {}
132
- args: list[str] = []
152
+ args: list[tuple[str, DType]] = []
133
153
  kernel: list[str] = []
134
- end_lines: dict[str, None] = {}
135
154
  vc = -1
136
155
 
137
- acc_to_assign: dict[UOp, UOp] = {}
156
+ local_args: list[str] = []
138
157
  for u in uops:
139
- if u.op is Ops.ASSIGN: # prealloc all assigns
140
- vc += 1
141
- r[u] = r[u.src[1]] = f"%assign{vc}"
142
- assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
143
- acc_to_assign[u.src[0]] = u.src[1]
144
- if u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
158
+ if AMX and u.op is Ops.WMMA: # prealloc aux buffers as AMX can only load from memory
145
159
  vc += 1
146
160
  r[u] = f"%wmma{vc}"
147
161
  for i, dtype in enumerate(u.arg[2].vec(sz) for sz in [prod(size for _, size in upcast) for upcast in u.arg[6]]):
@@ -150,17 +164,24 @@ class LLVMRenderer(Renderer):
150
164
 
151
165
  name = "test"
152
166
  for u in uops:
153
- if u.op is Ops.NAME:
154
- name = u.arg
167
+ if u.op is Ops.NOOP: continue
168
+ if u.op is Ops.SINK:
169
+ if u.arg is not None: name = u.arg.function_name
155
170
  continue
156
171
  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
157
172
  r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
158
- # NOTE: MallocAllocator promises 0x20 alignment
159
- args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
160
- elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
161
- elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
173
+ args.append((r[u], u.dtype))
174
+ elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
175
+ r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
176
+ assert isinstance(u.dtype, PtrDType)
177
+ if self.device == "LLVM" or u.op is Ops.DEFINE_REG:
178
+ kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
179
+ else:
180
+ local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
181
+ kernel.append(f" {r[u]} = addrspacecast [{u.dtype.size} x {ldt(u.dtype)}] addrspace(3)* @{r[u][1:]} to [{u.dtype.size} x {ldt(u.dtype)}]*")
162
182
  elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
163
- elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
183
+ elif u.op is Ops.CAST and (ldt(u.dtype) == ldt(u.src[0].dtype) or isinstance(u.dtype, PtrDType)):
184
+ r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop, or pointer cast
164
185
  else:
165
186
  # if it's an assign target, it's already preallocated
166
187
  if u not in r:
@@ -171,21 +192,51 @@ class LLVMRenderer(Renderer):
171
192
  if (l:=self.string_rewrite.rewrite(u, ctx=r)) is None:
172
193
  raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
173
194
  kernel.append(cast(str, l))
174
-
175
- # generate the phi nodes for the assigns
176
- if u.op is Ops.RANGE:
177
- for x in acc_to_assign:
178
- if u in x.src: # if this range is relevant for this acc
179
- vc += 1
180
- kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
181
- r[x] = f"%acc{vc}"
182
-
183
- # output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
184
- return f'''\
185
- define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
186
- {chr(10).join(kernel)}
187
- ret void
188
- }}
189
- {chr(10).join(end_lines.keys())}
190
- attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
191
- '''
195
+ return tuple(local_args), self._render_fn(name, args, kernel, prefix)
196
+
197
+ barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
198
+ code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
199
+ "l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
200
+ class AMDLLVMRenderer(LLVMRenderer):
201
+ device = "AMD"
202
+ has_local = True
203
+ shared_max = AMDRenderer.shared_max
204
+ global_max = AMDRenderer.global_max
205
+ abi = "amdgpu_kernel"
206
+ string_rewrite = PatternMatcher([
207
+ (UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
208
+ (UPat(Ops.BARRIER), lambda ctx: barrier),
209
+ ]) + base_rewrite
210
+ extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([
211
+ (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
212
+ lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
213
+ (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
214
+ lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(8), tuple(y.gep(i * 2) for i in range(8)))),
215
+ ])
216
+ def _render_footer(self, uops: list[UOp]) -> str:
217
+ # TODO: this is copied from cstyle
218
+ requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
219
+ attributes = ["alwaysinline", "nounwind", '"no-builtins"',
220
+ f'"amdgpu-flat-work-group-size"="1,{requiredMaxThreadsPerBlock}"', '"no-trapping-math"="true"']
221
+ return 'attributes #0 = { ' + ' '.join(attributes) + ' }'
222
+ def __init__(self, arch:str):
223
+ self.arch = arch
224
+ self.tensor_cores = AMDRenderer.get_tensor_cores(arch)
225
+ self.string_rewrite += PatternMatcher([(UPat(Ops.WMMA, name="wmma"), lambda ctx, wmma, arch=arch: render_wmma_amd(ctx, wmma, arch))])
226
+ if self.arch.split(":")[0] == "gfx1100":
227
+ self.extra_matcher += PatternMatcher([
228
+ (UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
229
+ lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),
230
+ (UPat(Ops.WMMA, name="x"), lambda x: UOp(Ops.WMMA, x.dtype, (x.src[0].bitcast(dtypes.uint16.vec(16)), x.src[1].bitcast(dtypes.uint16.vec(16)),
231
+ x.src[2]), x.arg) if x.src[0].dtype == dtypes.bfloat16.vec(16) else None),
232
+ ])
233
+ if self.arch.split(":")[0] == "gfx1201":
234
+ self.extra_matcher += PatternMatcher([
235
+ (UPat(Ops.WMMA, name="x", dtype=dtypes.bfloat16.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.uint16.vec(8),
236
+ (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)), x.src[2].bitcast(dtypes.uint16.vec(8))), (*x.arg,))
237
+ .bitcast(dtypes.bfloat16.vec(8)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None),
238
+ (UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(8)),
239
+ lambda x: UOp(Ops.WMMA, dtypes.float.vec(8), (x.src[0].bitcast(dtypes.uint16.vec(8)), x.src[1].bitcast(dtypes.uint16.vec(8)),
240
+ x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(8) else None)
241
+ ])
242
+ def __reduce__(self): return self.__class__, (self.arch,)
tinygrad/renderer/ptx.py CHANGED
@@ -1,11 +1,12 @@
1
1
  from typing import cast, Callable
2
2
  import struct
3
3
  from collections import defaultdict
4
- from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
5
- from tinygrad.dtype import dtypes, DType, PtrDType
4
+ from tinygrad.codegen.opt import tc
5
+ from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
6
+ from tinygrad.dtype import dtypes, DType, PtrDType, AddrSpace
6
7
  from tinygrad.renderer import Renderer
7
8
  from tinygrad.renderer.cstyle import CUDARenderer
8
- from tinygrad.helpers import flatten, get_single_element
9
+ from tinygrad.helpers import flatten, get_single_element, prod
9
10
 
10
11
  def render_val(x, dtype):
11
12
  if dtypes.is_float(dtype):
@@ -18,6 +19,7 @@ asm_for_op: dict[Ops, Callable] = {
18
19
  Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
19
20
  Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
20
21
  Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
22
+ Ops.TRUNC: lambda d,a,dt,name: f"cvt.rzi.{name}.{name} {d}, {a};",
21
23
  Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
22
24
  Ops.ADD: lambda d,a,b,dt,name: f"{'or' if dt == dtypes.bool else 'add'}.{name} {d}, {a}, {b};",
23
25
  Ops.MUL: lambda d,a,b,dt,name: f"{'and' if dt == dtypes.bool else 'mul'}{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
@@ -25,18 +27,19 @@ asm_for_op: dict[Ops, Callable] = {
25
27
  Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if dt == dtypes.bool else f"and.b{name[1:]} {d}, {a}, {b};",
26
28
  Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if dt == dtypes.bool else f"or.b{name[1:]} {d}, {a}, {b};",
27
29
  Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
28
- Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};",
30
+ Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
29
31
  Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
30
32
  Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
31
33
  Ops.WHERE: lambda d,a,b,c,dt,name: [f"@{a} mov.{name} {d}, {b};", f"@!{a} mov.{name} {d}, {c};"] if dt == dtypes.bool else \
32
34
  f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
33
35
  }
34
36
 
35
- supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE)
37
+ supports_half = (Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE, Ops.TRUNC)
36
38
  doesnt_support_half: tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
37
39
  ptx_matcher = PatternMatcher([
38
40
  # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
39
41
  (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
42
+ (UPat.var('x', dtype=dtypes.bool).alu(Ops.CMPEQ, UPat.var('y')), lambda x,y: (x^y)^True),
40
43
  (UPat.var('x', dtype=dtypes.bool)<UPat.var('y'), lambda x,y: (x^True)&y),
41
44
  # upcast to float32 all the ops that don't support half
42
45
  (UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
@@ -47,14 +50,20 @@ ptx_matcher = PatternMatcher([
47
50
  (UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
48
51
  lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
49
52
  # load/store use pointer arithmetic, and the cast does nothing
50
- (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
51
- (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
53
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))),
54
+ lambda buf,idx: (buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize) if buf.dtype.addrspace != AddrSpace.REG else None),
55
+ (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) or x.src[0].dtype == dtypes.void else None),
56
+ # move mask from INDEX to the load/store to enable pointer arithmetic
57
+ (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))),
58
+ lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))),
59
+ (UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate")), allow_any_len=True),
60
+ lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
52
61
  # ptx shr and shl instructions require y to be uint
53
62
  (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
54
63
  (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
55
64
  ])
56
65
 
57
- def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
66
+ def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort()) else 'global'
58
67
 
59
68
  def render_wmma(ctx: "PTXRenderer", wmma: UOp):
60
69
  assert ctx.wmma_r, "registry values for wmma must be populated"
@@ -84,7 +93,7 @@ string_rewrite = PatternMatcher([
84
93
  f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
85
94
  (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
86
95
  (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
87
- (UPat((Ops.CMPLT, Ops.CMPNE), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
96
+ (UPat((Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ), name="x", allow_any_len=True, src=(UPat.var("src0"),)),
88
97
  lambda ctx, x, src0: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], src0.dtype, ctx.types[src0.dtype])),
89
98
  (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
90
99
  (UPat(Ops.BITCAST, name="x", src=(UPat.var("a"),), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
@@ -103,19 +112,14 @@ string_rewrite = PatternMatcher([
103
112
  (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
104
113
  lambda ctx, x, loc: f"ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
105
114
  if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
106
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred", dtype=dtypes.bool),), allow_any_len=True), lambda ctx, x, pred: [
107
- f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
108
- (UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred"),), allow_any_len=True),
109
- lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
110
- (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
111
- (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
112
- (UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
115
+ (UPat(Ops.DEFINE_REG, src=()), lambda ctx: []),
116
+ (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
113
117
  (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
114
118
  ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
115
- ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
119
+ ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
116
120
  f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
117
121
  (UPat(Ops.DEFINE_LOCAL, name="x"),
118
- lambda ctx, x: [f".shared .align 4 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
122
+ lambda ctx, x: [f".shared .align 16 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),
119
123
  (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
120
124
  (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
121
125
  (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
@@ -127,12 +131,12 @@ class PTXRenderer(Renderer):
127
131
  device = "CUDA"
128
132
  suffix = "PTX"
129
133
  global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
130
- tc_sm80 = [tc for tc in CUDARenderer.tc_sm80 if tc.dtype_in in [dtypes.half, dtypes.float]]
134
+ tc_sm80 = [x for x in tc.cuda_sm80 if x.dtype_in in [dtypes.half, dtypes.float]]
131
135
  code_for_op = asm_for_op
132
136
  extra_matcher = ptx_matcher
133
137
  def __init__(self, arch:str, device="CUDA"):
134
138
  self.device, self.arch = device, arch
135
- self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else []
139
+ self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else []
136
140
  def __reduce__(self): return self.__class__, (self.arch, self.device)
137
141
 
138
142
  # language options
@@ -148,11 +152,12 @@ class PTXRenderer(Renderer):
148
152
 
149
153
  mem_types: dict[DType, str] = {**types, dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}
150
154
 
151
- def render_kernel(self, kernel, function_name, bufs, regs) -> str:
155
+ def render_kernel(self, kernel, function_name, bufs, regs, uops) -> str:
152
156
  def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
153
157
  kernel = '\n'.join(map(fmt, [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]))
158
+ launch_bounds = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
154
159
  params = ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs])
155
- return f"{self.kernel_prefix} {function_name}(\n\t{params}\n)\n{{\n{kernel}\n}}"
160
+ return f"{self.kernel_prefix.format(launch_bounds=launch_bounds)} {function_name} (\n\t{params}\n)\n.maxntid {launch_bounds}\n{{\n{kernel}\n}}"
156
161
 
157
162
  def render(self, uops:list[UOp]) -> str:
158
163
  kernel:list[str] = []
@@ -165,14 +170,15 @@ class PTXRenderer(Renderer):
165
170
 
166
171
  def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
167
172
  nonlocal c, r
168
- prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
173
+ prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype.base]}_"
169
174
  c[prefix] += 1
170
175
  return f"%{prefix}{c[prefix]-1}"
171
176
 
172
177
  name = "test"
173
178
  for u in uops:
174
- if u.op is Ops.NAME:
175
- name = u.arg
179
+ if u.op is Ops.NOOP: continue
180
+ if u.op is Ops.SINK:
181
+ if u.arg is not None: name = u.arg.function_name
176
182
  continue
177
183
  if u.op is Ops.VECTORIZE:
178
184
  r[u] = [cast(str,r[x]) for x in u.src]
@@ -183,6 +189,19 @@ class PTXRenderer(Renderer):
183
189
  if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
184
190
  r[u] = r[u.src[0]]
185
191
  continue
192
+ if u.op is Ops.DEFINE_REG:
193
+ r[u] = [ssa("reg", u, self.types[u.dtype.base.scalar()]) for _ in range(cast(PtrDType, u.dtype).size)]
194
+ continue
195
+ if u.op in {Ops.INDEX, Ops.LOAD, Ops.STORE} and isinstance(u.src[0].dtype, PtrDType) and u.src[0].dtype.addrspace == AddrSpace.REG:
196
+ if u.op is Ops.INDEX:
197
+ assert u.src[1].op == Ops.CONST, f"index on REG in ptx only supported on CONST, not {u.src[1].op}"
198
+ r[u] = r[u.src[0]][u.src[1].arg]
199
+ else:
200
+ r[u] = r[u.src[0]]
201
+ if u.op is Ops.STORE:
202
+ typ = "pred" if u.src[1].dtype == dtypes.bool else ("b"+self.types[u.src[1].dtype][1:])
203
+ kernel.append(f"mov.{typ} {self.r[u.src[0]]}, {self.r[u.src[1]]};")
204
+ continue
186
205
  if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
187
206
  elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
188
207
  elif u.op is Ops.LOAD:
@@ -191,12 +210,12 @@ class PTXRenderer(Renderer):
191
210
  elif u.op is Ops.DEFINE_GLOBAL: bufs.append((f"data{u.arg}", u.dtype))
192
211
  elif u.op is Ops.WMMA:
193
212
  # registers for packing/unpacking input and acc
194
- self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.arg[2].itemsize)],
195
- [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.arg[2].itemsize)],
196
- [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.arg[3].itemsize)]]
213
+ self.wmma_r = [[ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[0]]), 4 // u.src[0].dtype.scalar().itemsize)],
214
+ [ssa("wmma_in", dtype="b32") for _ in range(0, len(r[u.src[1]]), 4 // u.src[0].dtype.scalar().itemsize)],
215
+ [ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
197
216
  r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
198
217
  prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.ENDRANGE: ("pred", "pred"), Ops.RANGE: ("ridx", None),
199
- Ops.DEFINE_ACC: ("acc", None), Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
218
+ Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL:("local",self.types[dtypes.ulong]),
200
219
  Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
201
220
  if prefix: r[u] = ssa(prefix, u, dtype)
202
221
 
@@ -204,6 +223,5 @@ class PTXRenderer(Renderer):
204
223
  raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
205
224
  kernel.extend([l] if isinstance(l, str) else l)
206
225
 
207
- if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
208
- elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
209
- return self.render_kernel(kernel, name, bufs, c.items())
226
+ if u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
227
+ return self.render_kernel(kernel, name, bufs, c.items(), uops)
tinygrad/renderer/wgsl.py CHANGED
@@ -1,8 +1,7 @@
1
- from tinygrad.dtype import DType, PtrDType, dtypes
2
- from tinygrad.ops import UOp, Ops, PatternMatcher, UPat
1
+ from tinygrad.dtype import DType, PtrDType, dtypes, AddrSpace
2
+ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
3
3
  from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
4
4
  from tinygrad.helpers import strip_parens
5
- import math
6
5
 
7
6
  def sign_extend(val:UOp, sext_am:int):
8
7
  return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
@@ -20,24 +19,25 @@ def packed_store(bidx:UOp, var:UOp):
20
19
  def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
21
20
  div_idx = bidx.src[1]//(4//dtype.itemsize)
22
21
  shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
23
- if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=dtypes.uint32, arg=root.arg)
22
+ if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg)
24
23
  else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
25
24
  val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
26
25
  return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
27
26
 
28
- def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.half
27
+ def is_packed(dt:DType, odt:DType|None = None) -> bool:
28
+ if odt is None: odt = dt
29
+ return dt.itemsize < 4 and dt.base != dtypes.half and (not isinstance(odt, PtrDType) or odt.addrspace != AddrSpace.REG)
29
30
 
30
31
  wgsl_matcher = PatternMatcher([
31
32
  (UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
32
33
  lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
33
- (UPat(Ops.LOAD, name="l", src=(UPat.var('b'),)), lambda l,b: packed_load(l,b,l.dtype) if is_packed(l.dtype) else None),
34
- (UPat(Ops.LOAD, name="l", src=(UPat.var('b'), UPat.var('c'), UPat())),
35
- lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
36
- (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
37
- # TODO: why is this needed, and only for this MUL order
38
- (UPat(Ops.MUL, src=(UPat.var("a"), UPat.var("g").where(UPat.cvar("c1"), UPat.cvar("c2")))),
39
- lambda a,g,c1,c2: g.where(c1, a) if math.isnan(c1.arg) and c2.arg == 1.0 else None),
40
- (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None)
34
+ (UPat.load(UPat.var("b"), UPat.cvar("c"), name="l"),
35
+ lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype, b.dtype) else None),
36
+ (UPat.load(UPat.var("b"), name='l', allow_any_len=True), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype, b.dtype) else None),
37
+ (UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True),
38
+ lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype, bidx.dtype) else None),
39
+ (UPat.var("a") << UPat.var("b"),lambda a,b:(a.bitcast(dtypes.uint32)<<b.cast(dtypes.uint32)).bitcast(a.dtype) if b.dtype!=dtypes.uint32 else None),
40
+ (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
41
41
  ]) + extra_pm
42
42
 
43
43
  class WGSLRenderer(CStyleLanguage):
@@ -54,23 +54,27 @@ class WGSLRenderer(CStyleLanguage):
54
54
  dtypes.char: "i32", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", dtypes.half: "f16" }
55
55
 
56
56
  string_rewrite = PatternMatcher([
57
- (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "true" if x.arg else "false"),
58
- (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
59
- if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
60
- (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)}, {x.dtype.size}>;"),
61
- (UPat(Ops.BITCAST, dtype=dtypes.half, name="x"), lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]" \
62
- if x.src[0].dtype in [dtypes.short, dtypes.ushort, dtypes.uint32] else None),
57
+ (UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
58
+ (UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
59
+ lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
60
+ (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x:
61
+ f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
62
+ (UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:
63
+ f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
64
+ (UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
65
+ lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
63
66
  (UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
64
67
  (UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
65
68
  if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
66
69
  (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
67
- (UPat.load(UPat.var("b"),UPat.var("v"),UPat.var("g")),lambda ctx,b,v,g:f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[g]})"),
68
- (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
69
- (UPat.index(UPat.var("b"), UPat.var("idx")), lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
70
- (UPat.store(UPat.var('b'), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
70
+ (UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
71
+ (UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
72
+ (UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
71
73
  # (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
72
74
  f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
73
75
  else f"{ctx[b]} = {ctx[v]};"),
76
+ (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
77
+ lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
74
78
  # fix nan check: 'a != a -> is_nan()'
75
79
  (UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),
76
80
  ]) + base_rewrite