tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py CHANGED
@@ -1,26 +1,27 @@
1
- from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
1
+ from typing import cast, Optional, Callable
2
2
  import itertools, functools, random, math, time, multiprocessing, traceback, signal
3
3
  from collections import defaultdict
4
4
  from dataclasses import replace
5
5
  from tinygrad.ops import UOp, Ops, Variable, sym_infer
6
6
  from tinygrad.device import Device, Buffer, Compiler
7
7
  from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
+ from tinygrad.helpers import IGNORE_BEAM_CACHE
8
9
  from tinygrad.dtype import ImageDType, PtrDType
9
10
  from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
10
11
  from tinygrad.tensor import Tensor
11
12
  from tinygrad.engine.realize import CompiledRunner
12
- from tinygrad.renderer import Program
13
-
14
- actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
15
- actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
16
- actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
17
- actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
18
- actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
19
- if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
20
- actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
21
- actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
22
- actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
23
- actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
13
+ from tinygrad.renderer import ProgramSpec
14
+
15
+ actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
16
+ actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
17
+ actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
18
+ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
19
+ actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
20
+ if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
21
+ actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
22
+ actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
23
+ actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
24
+ actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
24
25
  if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
25
26
 
26
27
  def _get_test_global_size(global_size, max_global_size, var_vals):
@@ -33,8 +34,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
33
34
  break
34
35
  return test_global_size, factor
35
36
 
36
- def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
37
- max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
37
+ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
38
+ max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
38
39
  factor = 1
39
40
  if p.global_size is not None and max_global_size is not None:
40
41
  global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
@@ -45,9 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
45
46
  input_bufs = [rawbufs[i] for i in car.p.globals]
46
47
  for _ in range(cnt):
47
48
  if clear_l2:
48
- if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
49
+ if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
49
50
  else:
50
- with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
51
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
51
52
  tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
52
53
  if early_stop is not None and early_stop < min(tms): break
53
54
  return tms
@@ -55,7 +56,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
55
56
  class TimeoutException(Exception): pass
56
57
  def timeout_handler(signum, frame): raise TimeoutException()
57
58
 
58
- def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
59
+ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
59
60
  if hasattr(signal, "alarm"):
60
61
  signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
61
62
  # set timeout
@@ -80,16 +81,16 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tup
80
81
  # workers should ignore ctrl c
81
82
  def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
82
83
 
83
- def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_allocated() for buf in bufs]
84
+ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
84
85
 
85
86
  # *** external API ***
86
87
 
87
88
  # get (scrap) buffers for timing the linearizer
88
- def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
89
- bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
89
+ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
90
+ bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
90
91
  for x in lin.bufs:
91
92
  if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
92
- rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
93
+ rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
93
94
  for k,lx in bufsts.items():
94
95
  buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
95
96
  assert isinstance(dtype, (PtrDType, ImageDType))
@@ -97,18 +98,18 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
97
98
  buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
98
99
  rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
99
100
  assert all(r is not None for r in rawbufs)
100
- return cast(List[Buffer], rawbufs)
101
+ return cast(list[Buffer], rawbufs)
101
102
 
102
103
  # get dictionary of all possible actions
103
- def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
104
+ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
104
105
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
105
106
  for i,a in enumerate(actions):
106
107
  if a.axis is not None and a.op is not OptOps.TC:
107
- if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
108
+ if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
108
109
  lin2 = lin.copy()
109
110
  try:
110
111
  lin2.apply_opt(a)
111
- up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
112
+ up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
112
113
  for s,c in zip(lin2.full_shape, lin2.colors()):
113
114
  if c in {"magenta", "yellow"}: up *= s
114
115
  elif c in {"cyan", "green", "white"}: lcl *= s
@@ -117,8 +118,8 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
117
118
  except KernelOptError: pass
118
119
  return acted_lins
119
120
 
120
- beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
121
- def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
121
+ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
122
+ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
122
123
  global beam_pool
123
124
  key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
124
125
  if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
@@ -126,7 +127,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
126
127
  for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
127
128
  return ret
128
129
 
129
- beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
130
+ beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
130
131
  seen_libs = set()
131
132
 
132
133
  default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
@@ -139,12 +140,12 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
139
140
 
140
141
  try:
141
142
  rawbufs = _ensure_buffer_alloc(rawbufs)
142
- var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
143
+ var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
143
144
  exiting, st = False, time.perf_counter()
144
145
  dev = Device[lin.opts.device]
145
146
  while not exiting:
146
- acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
147
- timed_lins: List[Tuple[Kernel, float]] = []
147
+ acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
148
+ timed_lins: list[tuple[Kernel, float]] = []
148
149
  _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
149
150
  least_compute_ops = math.inf
150
151
  for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
@@ -152,18 +153,13 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
152
153
  p, lib, compile_et = proc
153
154
  if lib in seen_libs: continue
154
155
  # filter out kernels that use 1000x more compute than the smallest
155
- least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
156
+ least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
156
157
  if least_compute_ops*1000 < this_compute_ops: continue
157
- if len(CAPTURE_BEAM) > 0:
158
- with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
159
158
  seen_libs.add(lib)
160
159
  try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
161
- except RuntimeError as e:
162
- if len(CAPTURE_BEAM) > 0:
163
- with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
164
- continue # for runtime issues
160
+ except RuntimeError: continue # for runtime issues
165
161
  timed_lins.append((acted_lins[i], min(tms)))
166
- if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
162
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
167
163
  elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
168
164
 
169
165
  # done
@@ -180,19 +176,19 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
180
176
  if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
181
177
  return beam[0][0]
182
178
 
183
- def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
179
+ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
184
180
  test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
185
181
  MAX_WORKGROUP = 1024
186
182
  local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
187
183
  local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
188
184
  def try_exec(local_size):
189
- try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
185
+ try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
190
186
  except Exception: return float('inf')
191
187
  ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
192
188
  assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
193
189
  return ret[1]
194
190
 
195
- def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
191
+ def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
196
192
  key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
197
193
  "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
198
194
  if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
@@ -201,7 +197,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
201
197
  assert dev.compiler is not None
202
198
 
203
199
  rawbufs = _ensure_buffer_alloc(rawbufs)
204
- var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
200
+ var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
205
201
  p = lin.to_program()
206
202
  tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
207
203
  max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
