tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,32 +1,36 @@
1
1
  from __future__ import annotations
2
- from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, ClassVar
3
- import functools
2
+ from typing import Union, Tuple, Any, List, Dict, Callable
3
+ import functools, hashlib, math, operator, ctypes
4
4
  from enum import Enum, auto
5
- from tinygrad.helpers import prod, dedup
6
- from tinygrad.dtype import dtypes, DType
7
- from tinygrad.shape.symbolic import Variable
8
5
  from dataclasses import dataclass
6
+ from tinygrad.helpers import prod, dedup
7
+ from tinygrad.dtype import dtypes, DType, ConstType
8
+ from tinygrad.shape.symbolic import Variable, sint
9
+ from tinygrad.shape.shapetracker import ShapeTracker
9
10
 
10
11
  # these are the llops your accelerator must implement, along with toCpu
11
12
  # the Enum class doesn't work with mypy, this is static. sorry it's ugly
12
13
  # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
13
14
  # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
14
- class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
15
+ class UnaryOps(Enum):
16
+ """A -> A (elementwise)"""
17
+ EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
15
18
  class BinaryOps(Enum):
19
+ """A + A -> A (elementwise)"""
16
20
  ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
17
- class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
18
- class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
21
+ class TernaryOps(Enum):
22
+ """A + A + A -> A (elementwise)"""
23
+ WHERE = auto(); MULACC = auto() # noqa: E702
24
+ class ReduceOps(Enum):
25
+ """A -> B (reduce)"""
26
+ SUM = auto(); MAX = auto() # noqa: E702
19
27
  class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
20
- # Ops below this line are not allowed in ASTs
21
- class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
22
- class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
28
+ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
23
29
 
24
- Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
25
- OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
30
+ Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
26
31
 
27
- if TYPE_CHECKING:
28
- from tinygrad.shape.shapetracker import ShapeTracker
29
- from tinygrad.lazy import LazyBuffer
32
+ # do not preserve f(0) = 0
33
+ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
30
34
 
31
35
  @dataclass(frozen=True)
32
36
  class MemBuffer:
@@ -36,17 +40,10 @@ class MemBuffer:
36
40
 
37
41
  @dataclass(frozen=True)
38
42
  class ConstBuffer:
39
- val: Union[int, float]
43
+ val: ConstType
40
44
  dtype: DType
41
45
  st: ShapeTracker
42
46
 
43
- @dataclass(frozen=True)
44
- class ScheduleItem:
45
- ast: LazyOp
46
- out: LazyBuffer
47
- inputs: Tuple[LazyBuffer, ...]
48
- var_vals: Dict[Variable, int]
49
-
50
47
  @dataclass(frozen=True, eq=False)
51
48
  class LazyOp:
52
49
  op: Op
@@ -61,20 +58,30 @@ class LazyOp:
61
58
  def __eq__(self, x): return self.cached_compare(x, context={})
62
59
  def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
63
60
  @functools.cached_property
61
+ def dtype(self) -> DType:
62
+ if self.op in BufferOps: return self.arg.dtype
63
+ if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
64
+ return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
65
+
66
+ @functools.cached_property
67
+ def key(self) -> bytes:
68
+ return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
69
+ @functools.cached_property
64
70
  def hash(self): return hash((self.op, self.src, self.arg))
65
71
  def __hash__(self): return self.hash
66
72
  @functools.cached_property
67
73
  def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
68
74
  def vars(self) -> List[Variable]:
69
- return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
75
+ extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
76
+ const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
77
+ return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
70
78
 
71
79
  # **************** independent FlopCounter ****************
72
80
 
73
81
  @dataclass
74
82
  class FlopCounter:
75
83
  shape: Tuple[int, ...]
76
- dtype: DType
77
- flops: int
84
+ flops: sint
78
85
  mem: Dict[int, int]
79
86
  @property
80
87
  def mem_estimate(self): return sum(self.mem.values())
@@ -83,14 +90,15 @@ class FlopCounter:
83
90
  return ret
84
91
 
