tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py
CHANGED
@@ -1,26 +1,27 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import cast, Optional, Callable
|
2
2
|
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
3
3
|
from collections import defaultdict
|
4
4
|
from dataclasses import replace
|
5
5
|
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
6
6
|
from tinygrad.device import Device, Buffer, Compiler
|
7
7
|
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
8
|
+
from tinygrad.helpers import IGNORE_BEAM_CACHE
|
8
9
|
from tinygrad.dtype import ImageDType, PtrDType
|
9
10
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
10
11
|
from tinygrad.tensor import Tensor
|
11
12
|
from tinygrad.engine.realize import CompiledRunner
|
12
|
-
from tinygrad.renderer import
|
13
|
-
|
14
|
-
actions = [Opt(op=OptOps.UPCAST, axis=axis,
|
15
|
-
actions += [Opt(op=OptOps.UNROLL, axis=axis,
|
16
|
-
actions += [Opt(op=OptOps.LOCAL, axis=axis,
|
17
|
-
actions += [Opt(op=OptOps.GROUPTOP, axis=axis,
|
18
|
-
actions += [Opt(op=OptOps.GROUP, axis=axis,
|
19
|
-
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis,
|
20
|
-
actions += [Opt(op=OptOps.LOCAL, axis=0,
|
21
|
-
actions += [Opt(op=OptOps.
|
22
|
-
actions += [Opt(op=OptOps.TC, axis=axis,
|
23
|
-
actions += [Opt(op=OptOps.SWAP, axis=
|
13
|
+
from tinygrad.renderer import ProgramSpec
|
14
|
+
|
15
|
+
actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
16
|
+
actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
|
17
|
+
actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
|
18
|
+
actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
|
19
|
+
actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
|
20
|
+
if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
|
21
|
+
actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
|
22
|
+
actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
|
23
|
+
actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
|
24
|
+
actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
|
24
25
|
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
25
26
|
|
26
27
|
def _get_test_global_size(global_size, max_global_size, var_vals):
|
@@ -33,8 +34,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
|
33
34
|
break
|
34
35
|
return test_global_size, factor
|
35
36
|
|
36
|
-
def _time_program(p:
|
37
|
-
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") ->
|
37
|
+
def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
|
38
|
+
max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
|
38
39
|
factor = 1
|
39
40
|
if p.global_size is not None and max_global_size is not None:
|
40
41
|
global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
|
@@ -45,9 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
|
45
46
|
input_bufs = [rawbufs[i] for i in car.p.globals]
|
46
47
|
for _ in range(cnt):
|
47
48
|
if clear_l2:
|
48
|
-
if hasattr(dev:=Device[p.
|
49
|
+
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
49
50
|
else:
|
50
|
-
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
51
|
+
with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
51
52
|
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
52
53
|
if early_stop is not None and early_stop < min(tms): break
|
53
54
|
return tms
|
@@ -55,7 +56,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
|
55
56
|
class TimeoutException(Exception): pass
|
56
57
|
def timeout_handler(signum, frame): raise TimeoutException()
|
57
58
|
|
58
|
-
def _try_compile_linearized_w_idx(x:
|
59
|
+
def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
|
59
60
|
if hasattr(signal, "alarm"):
|
60
61
|
signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
|
61
62
|
# set timeout
|
@@ -80,16 +81,16 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tup
|
|
80
81
|
# workers should ignore ctrl c
|
81
82
|
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
82
83
|
|
83
|
-
def _ensure_buffer_alloc(bufs:
|
84
|
+
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
|
84
85
|
|
85
86
|
# *** external API ***
|
86
87
|
|
87
88
|
# get (scrap) buffers for timing the linearizer
|
88
|
-
def bufs_from_lin(lin:Kernel, allocate:bool=True) ->
|
89
|
-
bufsts:
|
89
|
+
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
|
90
|
+
bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
|
90
91
|
for x in lin.bufs:
|
91
92
|
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
92
|
-
rawbufs:
|
93
|
+
rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
|
93
94
|
for k,lx in bufsts.items():
|
94
95
|
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
95
96
|
assert isinstance(dtype, (PtrDType, ImageDType))
|
@@ -97,18 +98,18 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
|
97
98
|
buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
|
98
99
|
rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
|
99
100
|
assert all(r is not None for r in rawbufs)
|
100
|
-
return cast(
|
101
|
+
return cast(list[Buffer], rawbufs)
|
101
102
|
|
102
103
|
# get dictionary of all possible actions
|
103
|
-
def get_kernel_actions(lin:Kernel, include_0=True) ->
|
104
|
+
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
104
105
|
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
105
106
|
for i,a in enumerate(actions):
|
106
107
|
if a.axis is not None and a.op is not OptOps.TC:
|
107
|
-
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.
|
108
|
+
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
|
108
109
|
lin2 = lin.copy()
|
109
110
|
try:
|
110
111
|
lin2.apply_opt(a)
|
111
|
-
up, lcl, tc_up = 1, 1, prod(tc.dims)//
|
112
|
+
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
112
113
|
for s,c in zip(lin2.full_shape, lin2.colors()):
|
113
114
|
if c in {"magenta", "yellow"}: up *= s
|
114
115
|
elif c in {"cyan", "green", "white"}: lcl *= s
|
@@ -117,8 +118,8 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
|
117
118
|
except KernelOptError: pass
|
118
119
|
return acted_lins
|
119
120
|
|
120
|
-
beam_pool, BEAM_DEBUG
|
121
|
-
def beam_search(lin:Kernel, rawbufs:
|
121
|
+
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
122
|
+
def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
|
122
123
|
global beam_pool
|
123
124
|
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
124
125
|
if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
|
@@ -126,7 +127,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
126
127
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
127
128
|
return ret
|
128
129
|
|
129
|
-
beam:
|
130
|
+
beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
|
130
131
|
seen_libs = set()
|
131
132
|
|
132
133
|
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
|
@@ -139,12 +140,12 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
139
140
|
|
140
141
|
try:
|
141
142
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
142
|
-
var_vals:
|
143
|
+
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
143
144
|
exiting, st = False, time.perf_counter()
|
144
145
|
dev = Device[lin.opts.device]
|
145
146
|
while not exiting:
|
146
|
-
acted_lins:
|
147
|
-
timed_lins:
|
147
|
+
acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
|
148
|
+
timed_lins: list[tuple[Kernel, float]] = []
|
148
149
|
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
149
150
|
least_compute_ops = math.inf
|
150
151
|
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
@@ -152,18 +153,13 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
152
153
|
p, lib, compile_et = proc
|
153
154
|
if lib in seen_libs: continue
|
154
155
|
# filter out kernels that use 1000x more compute than the smallest
|
155
|
-
least_compute_ops = min(this_compute_ops:=sym_infer(p.
|
156
|
+
least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
|
156
157
|
if least_compute_ops*1000 < this_compute_ops: continue
|
157
|
-
if len(CAPTURE_BEAM) > 0:
|
158
|
-
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
|
159
158
|
seen_libs.add(lib)
|
160
159
|
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
161
|
-
except RuntimeError
|
162
|
-
if len(CAPTURE_BEAM) > 0:
|
163
|
-
with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
|
164
|
-
continue # for runtime issues
|
160
|
+
except RuntimeError: continue # for runtime issues
|
165
161
|
timed_lins.append((acted_lins[i], min(tms)))
|
166
|
-
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(
|
162
|
+
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
167
163
|
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
168
164
|
|
169
165
|
# done
|
@@ -180,19 +176,19 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|
180
176
|
if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
|
181
177
|
return beam[0][0]
|
182
178
|
|
183
|
-
def optimize_local_size(
|
179
|
+
def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
|
184
180
|
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
|
185
181
|
MAX_WORKGROUP = 1024
|
186
182
|
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
187
183
|
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
188
184
|
def try_exec(local_size):
|
189
|
-
try: return
|
185
|
+
try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
|
190
186
|
except Exception: return float('inf')
|
191
187
|
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
|
192
188
|
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
|
193
189
|
return ret[1]
|
194
190
|
|
195
|
-
def time_linearizer(lin:Kernel, rawbufs:
|
191
|
+
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
196
192
|
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
197
193
|
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
198
194
|
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
@@ -201,7 +197,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
|
|
201
197
|
assert dev.compiler is not None
|
202
198
|
|
203
199
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
204
|
-
var_vals:
|
200
|
+
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
205
201
|
p = lin.to_program()
|
206
202
|
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
207
203
|
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
tinygrad/gradient.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
from typing import cast, Iterator
|
2
|
+
import math, functools, dataclasses
|
3
|
+
from tinygrad.dtype import dtypes, sum_acc_dtype
|
4
|
+
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
5
|
+
from tinygrad.helpers import argsort
|
6
|
+
|
7
|
+
def reduce_gradient(ctx:UOp, ret:UOp):
|
8
|
+
if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
|
9
|
+
if ret.arg[0] == Ops.MAX:
|
10
|
+
max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
|
11
|
+
div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
|
12
|
+
return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
|
13
|
+
if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
|
14
|
+
|
15
|
+
# ctx is grad_output
|
16
|
+
pm_gradient = PatternMatcher([
|
17
|
+
(UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
|
18
|
+
(UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
|
19
|
+
(UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
|
20
|
+
(UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
|
21
|
+
(UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
|
22
|
+
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
23
|
+
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
24
|
+
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
25
|
+
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
26
|
+
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
27
|
+
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
28
|
+
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
29
|
+
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
30
|
+
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
31
|
+
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
32
|
+
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
33
|
+
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
34
|
+
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
35
|
+
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
36
|
+
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
|
37
|
+
# TODO: this cast can be removed by putting the casts around the EXPAND
|
38
|
+
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
39
|
+
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
40
|
+
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
41
|
+
# there's no gradient for bitcast
|
42
|
+
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
43
|
+
])
|
44
|
+
|
45
|
+
# copied from tensor.py, get relevant toposort of gradients
|
46
|
+
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
47
|
+
@functools.lru_cache(None)
|
48
|
+
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
49
|
+
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
|
50
|
+
visited.add(node)
|
51
|
+
if node.op is Ops.DETACH: return
|
52
|
+
if is_in_target_path(node):
|
53
|
+
for i in node.src:
|
54
|
+
if i not in visited: yield from _walk(i, visited)
|
55
|
+
yield node
|
56
|
+
return list(_walk(root, set()))
|
57
|
+
|
58
|
+
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
59
|
+
grads = {root: root_grad}
|
60
|
+
for t0 in reversed(_deepwalk(root, targets)):
|
61
|
+
if t0 not in grads: continue
|
62
|
+
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
|
63
|
+
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
|
64
|
+
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
|
65
|
+
for k,v in zip(t0.src, lgrads):
|
66
|
+
if v is None: continue
|
67
|
+
if k in grads: grads[k] = grads[k] + v
|
68
|
+
else: grads[k] = v
|
69
|
+
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
|
70
|
+
return grads
|
tinygrad/helpers.py
CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
3
3
|
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import
|
5
|
+
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
|
6
6
|
|
7
7
|
T = TypeVar("T")
|
8
8
|
U = TypeVar("U")
|
@@ -23,76 +23,78 @@ def argfix(*x):
|
|
23
23
|
return tuple(x[0])
|
24
24
|
return x
|
25
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
26
|
-
def all_same(items:Union[
|
27
|
-
def all_int(t: Sequence[Any]) -> TypeGuard[
|
26
|
+
def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
|
27
|
+
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
28
28
|
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
29
|
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
30
|
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
31
31
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
32
32
|
def ansilen(s:str): return len(ansistrip(s))
|
33
|
-
def make_tuple(x:Union[int, Sequence[int]], cnt:int) ->
|
33
|
+
def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
34
34
|
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
35
35
|
def fully_flatten(l):
|
36
36
|
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
37
|
+
if hasattr(l, "shape") and l.shape == (): return [l[()]]
|
37
38
|
flattened = []
|
38
|
-
|
39
|
-
else:
|
40
|
-
for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
|
39
|
+
for li in l: flattened.extend(fully_flatten(li))
|
41
40
|
return flattened
|
42
41
|
return [l]
|
43
42
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
44
43
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
45
|
-
def ceildiv(num, amt):
|
46
|
-
ret = -(num//-amt)
|
47
|
-
return ret if not isinstance(ret, float) else int(ret)
|
44
|
+
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
48
45
|
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
49
|
-
def
|
50
|
-
def
|
51
|
-
def
|
46
|
+
def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
|
47
|
+
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
|
48
|
+
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
49
|
+
def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
50
|
+
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
|
51
|
+
def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
|
52
|
+
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
|
52
53
|
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
53
54
|
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
54
55
|
return {k:v for d in ds for k,v in d.items()}
|
55
|
-
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) ->
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
return a,b
|
56
|
+
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
|
57
|
+
ret:tuple[list[T], list[T]] = ([], [])
|
58
|
+
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
59
|
+
return ret
|
60
60
|
def unwrap(x:Optional[T]) -> T:
|
61
61
|
assert x is not None
|
62
62
|
return x
|
63
|
+
def get_single_element(x:list[T]) -> T:
|
64
|
+
assert len(x) == 1, f"list {x} must only have 1 element"
|
65
|
+
return x[0]
|
63
66
|
def get_child(obj, key):
|
64
67
|
for k in key.split('.'):
|
65
68
|
if k.isnumeric(): obj = obj[int(k)]
|
66
69
|
elif isinstance(obj, dict): obj = obj[k]
|
67
70
|
else: obj = getattr(obj, k)
|
68
71
|
return obj
|
69
|
-
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
72
|
+
def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
70
73
|
|
71
74
|
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
72
|
-
def polyN(x:T, p:
|
75
|
+
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
73
76
|
|
74
77
|
@functools.lru_cache(maxsize=None)
|
75
78
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
76
79
|
@functools.lru_cache(maxsize=None)
|
77
80
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
78
|
-
def temp(x:str) -> str:
|
81
|
+
def temp(x:str, append_user:bool=False) -> str:
|
82
|
+
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix()
|
79
83
|
|
80
84
|
class Context(contextlib.ContextDecorator):
|
81
|
-
stack: ClassVar[List[dict[str, int]]] = [{}]
|
82
85
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
83
86
|
def __enter__(self):
|
84
|
-
|
85
|
-
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
86
|
-
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
|
87
|
+
self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
|
88
|
+
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
87
89
|
def __exit__(self, *args):
|
88
|
-
for k in
|
90
|
+
for k,v in self.old_context.items(): ContextVar._cache[k].value = v
|
89
91
|
|
90
92
|
class ContextVar:
|
91
|
-
_cache: ClassVar[
|
93
|
+
_cache: ClassVar[dict[str, ContextVar]] = {}
|
92
94
|
value: int
|
93
95
|
key: str
|
94
96
|
def __init__(self, key, default_value):
|
95
|
-
|
97
|
+
if key in ContextVar._cache: raise RuntimeError(f"attempt to recreate ContextVar {key}")
|
96
98
|
ContextVar._cache[key] = self
|
97
99
|
self.value, self.key = getenv(key, default_value), key
|
98
100
|
def __bool__(self): return bool(self.value)
|
@@ -100,12 +102,15 @@ class ContextVar:
|
|
100
102
|
def __gt__(self, x): return self.value > x
|
101
103
|
def __lt__(self, x): return self.value < x
|
102
104
|
|
103
|
-
DEBUG, IMAGE, BEAM, NOOPT
|
105
|
+
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
106
|
+
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
|
104
107
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
105
|
-
|
106
|
-
|
107
|
-
FUSE_ARANGE, FUSE_CONV_BW
|
108
|
+
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
109
|
+
TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
|
110
|
+
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
108
111
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
112
|
+
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
113
|
+
CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
|
109
114
|
|
110
115
|
@dataclass(frozen=True)
|
111
116
|
class Metadata:
|
@@ -160,11 +165,10 @@ class Profiling(contextlib.ContextDecorator):
|
|
160
165
|
|
161
166
|
# *** universal database cache ***
|
162
167
|
|
163
|
-
|
164
|
-
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(
|
165
|
-
CACHELEVEL = getenv("CACHELEVEL", 2)
|
168
|
+
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
169
|
+
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
166
170
|
|
167
|
-
VERSION =
|
171
|
+
VERSION = 19
|
168
172
|
_db_connection = None
|
169
173
|
def db_connection():
|
170
174
|
global _db_connection
|
@@ -182,8 +186,8 @@ def diskcache_clear():
|
|
182
186
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
183
187
|
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
184
188
|
|
185
|
-
def diskcache_get(table:str, key:Union[
|
186
|
-
if CACHELEVEL
|
189
|
+
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
190
|
+
if CACHELEVEL < 1: return None
|
187
191
|
if isinstance(key, (str,int)): key = {"key": key}
|
188
192
|
conn = db_connection()
|
189
193
|
cur = conn.cursor()
|
@@ -195,8 +199,8 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
|
195
199
|
return None
|
196
200
|
|
197
201
|
_db_tables = set()
|
198
|
-
def diskcache_put(table:str, key:Union[
|
199
|
-
if CACHELEVEL
|
202
|
+
def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
|
203
|
+
if CACHELEVEL < 1: return val
|
200
204
|
if isinstance(key, (str,int)): key = {"key": key}
|
201
205
|
conn = db_connection()
|
202
206
|
cur = conn.cursor()
|
@@ -205,7 +209,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|
205
209
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
206
210
|
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
207
211
|
_db_tables.add(table)
|
208
|
-
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
212
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
209
213
|
conn.commit()
|
210
214
|
cur.close()
|
211
215
|
return val
|
@@ -217,6 +221,10 @@ def diskcache(func):
|
|
217
221
|
return diskcache_put(table, key, func(*args, **kwargs))
|
218
222
|
return wrapper
|
219
223
|
|
224
|
+
# *** process replay ***
|
225
|
+
|
226
|
+
CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
|
227
|
+
|
220
228
|
# *** http support ***
|
221
229
|
|
222
230
|
def _ensure_downloads_dir() -> pathlib.Path:
|
@@ -228,28 +236,26 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
|
228
236
|
subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
|
229
237
|
subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
|
230
238
|
return downloads_dir
|
231
|
-
return pathlib.Path(
|
239
|
+
return pathlib.Path(cache_dir) / "downloads"
|
232
240
|
|
233
241
|
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
|
234
242
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
235
243
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
236
244
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
237
|
-
else:
|
238
|
-
fp = _ensure_downloads_dir() / (subdir or "") / \
|
239
|
-
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
245
|
+
else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
240
246
|
if not fp.is_file() or not allow_caching:
|
247
|
+
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
241
248
|
with urllib.request.urlopen(url, timeout=10) as r:
|
242
|
-
assert r.status == 200
|
249
|
+
assert r.status == 200, r.status
|
243
250
|
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
244
|
-
progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
245
|
-
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
246
251
|
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
247
|
-
|
252
|
+
progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
253
|
+
with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
|
248
254
|
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
249
255
|
f.close()
|
250
|
-
progress_bar.update(close=True)
|
251
|
-
if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
252
256
|
pathlib.Path(f.name).rename(fp)
|
257
|
+
progress_bar.update(close=True)
|
258
|
+
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
253
259
|
return fp
|
254
260
|
|
255
261
|
# *** Exec helpers
|
@@ -264,16 +270,27 @@ def cpu_objdump(lib, objdump_tool='objdump'):
|
|
264
270
|
pathlib.Path(f.name).write_bytes(lib)
|
265
271
|
print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
|
266
272
|
|
273
|
+
def capstone_flatdump(lib: bytes):
|
274
|
+
import capstone
|
275
|
+
match platform.machine():
|
276
|
+
case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
|
277
|
+
case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
|
278
|
+
case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
|
279
|
+
for instr in cs.disasm(lib, 0):
|
280
|
+
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
|
281
|
+
sys.stdout.flush()
|
282
|
+
|
267
283
|
# *** ctypes helpers
|
268
284
|
|
269
285
|
# TODO: make this work with read only memoryviews (if possible)
|
270
286
|
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
271
287
|
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
272
|
-
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
273
|
-
def mv_address(mv
|
274
|
-
def to_char_p_p(options:
|
288
|
+
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
289
|
+
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
290
|
+
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
|
291
|
+
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
275
292
|
@functools.lru_cache(maxsize=None)
|
276
|
-
def init_c_struct_t(fields:
|
293
|
+
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
|
277
294
|
class CStruct(ctypes.Structure):
|
278
295
|
_pack_, _fields_ = 1, fields
|
279
296
|
return CStruct
|
@@ -282,13 +299,15 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
|
|
282
299
|
|
283
300
|
# *** tqdm
|
284
301
|
|
285
|
-
class tqdm:
|
286
|
-
def __init__(self, iterable=None, desc:str='', disable:bool=False,
|
302
|
+
class tqdm(Generic[T]):
|
303
|
+
def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
|
304
|
+
unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
|
287
305
|
self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
|
288
306
|
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
289
307
|
self.set_description(desc)
|
290
308
|
self.update(0)
|
291
|
-
def __iter__(self):
|
309
|
+
def __iter__(self) -> Iterator[T]:
|
310
|
+
assert self.iterable is not None, "need an iterable to iterate"
|
292
311
|
for item in self.iterable:
|
293
312
|
yield item
|
294
313
|
self.update(1)
|