tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,69 +1,77 @@
1
- from typing import Final, Dict, Callable, Any, List, Optional
2
- from llvmlite import ir
3
- from tinygrad.dtype import DType, PtrDType, dtypes
4
- from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
5
- from tinygrad.codegen.uops import UOps, UOp, UOpGraph
1
+ from typing import List, Dict, cast
2
+ import math, struct
6
3
  from tinygrad.renderer import Renderer
7
-
8
- MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
9
-
10
- def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
11
-
12
- code_for_op: Final[Dict[Op, Callable]] = {
13
- UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \
14
- (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)),
15
- UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
16
- UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
17
- UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
18
- UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
19
- UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
20
- BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
21
- BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
22
- BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
23
- BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
24
- BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
25
- BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
26
- BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
27
- BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
28
- TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
29
-
30
- dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
31
- dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
32
- dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
33
-
34
- def cast(bb, val, input_type, output_type, bitcast=False):
35
- if input_type == output_type: return val
36
- llvm_type = dtype_to_llvm_dtype[output_type]
37
- if bitcast: return bb[-1].bitcast(val, llvm_type)
38
-
39
- if input_type == dtypes.bfloat16:
40
- val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
41
- input_type = dtypes.float32
42
- if output_type == dtypes.bfloat16:
43
- val = cast(bb, val, input_type, dtypes.float32)
44
- return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
45
-
4
+ from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
5
+ from tinygrad.dtype import dtypes, DType, PtrDType, truncate
6
+
7
+ def ldt(dt:DType):
8
+ if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
9
+ return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
10
+ dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
11
+ dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
12
+
13
+ def lconst(x, dtype:DType):
14
+ if dtype in dtypes.floats:
15
+ if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
16
+ return truncate[dtype](x)
17
+ return int(x)
18
+
19
+ def lcast(input_type:DType, output_type:DType):
46
20
  if dtypes.is_float(input_type):
47
- if dtypes.is_float(output_type):
48
- return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
49
- if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
50
- if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
51
-
21
+ if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
22
+ if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
52
23
  if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
53
- if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
54
- if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
55
- if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
56
- if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
57
-
24
+ if dtypes.is_float(output_type): return 'uitofp'
25
+ if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
58
26
  if dtypes.is_int(input_type):
59
- if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
60
- if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
61
- if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
62
- if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
63
-
27
+ if dtypes.is_float(output_type): return 'sitofp'
28
+ if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
64
29
  raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
65
30
 