85
92
  InterpretedFlopCounter: Dict[Op, Callable] = {
86
- BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
87
- BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
88
- BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
89
- UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
90
- **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
91
- **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
92
- **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
93
- TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
93
+ BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
94
+ BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
95
+ BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
96
+ UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
97
+ UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
98
+ **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
99
+ **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
100
+ **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
101
+ TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
94
102
 
95
103
  @functools.lru_cache(None)
96
104
  def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@@ -98,13 +106,31 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
98
106
  def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
99
107
  return run_ast(ast)
100
108
 
101
- # **************** global state Counters ****************
109
+ # **************** ops in python ****************
110
+
111
+ def hook_overflow(dv, fxn):
112
+ def wfxn(*args):
113
+ try: return fxn(*args)
114
+ except OverflowError: return dv
115
+ return wfxn
116
+
117
+ python_alu = {
118
+ UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
119
+ UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
120
+ UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
121
+ UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
122
+ BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
123
+ BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
124
+ BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
125
+ BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
126
+ TernaryOps.WHERE: lambda x,y,z: y if x else z}
127
+
128
+ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
129
+ # TODO: float16 and bfloat16?
130
+ dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
131
+ dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
132
+ dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
133
+ dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
134
+ dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
102
135
 
103
- class GlobalCounters:
104
- global_ops: ClassVar[int] = 0
105
- global_mem: ClassVar[int] = 0
106
- time_sum_s: ClassVar[float] = 0.0
107
- kernel_count: ClassVar[int] = 0
108
- mem_used: ClassVar[int] = 0 # NOTE: this is not reset
109
- @staticmethod
110
- def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
136
+ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
@@ -0,0 +1,61 @@
1
+ from typing import Optional, List, Tuple, Dict
2
+ import functools
3
+ from dataclasses import dataclass
4
+ from tinygrad.helpers import to_function_name
5
+ from tinygrad.codegen.uops import UOpGraph
6
+ from tinygrad.shape.symbolic import sym_infer, sint, Variable
7
+ from tinygrad.dtype import DType
8
+
9
+ @dataclass(frozen=True)
10
+ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
11
+ dims: Tuple[int,int,int] # N, M, K
12
+ dtype_in: DType # dtype for A and B
13
+ dtype_out: DType # dtype for C and D
14
+ threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
15
+ thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
16
+ thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
17
+ def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
18
+ def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
19
+
20
+ @dataclass(frozen=True)
21
+ class Program:
22
+ name:str
23
+ src:str
24
+ dname:str
25
+ global_size:Optional[List[int]]=None
26
+ local_size:Optional[List[int]]=None
27
+ uops:Optional[UOpGraph]=None
28
+ op_estimate:sint=0
29
+ mem_estimate:sint=0
30
+
31
+ @functools.cached_property
32
+ def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
33
+
34
+ @functools.cached_property
35
+ def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
36
+
37
+ @functools.cached_property
38
+ def outcount(self) -> int: return sum(x[1] for x in self.globals)
39
+
40
+ @functools.cached_property
41
+ def function_name(self) -> str: return to_function_name(self.name)
42
+
43
+ def launch_dims(self, var_vals:Dict[Variable, int]):
44
+ global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
45
+ local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
46
+ return global_size, local_size
47
+
48
+ class Renderer:
49
+ device: str = ""
50
+ suffix: str = ""
51
+ # TODO: make this generic with a list of supported types
52
+ supports_float4: bool = True
53
+ has_local: bool = True
54
+ has_shared: bool = True
55
+ # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
56
+ global_max: Optional[List[int]] = None
57
+ local_max: Optional[List[int]] = None
58
+ shared_max: int = 32768
59
+ tensor_cores: List[TensorCore] = []
60
+
61
+ def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
@@ -0,0 +1,276 @@
1
+ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
2
+ import struct
3
+ from collections import defaultdict
4
+ from tinygrad.helpers import DEBUG
5
+ from tinygrad.codegen.linearizer import UOps, UOp
6
+ from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
7
+ from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
8
+ from tinygrad.codegen.uops import UOpGraph, PatternMatcher
9
+ from tinygrad.renderer import Renderer, TensorCore
10
+
11
+ def render_val(x, dtype):
12
+ if dtypes.is_float(dtype):
13
+ if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
14
+ if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
15
+ return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
16
+ return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
17
+
18
+ class PTXRenderer(Renderer):
19
+ device = "CUDA"
20
+ suffix = "PTX"
21
+ global_max = [65535, 65535, 2147483647]
22
+ local_max = [64, 1024, 1024]
23
+ shared_max = 49152
24
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
25
+ def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
26
+
27
+ # language options
28
+ kernel_prefix = """.version VERSION
29
+ .target TARGET
30
+ .address_size 64
31
+ .visible .entry"""
32
+ barrier = "bar.sync\t0;"
33
+ has_pred = True
34
+ load_global = True
35
+ label_prefix = "$"
36
+ gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
37
+ gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
38
+ lid = [f'%tid.{chr(120+i)}' for i in range(3)]
39
+ asm_for_op: Dict[Op, Callable] = {
40
+ UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
41
+ UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
42
+ UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
43
+ BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
44
+ BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
45
+ BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
46
+ BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
47
+ BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
48
+ BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
49
+ BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
50
+ BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
51
+ TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
52
+ TernaryOps.WHERE: lambda d,a,b,c,dt,name:
53
+ f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
54
+ }
55
+ supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
56
+ TernaryOps.WHERE]
57
+ # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
58
+ types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
59
+ dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
60
+ dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
61
+
62
+ mem_types: Dict[DType, str] = types.copy()
63
+ mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
64
+
65
+ const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
66
+
67
+ def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
68
+ val = render_val(x, dtype)
69
+ if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
70
+ return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
71
+
72
+ def render_local(self, dest, name, size, dtype) -> List[str]:
73
+ return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
74
+
75
+ def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
76
+
77
+ def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
78
+
79
+ def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
80
+ assert dtype != dtypes.bool
81
+ if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
82
+ return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
83
+
84
+ def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
85
+ return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
86
+
87
+ def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
88
+ if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
89
+ if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
90
+ if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
91
+ rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
92
+ '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
93
+ return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
94
+
95
+ def render_kernel(self, kernel, function_name, bufs, regs) -> str:
96
+ kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
97
+ def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
98
+ return (f"{self.kernel_prefix} {function_name}(\n\t" +
99
+ ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
100
+ '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
101
+ "\n}")
102
+
103
+ def render(self, name:str, uops:UOpGraph) -> str:
104
+ kernel:List[str] = []
105
+ bufs = []
106
+
107
+ uops.linearize(ptx_matcher)
108
+ if DEBUG >= 4: uops.print()
109
+
110
+ def kk(*s: str): kernel.append("\n".join(s))
111
+
112
+ c: DefaultDict[str, int] = defaultdict(int)
113
+ r: Dict[UOp, Union[List[str], str]] = {}
114
+ def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
115
+ nonlocal c, r
116
+ prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
117
+ c[prefix] += 1
118
+ if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
119
+ return f"%{prefix}{c[prefix]-1}"
120
+
121
+ c_label: DefaultDict[str, int] = defaultdict(int)
122
+ r_label: Dict[UOp, str] = {}
123
+ def ssa_label(prefix:str, u:UOp):
124
+ nonlocal c_label, r_label
125
+ c_label[prefix] += 1
126
+ r_label[u] = f"{self.label_prefix}{prefix}_{c_label[prefix]-1}"
127
+ return r_label[u]
128
+
129
+ def const(x:ConstType, dtype:DType, mov=False):
130
+ if mov or dtype in self.const_requires_mov:
131
+ kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
132
+ return out
133
+ return self.render_const(x, dtype)
134
+
135
+ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
136
+ if atype == dtype or isinstance(atype, PtrDType):
137
+ if u: r[u] = a
138
+ return a
139
+ kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
140
+ return ret
141
+
142
+ for u in uops:
143
+ uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
144
+ if uop is UOps.IF:
145
+ assert vin[0].dtype is not None
146
+ kk(*self.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
147
+ elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
148
+ elif uop is UOps.ENDRANGE:
149
+ kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
150
+ self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
151
+ kk(*self.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
152
+ elif uop is UOps.ENDIF:
153
+ kk(f"{r_label[vin[0]]}:")
154
+ elif uop is UOps.STORE:
155
+ assert vin[0].dtype is not None and vin[2].dtype is not None
156
+ assert vin[0].dtype == dtypes.int64, "store isn't int64"
157
+ assert vin[1].uop is UOps.CONST, f"store isn't const {u}"
158
+ mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
159
+ if vin[2].dtype.count > 1:
160
+ kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
161
+ f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
162
+ else:
163
+ kk(*self.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=mem_type, offset=vin[1].arg))
164
+ else:
165
+ assert dtype is not None, f"None dtype for uop {uop}"
166
+ if uop is UOps.RANGE: kk(*self.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
167
+ elif uop is UOps.ALU:
168
+ assert vin[0].dtype is not None
169
+ if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
170
+ # pass in the other dtype here
171
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, self.types[vin[0].dtype]))
172
+ else:
173
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, self.types[dtype]))
174
+ elif uop is UOps.DEFINE_ACC:
175
+ if dtype.count > 1:
176
+ r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
177
+ for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args[0], dtype.scalar())};")
178
+ else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args[0], dtype)};")
179
+ elif uop is UOps.SPECIAL:
180
+ assert args[1][0] != "i", "idx not supported"
181
+ kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
182
+ r[u] = "%" + args[1]
183
+ kernel = [f".reg .u32 %{args[1]};"] + kernel
184
+ elif uop is UOps.CONST:
185
+ if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
186
+ else: r[u] = const(args, dtype, mov=True)
187
+ elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
188
+ elif uop is UOps.LOAD:
189
+ assert vin[0].dtype == dtypes.int64, "load isn't int64"
190
+ assert vin[1].uop is UOps.CONST, f"load isn't const {u}"
191
+ mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
192
+ if dtype.count > 1:
193
+ r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
194
+ if(len(vin)>3):
195
+ for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
196
+ kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
197
+ + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
198
+ else:
199
+ kk(*self.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
200
+ alt=r[vin[3]] if len(vin) > 3 else None, ss=mem_type, offset=vin[1].arg))
201
+ elif uop is UOps.PHI:
202
+ if dtype.count > 1:
203
+ for x0, x1 in zip(r[vin[0]], r[vin[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
204
+ else:
205
+ kk(f"mov.b{self.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
206
+ r[u] = r[vin[0]]
207
+ elif uop in {UOps.CAST, UOps.BITCAST}:
208
+ assert vin[0].dtype is not None
209
+ if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
210
+ else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
211
+ elif uop is UOps.DEFINE_LOCAL:
212
+ # TODO: we should sum these, and fetch 0xC000 from somewhere
213
+ assert args[1]*dtype.itemsize <= 0xC000, "too large local"
214
+ kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
215
+ elif uop is UOps.DEFINE_VAR:
216
+ bufs.append((args.expr, dtype))
217
+ r[u] = f"%{args.expr}"
218
+ if self.load_global: kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
219
+ elif uop is UOps.DEFINE_GLOBAL:
220
+ bufs.append((nm:=f"data{args[0]}", dtype))
221
+ r[u] = f"%{nm}"
222
+ if self.load_global:
223
+ dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
224
+ kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
225
+ elif uop is UOps.WMMA:
226
+ wmma = []
227
+ for vv in vin[:2]:
228
+ for i in range(0, len(r[vv]), 2):
229
+ wmma.append(ssa("wmma", dtype="b32"))
230
+ kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
231
+ r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
232
+ kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
233
+ {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[vin[2]])}}};')
234
+ else: raise NotImplementedError(f"no code for {uop}")
235
+
236
+ return self.render_kernel(kernel, name, bufs, c.items())
237
+
238
+ ptx_matcher = PatternMatcher([
239
+ ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})},
240
+ lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),
241
+ ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"__name__": "x", "dtype": dtypes.bool},{"__name__": "y"})},
242
+ lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
243
+ ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.ADD,
244
+ "vin": [{"__name__": "non_muls"}, {"__name__": "muls", "uop": UOps.ALU, "arg": BinaryOps.MUL}]},
245
+ lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
246
+ *[({"__name__": "x", "uop": UOps.ALU, "dtype": dtypes.half, "arg": op},
247
+ lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
248
+ for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
249
+ ({"__name__": "root", "uop": UOps.LOAD, "dtype": dtypes.bool,
250
+ "vin": ({"__name__": "x"},{"__name__": "y"},{"__name__": "z"},{"__name__": "k"})},
251
+ lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
252
+ ({"__name__": "root", "uop": UOps.LOAD,"dtype": dtypes.bool, "vin": ({},{})},
253
+ lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
254
+ ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool}, {})},
255
+ lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
256
+ ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{"__name__": "z","dtype": dtypes.bool})},
257
+ lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
258
+ ({"__name__": "root", "uop": UOps.STORE, "vin": ({},{},{},{"__name__": "g", "dtype": dtypes.int})},
259
+ lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
260
+ # ptr_ar (load/store)
261
+ ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
262
+ {"uop": UOps.ALU, "arg": BinaryOps.ADD,"vin":[{"__name__": "alu"}, {"__name__": "const", "uop":UOps.CONST}]})},
263
+ lambda root, alu, const: UOp(root.uop, root.dtype,
264
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
265
+ UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
266
+ ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
267
+ {"__name__": "const", "uop":UOps.CONST})},
268
+ lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
269
+ UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
270
+ )+root.vin[2:])),
271
+ ({"__name__": "root", "uop": {UOps.LOAD, UOps.STORE}, "__allow_len__":[2,3,4,5], "vin": ({"uop":{UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}},
272
+ {"__name__": "alu"})}, # no const here
273
+ lambda root, alu: UOp(root.uop, root.dtype,
274
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
275
+ UOp.const(dtypes.int64, 0))+root.vin[2:])),
276
+ ])