tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
    
        tinygrad/renderer/llvmir.py
    CHANGED
    
    | @@ -1,69 +1,77 @@ | |
| 1 | 
            -
            from typing import  | 
| 2 | 
            -
             | 
| 3 | 
            -
            from tinygrad.dtype import DType, PtrDType, dtypes
         | 
| 4 | 
            -
            from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
         | 
| 5 | 
            -
            from tinygrad.codegen.uops import UOps, UOp, UOpGraph
         | 
| 1 | 
            +
            from typing import List, Dict, cast
         | 
| 2 | 
            +
            import math, struct
         | 
| 6 3 | 
             
            from tinygrad.renderer import Renderer
         | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def  | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
               | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
               | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
              BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS),  # noqa: E501
         | 
| 24 | 
            -
              BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS),  # noqa: E501
         | 
| 25 | 
            -
              BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y),  # noqa: E501
         | 
| 26 | 
            -
              BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y),  # noqa: E501
         | 
| 27 | 
            -
              BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y),
         | 
| 28 | 
            -
              TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
         | 
| 31 | 
            -
              dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
         | 
| 32 | 
            -
              dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
         | 
| 33 | 
            -
             | 
| 34 | 
            -
            def cast(bb, val, input_type, output_type, bitcast=False):
         | 
| 35 | 
            -
              if input_type == output_type: return val
         | 
| 36 | 
            -
              llvm_type = dtype_to_llvm_dtype[output_type]
         | 
| 37 | 
            -
              if bitcast: return bb[-1].bitcast(val, llvm_type)
         | 
| 38 | 
            -
             | 
| 39 | 
            -
              if input_type == dtypes.bfloat16:
         | 
| 40 | 
            -
                val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
         | 
| 41 | 
            -
                input_type = dtypes.float32
         | 
| 42 | 
            -
              if output_type == dtypes.bfloat16:
         | 
| 43 | 
            -
                val = cast(bb, val, input_type, dtypes.float32)
         | 
| 44 | 
            -
                return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
         | 
| 45 | 
            -
             | 
| 4 | 
            +
            from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
         | 
| 5 | 
            +
            from tinygrad.dtype import dtypes, DType, PtrDType, truncate
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def ldt(dt:DType):
         | 
| 8 | 
            +
              if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
         | 
| 9 | 
            +
              return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
         | 
| 10 | 
            +
                      dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
         | 
| 11 | 
            +
                      dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def lconst(x, dtype:DType):
         | 
| 14 | 
            +
              if dtype in dtypes.floats:
         | 
| 15 | 
            +
                if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
         | 
| 16 | 
            +
                return truncate[dtype](x)
         | 
| 17 | 
            +
              return int(x)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def lcast(input_type:DType, output_type:DType):
         | 
| 46 20 | 
             
              if dtypes.is_float(input_type):
         | 
| 47 | 
            -
                if dtypes.is_float(output_type):
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
         | 
| 50 | 
            -
                if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
         | 
| 51 | 
            -
             | 
| 21 | 
            +
                if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
         | 
| 22 | 
            +
                if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
         | 
| 52 23 | 
             
              if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
         | 
| 53 | 
            -
                if  | 
| 54 | 
            -
                if dtypes. | 
| 55 | 
            -
                if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
         | 
| 56 | 
            -
                if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
         | 
| 57 | 
            -
             | 
| 24 | 
            +
                if dtypes.is_float(output_type): return 'uitofp'
         | 
| 25 | 
            +
                if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
         | 
| 58 26 | 
             
              if dtypes.is_int(input_type):
         | 
| 59 | 
            -
                if  | 
| 60 | 
            -
                if dtypes. | 
| 61 | 
            -
                if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
         | 
| 62 | 
            -
                if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
         | 
| 63 | 
            -
             | 
| 27 | 
            +
                if dtypes.is_float(output_type): return 'sitofp'
         | 
| 28 | 
            +
                if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
         | 
| 64 29 | 
             
              raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
         | 
| 65 30 |  | 
| 66 | 
            -
             | 
| 31 | 
            +
            # llvm ops, lop[<dtype>][<op>]
         | 
| 32 | 
            +
            unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
         | 
| 33 | 
            +
                             Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
         | 
| 34 | 
            +
            signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
         | 
| 35 | 
            +
            flags = " nsz arcp contract afn"
         | 
| 36 | 
            +
            float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
         | 
