tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/ops.py CHANGED
@@ -1,32 +1,37 @@
1
1
  from __future__ import annotations
2
- from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, ClassVar
3
- import functools
2
+ from typing import Union, Tuple, Any, List, Dict, Callable
3
+ import functools, hashlib, math, operator, ctypes, struct
4
4
  from enum import Enum, auto
5
- from tinygrad.helpers import prod, dedup
6
- from tinygrad.dtype import dtypes, DType
7
- from tinygrad.shape.symbolic import Variable
8
5
  from dataclasses import dataclass
6
+ from tinygrad.helpers import prod, dedup
7
+ from tinygrad.dtype import dtypes, DType, ConstType
8
+ from tinygrad.shape.symbolic import Variable, sint
9
+ from tinygrad.shape.shapetracker import ShapeTracker
9
10
 
10
11
  # these are the llops your accelerator must implement, along with toCpu
11
12
  # the Enum class doesn't work with mypy, this is static. sorry it's ugly
12
13
  # NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
13
14
  # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
14
- class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
15
+ class UnaryOps(Enum):
16
+ """A -> A (elementwise)"""
17
+ EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
15
18
  class BinaryOps(Enum):
16
- ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
17
- class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
18
- class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
19
+ """A + A -> A (elementwise)"""
20
+ ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
21
+ SHR = auto(); SHL = auto() # noqa: E702
22
+ class TernaryOps(Enum):
23
+ """A + A + A -> A (elementwise)"""
24
+ WHERE = auto(); MULACC = auto() # noqa: E702
25
+ class ReduceOps(Enum):
26
+ """A -> B (reduce)"""
27
+ SUM = auto(); MAX = auto() # noqa: E702
19
28
  class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
20
- # Ops below this line are not allowed in ASTs
21
- class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
22
- class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
29
+ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
23
30
 
24
- Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
25
- OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
31
+ Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
26
32
 
27
- if TYPE_CHECKING:
28
- from tinygrad.shape.shapetracker import ShapeTracker
29
- from tinygrad.lazy import LazyBuffer
33
+ # do not preserve f(0) = 0
34
+ UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
30
35
 
31
36
  @dataclass(frozen=True)
32
37
  class MemBuffer:
@@ -36,17 +41,10 @@ class MemBuffer:
36
41
 
37
42
  @dataclass(frozen=True)
38
43
  class ConstBuffer:
39
- val: Union[int, float]
44
+ val: ConstType | Variable
40
45
  dtype: DType
41
46
  st: ShapeTracker
42
47
 
43
- @dataclass(frozen=True)
44
- class ScheduleItem:
45
- ast: LazyOp
46
- out: LazyBuffer
47
- inputs: Tuple[LazyBuffer, ...]
48
- var_vals: Dict[Variable, int]
49
-
50
48
  @dataclass(frozen=True, eq=False)
51
49
  class LazyOp:
52
50
  op: Op
@@ -61,20 +59,30 @@ class LazyOp:
61
59
  def __eq__(self, x): return self.cached_compare(x, context={})
62
60
  def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
63
61
  @functools.cached_property
62
+ def dtype(self) -> DType:
63
+ if self.op in BufferOps: return self.arg.dtype
64
+ if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
65
+ return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
66
+
67
+ @functools.cached_property
68
+ def key(self) -> bytes:
69
+ return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
70
+ @functools.cached_property
64
71
  def hash(self): return hash((self.op, self.src, self.arg))
65
72
  def __hash__(self): return self.hash
66
73
  @functools.cached_property
67
74
  def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
68
75
  def vars(self) -> List[Variable]:
69
- return sorted(set.union(*[x.arg.st.vars() for x in self.lazyops if x.op in BufferOps], set()), key=lambda x: str(x.expr))
76
+ extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
77
+ const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
78
+ return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
70
79
 
71
80
  # **************** independent FlopCounter ****************
72
81
 
73
82
  @dataclass
74
83
  class FlopCounter:
75
84
  shape: Tuple[int, ...]
76
- dtype: DType
77
- flops: int
85
+ flops: sint
78
86
  mem: Dict[int, int]
79
87
  @property
80
88
  def mem_estimate(self): return sum(self.mem.values())
@@ -83,14 +91,15 @@ class FlopCounter:
83
91
  return ret
84
92
 