66
- def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
31
+ # llvm ops, lop[<dtype>][<op>]
32
+ unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
33
+ Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
34
+ signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
35
+ flags = " nsz arcp contract afn"
36
+ float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
37
+ lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
38
+
39
+ llvm_rewrite = PatternMatcher([
40
+ # memory load/store
41
+ (UPat(Ops.INDEX, name="x"), lambda ctx,x:
42
+ f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
43
+ (UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
44
+ f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
45
+ f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
46
+ f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
47
+ f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
48
+ f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
49
+ (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
50
+ (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
51
+
52
+ # unary/binary/ternary ops
53
+ (UPat(Ops.SQRT, name="x"), lambda ctx,x:
54
+ f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
55
+ (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
56
+ (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
57
+ (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
58
+ (UPat(Ops.WHERE, name="x"), lambda ctx,x:
59
+ f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
60
+
61
+ # range
62
+ (UPat(Ops.RANGE, name="x"), lambda ctx,x:
63
+ f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
64
+ f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
65
+ f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
66
+ (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
67
+ f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
68
+ f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
69
+ f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
70
+
71
+ # if
72
+ (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
73
+ (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
74
+ ])
67
75
 
68
76
  class LLVMRenderer(Renderer):
69
77
  device = "LLVM"
@@ -72,89 +80,63 @@ class LLVMRenderer(Renderer):
72
80
  has_shared = False
73
81
  global_max = None
74
82
 
75
- def render(self, name:str, uops:UOpGraph) -> str:
76
- # all llvm stuff goes into a module
77
- module = ir.Module(name=__file__)
78
-
79
- # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
80
- buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
81
- buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
82
-
83
- # create llvm function
84
- func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
85
- func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
86
- for a in func.args:
87
- if a.type.is_pointer: a.add_attribute("noalias")
88
-
89
- # add the function attribute "no-nans-fp-math"="true", which informs llvm that it allowed to use vectorization optimizations
90
- func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
91
- func.attributes.add('"no-nans-fp-math"="true"')
92
-
93
- bb = [ir.IRBuilder(func.append_basic_block("entry"))]
94
- loop_blocks: List = []
95
- reduce_phis: List = []
96
- # TODO: newvar probably shouldn't be optional
97
- lvars: Dict[Optional[UOp], Any] = {} # this Any is an llvm type
98
-
99
- for bufname,dtype in buf_to_dtype.items():
100
- if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
83
+ extra_matcher = PatternMatcher([
84
+ # rewrite RECIP with FDIV
85
+ (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
86
+ # rewrite cast to bool to CMPNE 0
87
+ (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
88
+ # *** also in cstyle ***
89
+ # gate any stores that aren't gated with ifs
90
+ (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
91
+ lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
92
+ # rewrite MAX to CMPLT + WHERE
93
+ (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
94
+ ])
95
+
96
+ def render(self, name: str, uops: List[UOp]) -> str:
97
+ r: Dict[UOp, str] = {}
98
+ args: List[str] = []
99
+ kernel: List[str] = []
100
+ end_lines: Dict[str, None] = {}
101
+ vc = -1
102
+
103
+ # prealloc all assigns
104
+ acc_to_assign: Dict[UOp, UOp] = {}
105
+ for u in uops:
106
+ if u.op is Ops.ASSIGN:
107
+ vc += 1
108
+ r[u] = r[u.src[1]] = f"%assign{vc}"
109
+ assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
110
+ acc_to_assign[u.src[0]] = u.src[1]
101
111
 
102
112
  for u in uops:
103
- uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
104
- if uop is UOps.STORE:
105
- element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
106
- if len(src) > 3:
107
- with bb[-1].if_then(lvars[src[3]]):
108
- bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
109
- else:
110
- bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
111
- elif uop is UOps.ENDRANGE:
112
- loop_entry_bb, phis = loop_blocks.pop()
113
- idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
114
- lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
115
- for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
116
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
117
- bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
113
+ # hack for defining sqrt function (TODO: can we get a transcendental for this?)
114
+ if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
115
+
116
+ if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
117
+ r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
118
+ args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
119
+ elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
120
+ elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
121
+ elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
122
+ elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
118
123
  else:
119
- assert dtype is not None, f"None dtype for uop {uop}"
120
- if uop is UOps.RANGE:
121
- bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
122
- bb[-2].branch(bb[-1].block)
123
-
124
- phis = []
125
- for rp in reduce_phis:
126
- incoming = lvars[rp]
127
- lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
128
- lvars[rp].add_incoming(incoming, bb[-2].block)
129
- phis.append((rp, lvars[rp]))
130
-
131
- lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
132
- lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
133
- loop_blocks.append((bb[-1].block, phis))
134
- elif uop is UOps.DEFINE_ACC:
135
- lvars[u] = const(src[0].arg, dtype)
136
- reduce_phis.append(u)
137
- elif uop is UOps.LOAD:
138
- if len(src) > 2:
139
- aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
140
- val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
141
- val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
142
- else:
143
- val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
144
- lvars[u] = val
145
- elif uop is UOps.PHI:
146
- lvars[u] = lvars[src[1]]
147
- # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
148
- backward = src[0]
149
- while backward.op is UOps.PHI: backward = backward.src[0]
150
- lvars[backward] = lvars[u]
151
- elif uop is UOps.ALU:
152
- lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
153
- elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
154
- elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
155
- elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
156
- elif uop is UOps.CONST: lvars[u] = const(args, dtype)
157
- else: raise RuntimeError(f"failed to render {uop}")
158
-
159
- bb[-1].ret_void()
160
- return str(module)
124
+ # if it's an assign target, it's already preallocated
125
+ if u not in r:
126
+ vc += 1
127
+ r[u] = f"%v{vc}"
128
+
129
+ # do the rendering of the llvm ir code
130
+ if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
131
+ kernel.append(cast(str, l))
132
+
133
+ # generate the phi nodes for the assigns
134
+ if u.op is Ops.RANGE:
135
+ for x in acc_to_assign:
136
+ if u in x.src: # if this range is relevent for this acc
137
+ vc += 1
138
+ kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
139
+ r[x] = f"%acc{vc}"
140
+
141
+ # output the function
142
+ return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
@@ -0,0 +1,225 @@
1
+ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
2
+ import struct
3
+ from collections import defaultdict
4
+ from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
5
+ from tinygrad.dtype import dtypes, DType, PtrDType
6
+ from tinygrad.renderer import Renderer
7
+ from tinygrad.renderer.cstyle import CUDARenderer
8
+ from tinygrad.helpers import prod, flatten
9
+
10
+ def render_val(x, dtype):
11
+ if dtypes.is_float(dtype):
12
+ if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
13
+ if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
14
+ return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
15
+ return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
16
+
17
+ asm_for_op: Dict[Ops, Callable] = {
18
+ Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
19
+ Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
20
+ Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
21
+ Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
22
+ Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
23
+ Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
24
+ Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
25
+ Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
26
+ Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
27
+ Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
28
+ Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
29
+ Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
30
+ Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
31
+ Ops.WHERE: lambda d,a,b,c,dt,name:
32
+ f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
33
+ }
34
+
35
+ supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
36
+ doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
37
+ ptx_matcher = PatternMatcher([
38
+ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
39
+ (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
40
+ (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
41
+ # upcast to float32 all the ops that don't support half
42
+ (UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
43
+ lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
44
+ # load/store bool -> uint8
45
+ (UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
46
+ lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
47
+ (UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
48
+ lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
49
+ # load/store use pointer arithmetic, and the cast does nothing
50
+ (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
51
+ (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
52
+ # ptx shr and shl instructions require y to be uint
53
+ (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
54
+ (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
55
+ ])
56
+
57
+ def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
58
+
59
+ def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
60
+ gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
61
+ return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
62
+ if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
63
+
64
+ def render_wmma(ctx: "PTXRenderer", x: UOp):
65
+ assert ctx.wmma_r, "registry values for wmma must be populated"
66
+ _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
67
+ n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
68
+ dt_map = { dtypes.half: "f16" }
69
+ _i = 0
70
+ for vv in x.src[:2]:
71
+ for i in range(0, len(ctx.r[vv]), 2):
72
+ yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
73
+ _i += 1
74
+ yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
75
+ f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
76
+ f'{{{", ".join(ctx.r[x.src[2]])}}};'
77
+
78
+ def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
79
+ (a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
80
+
81
+ string_rewrite = PatternMatcher([
82
+ (UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
83
+ (UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
84
+ (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
85
+ (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
86
+ (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
87
+ (UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
88
+ lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
89
+ (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
90
+ (UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
91
+ (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
92
+ lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
93
+ (UPat(Ops.CAST, name="x", dtype=dtypes.bool),
94
+ lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
95
+ (UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
96
+ lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
97
+ (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
98
+ [f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
99
+ [f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
100
+ ]) if alt.dtype.count > 1 else [
101
+ f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
102
+ f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
103
+ (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
104
+ lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
105
+ if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
106
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
107
+ lambda ctx, x, pred: flatten([
108
+ [f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
109
+ f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
110
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
111
+ lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
112
+ f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
113
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
114
+ f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
115
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
116
+ f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
117
+ (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
118
+ lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
119
+ (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
120
+ (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
121
+ (UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
122
+ (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
123
+ ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
124
+ ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
125
+ f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
126
+ (UPat(Ops.DEFINE_LOCAL, name="x"),
127
+ lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
128
+ (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
129
+ (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
130
+ (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
131
+ (UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
132
+ (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
133
+ ])
134
+
135
+ class PTXRenderer(Renderer):
136
+ device = "CUDA"
137
+ suffix = "PTX"
138
+ global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
139
+ tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
140
+ code_for_op = asm_for_op
141
+ extra_matcher = ptx_matcher
142
+ def __init__(self, arch:str, device="CUDA"):
143
+ self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
144
+ def __reduce__(self): return self.__class__, (self.arch, self.device)
145
+
146
+ # language options
147
+ kernel_prefix = """.version VERSION
148
+ .target TARGET
149
+ .address_size 64
150
+ .visible .entry"""
151
+ barrier = "bar.sync\t0;"
152
+ supports_half = supports_half
153
+ # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
154
+ types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
155
+ dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
156
+ dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
157
+
158
+ mem_types: Dict[DType, str] = types.copy()
159
+ mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
160
+
161
+ def render_kernel(self, kernel, function_name, bufs, regs) -> str:
162
+ kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
163
+ def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
164
+ return (f"{self.kernel_prefix} {function_name}(\n\t" +
165
+ ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
166
+ '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
167
+ "\n}")
168
+
169
+ def render(self, name:str, uops:List[UOp]) -> str:
170
+ kernel:List[str] = []
171
+ bufs = []
172
+
173
+ c: DefaultDict[str, int] = defaultdict(int)
174
+ r: Dict[UOp, Union[List[str], str]] = {}
175
+ self.r = r
176
+ self.uops = uops
177
+
178
+ def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
179
+ nonlocal c, r
180
+ prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
181
+ c[prefix] += 1
182
+ return f"%{prefix}{c[prefix]-1}"
183
+
184
+ for u in uops:
185
+ if u.op is Ops.VECTORIZE:
186
+ r[u] = [cast(str,r[x]) for x in u.src]
187
+ continue
188
+ if u.op is Ops.GEP:
189
+ assert len(u.arg) == 1
190
+ r[u] = r[u.src[0]][u.arg[0]]
191
+ continue
192
+ if u.op in {Ops.CAST, Ops.BITCAST}:
193
+ if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
194
+ r[u] = r[u.src[0]]
195
+ continue
196
+ r[u] = ssa('cast', u, self.types[u.dtype])
197
+ elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
198
+ elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
199
+ elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
200
+ elif u.op is Ops.DEFINE_ACC:
201
+ if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
202
+ r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
203
+ r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
204
+ elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
205
+ elif u.op is Ops.DEFINE_VAR:
206
+ bufs.append((u.arg[0], u.dtype))
207
+ r[u] = ssa("dat", u, self.types[u.dtype])
208
+ elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
209
+ elif u.op is Ops.LOAD:
210
+ assert u.src[0].dtype == dtypes.int64, "load isn't int64"
211
+ r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
212
+ elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
213
+ elif u.op is Ops.DEFINE_GLOBAL:
214
+ bufs.append((f"data{u.arg}", u.dtype))
215
+ r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
216
+ elif u.op is Ops.WMMA:
217
+ self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
218
+ r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
219
+ if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
220
+ raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
221
+ kernel.extend([l] if isinstance(l, str) else l)
222
+
223
+ if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
224
+ elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
225
+ return self.render_kernel(kernel, name, bufs, c.items())