tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,45 @@
|
|
1
|
-
from typing import cast,
|
1
|
+
from typing import cast, Callable
|
2
2
|
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
|
3
3
|
from collections import defaultdict
|
4
4
|
from dataclasses import replace
|
5
|
-
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
5
|
+
from tinygrad.uop.ops import UOp, Ops, Variable, sym_infer, AxisType
|
6
6
|
from tinygrad.device import Device, Buffer, Compiler
|
7
7
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
8
8
|
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
|
9
9
|
from tinygrad.dtype import ImageDType, PtrDType
|
10
|
-
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
10
|
+
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps, KernelOptError
|
11
11
|
from tinygrad.tensor import Tensor
|
12
|
-
from tinygrad.engine.realize import CompiledRunner
|
12
|
+
from tinygrad.engine.realize import CompiledRunner, get_program
|
13
13
|
from tinygrad.renderer import ProgramSpec
|
14
14
|
|
15
|
-
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(
|
15
|
+
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(8)]
|
16
16
|
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
17
17
|
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
18
18
|
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
19
19
|
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
20
20
|
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
21
21
|
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
22
|
-
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
|
23
|
-
|
22
|
+
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0, getenv("TC", 1)))]
|
23
|
+
# covers resnet kernels (3 global * 3 reduce)
|
24
|
+
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2), getenv("TC", 1))) for axis in range(9)]
|
24
25
|
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
25
26
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
26
27
|
|
27
|
-
def
|
28
|
-
test_global_size
|
28
|
+
def get_test_global_size(global_size, max_global_size, var_vals):
|
29
|
+
test_global_size = [sym_infer(sz, var_vals) for sz in global_size]
|
30
|
+
input_size = prod(test_global_size)
|
29
31
|
while prod(test_global_size) > max_global_size:
|
30
32
|
for j in range(len(global_size)-1,-1,-1):
|
31
33
|
if test_global_size[j] > 16:
|
32
34
|
test_global_size[j] //= 2
|
33
|
-
factor *= 2
|
34
35
|
break
|
35
|
-
return test_global_size,
|
36
|
+
return test_global_size, input_size / prod(test_global_size)
|
36
37
|
|
37
|
-
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:
|
38
|
-
max_global_size:
|
38
|
+
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:float|None=None,
|
39
|
+
allow_test_size:int=True, max_global_size:int|None=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
39
40
|
factor = 1
|
40
|
-
if p.global_size is not None and max_global_size is not None:
|
41
|
-
global_size, factor =
|
41
|
+
if allow_test_size and p.global_size is not None and max_global_size is not None:
|
42
|
+
global_size, factor = get_test_global_size(p.global_size, max_global_size, var_vals)
|
42
43
|
p = replace(p, global_size=global_size)
|
43
44
|
try: car = CompiledRunner(p, precompiled=lib)
|
44
45
|
except AssertionError: return [math.inf] * cnt
|
@@ -56,16 +57,18 @@ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbuf
|
|
56
57
|
class TimeoutException(Exception): pass
|
57
58
|
def timeout_handler(signum, frame): raise TimeoutException()
|
58
59
|
|
59
|
-
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int,
|
60
|
+
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, tuple[ProgramSpec, bytes, float]|None]:
|
60
61
|
if hasattr(signal, "alarm"):
|
61
62
|
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
62
63
|
# set timeout
|
63
64
|
signal.alarm(getenv("BEAM_TIMEOUT_SEC", 10))
|
64
65
|
ret = None
|
65
66
|
try:
|
66
|
-
p = x[1].
|
67
|
+
p = get_program(x[1].copy().get_optimized_ast(name_override="test"), x[1].opts)
|
67
68
|
assert p.uops is not None, "uop list wasn't generated?"
|
68
|
-
if len(p.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0:
|
69
|
+
if len(p.uops) >= (uops_max:=getenv("BEAM_UOPS_MAX", 3000)) > 0:
|
70
|
+
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many uops. {len(p.uops)=}, {uops_max=}")
|
71
|
+
raise RuntimeError("too many uops")
|
69
72
|
st = time.perf_counter()
|
70
73
|
prog = compiler.compile(p.src)
|
71
74
|
et = time.perf_counter() - st
|
@@ -78,10 +81,12 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
|
|
78
81
|
if hasattr(signal, "alarm"): signal.alarm(0)
|
79
82
|
return x[0], ret
|
80
83
|
|
81
|
-
# workers should ignore ctrl c
|
82
|
-
def _init_worker():
|
84
|
+
# workers should not open devices and should ignore ctrl c and should not launch VIZ
|
85
|
+
def _init_worker():
|
86
|
+
Context(ALLOW_DEVICE_USAGE=0, VIZ=0, TRACK_MATCH_STATS=0).__enter__()
|
87
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
83
88
|
|
84
|
-
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
89
|
+
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
|
85
90
|
|
86
91
|
# *** external API ***
|
87
92
|
|
@@ -89,39 +94,46 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
|
|
89
94
|
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
90
95
|
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
91
96
|
for x in lin.bufs:
|
92
|
-
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
93
|
-
|
97
|
+
if x.src[0].base.op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].base.arg].append(x)
|
98
|
+
# TODO: Nones are staying in here if buffers are optimized out!
|
99
|
+
# TODO: add a test for this
|
100
|
+
rawbufs: list[Buffer|None] = [None]*(max(bufsts)+1)
|
94
101
|
for k,lx in bufsts.items():
|
95
102
|
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
96
103
|
assert isinstance(dtype, (PtrDType, ImageDType))
|
97
104
|
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
98
105
|
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
99
106
|
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
100
|
-
assert all(r is not None for r in rawbufs)
|
107
|
+
#assert all(r is not None for r in rawbufs)
|
101
108
|
return cast(list[Buffer], rawbufs)
|
102
109
|
|
103
110
|
# get dictionary of all possible actions
|
104
|
-
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
111
|
+
def get_kernel_actions(lin:Kernel, include_0=True, candidates:list[Opt]|None=None) -> dict[int, Kernel]:
|
105
112
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
106
|
-
kernel_actions = actions.copy()
|
113
|
+
kernel_actions = (actions if candidates is None else candidates).copy()
|
107
114
|
|
108
115
|
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
109
116
|
for i, action in enumerate(kernel_actions):
|
110
117
|
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
111
118
|
# replace every tc_action with default tc with one tc_action for each available tc
|
112
|
-
kernel_actions[i:i+1] =
|
119
|
+
kernel_actions[i:i+1] = \
|
120
|
+
[Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1], tc_arg[2])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
113
121
|
|
114
122
|
for i,a in enumerate(kernel_actions):
|
115
123
|
if a.axis is not None and a.op is not OptOps.TC:
|
116
|
-
|
124
|
+
try: ax = lin.real_axis(a.op, a.axis)
|
125
|
+
except KernelOptError: continue
|
126
|
+
if (ax >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue
|
117
127
|
lin2 = lin.copy()
|
118
128
|
try:
|
119
129
|
lin2.apply_opt(a)
|
120
130
|
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
121
|
-
for s,c in zip(lin2.full_shape, lin2.
|
122
|
-
if c in
|
123
|
-
elif c in
|
124
|
-
if up//tc_up > max_up or lcl > max_lcl:
|
131
|
+
for s,c in zip(lin2.full_shape, lin2.axis_types):
|
132
|
+
if c in (AxisType.UPCAST, AxisType.UNROLL): up *= s
|
133
|
+
elif c in (AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= s
|
134
|
+
if up//tc_up > max_up or lcl > max_lcl:
|
135
|
+
if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
|
136
|
+
continue
|
125
137
|
acted_lins[i+1] = lin2
|
126
138
|
except KernelOptError: pass
|
127
139
|
return acted_lins
|
@@ -138,7 +150,7 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
|
|
138
150
|
beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
|
139
151
|
seen_libs = set()
|
140
152
|
|
141
|
-
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
153
|
+
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL", "HIP"} else 0
|
142
154
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
143
155
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
144
156
|
@atexit.register
|
@@ -166,8 +178,12 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
|
|
166
178
|
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
167
179
|
if least_compute_ops*1000 < this_compute_ops: continue
|
168
180
|
seen_libs.add(lib)
|
169
|
-
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
|
170
|
-
|
181
|
+
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0,
|
182
|
+
allow_test_size=allow_test_size, clear_l2=hasattr(dev, 'invalidate_caches'))
|
183
|
+
except Exception as e:
|
184
|
+
if BEAM_DEBUG: print(f"BEAM failed for opts: {acted_lins[i].applied_opts}\n{e}")
|
185
|
+
if isinstance(e, RuntimeError): continue
|
186
|
+
raise
|
171
187
|
timed_lins.append((acted_lins[i], min(tms)))
|
172
188
|
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
173
189
|
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, resolve, sint
|
2
|
+
from tinygrad.helpers import all_same, prod, unwrap, colored
|
3
|
+
from tinygrad.shape.shapetracker import ShapeTracker
|
4
|
+
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
5
|
+
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
6
|
+
from tinygrad.dtype import ImageDType, dtypes
|
7
|
+
|
8
|
+
merge_views = PatternMatcher([
|
9
|
+
# merge adjacent views
|
10
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="v1"),), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
11
|
+
# replace MovementOps with VIEW
|
12
|
+
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.base.view(mop.st)),
|
13
|
+
# remove NOOP views
|
14
|
+
(UPat.var("x").view(name="view"),
|
15
|
+
lambda x,view: x if x.st is not None and x.op not in GroupOp.Defines and view.st.contiguous and view.shape == x.shape else None),
|
16
|
+
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}).view(name="view"),
|
17
|
+
lambda view: view.const_like(0) if (mask:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in mask) else None),
|
18
|
+
# only unmaksed VIEW on CONST replaces the ShapeTracker
|
19
|
+
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"),
|
20
|
+
lambda x,view: x.replace(src=(x.src[0].replace(arg=x.st+view.st),)) if all(v.mask is None for v in (x.st+view.st).views) else None),
|
21
|
+
])
|
22
|
+
|
23
|
+
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
24
|
+
# contiguous, expand, and the same with ones removed
|
25
|
+
if unwrap(view.st).contiguous and len(r.shape) < len(view.shape) and \
|
26
|
+
tuple(x for x in r.shape if resolve(x != 1)) == tuple(x for x in view.shape if resolve(x != 1)):
|
27
|
+
new_shape: list[sint] = []
|
28
|
+
new_reduce_axis = []
|
29
|
+
if (contraction:=get_contraction_with_reduce(view.shape, r.shape, r.arg[1])) is None: return None
|
30
|
+
for i,pairs in enumerate(contraction):
|
31
|
+
new_shape_chunk = [view.shape[p] for p in pairs]
|
32
|
+
if i in r.arg[1]:
|
33
|
+
# if this is a reduce axis, we need a 1 in the view here to put it
|
34
|
+
assert len(new_shape_chunk) > 0
|
35
|
+
new_shape += [1]*(len(pairs)-1) + [src.shape[i]]
|
36
|
+
new_reduce_axis.append(len(new_shape)-1)
|
37
|
+
else:
|
38
|
+
# otherwise, pass through the new_shape_chunk
|
39
|
+
new_shape += new_shape_chunk
|
40
|
+
ret = r.replace(src=(src.reshape(tuple(new_shape)),), arg=(r.arg[0], tuple(new_reduce_axis))+r.arg[2:])
|
41
|
+
assert ret.shape == view.shape, f"shape mismatch on reduce_push_add_ones, {ret.shape} != {view.shape}"
|
42
|
+
return ret
|
43
|
+
return None
|
44
|
+
|
45
|
+
view_left = merge_views+PatternMatcher([
|
46
|
+
# view before elementwise and buffer ops
|
47
|
+
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
48
|
+
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
49
|
+
# if there's ones added after reduce, put this before the reduce
|
50
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
51
|
+
])
|
52
|
+
|
53
|
+
view_left_through_load = PatternMatcher([
|
54
|
+
# view before load
|
55
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
|
56
|
+
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
57
|
+
])
|
58
|
+
|
59
|
+
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
|
60
|
+
|
61
|
+
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
62
|
+
def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False):
|
63
|
+
# contiguous and same size can push to children
|
64
|
+
# if there's a reduce child, shapes match with ones removed
|
65
|
+
if unwrap(view.st).contiguous and view.size == r.size and \
|
66
|
+
(not (len(r.arg) == 3 and r.arg[2]) or # arg[2] = True is fuse marker
|
67
|
+
tuple((i,x) for i,x in enumerate(r.shape) if resolve(x != 1)) == tuple((i,x) for i,x in enumerate(view.shape) if resolve(x != 1))):
|
68
|
+
return None
|
69
|
+
# swizzle the input
|
70
|
+
input_st = ShapeTracker.from_shape(src.shape)
|
71
|
+
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
72
|
+
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
73
|
+
strides = strides_for_shape(rshape)
|
74
|
+
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
75
|
+
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
76
|
+
new_view = tmp + ShapeTracker(tuple(nv))
|
77
|
+
swizzled_input = apply_swizzle(src.view(new_view))
|
78
|
+
# create a new reduceop
|
79
|
+
new_axis = tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))
|
80
|
+
if fuse: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input.fuse(),), (r.arg[0], new_axis, True))
|
81
|
+
else: red = UOp(Ops.REDUCE_AXIS, r.dtype, (swizzled_input,), (r.arg[0], new_axis))
|
82
|
+
return red.reshape(view.shape)
|
83
|
+
|
84
|
+
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
85
|
+
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
86
|
+
new_axis = [i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u]
|
87
|
+
return src.r(r.arg[0], tuple(new_axis)).reshape(r.shape)
|
88
|
+
|
89
|
+
def elementwise_view_right(root:UOp):
|
90
|
+
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in ALWAYS_CONTIGUOUS]): return None
|
91
|
+
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
92
|
+
# place view after applying the elementwise op
|
93
|
+
new_st = ShapeTracker.from_shape(swizzles[0].base.shape)
|
94
|
+
new_src = [x.base if x.base.shape==new_st.shape else apply_swizzle(x.view(new_st)) for x in root.src]
|
95
|
+
# reshape to match downstream shapes
|
96
|
+
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
97
|
+
|
98
|
+
# push VIEW to children
|
99
|
+
view_right = merge_views+PatternMatcher([
|
100
|
+
# push a non contiguous ShapeTracker through reduceop
|
101
|
+
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
102
|
+
# apply view after reduceops
|
103
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
104
|
+
# apply view after elementwise ops
|
105
|
+
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
106
|
+
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
107
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
108
|
+
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
109
|
+
# remove view from sink
|
110
|
+
(UPat(Ops.VIEW, name="v").sink(name="sink"), lambda v,sink: v.src[0].sink(arg=sink.arg)),
|
111
|
+
])
|
112
|
+
|
113
|
+
def check_load_st(glbl:UOp, view:UOp):
|
114
|
+
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
115
|
+
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
116
|
+
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
|
117
|
+
# if it has a single view and it's equal when you shrink a contig, it's fine
|
118
|
+
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
|
119
|
+
# otherwise, it's not fine
|
120
|
+
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
121
|
+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
122
|
+
|
123
|
+
fix_kernel_ops = view_left_through_load+PatternMatcher([
|
124
|
+
# add view to LOAD and STORE
|
125
|
+
(UPat(Ops.DEFINE_GLOBAL, name="g").load(), lambda g: g.view(g.st).load()),
|
126
|
+
(UPat(Ops.DEFINE_GLOBAL, name="g").store(UPat.var('x')), lambda g,x: g.view(g.st).store(x)),
|
127
|
+
# VALID
|
128
|
+
(UPat(Ops.VIEW, src=(UPat.cvar(),), name="self"),
|
129
|
+
lambda self: UOp.where(UOp(Ops.VALID, dtypes.bool, (UOp(Ops.VIEW, arg=self.st),)), self.const_like(self.base.arg), 0)),
|
130
|
+
# no ImageDType after index
|
131
|
+
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL, Ops.VIEW}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
132
|
+
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
133
|
+
(UPat(Ops.LOAD, src=(UPat.var("glbl").view(name="view"),)), check_load_st),
|
134
|
+
])
|
@@ -0,0 +1,127 @@
|
|
1
|
+
import math, functools
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from tinygrad.dtype import DType, dtypes
|
4
|
+
from tinygrad.helpers import getenv
|
5
|
+
|
6
|
+
@dataclass(frozen=True)
|
7
|
+
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
8
|
+
dims: tuple[int,int,int] # N, M, K
|
9
|
+
threads: int # number of threads that construct the warp
|
10
|
+
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
11
|
+
dtype_in: DType # dtype for A and B
|
12
|
+
dtype_out: DType # dtype for C and D
|
13
|
+
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
14
|
+
# (local_swizzle, upcast_swizzle, reduce_swizzle)
|
15
|
+
# l<num> is the num axis of the locals, similar for u<num> and upcasts, r<num> and reduces
|
16
|
+
swizzle: tuple[tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]], tuple[tuple[str, ...], tuple[str, ...], tuple[str, ...]]]
|
17
|
+
@functools.cache # pylint: disable=method-cache-max-size-none
|
18
|
+
def _remaps(self) -> list[dict[str, str]]:
|
19
|
+
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
20
|
+
fwd_st = [f"l{i}" for i in range(local_axes)] + [f"u{i}" for i in range(upcast_axes)] + [f"r{i}" for i in range(reduce_axes)]
|
21
|
+
return [dict(zip(fwd_st, sum(s, ()))) for s in self.swizzle]
|
22
|
+
def permutes_for_shape_str(self, shape_str:list[str]) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
23
|
+
ret = [[shape_str.index(remap[ss]) if ss in remap else i for i,ss in enumerate(shape_str)] for remap in self._remaps()]
|
24
|
+
return tuple(ret[0]), tuple(ret[1])
|
25
|
+
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
26
|
+
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
27
|
+
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
28
|
+
def base_upcast_axes(self):
|
29
|
+
# this is defined in the swizzle. first we use the upcast axes, then the reduce
|
30
|
+
return ([f"r{i}" for i in range(len(self.get_reduce_axes()))] + [f"u{i}" for i in range(len(self.get_upcast_axes()))])[::-1]
|
31
|
+
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
32
|
+
def __post_init__(self):
|
33
|
+
# all axes have size 2, <local> <reduce> <upcast> is the order
|
34
|
+
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
35
|
+
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), \
|
36
|
+
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})"
|
37
|
+
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
38
|
+
assert 2**upcast_axes == self.elements_per_thread[2], \
|
39
|
+
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
|
40
|
+
# check dims match opts
|
41
|
+
assert self.dims[0] == 2**len(gd:=[x for x in self.opts if x[1] == '0']), f"opts wrong on dims[0], {self.dims[0]} vs {gd}"
|
42
|
+
assert self.dims[1] == 2**len(gd:=[x for x in self.opts if x[1] == '1']), f"opts wrong on dims[1], {self.dims[1]} vs {gd}"
|
43
|
+
# NOTE: the K opts is implictly set by the dim
|
44
|
+
# check swizzle
|
45
|
+
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
|
46
|
+
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
|
47
|
+
assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == upcast_axes, "upcast swizzle size is wrong"
|
48
|
+
assert len(self.swizzle[0][2]) == len(self.swizzle[1][2]) == reduce_axes, "reduce swizzle size is wrong"
|
49
|
+
assert all(len(s) == local_axes+upcast_axes+reduce_axes for s in self._remaps()), "remaps are the wrong size"
|
50
|
+
# check elements_per_thread
|
51
|
+
un, ln = 0, 0
|
52
|
+
zero_stride_0 = []
|
53
|
+
zero_stride_1 = []
|
54
|
+
for o in self.opts:
|
55
|
+
if o[1] == '0': zero_stride_0.append(o[0] + str(un if o[0] == 'u' else ln))
|
56
|
+
if o[1] == '1': zero_stride_1.append(o[0] + str(un if o[0] == 'u' else ln))
|
57
|
+
if o[0] == 'u': un += 1
|
58
|
+
if o[0] == 'l': ln += 1
|
59
|
+
# NOTE: all the zero_stride dims can be placed in any order in the swizzle
|
60
|
+
upcasted_0 = [x for x in (self.swizzle[0][1] + self.swizzle[0][2]) if x not in zero_stride_0 and x[0] != 'l']
|
61
|
+
upcasted_1 = [x for x in (self.swizzle[1][1] + self.swizzle[1][2]) if x not in zero_stride_1 and x[0] != 'l']
|
62
|
+
assert 2**len(upcasted_0) == self.elements_per_thread[0], f"mismatch in elements_per_thread[0], {upcasted_0} vs {self.elements_per_thread[0]}"
|
63
|
+
assert 2**len(upcasted_1) == self.elements_per_thread[1], f"mismatch in elements_per_thread[1], {upcasted_1} vs {self.elements_per_thread[1]}"
|
64
|
+
|
65
|
+
# ***** NVIDIA *****
|
66
|
+
|
67
|
+
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
68
|
+
|
69
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
70
|
+
cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
71
|
+
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
|
72
|
+
(('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
|
73
|
+
for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
|
74
|
+
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
75
|
+
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
|
76
|
+
(('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
|
77
|
+
for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
78
|
+
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
79
|
+
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
|
80
|
+
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
|
81
|
+
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
|
82
|
+
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
|
83
|
+
cuda_sm75: list[TensorCore] = cuda_8168_f16
|
84
|
+
|
85
|
+
# ***** AMD *****
|
86
|
+
|
87
|
+
# https://gpuopen.com/learn/wmma_on_rdna3/
|
88
|
+
amd_rdna3 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
89
|
+
opts=("l0","l0","l0","l0","l1","u1","u1","u1"),
|
90
|
+
swizzle=((('l4', 'u0', 'u1', 'u2', 'l0'), ('r1', 'r2', 'r3'), ('l1', 'l2', 'l3', 'r0')),
|
91
|
+
(('l0', 'l1', 'l2', 'l3', 'l4'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))
|
92
|
+
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
|
93
|
+
amd_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
|
94
|
+
opts=("l0","l0","l0","l0","u1","u1","u1","l1"),
|
95
|
+
swizzle=((('u0', 'u1', 'u2', 'l4', 'r2'), ('r0', 'r1', 'r3'), ('l0', 'l1', 'l2', 'l3')),
|
96
|
+
(('l0', 'l1', 'l2', 'l3', 'r2'), ('r0', 'r1', 'r3'), ('l4', 'u0', 'u1', 'u2'))))
|
97
|
+
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
98
|
+
|
99
|
+
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
|
100
|
+
amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
|
101
|
+
opts=("l0","l0","l0","l0","u1","u1","l1","l1"),
|
102
|
+
swizzle=((('u0', 'u1', 'l4', 'l5', 'r2', 'r3'), ('r0', 'r1'), ('l0', 'l1', 'l2', 'l3')),
|
103
|
+
(('l0', 'l1', 'l2', 'l3', 'r2', 'r3'), ('r0', 'r1'), ('l4', 'l5', 'u0', 'u1'))))
|
104
|
+
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
105
|
+
|
106
|
+
# ***** Apple Metal *****
|
107
|
+
|
108
|
+
metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do,
|
109
|
+
opts=("u0","l0","l1","l1","l0","l1"),
|
110
|
+
swizzle=((('r1', 'l1', 'l2', 'r2', 'l4'), ('r0',), ('u0', 'l0', 'l3')),
|
111
|
+
(('l0', 'r0', 'r1', 'l3', 'r2'), ('u0',), ('l1', 'l2', 'l4'))))
|
112
|
+
for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
113
|
+
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
114
|
+
|
115
|
+
# ***** Apple AMX *****
|
116
|
+
|
117
|
+
amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
118
|
+
swizzle=(((), ('u0', 'u1', 'u2', 'u3', 'u4', 'u5', 'u6', 'u7'), ()),
|
119
|
+
((), ('u4', 'u5', 'u6', 'u7', 'u0', 'u1', 'u2', 'u3'), ())),
|
120
|
+
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
121
|
+
|
122
|
+
# ***** Intel ****
|
123
|
+
|
124
|
+
intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
125
|
+
opts=("l0","l0","l0","u1","u1","u1"),
|
126
|
+
swizzle=((('r1', 'r2', 'r3'), ('u0', 'u1', 'u2'), ('l0', 'l1', 'l2', 'r0')),
|
127
|
+
(('l0', 'l1', 'l2'), ('r1', 'r2', 'r3'), ('u0', 'u1', 'u2', 'r0'))))]
|
@@ -0,0 +1,67 @@
|
|
1
|
+
from tinygrad.dtype import dtypes, least_upper_dtype
|
2
|
+
from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
3
|
+
from tinygrad.uop.symbolic import symbolic
|
4
|
+
|
5
|
+
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
6
|
+
# this is badly tested and low quality. remove it?
|
7
|
+
|
8
|
+
FP = (1 << 15)
|
9
|
+
pm_quant = symbolic+PatternMatcher([
|
10
|
+
# cast after add/mul
|
11
|
+
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
12
|
+
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
13
|
+
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
14
|
+
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
15
|
+
|
16
|
+
# masked MUL after masked ADD
|
17
|
+
((UPat.var("x") + UPat.var("v").where(UPat.var('cadd'), UPat(Ops.CONST, arg=0))) * UPat.var("v").where(UPat.var('cmul'), UPat(Ops.CONST, arg=0)),
|
18
|
+
lambda x,v,cadd,cmul: x*v.where(cmul, 0)+v.where(cadd*cmul, 0)),
|
19
|
+
|
20
|
+
# MUL after reduce
|
21
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c.arg),
|
22
|
+
# CAST after reduce (doesn't work if it's a size change)
|
23
|
+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
24
|
+
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
25
|
+
|
26
|
+
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
27
|
+
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
28
|
+
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
29
|
+
# mul 0 * c1 is 0
|
30
|
+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
31
|
+
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
32
|
+
# mul (with plus) 0 * c1 is 0
|
33
|
+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
34
|
+
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
35
|
+
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
36
|
+
lambda ld,v,c1: ld*c1),
|
37
|
+
|
38
|
+
# const push through add
|
39
|
+
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
40
|
+
|
41
|
+
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
42
|
+
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("cc")).cast(dtypes.int),
|
43
|
+
lambda x,c1,cc: ((x*(c1*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
44
|
+
# fixed point mult, replace (x.float()*c1 + y.float()*c2)*cc.int() with an int expression
|
45
|
+
((UPat.var("x").cast(dtypes.float)*UPat.var("c1")+UPat.var("y").cast(dtypes.float)*UPat.var("c2")+UPat.var("cc")).cast(dtypes.int),
|
46
|
+
lambda x,c1,y,c2,cc: ((x*(c1*FP).cast(x.dtype) + y.cast(x.dtype)*(c2*FP).cast(x.dtype) + (cc*FP).cast(x.dtype)) // FP).cast(dtypes.int)),
|
47
|
+
|
48
|
+
# where move
|
49
|
+
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
50
|
+
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
51
|
+
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
|
52
|
+
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
|
53
|
+
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
|
54
|
+
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
55
|
+
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
56
|
+
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
57
|
+
|
58
|
+
# where on two adds
|
59
|
+
(UPat.var("x") + UPat.var("v").where(UPat.var("a0"), UPat.var("a1")) + UPat.var("v").where(UPat.var("b0"), UPat.var("b1")),
|
60
|
+
lambda x,v,a0,a1,b0,b1: x + v.where(a0+b0, a1+b1)),
|
61
|
+
|
62
|
+
# split REDUCE into multiple reduces (who remembers FOIL?)
|
63
|
+
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * UPat(Ops.CAST, name="v2"),), name="r"),
|
64
|
+
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
65
|
+
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
66
|
+
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
67
|
+
])
|