85
93
  InterpretedFlopCounter: Dict[Op, Callable] = {
86
- BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {arg.idx: arg.dtype.itemsize*arg.st.real_size()}),
87
- BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
88
- BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
89
- UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
90
- **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
91
- **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
92
- **{op:lambda self,new_shape: FlopCounter(new_shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps},
93
- TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
94
+ BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
95
+ BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
96
+ BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
97
+ UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
98
+ UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
99
+ **{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
100
+ **{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
101
+ **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
102
+ TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
94
103
 
95
104
  @functools.lru_cache(None)
96
105
  def get_lazyop_info(ast:LazyOp) -> FlopCounter:
@@ -98,13 +107,63 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
98
107
  def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
99
108
  return run_ast(ast)
100
109
 
101
- # **************** global state Counters ****************
110
+ # **************** ops in python ****************
111
+
112
+ def hook_overflow(dv, fxn):
113
+ def wfxn(*args):
114
+ try: return fxn(*args)
115
+ except OverflowError: return dv
116
+ return wfxn
117
+
118
+ python_alu = {
119
+ UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
120
+ UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
121
+ UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
122
+ UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
123
+ UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
124
+ UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
125
+ BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
126
+ BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
127
+ BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
128
+ BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
129
+ TernaryOps.WHERE: lambda x,y,z: y if x else z}
130
+
131
+ def truncate_fp16(x):
132
+ try:
133
+ x = float(x)
134
+ struct.pack("@e", x)
135
+ return x
136
+ except OverflowError: return math.copysign(math.inf, x)
137
+
138
+ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
139
+ # TODO: bfloat16
140
+ dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
141
+ dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
142
+ dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
143
+ dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
144
+ dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
145
+
146
+ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
102
147
 
103
- class GlobalCounters:
104
- global_ops: ClassVar[int] = 0
105
- global_mem: ClassVar[int] = 0
106
- time_sum_s: ClassVar[float] = 0.0
107
- kernel_count: ClassVar[int] = 0
108
- mem_used: ClassVar[int] = 0 # NOTE: this is not reset
109
- @staticmethod
110
- def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
148
+ # the living definition of LazyOps
149
+ def verify_lazyop(*ast:LazyOp):
150
+ sts: Dict[LazyOp, ShapeTracker] = {}
151
+ def dfs(op:LazyOp, st:ShapeTracker):
152
+ if op in sts: return
153
+ for x in op.src: dfs(x, st)
154
+ # only reduceop is allowed to change shape, limited to turning n to 1
155
+ if op.op in ReduceOps:
156
+ expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
157
+ assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
158
+ st = ShapeTracker.from_shape(expected_shape)
159
+ else:
160
+ # movementops are pushed to the edges with LOAD
161
+ if op.op in BufferOps: st = op.arg.st
162
+ else: st = sts[op.src[0]]
163
+ for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
164
+ sts[op] = st
165
+ for i, out in enumerate(ast):
166
+ assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
167
+ assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
168
+ assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
169
+ dfs(out, out.arg.st)
@@ -0,0 +1,65 @@
1
+ from typing import Optional, List, Tuple, Dict
2
+ import functools
3
+ from dataclasses import dataclass
4
+ from tinygrad.helpers import getenv, to_function_name
5
+ from tinygrad.codegen.uops import UOpGraph
6
+ from tinygrad.shape.symbolic import sym_infer, sint, Variable
7
+ from tinygrad.dtype import DType
8
+
9
+ @dataclass(frozen=True)
10
+ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
11
+ dims: Tuple[int,int,int] # N, M, K
12
+ dtype_in: DType # dtype for A and B
13
+ dtype_out: DType # dtype for C and D
14
+ threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
15
+ thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
16
+ thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
17
+ def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
18
+ def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
19
+
20
+ @dataclass(frozen=True)
21
+ class Program:
22
+ name:str
23
+ src:str
24
+ dname:str
25
+ global_size:Optional[List[int]]=None
26
+ local_size:Optional[List[int]]=None
27
+ uops:Optional[UOpGraph]=None
28
+ op_estimate:sint=0
29
+ mem_estimate:sint=0
30
+
31
+ @functools.cached_property
32
+ def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
33
+
34
+ @functools.cached_property
35
+ def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
36
+
37
+ @functools.cached_property
38
+ def outcount(self) -> int: return sum(x[1] for x in self.globals)
39
+
40
+ @functools.cached_property
41
+ def function_name(self) -> str: return to_function_name(self.name)
42
+
43
+ def launch_dims(self, var_vals:Dict[Variable, int]):
44
+ global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
45
+ local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
46
+ return global_size, local_size
47
+
48
+ class Renderer:
49
+ device: str = ""
50
+ suffix: str = ""
51
+ # TODO: make this generic with a list of supported types
52
+ supports_float4: bool = True
53
+ has_local: bool = True
54
+ has_shared: bool = True
55
+ # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
56
+ global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
57
+ local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
58
+ shared_max: int = 32768
59
+ tensor_cores: List[TensorCore] = []
60
+ @functools.cached_property
61
+ def tc_opt(self): return getenv("TC_OPT")
62
+ @functools.cached_property
63
+ def tc(self): return getenv("TC", 1)
64
+
65
+ def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
@@ -0,0 +1,269 @@
1
+ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
2
+ import struct, math
3
+ from collections import defaultdict
4
+ from tinygrad.helpers import DEBUG
5
+ from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
6
+ from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
7
+ from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
8
+ from tinygrad.renderer import Renderer, TensorCore
9
+
10
+ def render_val(x, dtype):
11
+ if dtypes.is_float(dtype):
12
+ if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
13
+ if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
14
+ return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
15
+ return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
16
+
17
+ class PTXRenderer(Renderer):
18
+ device = "CUDA"
19
+ suffix = "PTX"
20
+ global_max = (2147483647, 65535, 65535)
21
+ local_max = (1024, 1024, 64)
22
+ shared_max = 49152
23
+ tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
24
+ def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
25
+
26
+ # language options
27
+ kernel_prefix = """.version VERSION
28
+ .target TARGET
29
+ .address_size 64
30
+ .visible .entry"""
31
+ barrier = "bar.sync\t0;"
32
+ gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
33
+ gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
34
+ lid = [f'%tid.{chr(120+i)}' for i in range(3)]
35
+ asm_for_op: Dict[Op, Callable] = {
36
+ UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
37
+ else f"neg.{name} {d}, {a};",
38
+ UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
39
+ UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
40
+ UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
41
+ BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
42
+ BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
43
+ BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
44
+ BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
45
+ BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
46
+ BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
47
+ BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
48
+ BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
49
+ TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
50
+ TernaryOps.WHERE: lambda d,a,b,c,dt,name:
51
+ f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
52
+ }
53
+ supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
54
+ TernaryOps.WHERE]
55
+ # HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
56
+ types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
57
+ dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
58
+ dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
59
+
60
+ mem_types: Dict[DType, str] = types.copy()
61
+ mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
62
+
63
+ const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
64
+
65
+ def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
66
+ val = render_val(x, dtype)
67
+ if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
68
+ return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
69
+
70
+ def render_local(self, dest, name, size, dtype) -> List[str]:
71
+ return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
72
+
73
+ def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
74
+
75
+ def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
76
+
77
+ def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
78
+ assert dtype != dtypes.bool
79
+ if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
80
+ return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
81
+
82
+ def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
83
+ return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
84
+
85
+ def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
86
+ if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
87
+ if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
88
+ if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
89
+ rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
90
+ '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
91
+ return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
92
+
93
+ def render_kernel(self, kernel, function_name, bufs, regs) -> str:
94
+ kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
95
+ def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
96
+ return (f"{self.kernel_prefix} {function_name}(\n\t" +
97
+ ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
98
+ '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
99
+ "\n}")
100
+
101
+ def render(self, name:str, uops:UOpGraph) -> str:
102
+ kernel:List[str] = []
103
+ bufs = []
104
+
105
+ uops.linearize(ptx_matcher)
106
+ if DEBUG >= 4: uops.print()
107
+
108
+ def kk(*s: str): kernel.append("\n".join(s))
109
+
110
+ c: DefaultDict[str, int] = defaultdict(int)
111
+ r: Dict[UOp, Union[List[str], str]] = {}
112
+ def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
113
+ nonlocal c, r
114
+ prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
115
+ c[prefix] += 1
116
+ if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
117
+ return f"%{prefix}{c[prefix]-1}"
118
+
119
+ def const(x:ConstType, dtype:DType, mov=False):
120
+ if mov or dtype in self.const_requires_mov:
121
+ kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
122
+ return out
123
+ return self.render_const(x, dtype)
124
+
125
+ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
126
+ if atype == dtype or isinstance(atype, PtrDType):
127
+ if u: r[u] = a
128
+ return a
129
+ kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
130
+ return ret
131
+
132
+ for u in uops:
133
+ uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
134
+ if uop is UOps.IF:
135
+ assert src[0].dtype is not None
136
+ kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
137
+ elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
138
+ elif uop is UOps.ENDRANGE:
139
+ kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
140
+ self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
141
+ kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
142
+ elif uop is UOps.ENDIF:
143
+ kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
144
+ elif uop is UOps.STORE:
145
+ assert src[0].dtype is not None and src[2].dtype is not None
146
+ assert src[0].dtype == dtypes.int64, "store isn't int64"
147
+ assert src[1].op is UOps.CONST, f"store isn't const {u}"
148
+ mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
149
+ if src[2].dtype.count > 1:
150
+ kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
151
+ f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
152
+ else:
153
+ kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
154
+ else:
155
+ assert dtype is not None, f"None dtype for uop {uop}"
156
+ if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
157
+ elif uop is UOps.ALU:
158
+ assert src[0].dtype is not None
159
+ if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
160
+ # pass in the other dtype here
161
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
162
+ else:
163
+ kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
164
+ elif uop is UOps.DEFINE_ACC:
165
+ if dtype.count > 1:
166
+ r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
167
+ for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
168
+ else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
169
+ elif uop is UOps.SPECIAL:
170
+ assert args[1][0] != "i", "idx not supported"
171
+ kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
172
+ r[u] = "%" + args[1]
173
+ kernel = [f".reg .u32 %{args[1]};"] + kernel
174
+ elif uop is UOps.CONST:
175
+ if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
176
+ else: r[u] = const(args, dtype, mov=True)
177
+ elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
178
+ elif uop is UOps.LOAD:
179
+ assert src[0].dtype == dtypes.int64, "load isn't int64"
180
+ assert src[1].op is UOps.CONST, f"load isn't const {u}"
181
+ mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
182
+ if dtype.count > 1:
183
+ r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
184
+ if(len(src)>3):
185
+ for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
186
+ kk((f"@{r[src[2]]}"if len(src) > 3 else "")
187
+ + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
188
+ else:
189
+ kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
190
+ alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
191
+ elif uop is UOps.PHI:
192
+ if dtype.count > 1:
193
+ for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
194
+ else:
195
+ kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
196
+ r[u] = r[src[0]]
197
+ elif uop in {UOps.CAST, UOps.BITCAST}:
198
+ assert src[0].dtype is not None
199
+ if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
200
+ else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
201
+ elif uop is UOps.DEFINE_LOCAL:
202
+ # TODO: we should sum these, and fetch 0xC000 from somewhere
203
+ assert args[1]*dtype.itemsize <= 0xC000, "too large local"
204
+ kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
205
+ elif uop is UOps.DEFINE_VAR:
206
+ bufs.append((args.expr, dtype))
207
+ r[u] = f"%{args.expr}"
208
+ kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
209
+ elif uop is UOps.DEFINE_GLOBAL:
210
+ bufs.append((nm:=f"data{args[0]}", dtype))
211
+ r[u] = f"%{nm}"
212
+ dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
213
+ kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
214
+ elif uop is UOps.WMMA:
215
+ wmma = []
216
+ for vv in src[:2]:
217
+ for i in range(0, len(r[vv]), 2):
218
+ wmma.append(ssa("wmma", dtype="b32"))
219
+ kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
220
+ r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
221
+ kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
222
+ {{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
223
+ else: raise NotImplementedError(f"no code for {uop}")
224
+
225
+ return self.render_kernel(kernel, name, bufs, c.items())
226
+
227
+ ptx_matcher = PatternMatcher([
228
+ (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
229
+ src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
230
+ lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
231
+ (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
232
+ src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
233
+ lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
234
+ (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
235
+ (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
236
+ lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
237
+ (UPat(UOps.ALU, BinaryOps.ADD,
238
+ [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
239
+ lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
240
+ *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
241
+ lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
242
+ for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
243
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
244
+ lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
245
+ (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
246
+ lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
247
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
248
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
249
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
250
+ lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
251
+ (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
252
+ lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
253
+ # ptr_ar (load/store)
254
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
255
+ UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
256
+ lambda root, alu, const: UOp(root.op, root.dtype,
257
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
258
+ UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
259
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
260
+ UPat(UOps.CONST, name="const"))),
261
+ lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
262
+ UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
263
+ )+root.src[2:])),
264
+ (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
265
+ UPat(name="alu"))), # no const here
266
+ lambda root, alu: UOp(root.op, root.dtype,
267
+ (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
268
+ UOp.const(dtypes.int64, 0))+root.src[2:])),
269
+ ])