tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +35 -37
- tinygrad/codegen/linearize.py +19 -10
- tinygrad/codegen/lowerer.py +31 -8
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +10 -0
- tinygrad/device.py +28 -11
- tinygrad/dtype.py +12 -3
- tinygrad/engine/jit.py +3 -2
- tinygrad/engine/multi.py +0 -1
- tinygrad/engine/realize.py +7 -4
- tinygrad/engine/schedule.py +227 -255
- tinygrad/engine/search.py +20 -27
- tinygrad/gradient.py +3 -0
- tinygrad/helpers.py +7 -4
- tinygrad/nn/state.py +2 -2
- tinygrad/ops.py +64 -329
- tinygrad/renderer/__init__.py +19 -3
- tinygrad/renderer/cstyle.py +39 -18
- tinygrad/renderer/llvmir.py +55 -18
- tinygrad/renderer/ptx.py +6 -2
- tinygrad/renderer/wgsl.py +20 -12
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/metal.py +28 -29
- tinygrad/runtime/ops_amd.py +37 -34
- tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
- tinygrad/runtime/ops_disk.py +1 -1
- tinygrad/runtime/ops_dsp.py +59 -33
- tinygrad/runtime/ops_llvm.py +14 -12
- tinygrad/runtime/ops_metal.py +78 -62
- tinygrad/runtime/ops_nv.py +9 -6
- tinygrad/runtime/ops_python.py +5 -5
- tinygrad/runtime/ops_webgpu.py +200 -38
- tinygrad/runtime/support/am/amdev.py +23 -11
- tinygrad/runtime/support/am/ip.py +10 -10
- tinygrad/runtime/support/elf.py +2 -0
- tinygrad/runtime/support/hcq.py +7 -5
- tinygrad/runtime/support/llvm.py +8 -14
- tinygrad/shape/shapetracker.py +3 -2
- tinygrad/shape/view.py +2 -3
- tinygrad/spec.py +21 -20
- tinygrad/tensor.py +150 -90
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- tinygrad/codegen/rewriter.py +0 -516
- tinygrad-0.10.1.dist-info/RECORD +0 -86
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
- {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from typing import cast, Optional, Callable
|
2
|
-
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
2
|
+
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
|
3
3
|
from collections import defaultdict
|
4
4
|
from dataclasses import replace
|
5
5
|
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
6
6
|
from tinygrad.device import Device, Buffer, Compiler
|
7
|
-
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored,
|
8
|
-
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
7
|
+
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
8
|
+
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
|
9
9
|
from tinygrad.dtype import ImageDType, PtrDType
|
10
10
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
11
11
|
from tinygrad.tensor import Tensor
|
@@ -103,9 +103,17 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
|
103
103
|
# get dictionary of all possible actions
|
104
104
|
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
105
105
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
106
|
-
|
106
|
+
kernel_actions = actions.copy()
|
107
|
+
|
108
|
+
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
109
|
+
for i, action in enumerate(kernel_actions):
|
110
|
+
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
111
|
+
# replace every tc_action with default tc with one tc_action for each available tc
|
112
|
+
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
113
|
+
|
114
|
+
for i,a in enumerate(kernel_actions):
|
107
115
|
if a.axis is not None and a.op is not OptOps.TC:
|
108
|
-
if ((ax:=
|
116
|
+
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
109
117
|
lin2 = lin.copy()
|
110
118
|
try:
|
111
119
|
lin2.apply_opt(a)
|
@@ -133,10 +141,12 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
|
|
133
141
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
134
142
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
135
143
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
144
|
+
@atexit.register
|
145
|
+
def close_pool(): beam_pool.close()
|
136
146
|
|
137
147
|
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
138
148
|
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
139
|
-
if DEBUG >= 2: print(f" 0.00s:
|
149
|
+
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
140
150
|
|
141
151
|
try:
|
142
152
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
@@ -159,21 +169,21 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
|
|
159
169
|
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
160
170
|
except RuntimeError: continue # for runtime issues
|
161
171
|
timed_lins.append((acted_lins[i], min(tms)))
|
162
|
-
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {compile_et
|
163
|
-
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]
|
172
|
+
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
173
|
+
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
164
174
|
|
165
175
|
# done
|
166
176
|
opts = sorted(timed_lins, key=lambda x: x[1])
|
167
177
|
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
168
178
|
if not exiting: beam = opts[:amt]
|
169
179
|
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
170
|
-
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(
|
180
|
+
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
171
181
|
except KeyboardInterrupt as e:
|
172
182
|
if beam_pool is not None: beam_pool.terminate()
|
173
183
|
raise e
|
174
184
|
|
175
185
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
176
|
-
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]
|
186
|
+
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
|
177
187
|
return beam[0][0]
|
178
188
|
|
179
189
|
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
@@ -187,20 +197,3 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
|
|
187
197
|
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
188
198
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
189
199
|
return ret[1]
|
190
|
-
|
191
|
-
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
192
|
-
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
193
|
-
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
194
|
-
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
195
|
-
|
196
|
-
dev = Device[lin.opts.device]
|
197
|
-
assert dev.compiler is not None
|
198
|
-
|
199
|
-
rawbufs = _ensure_buffer_alloc(rawbufs)
|
200
|
-
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
201
|
-
p = lin.to_program()
|
202
|
-
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
203
|
-
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
204
|
-
|
205
|
-
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
206
|
-
return min(tms)
|
tinygrad/gradient.py
CHANGED
@@ -22,6 +22,9 @@ pm_gradient = PatternMatcher([
|
|
22
22
|
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
23
23
|
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
24
24
|
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
25
|
+
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
|
26
|
+
(ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
|
27
|
+
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
|
25
28
|
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
26
29
|
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
27
30
|
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
tinygrad/helpers.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
|
3
3
|
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
|
@@ -28,6 +28,7 @@ def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstan
|
|
28
28
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
29
|
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
30
|
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
31
|
+
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
31
32
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
32
33
|
def ansilen(s:str): return len(ansistrip(s))
|
33
34
|
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
@@ -79,7 +80,7 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
|
|
79
80
|
@functools.lru_cache(maxsize=None)
|
80
81
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
81
82
|
def temp(x:str, append_user:bool=False) -> str:
|
82
|
-
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{
|
83
|
+
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
83
84
|
|
84
85
|
class Context(contextlib.ContextDecorator):
|
85
86
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
@@ -106,11 +107,12 @@ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), Cont
|
|
106
107
|
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
|
107
108
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
108
109
|
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
109
|
-
TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
|
110
|
+
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
|
110
111
|
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
111
112
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
112
113
|
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
113
|
-
CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
|
114
|
+
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
115
|
+
DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
|
114
116
|
|
115
117
|
@dataclass(frozen=True)
|
116
118
|
class Metadata:
|
@@ -276,6 +278,7 @@ def capstone_flatdump(lib: bytes):
|
|
276
278
|
case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
|
277
279
|
case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
|
278
280
|
case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
|
281
|
+
cs.skipdata = True
|
279
282
|
for instr in cs.disasm(lib, 0):
|
280
283
|
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
|
281
284
|
sys.stdout.flush()
|
tinygrad/nn/state.py
CHANGED
@@ -195,8 +195,8 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
|
195
195
|
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
196
196
|
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
197
197
|
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
198
|
-
if DEBUG >= 3: print(f"WARNING: this torch load is slow.
|
199
|
-
assert storage[1] != dtypes.bfloat16, "can't
|
198
|
+
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
199
|
+
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
200
200
|
# TODO: find a nice way to support all shapetracker on disktensors
|
201
201
|
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
202
202
|
|