| 37 | 
            +
            lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            llvm_rewrite = PatternMatcher([
         | 
| 40 | 
            +
              # memory load/store
         | 
| 41 | 
            +
              (UPat(Ops.INDEX, name="x"), lambda ctx,x:
         | 
| 42 | 
            +
               f"  {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
         | 
| 43 | 
            +
              (UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
         | 
| 44 | 
            +
               f"  br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
         | 
| 45 | 
            +
               f"  br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
         | 
| 46 | 
            +
               f"  {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
         | 
| 47 | 
            +
               f"  br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
         | 
| 48 | 
            +
               f"  {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
         | 
| 49 | 
            +
              (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f"  {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
         | 
| 50 | 
            +
              (UPat(Ops.STORE, name="x"), lambda ctx,x: f"  store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
         | 
| 51 | 
            +
             | 
| 52 | 
            +
              # unary/binary/ternary ops
         | 
| 53 | 
            +
              (UPat(Ops.SQRT, name="x"), lambda ctx,x:
         | 
| 54 | 
            +
               f"  {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
         | 
| 55 | 
            +
              (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"  {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
         | 
| 56 | 
            +
              (UPat(Ops.CAST, name="x"), lambda ctx,x: f"  {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
         | 
| 57 | 
            +
              (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f"  {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
         | 
| 58 | 
            +
              (UPat(Ops.WHERE, name="x"), lambda ctx,x:
         | 
| 59 | 
            +
               f"  {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
         | 
| 60 | 
            +
             | 
| 61 | 
            +
              # range
         | 
| 62 | 
            +
              (UPat(Ops.RANGE, name="x"), lambda ctx,x:
         | 
| 63 | 
            +
               f"  br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
         | 
| 64 | 
            +
               f"  br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
         | 
| 65 | 
            +
               f"  {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
         | 
| 66 | 
            +
              (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
         | 
| 67 | 
            +
               f"  br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
         | 
| 68 | 
            +
               f"  {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n  {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
         | 
| 69 | 
            +
               f"  br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
         | 
| 70 | 
            +
             | 
| 71 | 
            +
              # if
         | 
| 72 | 
            +
              (UPat(Ops.IF, name="x"), lambda ctx,x: f"  br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
         | 
| 73 | 
            +
              (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f"  br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
         | 
| 74 | 
            +
            ])
         | 
| 67 75 |  | 
| 68 76 | 
             
            class LLVMRenderer(Renderer):
         | 
| 69 77 | 
             
              device = "LLVM"
         | 
| @@ -72,89 +80,63 @@ class LLVMRenderer(Renderer): | |
| 72 80 | 
             
              has_shared = False
         | 
| 73 81 | 
             
              global_max = None
         | 
| 74 82 |  | 
| 75 | 
            -
               | 
| 76 | 
            -
                #  | 
| 77 | 
            -
                 | 
| 78 | 
            -
             | 
| 79 | 
            -
                 | 
| 80 | 
            -
                 | 
| 81 | 
            -
                 | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
                 | 
| 85 | 
            -
                 | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
                 | 
| 90 | 
            -
                 | 
| 91 | 
            -
                 | 
| 92 | 
            -
             | 
| 93 | 
            -
                 | 
| 94 | 
            -
             | 
| 95 | 
            -
                 | 
| 96 | 
            -
                 | 
| 97 | 
            -
                 | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 83 | 
            +
              extra_matcher = PatternMatcher([
         | 
| 84 | 
            +
                # rewrite RECIP with FDIV
         | 
| 85 | 
            +
                (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
         | 
| 86 | 
            +
                # rewrite cast to bool to CMPNE 0
         | 
| 87 | 
            +
                (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
         | 
| 88 | 
            +
                # *** also in cstyle ***
         | 
| 89 | 
            +
                # gate any stores that aren't gated with ifs
         | 
| 90 | 
            +
                (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
         | 
| 91 | 
            +
                  lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
         | 
| 92 | 
            +
                # rewrite MAX to CMPLT + WHERE
         | 
| 93 | 
            +
                (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
         | 
| 94 | 
            +
              ])
         | 
| 95 | 
            +
             | 
| 96 | 
            +
              def render(self, name: str, uops: List[UOp]) -> str:
         | 
| 97 | 
            +
                r: Dict[UOp, str] = {}
         | 
| 98 | 
            +
                args: List[str] = []
         | 
| 99 | 
            +
                kernel: List[str] = []
         | 
| 100 | 
            +
                end_lines: Dict[str, None] = {}
         | 
| 101 | 
            +
                vc = -1
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # prealloc all assigns
         | 
| 104 | 
            +
                acc_to_assign: Dict[UOp, UOp] = {}
         | 
| 105 | 
            +
                for u in uops:
         | 
| 106 | 
            +
                  if u.op is Ops.ASSIGN:
         | 
| 107 | 
            +
                    vc += 1
         | 
| 108 | 
            +
                    r[u] = r[u.src[1]] = f"%assign{vc}"
         | 
| 109 | 
            +
                    assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
         | 
| 110 | 
            +
                    acc_to_assign[u.src[0]] = u.src[1]
         | 
| 101 111 |  | 
| 102 112 | 
             
                for u in uops:
         | 
| 103 | 
            -
                   | 
| 104 | 
            -
                  if  | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
                  elif  | 
| 112 | 
            -
             | 
| 113 | 
            -
                    idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
         | 
| 114 | 
            -
                    lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
         | 
| 115 | 
            -
                    for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
         | 
| 116 | 
            -
                    bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
         | 
| 117 | 
            -
                    bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
         | 
| 113 | 
            +
                  # hack for defining sqrt function (TODO: can we get a transcendental for this?)
         | 
| 114 | 
            +
                  if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                  if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
         | 
| 117 | 
            +
                    r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
         | 
| 118 | 
            +
                    args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
         | 
| 119 | 
            +
                  elif u.op is Ops.ASSIGN: pass  # assign is already handled by the first pass
         | 
| 120 | 
            +
                  elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]]  # a define acc can be used and never be assigned to
         | 
| 121 | 
            +
                  elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
         | 
| 122 | 
            +
                  elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
         | 
| 118 123 | 
             
                  else:
         | 
| 119 | 
            -
                     | 
| 120 | 
            -
                    if  | 
| 121 | 
            -
                       | 
| 122 | 
            -
                       | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
                      if len(src) > 2:
         | 
| 139 | 
            -
                        aug_idx = bb[-1].select(lvars[src[2]], lvars[src[1]], ir.Constant(ir.IntType(32), 0))
         | 
| 140 | 
            -
                        val = bb[-1].load(bb[-1].gep(lvars[src[0]], [aug_idx], inbounds=True))
         | 
| 141 | 
            -
                        val = bb[-1].select(lvars[src[2]], val, lvars[src[3]])
         | 
| 142 | 
            -
                      else:
         | 
| 143 | 
            -
                        val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
         | 
| 144 | 
            -
                      lvars[u] = val
         | 
| 145 | 
            -
                    elif uop is UOps.PHI:
         | 
| 146 | 
            -
                      lvars[u] = lvars[src[1]]
         | 
| 147 | 
            -
                      # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
         | 
| 148 | 
            -
                      backward = src[0]
         | 
| 149 | 
            -
                      while backward.op is UOps.PHI: backward = backward.src[0]
         | 
| 150 | 
            -
                      lvars[backward] = lvars[u]
         | 
| 151 | 
            -
                    elif uop is UOps.ALU:
         | 
| 152 | 
            -
                      lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in src], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else src[0].dtype)
         | 
| 153 | 
            -
                    elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
         | 
| 154 | 
            -
                    elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
         | 
| 155 | 
            -
                    elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
         | 
| 156 | 
            -
                    elif uop is UOps.CONST: lvars[u] = const(args, dtype)
         | 
| 157 | 
            -
                    else: raise RuntimeError(f"failed to render {uop}")
         | 
| 158 | 
            -
             | 
| 159 | 
            -
                bb[-1].ret_void()
         | 
| 160 | 
            -
                return str(module)
         | 
| 124 | 
            +
                    # if it's an assign target, it's already preallocated
         | 
| 125 | 
            +
                    if u not in r:
         | 
| 126 | 
            +
                      vc += 1
         | 
| 127 | 
            +
                      r[u] = f"%v{vc}"
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # do the rendering of the llvm ir code
         | 
| 130 | 
            +
                    if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
         | 
| 131 | 
            +
                    kernel.append(cast(str, l))
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # generate the phi nodes for the assigns
         | 
| 134 | 
            +
                    if u.op is Ops.RANGE:
         | 
| 135 | 
            +
                      for x in acc_to_assign:
         | 
| 136 | 
            +
                        if u in x.src:  # if this range is relevent for this acc
         | 
| 137 | 
            +
                          vc += 1
         | 
| 138 | 
            +
                          kernel.append(f"  %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
         | 
| 139 | 
            +
                          r[x] = f"%acc{vc}"
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                # output the function
         | 
| 142 | 
            +
                return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n  ret void\n}\n"+'\n'.join(end_lines.keys())
         | 
    
        tinygrad/renderer/ptx.py
    ADDED
    
    | @@ -0,0 +1,225 @@ | |
| 1 | 
            +
            from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple
         | 
| 2 | 
            +
            import struct
         | 
| 3 | 
            +
            from collections import defaultdict
         | 
| 4 | 
            +
            from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
         | 
| 5 | 
            +
            from tinygrad.dtype import dtypes, DType, PtrDType
         | 
| 6 | 
            +
            from tinygrad.renderer import Renderer
         | 
| 7 | 
            +
            from tinygrad.renderer.cstyle import CUDARenderer
         | 
| 8 | 
            +
            from tinygrad.helpers import prod, flatten
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def render_val(x, dtype):
         | 
| 11 | 
            +
              if dtypes.is_float(dtype):
         | 
| 12 | 
            +
                if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
         | 
| 13 | 
            +
                if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
         | 
| 14 | 
            +
                return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
         | 
| 15 | 
            +
              return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            asm_for_op: Dict[Ops, Callable] = {
         | 
| 18 | 
            +
              Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
         | 
| 19 | 
            +
              Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
         | 
| 20 | 
            +
              Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
         | 
| 21 | 
            +
              Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
         | 
| 22 | 
            +
              Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
         | 
| 23 | 
            +
              Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
         | 
| 24 | 
            +
              Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
         | 
| 25 | 
            +
              Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};",
         | 
| 26 | 
            +
              Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};",
         | 
| 27 | 
            +
              Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
         | 
| 28 | 
            +
              Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
         | 
| 29 | 
            +
              Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
         | 
| 30 | 
            +
              Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
         | 
| 31 | 
            +
              Ops.WHERE: lambda d,a,b,c,dt,name:
         | 
| 32 | 
            +
                f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
         | 
| 33 | 
            +
            }
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE]
         | 
| 36 | 
            +
            doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half)
         | 
| 37 | 
            +
            ptx_matcher = PatternMatcher([
         | 
| 38 | 
            +
              # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only)
         | 
| 39 | 
            +
              (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
         | 
| 40 | 
            +
              (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
         | 
| 41 | 
            +
              # upcast to float32 all the ops that don't support half
         | 
| 42 | 
            +
              (UPat(doesnt_support_half, dtype=dtypes.half, name="x"),
         | 
| 43 | 
            +
                lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))),
         | 
| 44 | 
            +
              # load/store bool -> uint8
         | 
| 45 | 
            +
              (UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
         | 
| 46 | 
            +
               lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
         | 
| 47 | 
            +
              (UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
         | 
| 48 | 
            +
               lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
         | 
| 49 | 
            +
              # load/store use pointer arithmetic, and the cast does nothing
         | 
| 50 | 
            +
              (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
         | 
| 51 | 
            +
              (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
         | 
| 52 | 
            +
              # ptx shr and shl instructions require y to be uint
         | 
| 53 | 
            +
              (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
         | 
| 54 | 
            +
              (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
         | 
| 55 | 
            +
            ])
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None):
         | 
| 60 | 
            +
              gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else ""
         | 
| 61 | 
            +
              return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \
         | 
| 62 | 
            +
                if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"]
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            def render_wmma(ctx: "PTXRenderer", x: UOp):
         | 
| 65 | 
            +
              assert ctx.wmma_r, "registry values for wmma must be populated"
         | 
| 66 | 
            +
              _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg
         | 
| 67 | 
            +
              n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
         | 
| 68 | 
            +
              dt_map = { dtypes.half: "f16" }
         | 
| 69 | 
            +
              _i = 0
         | 
| 70 | 
            +
              for vv in x.src[:2]:
         | 
| 71 | 
            +
                for i in range(0, len(ctx.r[vv]), 2):
         | 
| 72 | 
            +
                  yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};"
         | 
| 73 | 
            +
                  _i += 1
         | 
| 74 | 
            +
              yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\
         | 
| 75 | 
            +
              f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \
         | 
| 76 | 
            +
              f'{{{", ".join(ctx.r[x.src[2]])}}};'
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \
         | 
| 79 | 
            +
              (a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else ''
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            string_rewrite = PatternMatcher([
         | 
| 82 | 
            +
              (UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
         | 
| 83 | 
            +
              (UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
         | 
| 84 | 
            +
              (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store),
         | 
| 85 | 
            +
              (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
         | 
| 86 | 
            +
              (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"),
         | 
| 87 | 
            +
              (UPat((Ops.CMPLT, Ops.CMPNE), name="x"),
         | 
| 88 | 
            +
              lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])),
         | 
| 89 | 
            +
              (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])),
         | 
| 90 | 
            +
              (UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"),
         | 
| 91 | 
            +
              (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))),
         | 
| 92 | 
            +
               lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"),
         | 
| 93 | 
            +
              (UPat(Ops.CAST, name="x", dtype=dtypes.bool),
         | 
| 94 | 
            +
               lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"),
         | 
| 95 | 
            +
              (UPat(Ops.CAST, name="x", src=(UPat.var("a"))),
         | 
| 96 | 
            +
               lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"),
         | 
| 97 | 
            +
              (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([
         | 
| 98 | 
            +
                [f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]],
         | 
| 99 | 
            +
                [f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"]
         | 
| 100 | 
            +
              ]) if alt.dtype.count > 1 else [
         | 
| 101 | 
            +
                f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];",
         | 
| 102 | 
            +
                f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]),
         | 
| 103 | 
            +
              (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True),
         | 
| 104 | 
            +
               lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \
         | 
| 105 | 
            +
                 if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"),
         | 
| 106 | 
            +
              (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True),
         | 
| 107 | 
            +
               lambda ctx, x, pred: flatten([
         | 
| 108 | 
            +
                [f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;",
         | 
| 109 | 
            +
                 f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
         | 
| 110 | 
            +
              (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True),
         | 
| 111 | 
            +
               lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};",
         | 
| 112 | 
            +
                  f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])),
         | 
| 113 | 
            +
              (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [
         | 
| 114 | 
            +
                f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]),
         | 
| 115 | 
            +
              (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [
         | 
| 116 | 
            +
                f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
         | 
| 117 | 
            +
              (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True),
         | 
| 118 | 
            +
               lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
         | 
| 119 | 
            +
              (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
         | 
| 120 | 
            +
              (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
         | 
| 121 | 
            +
              (UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
         | 
| 122 | 
            +
              (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
         | 
| 123 | 
            +
                ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
         | 
| 124 | 
            +
                ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
         | 
| 125 | 
            +
                f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
         | 
| 126 | 
            +
              (UPat(Ops.DEFINE_LOCAL, name="x"),
         | 
| 127 | 
            +
               lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]),
         | 
| 128 | 
            +
              (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
         | 
| 129 | 
            +
              (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
         | 
| 130 | 
            +
              (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
         | 
| 131 | 
            +
              (UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
         | 
| 132 | 
            +
              (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
         | 
| 133 | 
            +
            ])
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            class PTXRenderer(Renderer):
         | 
| 136 | 
            +
              device = "CUDA"
         | 
| 137 | 
            +
              suffix = "PTX"
         | 
| 138 | 
            +
              global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
         | 
| 139 | 
            +
              tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
         | 
| 140 | 
            +
              code_for_op = asm_for_op
         | 
| 141 | 
            +
              extra_matcher = ptx_matcher
         | 
| 142 | 
            +
              def __init__(self, arch:str, device="CUDA"):
         | 
| 143 | 
            +
                self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
         | 
| 144 | 
            +
              def __reduce__(self): return self.__class__, (self.arch, self.device)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
              # language options
         | 
| 147 | 
            +
              kernel_prefix = """.version VERSION
         | 
| 148 | 
            +
            .target TARGET
         | 
| 149 | 
            +
            .address_size 64
         | 
| 150 | 
            +
            .visible .entry"""
         | 
| 151 | 
            +
              barrier = "bar.sync\t0;"
         | 
| 152 | 
            +
              supports_half = supports_half
         | 
| 153 | 
            +
              # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
         | 
| 154 | 
            +
              types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
         | 
| 155 | 
            +
                                          dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
         | 
| 156 | 
            +
                                          dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
         | 
| 157 | 
            +
             | 
| 158 | 
            +
              mem_types: Dict[DType, str] =  types.copy()
         | 
| 159 | 
            +
              mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
         | 
| 160 | 
            +
             | 
| 161 | 
            +
              def render_kernel(self, kernel, function_name, bufs, regs) -> str:
         | 
| 162 | 
            +
                kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
         | 
| 163 | 
            +
                def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
         | 
| 164 | 
            +
                return (f"{self.kernel_prefix} {function_name}(\n\t" +
         | 
| 165 | 
            +
                        ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
         | 
| 166 | 
            +
                        '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
         | 
| 167 | 
            +
                        "\n}")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
              def render(self, name:str, uops:List[UOp]) -> str:
         | 
| 170 | 
            +
                kernel:List[str] = []
         | 
| 171 | 
            +
                bufs = []
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                c: DefaultDict[str, int] = defaultdict(int)
         | 
| 174 | 
            +
                r: Dict[UOp, Union[List[str], str]] = {}
         | 
| 175 | 
            +
                self.r = r
         | 
| 176 | 
            +
                self.uops = uops
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
         | 
| 179 | 
            +
                  nonlocal c, r
         | 
| 180 | 
            +
                  prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_"
         | 
| 181 | 
            +
                  c[prefix] += 1
         | 
| 182 | 
            +
                  return f"%{prefix}{c[prefix]-1}"
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                for u in uops:
         | 
| 185 | 
            +
                  if u.op is Ops.VECTORIZE:
         | 
| 186 | 
            +
                    r[u] = [cast(str,r[x]) for x in u.src]
         | 
| 187 | 
            +
                    continue
         | 
| 188 | 
            +
                  if u.op is Ops.GEP:
         | 
| 189 | 
            +
                    assert len(u.arg) == 1
         | 
| 190 | 
            +
                    r[u] = r[u.src[0]][u.arg[0]]
         | 
| 191 | 
            +
                    continue
         | 
| 192 | 
            +
                  if u.op in {Ops.CAST, Ops.BITCAST}:
         | 
| 193 | 
            +
                    if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType):
         | 
| 194 | 
            +
                      r[u] = r[u.src[0]]
         | 
| 195 | 
            +
                      continue
         | 
| 196 | 
            +
                    r[u] = ssa('cast', u, self.types[u.dtype])
         | 
| 197 | 
            +
                  elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred")
         | 
| 198 | 
            +
                  elif u.op is Ops.RANGE: r[u] = ssa("ridx", u)
         | 
| 199 | 
            +
                  elif u.op in GroupOp.ALU: r[u] = ssa("alu", u)
         | 
| 200 | 
            +
                  elif u.op is Ops.DEFINE_ACC:
         | 
| 201 | 
            +
                    if u.dtype.scalar() in [dtypes.half, dtypes.bool]:
         | 
| 202 | 
            +
                      r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0])
         | 
| 203 | 
            +
                    r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u)
         | 
| 204 | 
            +
                  elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
         | 
| 205 | 
            +
                  elif u.op is Ops.DEFINE_VAR:
         | 
| 206 | 
            +
                    bufs.append((u.arg[0], u.dtype))
         | 
| 207 | 
            +
                    r[u] = ssa("dat", u, self.types[u.dtype])
         | 
| 208 | 
            +
                  elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype])
         | 
| 209 | 
            +
                  elif u.op is Ops.LOAD:
         | 
| 210 | 
            +
                    assert u.src[0].dtype == dtypes.int64, "load isn't int64"
         | 
| 211 | 
            +
                    r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u)
         | 
| 212 | 
            +
                  elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong])
         | 
| 213 | 
            +
                  elif u.op is Ops.DEFINE_GLOBAL:
         | 
| 214 | 
            +
                    bufs.append((f"data{u.arg}", u.dtype))
         | 
| 215 | 
            +
                    r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype])
         | 
| 216 | 
            +
                  elif u.op is Ops.WMMA:
         | 
| 217 | 
            +
                    self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)]
         | 
| 218 | 
            +
                    r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
         | 
| 219 | 
            +
                  if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None:
         | 
| 220 | 
            +
                    raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}")
         | 
| 221 | 
            +
                  kernel.extend([l] if isinstance(l, str) else l)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                  if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
         | 
| 224 | 
            +
                  elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel
         | 
| 225 | 
            +
                return self.render_kernel(kernel, name, bufs, c.items())
         |