tinygrad/gradient.py ADDED
@@ -0,0 +1,70 @@
1
+ from typing import cast, Iterator
2
+ import math, functools, dataclasses
3
+ from tinygrad.dtype import dtypes, sum_acc_dtype
4
+ from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
5
+ from tinygrad.helpers import argsort
6
+
7
+ def reduce_gradient(ctx:UOp, ret:UOp):
8
+ if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
9
+ if ret.arg[0] == Ops.MAX:
10
+ max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
11
+ div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
12
+ return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
13
+ if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
14
+
15
+ # ctx is grad_output
16
+ pm_gradient = PatternMatcher([
17
+ (UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
18
+ (UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
19
+ (UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
20
+ (UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
21
+ (UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
22
+ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
23
+ (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
24
+ (UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25
+ (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
26
+ (ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
27
+ (UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
28
+ (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
29
+ (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
30
+ (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
31
+ (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
32
+ (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
33
+ (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
34
+ (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
35
+ (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
36
+ (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
37
+ # TODO: this cast can be removed by putting the casts around the EXPAND
38
+ (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
39
+ (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
40
+ (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
41
+ # there's no gradient for bitcast
42
+ (UPat(Ops.BITCAST), lambda ctx: (None,)),
43
+ ])
44
+
45
+ # copied from tensor.py, get relevant toposort of gradients
46
+ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
47
+ @functools.lru_cache(None)
48
+ def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
49
+ def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
50
+ visited.add(node)
51
+ if node.op is Ops.DETACH: return
52
+ if is_in_target_path(node):
53
+ for i in node.src:
54
+ if i not in visited: yield from _walk(i, visited)
55
+ yield node
56
+ return list(_walk(root, set()))
57
+
58
+ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
59
+ grads = {root: root_grad}
60
+ for t0 in reversed(_deepwalk(root, targets)):
61
+ if t0 not in grads: continue
62
+ lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
63
+ if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
64
+ assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
65
+ for k,v in zip(t0.src, lgrads):
66
+ if v is None: continue
67
+ if k in grads: grads[k] = grads[k] + v
68
+ else: grads[k] = v
69
+ if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
70
+ return grads
tinygrad/helpers.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
3
3
  import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
4
4
  from dataclasses import dataclass
5
- from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
5
+ from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
6
6
 
7
7
  T = TypeVar("T")
8
8
  U = TypeVar("U")
@@ -23,76 +23,78 @@ def argfix(*x):
23
23
  return tuple(x[0])
24
24
  return x
25
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
26
- def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
27
- def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
26
+ def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
27
+ def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
28
28
  def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
29
  def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
30
  def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
31
31
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
32
32
  def ansilen(s:str): return len(ansistrip(s))
33
- def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
33
+ def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
34
34
  def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
35
35
  def fully_flatten(l):
36
36
  if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
37
+ if hasattr(l, "shape") and l.shape == (): return [l[()]]
37
38
  flattened = []
38
- if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
39
- else:
40
- for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
39
+ for li in l: flattened.extend(fully_flatten(li))
41
40
  return flattened
42
41
  return [l]
43
42
  def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
44
43
  def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
45
- def ceildiv(num, amt):
46
- ret = -(num//-amt)
47
- return ret if not isinstance(ret, float) else int(ret)
44
+ def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
48
45
  def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
49
- def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
50
- def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
51
- def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
46
+ def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
47
+ def hi32(x:Any) -> Any: return x >> 32 # Any is sint
48
+ def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
49
+ def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
50
+ def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
51
+ def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
52
+ def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
52
53
  kvs = set([(k,v) for d in ds for k,v in d.items()])
53
54
  assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
54
55
  return {k:v for d in ds for k,v in d.items()}
55
- def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
56
- a:List[T] = []
57
- b:List[T] = []
58
- for s in itr: (a if fxn(s) else b).append(s)
59
- return a,b
56
+ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
57
+ ret:tuple[list[T], list[T]] = ([], [])
58
+ for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
59
+ return ret
60
60
  def unwrap(x:Optional[T]) -> T:
61
61
  assert x is not None
62
62
  return x
63
+ def get_single_element(x:list[T]) -> T:
64
+ assert len(x) == 1, f"list {x} must only have 1 element"
65
+ return x[0]
63
66
  def get_child(obj, key):
64
67
  for k in key.split('.'):
65
68
  if k.isnumeric(): obj = obj[int(k)]
66
69
  elif isinstance(obj, dict): obj = obj[k]
67
70
  else: obj = getattr(obj, k)
68
71
  return obj
69
- def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
72
+ def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
70
73
 
71
74
  # for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
72
- def polyN(x:T, p:List[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
75
+ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
73
76
 
74
77
  @functools.lru_cache(maxsize=None)
75
78
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
76
79
  @functools.lru_cache(maxsize=None)
77
80
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
78
- def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
81
+ def temp(x:str, append_user:bool=False) -> str:
82
+ return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix()
79
83
 
80
84
  class Context(contextlib.ContextDecorator):
81
- stack: ClassVar[List[dict[str, int]]] = [{}]
82
85
  def __init__(self, **kwargs): self.kwargs = kwargs
83
86
  def __enter__(self):
84
- Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
85
- for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
86
- Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
87
+ self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
88
+ for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
87
89
  def __exit__(self, *args):
88
- for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
90
+ for k,v in self.old_context.items(): ContextVar._cache[k].value = v
89
91
 
90
92
  class ContextVar:
91
- _cache: ClassVar[Dict[str, ContextVar]] = {}
93
+ _cache: ClassVar[dict[str, ContextVar]] = {}
92
94
  value: int
93
95
  key: str
94
96
  def __init__(self, key, default_value):
95
- assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
97
+ if key in ContextVar._cache: raise RuntimeError(f"attempt to recreate ContextVar {key}")
96
98
  ContextVar._cache[key] = self
97
99
  self.value, self.key = getenv(key, default_value), key
98
100
  def __bool__(self): return bool(self.value)
@@ -100,12 +102,15 @@ class ContextVar:
100
102
  def __gt__(self, x): return self.value > x
101
103
  def __lt__(self, x): return self.value < x
102
104
 
103
- DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
105
+ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
106
+ JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
104
107
  WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
105
- PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
106
- USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
107
- FUSE_ARANGE, FUSE_CONV_BW, LAZYCACHE = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0), ContextVar("LAZYCACHE", 1)
108
+ USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
109
+ TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
110
+ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
108
111
  SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
112
+ PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
113
+ CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
109
114
 
110
115
  @dataclass(frozen=True)
111
116
  class Metadata:
@@ -160,11 +165,10 @@ class Profiling(contextlib.ContextDecorator):
160
165
 
161
166
  # *** universal database cache ***
162
167
 
163
- _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
164
- CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
165
- CACHELEVEL = getenv("CACHELEVEL", 2)
168
+ cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
169
+ CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
166
170
 
167
- VERSION = 16
171
+ VERSION = 19
168
172
  _db_connection = None
169
173
  def db_connection():
170
174
  global _db_connection
@@ -182,8 +186,8 @@ def diskcache_clear():
182
186
  drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
183
187
  cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
184
188
 
185
- def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
186
- if CACHELEVEL == 0: return None
189
+ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
190
+ if CACHELEVEL < 1: return None
187
191
  if isinstance(key, (str,int)): key = {"key": key}
188
192
  conn = db_connection()
189
193
  cur = conn.cursor()
@@ -195,8 +199,8 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
195
199
  return None
196
200
 
197
201
  _db_tables = set()
198
- def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
199
- if CACHELEVEL == 0: return val
202
+ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
203
+ if CACHELEVEL < 1: return val
200
204
  if isinstance(key, (str,int)): key = {"key": key}
201
205
  conn = db_connection()
202
206
  cur = conn.cursor()
@@ -205,7 +209,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
205
209
  ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
206
210
  cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
207
211
  _db_tables.add(table)
208
- cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
212
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
209
213
  conn.commit()
210
214
  cur.close()
211
215
  return val
@@ -217,6 +221,10 @@ def diskcache(func):
217
221
  return diskcache_put(table, key, func(*args, **kwargs))
218
222
  return wrapper
219
223
 
224
+ # *** process replay ***
225
+
226
+ CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
227
+
220
228
  # *** http support ***
221
229
 
222
230
  def _ensure_downloads_dir() -> pathlib.Path:
@@ -228,28 +236,26 @@ def _ensure_downloads_dir() -> pathlib.Path:
228
236
  subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
229
237
  subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
230
238
  return downloads_dir
231
- return pathlib.Path(_cache_dir) / "tinygrad" / "downloads"
239
+ return pathlib.Path(cache_dir) / "downloads"
232
240
 
233
241
  def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
234
242
  allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
235
243
  if url.startswith(("/", ".")): return pathlib.Path(url)
236
244
  if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
237
- else:
238
- fp = _ensure_downloads_dir() / (subdir or "") / \
239
- ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
245
+ else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
240
246
  if not fp.is_file() or not allow_caching:
247
+ (_dir := fp.parent).mkdir(parents=True, exist_ok=True)
241
248
  with urllib.request.urlopen(url, timeout=10) as r:
242
- assert r.status == 200
249
+ assert r.status == 200, r.status
243
250
  length = int(r.headers.get('content-length', 0)) if not gunzip else None
244
- progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
245
- (path := fp.parent).mkdir(parents=True, exist_ok=True)
246
251
  readfile = gzip.GzipFile(fileobj=r) if gunzip else r
247
- with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
252
+ progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
253
+ with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
248
254
  while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
249
255
  f.close()
250
- progress_bar.update(close=True)
251
- if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
252
256
  pathlib.Path(f.name).rename(fp)
257
+ progress_bar.update(close=True)
258
+ if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
253
259
  return fp
254
260
 
255
261
  # *** Exec helpers
@@ -264,16 +270,27 @@ def cpu_objdump(lib, objdump_tool='objdump'):
264
270
  pathlib.Path(f.name).write_bytes(lib)
265
271
  print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
266
272
 
273
+ def capstone_flatdump(lib: bytes):
274
+ import capstone
275
+ match platform.machine():
276
+ case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
277
+ case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
278
+ case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
279
+ for instr in cs.disasm(lib, 0):
280
+ print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
281
+ sys.stdout.flush()
282
+
267
283
  # *** ctypes helpers
268
284
 
269
285
  # TODO: make this work with read only memoryviews (if possible)
270
286
  def from_mv(mv:memoryview, to_type=ctypes.c_char):
271
287
  return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
272
- def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
273
- def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
274
- def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
288
+ def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
289
+ def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
290
+ def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
291
+ return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
275
292
  @functools.lru_cache(maxsize=None)
276
- def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
293
+ def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
277
294
  class CStruct(ctypes.Structure):
278
295
  _pack_, _fields_ = 1, fields
279
296
  return CStruct
@@ -282,13 +299,15 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
282
299
 
283
300
  # *** tqdm
284
301
 
285
- class tqdm:
286
- def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
302
+ class tqdm(Generic[T]):
303
+ def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
304
+ unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
287
305
  self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
288
306
  self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
289
307
  self.set_description(desc)
290
308
  self.update(0)
291
- def __iter__(self):
309
+ def __iter__(self) -> Iterator[T]:
310
+ assert self.iterable is not None, "need an iterable to iterate"
292
311
  for item in self.iterable:
293
312
  yield item
294
313
  self.update(1)