tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py
CHANGED
@@ -1,26 +1,27 @@
|
|
1
|
-
from typing import
|
2
|
-
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
1
|
+
from typing import cast, Optional, Callable
|
2
|
+
import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
|
3
3
|
from collections import defaultdict
|
4
4
|
from dataclasses import replace
|
5
5
|
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
6
6
|
from tinygrad.device import Device, Buffer, Compiler
|
7
|
-
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored,
|
7
|
+
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
|
8
|
+
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
|
8
9
|
from tinygrad.dtype import ImageDType, PtrDType
|
9
10
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
10
11
|
from tinygrad.tensor import Tensor
|
11
12
|
from tinygrad.engine.realize import CompiledRunner
|
12
|
-
from tinygrad.renderer import
|
13
|
-
|
14
|
-
actions = [Opt(op=OptOps.UPCAST, axis=axis,
|
15
|
-
actions += [Opt(op=OptOps.UNROLL, axis=axis,
|
16
|
-
actions += [Opt(op=OptOps.LOCAL, axis=axis,
|
17
|
-
actions += [Opt(op=OptOps.GROUPTOP, axis=axis,
|
18
|
-
actions += [Opt(op=OptOps.GROUP, axis=axis,
|
19
|
-
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis,
|
20
|
-
actions += [Opt(op=OptOps.LOCAL, axis=0,
|
21
|
-
actions += [Opt(op=OptOps.
|
22
|
-
actions += [Opt(op=OptOps.TC, axis=axis,
|
23
|
-
actions += [Opt(op=OptOps.SWAP, axis=
|
13
|
+
from tinygrad.renderer import ProgramSpec
|
14
|
+
|
15
|
+
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
16
|
+
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
17
|
+
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
18
|
+
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
19
|
+
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
20
|
+
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
21
|
+
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
22
|
+
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
|
23
|
+
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
24
|
+
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
24
25
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
25
26
|
|
26
27
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
@@ -33,8 +34,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
|
33
34
|
break
|
34
35
|
return test_global_size, factor
|
35
36
|
|
36
|
-
def _time_program(p:
|
37
|
-
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") ->
|
37
|
+
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
|
38
|
+
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
38
39
|
factor = 1
|
39
40
|
if p.global_size is not None and max_global_size is not None:
|
40
41
|
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
@@ -45,9 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
|
45
46
|
input_bufs = [rawbufs[i] for i in car.p.globals]
|
46
47
|
for _ in range(cnt):
|
47
48
|
if clear_l2:
|
48
|
-
if hasattr(dev:=Device[p.
|
49
|
+
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
49
50
|
else:
|
50
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
51
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
51
52
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
52
53
|
if early_stop is not None and early_stop < min(tms): break
|
53
54
|
return tms
|
@@ -55,7 +56,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
|
55
56
|
class TimeoutException(Exception): pass
|
56
57
|
def timeout_handler(signum, frame): raise TimeoutException()
|
57
58
|
|
58
|
-
def _try_compile_linearized_w_idx(x:
|
59
|
+
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
|
59
60
|
if hasattr(signal, "alarm"):
|
60
61
|
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
61
62
|
# set timeout
|
@@ -80,16 +81,16 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tup
|
|
80
81
|
# workers should ignore ctrl c
|
81
82
|
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
82
83
|
|
83
|
-
def _ensure_buffer_alloc(bufs:
|
84
|
+
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
84
85
|
|
85
86
|
# *** external API ***
|
86
87
|
|
87
88
|
# get (scrap) buffers for timing the linearizer
|
88
|
-
def bufs_from_lin(lin:Kernel, allocate:bool=True) ->
|
89
|
-
bufsts:
|
89
|
+
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
90
|
+
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
90
91
|
for x in lin.bufs:
|
91
92
|
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
92
|
-
rawbufs:
|
93
|
+
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
|
93
94
|
for k,lx in bufsts.items():
|
94
95
|
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
95
96
|
assert isinstance(dtype, (PtrDType, ImageDType))
|
@@ -97,18 +98,26 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
|
97
98
|
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
98
99
|
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
99
100
|
assert all(r is not None for r in rawbufs)
|
100
|
-
return cast(
|
101
|
+
return cast(list[Buffer], rawbufs)
|
101
102
|
|
102
103
|
# get dictionary of all possible actions
|
103
|
-
def get_kernel_actions(lin:Kernel, include_0=True) ->
|
104
|
+
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
104
105
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
105
|
-
|
106
|
+
kernel_actions = actions.copy()
|
107
|
+
|
108
|
+
if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
|
109
|
+
for i, action in enumerate(kernel_actions):
|
110
|
+
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
|
111
|
+
# replace every tc_action with default tc with one tc_action for each available tc
|
112
|
+
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
|
113
|
+
|
114
|
+
for i,a in enumerate(kernel_actions):
|
106
115
|
if a.axis is not None and a.op is not OptOps.TC:
|
107
|
-
if ((ax:=
|
116
|
+
if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
|
108
117
|
lin2 = lin.copy()
|
109
118
|
try:
|
110
119
|
lin2.apply_opt(a)
|
111
|
-
up, lcl, tc_up = 1, 1, prod(tc.dims)//
|
120
|
+
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
112
121
|
for s,c in zip(lin2.full_shape, lin2.colors()):
|
113
122
|
if c in {"magenta", "yellow"}: up *= s
|
114
123
|
elif c in {"cyan", "green", "white"}: lcl *= s
|
@@ -117,8 +126,8 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
|
117
126
|
except KernelOptError: pass
|
118
127
|
return acted_lins
|
119
128
|
|
120
|
-
beam_pool, BEAM_DEBUG
|
121
|
-
def beam_search(lin:Kernel, rawbufs:
|
129
|
+
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
130
|
+
def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
|
122
131
|
global beam_pool
|
123
132
|
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
124
133
|
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
@@ -126,25 +135,27 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
126
135
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
127
136
|
return ret
|
128
137
|
|
129
|
-
beam:
|
138
|
+
beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
|
130
139
|
seen_libs = set()
|
131
140
|
|
132
141
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
133
142
|
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
134
143
|
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
144
|
+
@atexit.register
|
145
|
+
def close_pool(): beam_pool.close()
|
135
146
|
|
136
147
|
min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
|
137
148
|
if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
|
138
|
-
if DEBUG >= 2: print(f" 0.00s:
|
149
|
+
if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
|
139
150
|
|
140
151
|
try:
|
141
152
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
142
|
-
var_vals:
|
153
|
+
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
143
154
|
exiting, st = False, time.perf_counter()
|
144
155
|
dev = Device[lin.opts.device]
|
145
156
|
while not exiting:
|
146
|
-
acted_lins:
|
147
|
-
timed_lins:
|
157
|
+
acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
158
|
+
timed_lins: list[tuple[Kernel, float]] = []
|
148
159
|
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
149
160
|
least_compute_ops = math.inf
|
150
161
|
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
@@ -152,59 +163,37 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
152
163
|
p, lib, compile_et = proc
|
153
164
|
if lib in seen_libs: continue
|
154
165
|
# filter out kernels that use 1000x more compute than the smallest
|
155
|
-
least_compute_ops = min(this_compute_ops:=sym_infer(p.
|
166
|
+
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
156
167
|
if least_compute_ops*1000 < this_compute_ops: continue
|
157
|
-
if len(CAPTURE_BEAM) > 0:
|
158
|
-
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
|
159
168
|
seen_libs.add(lib)
|
160
169
|
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
161
|
-
except RuntimeError
|
162
|
-
if len(CAPTURE_BEAM) > 0:
|
163
|
-
with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
|
164
|
-
continue # for runtime issues
|
170
|
+
except RuntimeError: continue # for runtime issues
|
165
171
|
timed_lins.append((acted_lins[i], min(tms)))
|
166
|
-
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(
|
167
|
-
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]
|
172
|
+
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
173
|
+
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
168
174
|
|
169
175
|
# done
|
170
176
|
opts = sorted(timed_lins, key=lambda x: x[1])
|
171
177
|
exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
|
172
178
|
if not exiting: beam = opts[:amt]
|
173
179
|
elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
|
174
|
-
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(
|
180
|
+
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
|
175
181
|
except KeyboardInterrupt as e:
|
176
182
|
if beam_pool is not None: beam_pool.terminate()
|
177
183
|
raise e
|
178
184
|
|
179
185
|
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
180
|
-
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]
|
186
|
+
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
|
181
187
|
return beam[0][0]
|
182
188
|
|
183
|
-
def optimize_local_size(
|
189
|
+
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
184
190
|
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
185
191
|
MAX_WORKGROUP = 1024
|
186
192
|
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
187
193
|
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
188
194
|
def try_exec(local_size):
|
189
|
-
try: return
|
195
|
+
try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
190
196
|
except Exception: return float('inf')
|
191
197
|
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
192
198
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
193
199
|
return ret[1]
|
194
|
-
|
195
|
-
def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
196
|
-
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
197
|
-
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
198
|
-
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
199
|
-
|
200
|
-
dev = Device[lin.opts.device]
|
201
|
-
assert dev.compiler is not None
|
202
|
-
|
203
|
-
rawbufs = _ensure_buffer_alloc(rawbufs)
|
204
|
-
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
205
|
-
p = lin.to_program()
|
206
|
-
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
207
|
-
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
208
|
-
|
209
|
-
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
210
|
-
return min(tms)
|
tinygrad/gradient.py
ADDED
@@ -0,0 +1,73 @@
|
|
1
|
+
from typing import cast, Iterator
|
2
|
+
import math, functools, dataclasses
|
3
|
+
from tinygrad.dtype import dtypes, sum_acc_dtype
|
4
|
+
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
5
|
+
from tinygrad.helpers import argsort
|
6
|
+
|
7
|
+
def reduce_gradient(ctx:UOp, ret:UOp):
|
8
|
+
if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
|
9
|
+
if ret.arg[0] == Ops.MAX:
|
10
|
+
max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
|
11
|
+
div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
|
12
|
+
return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
|
13
|
+
if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
|
14
|
+
|
15
|
+
# ctx is grad_output
|
16
|
+
pm_gradient = PatternMatcher([
|
17
|
+
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
|
18
|
+
(UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
19
|
+
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
|
20
|
+
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
|
21
|
+
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
|
22
|
+
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
23
|
+
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
24
|
+
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
25
|
+
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
|
26
|
+
(ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
|
27
|
+
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
|
28
|
+
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
29
|
+
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
30
|
+
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
31
|
+
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
32
|
+
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
33
|
+
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
34
|
+
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
35
|
+
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
36
|
+
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
37
|
+
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
38
|
+
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
39
|
+
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
|
40
|
+
# TODO: this cast can be removed by putting the casts around the EXPAND
|
41
|
+
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
42
|
+
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
43
|
+
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
44
|
+
# there's no gradient for bitcast
|
45
|
+
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
46
|
+
])
|
47
|
+
|
48
|
+
# copied from tensor.py, get relevant toposort of gradients
|
49
|
+
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
50
|
+
@functools.lru_cache(None)
|
51
|
+
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
52
|
+
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
|
53
|
+
visited.add(node)
|
54
|
+
if node.op is Ops.DETACH: return
|
55
|
+
if is_in_target_path(node):
|
56
|
+
for i in node.src:
|
57
|
+
if i not in visited: yield from _walk(i, visited)
|
58
|
+
yield node
|
59
|
+
return list(_walk(root, set()))
|
60
|
+
|
61
|
+
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
62
|
+
grads = {root: root_grad}
|
63
|
+
for t0 in reversed(_deepwalk(root, targets)):
|
64
|
+
if t0 not in grads: continue
|
65
|
+
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
66
|
+
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
67
|
+
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
68
|
+
for k,v in zip(t0.src, lgrads):
|
69
|
+
if v is None: continue
|
70
|
+
if k in grads: grads[k] = grads[k] + v
|
71
|
+
else: grads[k] = v
|
72
|
+
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
|
73
|
+
return grads
|
tinygrad/helpers.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
|
3
3
|
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
5
|
+
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
|
6
6
|
|
7
7
|
T = TypeVar("T")
|
8
8
|
U = TypeVar("U")
|
@@ -23,76 +23,79 @@ def argfix(*x):
|
|
23
23
|
return tuple(x[0])
|
24
24
|
return x
|
25
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
26
|
-
def all_same(items:Union[
|
27
|
-
def all_int(t: Sequence[Any]) -> TypeGuard[
|
26
|
+
def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
|
27
|
+
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
28
28
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
29
|
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
30
|
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
31
|
+
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
31
32
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
32
33
|
def ansilen(s:str): return len(ansistrip(s))
|
33
|
-
def make_tuple(x:Union[int, Sequence[int]], cnt:int) ->
|
34
|
+
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
34
35
|
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
35
36
|
def fully_flatten(l):
|
36
37
|
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
38
|
+
if hasattr(l, "shape") and l.shape == (): return [l[()]]
|
37
39
|
flattened = []
|
38
|
-
|
39
|
-
else:
|
40
|
-
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
|
40
|
+
for li in l: flattened.extend(fully_flatten(li))
|
41
41
|
return flattened
|
42
42
|
return [l]
|
43
43
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
44
44
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
45
|
-
def ceildiv(num, amt):
|
46
|
-
ret = -(num//-amt)
|
47
|
-
return ret if not isinstance(ret, float) else int(ret)
|
45
|
+
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
48
46
|
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
49
|
-
def
|
50
|
-
def
|
51
|
-
def
|
47
|
+
def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
|
48
|
+
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
|
49
|
+
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
50
|
+
def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
51
|
+
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
|
52
|
+
def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
|
53
|
+
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
|
52
54
|
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
53
55
|
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
54
56
|
return {k:v for d in ds for k,v in d.items()}
|
55
|
-
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) ->
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
return a,b
|
57
|
+
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
|
58
|
+
ret:tuple[list[T], list[T]] = ([], [])
|
59
|
+
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
60
|
+
return ret
|
60
61
|
def unwrap(x:Optional[T]) -> T:
|
61
62
|
assert x is not None
|
62
63
|
return x
|
64
|
+
def get_single_element(x:list[T]) -> T:
|
65
|
+
assert len(x) == 1, f"list {x} must only have 1 element"
|
66
|
+
return x[0]
|
63
67
|
def get_child(obj, key):
|
64
68
|
for k in key.split('.'):
|
65
69
|
if k.isnumeric(): obj = obj[int(k)]
|
66
70
|
elif isinstance(obj, dict): obj = obj[k]
|
67
71
|
else: obj = getattr(obj, k)
|
68
72
|
return obj
|
69
|
-
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
73
|
+
def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
70
74
|
|
71
75
|
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
72
|
-
def polyN(x:T, p:
|
76
|
+
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
73
77
|
|
74
78
|
@functools.lru_cache(maxsize=None)
|
75
79
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
76
80
|
@functools.lru_cache(maxsize=None)
|
77
81
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
78
|
-
def temp(x:str) -> str:
|
82
|
+
def temp(x:str, append_user:bool=False) -> str:
|
83
|
+
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
79
84
|
|
80
85
|
class Context(contextlib.ContextDecorator):
|
81
|
-
stack: ClassVar[List[dict[str, int]]] = [{}]
|
82
86
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
83
87
|
def __enter__(self):
|
84
|
-
|
85
|
-
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
86
|
-
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
|
88
|
+
self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
|
89
|
+
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
87
90
|
def __exit__(self, *args):
|
88
|
-
for k in
|
91
|
+
for k,v in self.old_context.items(): ContextVar._cache[k].value = v
|
89
92
|
|
90
93
|
class ContextVar:
|
91
|
-
_cache: ClassVar[
|
94
|
+
_cache: ClassVar[dict[str, ContextVar]] = {}
|
92
95
|
value: int
|
93
96
|
key: str
|
94
97
|
def __init__(self, key, default_value):
|
95
|
-
|
98
|
+
if key in ContextVar._cache: raise RuntimeError(f"attempt to recreate ContextVar {key}")
|
96
99
|
ContextVar._cache[key] = self
|
97
100
|
self.value, self.key = getenv(key, default_value), key
|
98
101
|
def __bool__(self): return bool(self.value)
|
@@ -100,12 +103,16 @@ class ContextVar:
|
|
100
103
|
def __gt__(self, x): return self.value > x
|
101
104
|
def __lt__(self, x): return self.value < x
|
102
105
|
|
103
|
-
DEBUG, IMAGE, BEAM, NOOPT
|
106
|
+
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
107
|
+
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
|
104
108
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
105
|
-
|
106
|
-
|
107
|
-
FUSE_ARANGE, FUSE_CONV_BW
|
109
|
+
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
110
|
+
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
|
111
|
+
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
108
112
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
113
|
+
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
114
|
+
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
115
|
+
DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
|
109
116
|
|
110
117
|
@dataclass(frozen=True)
|
111
118
|
class Metadata:
|
@@ -160,11 +167,10 @@ class Profiling(contextlib.ContextDecorator):
|
|
160
167
|
|
161
168
|
# *** universal database cache ***
|
162
169
|
|
163
|
-
|
164
|
-
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(
|
165
|
-
CACHELEVEL = getenv("CACHELEVEL", 2)
|
170
|
+
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
171
|
+
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
166
172
|
|
167
|
-
VERSION =
|
173
|
+
VERSION = 19
|
168
174
|
_db_connection = None
|
169
175
|
def db_connection():
|
170
176
|
global _db_connection
|
@@ -182,8 +188,8 @@ def diskcache_clear():
|
|
182
188
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
183
189
|
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
184
190
|
|
185
|
-
def diskcache_get(table:str, key:Union[
|
186
|
-
if CACHELEVEL
|
191
|
+
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
192
|
+
if CACHELEVEL < 1: return None
|
187
193
|
if isinstance(key, (str,int)): key = {"key": key}
|
188
194
|
conn = db_connection()
|
189
195
|
cur = conn.cursor()
|
@@ -195,8 +201,8 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
|
195
201
|
return None
|
196
202
|
|
197
203
|
_db_tables = set()
|
198
|
-
def diskcache_put(table:str, key:Union[
|
199
|
-
if CACHELEVEL
|
204
|
+
def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
|
205
|
+
if CACHELEVEL < 1: return val
|
200
206
|
if isinstance(key, (str,int)): key = {"key": key}
|
201
207
|
conn = db_connection()
|
202
208
|
cur = conn.cursor()
|
@@ -205,7 +211,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|
205
211
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
206
212
|
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
207
213
|
_db_tables.add(table)
|
208
|
-
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
214
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
209
215
|
conn.commit()
|
210
216
|
cur.close()
|
211
217
|
return val
|
@@ -217,6 +223,10 @@ def diskcache(func):
|
|
217
223
|
return diskcache_put(table, key, func(*args, **kwargs))
|
218
224
|
return wrapper
|
219
225
|
|
226
|
+
# *** process replay ***
|
227
|
+
|
228
|
+
CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
|
229
|
+
|
220
230
|
# *** http support ***
|
221
231
|
|
222
232
|
def _ensure_downloads_dir() -> pathlib.Path:
|
@@ -228,28 +238,26 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
|
228
238
|
subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
|
229
239
|
subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
|
230
240
|
return downloads_dir
|
231
|
-
return pathlib.Path(
|
241
|
+
return pathlib.Path(cache_dir) / "downloads"
|
232
242
|
|
233
243
|
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
|
234
244
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
235
245
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
236
246
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
237
|
-
else:
|
238
|
-
fp = _ensure_downloads_dir() / (subdir or "") / \
|
239
|
-
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
247
|
+
else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
240
248
|
if not fp.is_file() or not allow_caching:
|
249
|
+
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
241
250
|
with urllib.request.urlopen(url, timeout=10) as r:
|
242
|
-
assert r.status == 200
|
251
|
+
assert r.status == 200, r.status
|
243
252
|
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
244
|
-
progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
245
|
-
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
246
253
|
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
247
|
-
|
254
|
+
progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
255
|
+
with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
|
248
256
|
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
249
257
|
f.close()
|
250
|
-
progress_bar.update(close=True)
|
251
|
-
if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
252
258
|
pathlib.Path(f.name).rename(fp)
|
259
|
+
progress_bar.update(close=True)
|
260
|
+
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
253
261
|
return fp
|
254
262
|
|
255
263
|
# *** Exec helpers
|
@@ -264,16 +272,28 @@ def cpu_objdump(lib, objdump_tool='objdump'):
|
|
264
272
|
pathlib.Path(f.name).write_bytes(lib)
|
265
273
|
print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
|
266
274
|
|
275
|
+
def capstone_flatdump(lib: bytes):
|
276
|
+
import capstone
|
277
|
+
match platform.machine():
|
278
|
+
case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
|
279
|
+
case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
|
280
|
+
case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
|
281
|
+
cs.skipdata = True
|
282
|
+
for instr in cs.disasm(lib, 0):
|
283
|
+
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
|
284
|
+
sys.stdout.flush()
|
285
|
+
|
267
286
|
# *** ctypes helpers
|
268
287
|
|
269
288
|
# TODO: make this work with read only memoryviews (if possible)
|
270
289
|
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
271
290
|
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
272
|
-
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
273
|
-
def mv_address(mv
|
274
|
-
def to_char_p_p(options:
|
291
|
+
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
292
|
+
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
293
|
+
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
|
294
|
+
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
275
295
|
@functools.lru_cache(maxsize=None)
|
276
|
-
def init_c_struct_t(fields:
|
296
|
+
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
|
277
297
|
class CStruct(ctypes.Structure):
|
278
298
|
_pack_, _fields_ = 1, fields
|
279
299
|
return CStruct
|
@@ -282,13 +302,15 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
|
|
282
302
|
|
283
303
|
# *** tqdm
|
284
304
|
|
285
|
-
class tqdm:
|
286
|
-
def __init__(self, iterable=None, desc:str='', disable:bool=False,
|
305
|
+
class tqdm(Generic[T]):
|
306
|
+
def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
|
307
|
+
unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
|
287
308
|
self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
|
288
309
|
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
289
310
|
self.set_description(desc)
|
290
311
|
self.update(0)
|
291
|
-
def __iter__(self):
|
312
|
+
def __iter__(self) -> Iterator[T]:
|
313
|
+
assert self.iterable is not None, "need an iterable to iterate"
|
292
314
|
for item in self.iterable:
|
293
315
|
yield item
|
294
316
|
self.update(1)
|