tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/ops.py
CHANGED
@@ -1,32 +1,37 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import
|
3
|
-
import functools
|
2
|
+
from typing import Union, Tuple, Any, List, Dict, Callable
|
3
|
+
import functools, hashlib, math, operator, ctypes, struct
|
4
4
|
from enum import Enum, auto
|
5
|
-
from tinygrad.helpers import prod, dedup
|
6
|
-
from tinygrad.dtype import dtypes, DType
|
7
|
-
from tinygrad.shape.symbolic import Variable
|
8
5
|
from dataclasses import dataclass
|
6
|
+
from tinygrad.helpers import prod, dedup
|
7
|
+
from tinygrad.dtype import dtypes, DType, ConstType
|
8
|
+
from tinygrad.shape.symbolic import Variable, sint
|
9
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
9
10
|
|
10
11
|
# these are the llops your accelerator must implement, along with toCpu
|
11
12
|
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
12
13
|
# NOTE: MOD, CMPLT don't have to be implemented on vectors, just scalars
|
13
14
|
# NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
|
14
|
-
class UnaryOps(Enum):
|
15
|
+
class UnaryOps(Enum):
|
16
|
+
"""A -> A (elementwise)"""
|
17
|
+
EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
|
15
18
|
class BinaryOps(Enum):
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
+
"""A + A -> A (elementwise)"""
|
20
|
+
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
21
|
+
SHR = auto(); SHL = auto() # noqa: E702
|
22
|
+
class TernaryOps(Enum):
|
23
|
+
"""A + A + A -> A (elementwise)"""
|
24
|
+
WHERE = auto(); MULACC = auto() # noqa: E702
|
25
|
+
class ReduceOps(Enum):
|
26
|
+
"""A -> B (reduce)"""
|
27
|
+
SUM = auto(); MAX = auto() # noqa: E702
|
19
28
|
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
20
|
-
|
21
|
-
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
22
|
-
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
29
|
+
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702
|
23
30
|
|
24
|
-
Op = Union[UnaryOps, BinaryOps, ReduceOps,
|
25
|
-
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
31
|
+
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
26
32
|
|
27
|
-
|
28
|
-
|
29
|
-
from tinygrad.lazy import LazyBuffer
|
33
|
+
# do not preserve f(0) = 0
|
34
|
+
UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
|
30
35
|
|
31
36
|
@dataclass(frozen=True)
|
32
37
|
class MemBuffer:
|
@@ -36,17 +41,10 @@ class MemBuffer:
|
|
36
41
|
|
37
42
|
@dataclass(frozen=True)
|
38
43
|
class ConstBuffer:
|
39
|
-
val:
|
44
|
+
val: ConstType | Variable
|
40
45
|
dtype: DType
|
41
46
|
st: ShapeTracker
|
42
47
|
|
43
|
-
@dataclass(frozen=True)
|
44
|
-
class ScheduleItem:
|
45
|
-
ast: LazyOp
|
46
|
-
out: LazyBuffer
|
47
|
-
inputs: Tuple[LazyBuffer, ...]
|
48
|
-
var_vals: Dict[Variable, int]
|
49
|
-
|
50
48
|
@dataclass(frozen=True, eq=False)
|
51
49
|
class LazyOp:
|
52
50
|
op: Op
|
@@ -61,20 +59,30 @@ class LazyOp:
|
|
61
59
|
def __eq__(self, x): return self.cached_compare(x, context={})
|
62
60
|
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
63
61
|
@functools.cached_property
|
62
|
+
def dtype(self) -> DType:
|
63
|
+
if self.op in BufferOps: return self.arg.dtype
|
64
|
+
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
|
65
|
+
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
|
66
|
+
|
67
|
+
@functools.cached_property
|
68
|
+
def key(self) -> bytes:
|
69
|
+
return hashlib.sha256(functools.reduce(lambda x,y: x+y, [s.key for s in self.src], str((self.op, self.arg)).encode())).digest()
|
70
|
+
@functools.cached_property
|
64
71
|
def hash(self): return hash((self.op, self.src, self.arg))
|
65
72
|
def __hash__(self): return self.hash
|
66
73
|
@functools.cached_property
|
67
74
|
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
|
68
75
|
def vars(self) -> List[Variable]:
|
69
|
-
|
76
|
+
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
|
77
|
+
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
78
|
+
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
|
70
79
|
|
71
80
|
# **************** independent FlopCounter ****************
|
72
81
|
|
73
82
|
@dataclass
|
74
83
|
class FlopCounter:
|
75
84
|
shape: Tuple[int, ...]
|
76
|
-
|
77
|
-
flops: int
|
85
|
+
flops: sint
|
78
86
|
mem: Dict[int, int]
|
79
87
|
@property
|
80
88
|
def mem_estimate(self): return sum(self.mem.values())
|
@@ -83,14 +91,15 @@ class FlopCounter:
|
|
83
91
|
return ret
|
84
92
|
|
85
93
|
InterpretedFlopCounter: Dict[Op, Callable] = {
|
86
|
-
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape,
|
87
|
-
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape,
|
88
|
-
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape,
|
89
|
-
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape,
|
90
|
-
|
91
|
-
**{op:lambda self
|
92
|
-
**{op:lambda self,
|
93
|
-
|
94
|
+
BufferOps.LOAD: lambda arg: FlopCounter(arg.st.shape, 0, {arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
95
|
+
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, 0, {}),
|
96
|
+
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize * arg.st.real_size()}),
|
97
|
+
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # cast uses no flops
|
98
|
+
UnaryOps.BITCAST: lambda self,arg: FlopCounter(self.shape, self.consume_flops(), self.mem), # bitcast uses no flops
|
99
|
+
**{op:lambda self: FlopCounter(self.shape, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op not in {UnaryOps.CAST, UnaryOps.BITCAST}}, # noqa: E501
|
100
|
+
**{op:lambda self,y: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
101
|
+
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
102
|
+
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
94
103
|
|
95
104
|
@functools.lru_cache(None)
|
96
105
|
def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
@@ -98,13 +107,63 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
|
98
107
|
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
99
108
|
return run_ast(ast)
|
100
109
|
|
101
|
-
# ****************
|
110
|
+
# **************** ops in python ****************
|
111
|
+
|
112
|
+
def hook_overflow(dv, fxn):
|
113
|
+
def wfxn(*args):
|
114
|
+
try: return fxn(*args)
|
115
|
+
except OverflowError: return dv
|
116
|
+
return wfxn
|
117
|
+
|
118
|
+
python_alu = {
|
119
|
+
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
120
|
+
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
|
121
|
+
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
|
122
|
+
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
123
|
+
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
|
124
|
+
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
125
|
+
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
|
126
|
+
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
127
|
+
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
128
|
+
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
129
|
+
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
130
|
+
|
131
|
+
def truncate_fp16(x):
|
132
|
+
try:
|
133
|
+
x = float(x)
|
134
|
+
struct.pack("@e", x)
|
135
|
+
return x
|
136
|
+
except OverflowError: return math.copysign(math.inf, x)
|
137
|
+
|
138
|
+
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
139
|
+
# TODO: bfloat16
|
140
|
+
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
141
|
+
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
142
|
+
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
143
|
+
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
144
|
+
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
|
145
|
+
|
146
|
+
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
102
147
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
148
|
+
# the living definition of LazyOps
|
149
|
+
def verify_lazyop(*ast:LazyOp):
|
150
|
+
sts: Dict[LazyOp, ShapeTracker] = {}
|
151
|
+
def dfs(op:LazyOp, st:ShapeTracker):
|
152
|
+
if op in sts: return
|
153
|
+
for x in op.src: dfs(x, st)
|
154
|
+
# only reduceop is allowed to change shape, limited to turning n to 1
|
155
|
+
if op.op in ReduceOps:
|
156
|
+
expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
|
157
|
+
assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
|
158
|
+
st = ShapeTracker.from_shape(expected_shape)
|
159
|
+
else:
|
160
|
+
# movementops are pushed to the edges with LOAD
|
161
|
+
if op.op in BufferOps: st = op.arg.st
|
162
|
+
else: st = sts[op.src[0]]
|
163
|
+
for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
|
164
|
+
sts[op] = st
|
165
|
+
for i, out in enumerate(ast):
|
166
|
+
assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
|
167
|
+
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
|
168
|
+
assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
|
169
|
+
dfs(out, out.arg.st)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import Optional, List, Tuple, Dict
|
2
|
+
import functools
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from tinygrad.helpers import getenv, to_function_name
|
5
|
+
from tinygrad.codegen.uops import UOpGraph
|
6
|
+
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
7
|
+
from tinygrad.dtype import DType
|
8
|
+
|
9
|
+
@dataclass(frozen=True)
|
10
|
+
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
11
|
+
dims: Tuple[int,int,int] # N, M, K
|
12
|
+
dtype_in: DType # dtype for A and B
|
13
|
+
dtype_out: DType # dtype for C and D
|
14
|
+
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
15
|
+
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
16
|
+
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
17
|
+
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
18
|
+
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
|
19
|
+
|
20
|
+
@dataclass(frozen=True)
|
21
|
+
class Program:
|
22
|
+
name:str
|
23
|
+
src:str
|
24
|
+
dname:str
|
25
|
+
global_size:Optional[List[int]]=None
|
26
|
+
local_size:Optional[List[int]]=None
|
27
|
+
uops:Optional[UOpGraph]=None
|
28
|
+
op_estimate:sint=0
|
29
|
+
mem_estimate:sint=0
|
30
|
+
|
31
|
+
@functools.cached_property
|
32
|
+
def vars(self) -> List[Variable]: return [] if self.uops is None else self.uops.vars()
|
33
|
+
|
34
|
+
@functools.cached_property
|
35
|
+
def globals(self) -> List[Tuple[int, bool]]: return [] if self.uops is None else self.uops.globals()
|
36
|
+
|
37
|
+
@functools.cached_property
|
38
|
+
def outcount(self) -> int: return sum(x[1] for x in self.globals)
|
39
|
+
|
40
|
+
@functools.cached_property
|
41
|
+
def function_name(self) -> str: return to_function_name(self.name)
|
42
|
+
|
43
|
+
def launch_dims(self, var_vals:Dict[Variable, int]):
|
44
|
+
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else None
|
45
|
+
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
46
|
+
return global_size, local_size
|
47
|
+
|
48
|
+
class Renderer:
|
49
|
+
device: str = ""
|
50
|
+
suffix: str = ""
|
51
|
+
# TODO: make this generic with a list of supported types
|
52
|
+
supports_float4: bool = True
|
53
|
+
has_local: bool = True
|
54
|
+
has_shared: bool = True
|
55
|
+
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
56
|
+
global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
57
|
+
local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
|
58
|
+
shared_max: int = 32768
|
59
|
+
tensor_cores: List[TensorCore] = []
|
60
|
+
@functools.cached_property
|
61
|
+
def tc_opt(self): return getenv("TC_OPT")
|
62
|
+
@functools.cached_property
|
63
|
+
def tc(self): return getenv("TC", 1)
|
64
|
+
|
65
|
+
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
|
@@ -0,0 +1,269 @@
|
|
1
|
+
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
2
|
+
import struct, math
|
3
|
+
from collections import defaultdict
|
4
|
+
from tinygrad.helpers import DEBUG
|
5
|
+
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
6
|
+
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
7
|
+
from tinygrad.codegen.uops import UOps, UOp, UOpGraph, PatternMatcher, UPat
|
8
|
+
from tinygrad.renderer import Renderer, TensorCore
|
9
|
+
|
10
|
+
def render_val(x, dtype):
|
11
|
+
if dtypes.is_float(dtype):
|
12
|
+
if dtype == dtypes.double: return "0d%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
13
|
+
if dtype == dtypes.half: return "0x%02X%02X" % tuple(struct.pack("e",x)[::-1])
|
14
|
+
return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
15
|
+
return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
16
|
+
|
17
|
+
class PTXRenderer(Renderer):
|
18
|
+
device = "CUDA"
|
19
|
+
suffix = "PTX"
|
20
|
+
global_max = (2147483647, 65535, 65535)
|
21
|
+
local_max = (1024, 1024, 64)
|
22
|
+
shared_max = 49152
|
23
|
+
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
24
|
+
def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
25
|
+
|
26
|
+
# language options
|
27
|
+
kernel_prefix = """.version VERSION
|
28
|
+
.target TARGET
|
29
|
+
.address_size 64
|
30
|
+
.visible .entry"""
|
31
|
+
barrier = "bar.sync\t0;"
|
32
|
+
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
33
|
+
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
34
|
+
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
35
|
+
asm_for_op: Dict[Op, Callable] = {
|
36
|
+
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"sub.{name} {d}, 0, {a};" if dtypes.is_unsigned(dt) \
|
37
|
+
else f"neg.{name} {d}, {a};",
|
38
|
+
UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};",
|
39
|
+
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
40
|
+
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
41
|
+
BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};",
|
42
|
+
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
43
|
+
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
44
|
+
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
45
|
+
BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};",
|
46
|
+
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
47
|
+
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
48
|
+
BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};",
|
49
|
+
TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};",
|
50
|
+
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
51
|
+
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
52
|
+
}
|
53
|
+
supports_half: List[Op] = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
54
|
+
TernaryOps.WHERE]
|
55
|
+
# HACK: Use s16 and u16 for int8 and uint8 buffers. This can be wrong in cast.
|
56
|
+
types: Dict[DType, str] = { dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
57
|
+
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
58
|
+
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", dtypes.bool: "pred" }
|
59
|
+
|
60
|
+
mem_types: Dict[DType, str] = types.copy()
|
61
|
+
mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"})
|
62
|
+
|
63
|
+
const_requires_mov: List[DType] = [dtypes.half, dtypes.bool]
|
64
|
+
|
65
|
+
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
66
|
+
val = render_val(x, dtype)
|
67
|
+
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
68
|
+
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
69
|
+
|
70
|
+
def render_local(self, dest, name, size, dtype) -> List[str]:
|
71
|
+
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
|
72
|
+
|
73
|
+
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
74
|
+
|
75
|
+
def render_bra(self, b1, pred=None) -> List[str]: return [f"@{pred} bra {b1};"] if pred else [f"bra {b1};"]
|
76
|
+
|
77
|
+
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]:
|
78
|
+
assert dtype != dtypes.bool
|
79
|
+
if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"]
|
80
|
+
return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"]
|
81
|
+
|
82
|
+
def render_store(self, loc, val, dtype, gate=None, ss="", offset=0) -> List[str]:
|
83
|
+
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_types[dtype]} [{loc}+{offset}], {val};"]
|
84
|
+
|
85
|
+
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
86
|
+
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
87
|
+
if atype == dtypes.bool: return[f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"]
|
88
|
+
if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"]
|
89
|
+
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
90
|
+
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
91
|
+
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
92
|
+
|
93
|
+
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
94
|
+
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
95
|
+
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
96
|
+
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
97
|
+
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
98
|
+
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
99
|
+
"\n}")
|
100
|
+
|
101
|
+
def render(self, name:str, uops:UOpGraph) -> str:
|
102
|
+
kernel:List[str] = []
|
103
|
+
bufs = []
|
104
|
+
|
105
|
+
uops.linearize(ptx_matcher)
|
106
|
+
if DEBUG >= 4: uops.print()
|
107
|
+
|
108
|
+
def kk(*s: str): kernel.append("\n".join(s))
|
109
|
+
|
110
|
+
c: DefaultDict[str, int] = defaultdict(int)
|
111
|
+
r: Dict[UOp, Union[List[str], str]] = {}
|
112
|
+
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
113
|
+
nonlocal c, r
|
114
|
+
prefix += f"_{dtype if dtype is not None else self.types[cast(DType, cast(UOp, u).dtype)]}_"
|
115
|
+
c[prefix] += 1
|
116
|
+
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
117
|
+
return f"%{prefix}{c[prefix]-1}"
|
118
|
+
|
119
|
+
def const(x:ConstType, dtype:DType, mov=False):
|
120
|
+
if mov or dtype in self.const_requires_mov:
|
121
|
+
kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype]))))
|
122
|
+
return out
|
123
|
+
return self.render_const(x, dtype)
|
124
|
+
|
125
|
+
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
126
|
+
if atype == dtype or isinstance(atype, PtrDType):
|
127
|
+
if u: r[u] = a
|
128
|
+
return a
|
129
|
+
kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast))
|
130
|
+
return ret
|
131
|
+
|
132
|
+
for u in uops:
|
133
|
+
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
134
|
+
if uop is UOps.IF:
|
135
|
+
assert src[0].dtype is not None
|
136
|
+
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)))
|
137
|
+
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
138
|
+
elif uop is UOps.ENDRANGE:
|
139
|
+
kk(self.asm_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
140
|
+
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
141
|
+
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
142
|
+
elif uop is UOps.ENDIF:
|
143
|
+
kk(f"IF_{r[src[0].src[0]][1:]}_{cast(List, uops._uops).index(src[0])}:")
|
144
|
+
elif uop is UOps.STORE:
|
145
|
+
assert src[0].dtype is not None and src[2].dtype is not None
|
146
|
+
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
147
|
+
assert src[1].op is UOps.CONST, f"store isn't const {u}"
|
148
|
+
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
149
|
+
if src[2].dtype.count > 1:
|
150
|
+
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
151
|
+
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
152
|
+
else:
|
153
|
+
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
154
|
+
else:
|
155
|
+
assert dtype is not None, f"None dtype for uop {uop}"
|
156
|
+
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
157
|
+
elif uop is UOps.ALU:
|
158
|
+
assert src[0].dtype is not None
|
159
|
+
if args is BinaryOps.CMPLT or args is BinaryOps.CMPNE:
|
160
|
+
# pass in the other dtype here
|
161
|
+
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], src[0].dtype, self.types[src[0].dtype]))
|
162
|
+
else:
|
163
|
+
kk(self.asm_for_op[args](ssa("alu", u), *[r[x] for x in src], dtype, self.types[dtype]))
|
164
|
+
elif uop is UOps.DEFINE_ACC:
|
165
|
+
if dtype.count > 1:
|
166
|
+
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
167
|
+
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
168
|
+
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
169
|
+
elif uop is UOps.SPECIAL:
|
170
|
+
assert args[1][0] != "i", "idx not supported"
|
171
|
+
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
172
|
+
r[u] = "%" + args[1]
|
173
|
+
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
174
|
+
elif uop is UOps.CONST:
|
175
|
+
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
176
|
+
else: r[u] = const(args, dtype, mov=True)
|
177
|
+
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
178
|
+
elif uop is UOps.LOAD:
|
179
|
+
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
180
|
+
assert src[1].op is UOps.CONST, f"load isn't const {u}"
|
181
|
+
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
182
|
+
if dtype.count > 1:
|
183
|
+
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
184
|
+
if(len(src)>3):
|
185
|
+
for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};")
|
186
|
+
kk((f"@{r[src[2]]}"if len(src) > 3 else "")
|
187
|
+
+ f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+{src[1].arg}];")
|
188
|
+
else:
|
189
|
+
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if len(src) > 3 else None,
|
190
|
+
alt=r[src[3]] if len(src) > 3 else None, ss=mem_type, offset=src[1].arg))
|
191
|
+
elif uop is UOps.PHI:
|
192
|
+
if dtype.count > 1:
|
193
|
+
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
194
|
+
else:
|
195
|
+
kk(f"mov.b{self.types[dtype][1:]} {r[src[0]]}, {r[src[1]]};")
|
196
|
+
r[u] = r[src[0]]
|
197
|
+
elif uop in {UOps.CAST, UOps.BITCAST}:
|
198
|
+
assert src[0].dtype is not None
|
199
|
+
if dtype.count>1: r[u] = [r[x] for x in src] # type: ignore
|
200
|
+
else: _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
201
|
+
elif uop is UOps.DEFINE_LOCAL:
|
202
|
+
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
203
|
+
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
204
|
+
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
205
|
+
elif uop is UOps.DEFINE_VAR:
|
206
|
+
bufs.append((args.expr, dtype))
|
207
|
+
r[u] = f"%{args.expr}"
|
208
|
+
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
209
|
+
elif uop is UOps.DEFINE_GLOBAL:
|
210
|
+
bufs.append((nm:=f"data{args[0]}", dtype))
|
211
|
+
r[u] = f"%{nm}"
|
212
|
+
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
213
|
+
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
214
|
+
elif uop is UOps.WMMA:
|
215
|
+
wmma = []
|
216
|
+
for vv in src[:2]:
|
217
|
+
for i in range(0, len(r[vv]), 2):
|
218
|
+
wmma.append(ssa("wmma", dtype="b32"))
|
219
|
+
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
220
|
+
r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
221
|
+
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
222
|
+
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[src[2]])}}};')
|
223
|
+
else: raise NotImplementedError(f"no code for {uop}")
|
224
|
+
|
225
|
+
return self.render_kernel(kernel, name, bufs, c.items())
|
226
|
+
|
227
|
+
ptx_matcher = PatternMatcher([
|
228
|
+
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
229
|
+
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
230
|
+
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
231
|
+
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
232
|
+
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
233
|
+
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(dtypes.int, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
234
|
+
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
235
|
+
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
236
|
+
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
237
|
+
(UPat(UOps.ALU, BinaryOps.ADD,
|
238
|
+
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
239
|
+
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
240
|
+
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
241
|
+
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
242
|
+
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
243
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
244
|
+
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
245
|
+
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
246
|
+
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
247
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
248
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
249
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
250
|
+
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
251
|
+
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
252
|
+
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
253
|
+
# ptr_ar (load/store)
|
254
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
255
|
+
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
256
|
+
lambda root, alu, const: UOp(root.op, root.dtype,
|
257
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
258
|
+
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
259
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
260
|
+
UPat(UOps.CONST, name="const"))),
|
261
|
+
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
262
|
+
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
263
|
+
)+root.src[2:])),
|
264
|
+
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
265
|
+
UPat(name="alu"))), # no const here
|
266
|
+
lambda root, alu: UOp(root.op, root.dtype,
|
267
|
+
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
268
|
+
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
269
|
+
])
|