tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py CHANGED
@@ -1,26 +1,27 @@
1
- from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
2
- import itertools, functools, random, math, time, multiprocessing, traceback, signal
1
+ from typing import cast, Optional, Callable
2
+ import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
3
3
  from collections import defaultdict
4
4
  from dataclasses import replace
5
5
  from tinygrad.ops import UOp, Ops, Variable, sym_infer
6
6
  from tinygrad.device import Device, Buffer, Compiler
7
- from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
7
+ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
8
+ from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
8
9
  from tinygrad.dtype import ImageDType, PtrDType
9
10
  from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
10
11
  from tinygrad.tensor import Tensor
11
12
  from tinygrad.engine.realize import CompiledRunner
12
- from tinygrad.renderer import Program
13
-
14
- actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
15
- actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)]
16
- actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
17
- actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
18
- actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
19
- if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
20
- actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)]
21
- actions += [Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
22
- actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
23
- actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)]
13
+ from tinygrad.renderer import ProgramSpec
14
+
15
+ actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
16
+ actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)]
17
+ actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)]
18
+ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)]
19
+ actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)]
20
+ if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)]
21
+ actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)]
22
+ actions += [Opt(op=OptOps.TC, axis=0, arg=(-1, 0))]
23
+ actions += [Opt(op=OptOps.TC, axis=axis, arg=(-1, getenv("TC_OPT", 2))) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce)
24
+ actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)]
24
25
  if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
25
26
 
26
27
  def _get_test_global_size(global_size, max_global_size, var_vals):
@@ -33,8 +34,8 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
33
34
  break
34
35
  return test_global_size, factor
35
36
 
36
- def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:List[Buffer], early_stop:Optional[float]=None,
37
- max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> List[float]:
37
+ def _time_program(p:ProgramSpec, lib:bytes, var_vals:dict[Variable, int], rawbufs:list[Buffer], early_stop:Optional[float]=None,
38
+ max_global_size:Optional[int]=65536, clear_l2=False, cnt=3, name="test") -> list[float]:
38
39
  factor = 1
39
40
  if p.global_size is not None and max_global_size is not None:
40
41
  global_size, factor = _get_test_global_size(p.global_size, max_global_size, var_vals)
@@ -45,9 +46,9 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
45
46
  input_bufs = [rawbufs[i] for i in car.p.globals]
46
47
  for _ in range(cnt):
47
48
  if clear_l2:
48
- if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
49
+ if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
49
50
  else:
50
- with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
51
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0, TRACK_MATCH_STATS=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
51
52
  tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
52
53
  if early_stop is not None and early_stop < min(tms): break
53
54
  return tms
@@ -55,7 +56,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
55
56
  class TimeoutException(Exception): pass
56
57
  def timeout_handler(signum, frame): raise TimeoutException()
57
58
 
58
- def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tuple[int, Optional[Tuple[Program, bytes, float]]]:
59
+ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tuple[int, Optional[tuple[ProgramSpec, bytes, float]]]:
59
60
  if hasattr(signal, "alarm"):
60
61
  signal.signal(getattr(signal, 'SIGALRM'), timeout_handler)
61
62
  # set timeout
@@ -80,16 +81,16 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Kernel], compiler:Compiler) -> Tup
80
81
  # workers should ignore ctrl c
81
82
  def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
82
83
 
83
- def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_allocated() for buf in bufs]
84
+ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() for buf in bufs]
84
85
 
85
86
  # *** external API ***
86
87
 
87
88
  # get (scrap) buffers for timing the linearizer
88
- def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
89
- bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
89
+ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
90
+ bufsts: defaultdict[int, list[UOp]] = defaultdict(list)
90
91
  for x in lin.bufs:
91
92
  if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
92
- rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
93
+ rawbufs: list[Optional[Buffer]] = [None]*len(bufsts)
93
94
  for k,lx in bufsts.items():
94
95
  buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
95
96
  assert isinstance(dtype, (PtrDType, ImageDType))
@@ -97,18 +98,26 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
97
98
  buf_dtype = dtype if isinstance(dtype, ImageDType) else dtype.base
98
99
  rawbufs[k] = Buffer(lin.opts.device, buf_size, buf_dtype).allocate() if allocate else Buffer(lin.opts.device, buf_size, buf_dtype)
99
100
  assert all(r is not None for r in rawbufs)
100
- return cast(List[Buffer], rawbufs)
101
+ return cast(list[Buffer], rawbufs)
101
102
 
102
103
  # get dictionary of all possible actions
103
- def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
104
+ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
104
105
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
105
- for i,a in enumerate(actions):
106
+ kernel_actions = actions.copy()
107
+
108
+ if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
109
+ for i, action in enumerate(kernel_actions):
110
+ if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
111
+ # replace every tc_action with default tc with one tc_action for each available tc
112
+ kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
113
+
114
+ for i,a in enumerate(kernel_actions):
106
115
  if a.axis is not None and a.op is not OptOps.TC:
107
- if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
116
+ if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
108
117
  lin2 = lin.copy()
109
118
  try:
110
119
  lin2.apply_opt(a)
111
- up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
120
+ up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
112
121
  for s,c in zip(lin2.full_shape, lin2.colors()):
113
122
  if c in {"magenta", "yellow"}: up *= s
114
123
  elif c in {"cyan", "green", "white"}: lcl *= s
@@ -117,8 +126,8 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
117
126
  except KernelOptError: pass
118
127
  return acted_lins
119
128
 
120
- beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
121
- def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
129
+ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
130
+ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel:
122
131
  global beam_pool
123
132
  key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
124
133
  if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
@@ -126,25 +135,27 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
126
135
  for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
127
136
  return ret
128
137
 
129
- beam: List[Tuple[Kernel, float]] = [(lin, float("inf"))]
138
+ beam: list[tuple[Kernel, float]] = [(lin, float("inf"))]
130
139
  seen_libs = set()
131
140
 
132
141
  default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
133
142
  if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
134
143
  beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
144
+ @atexit.register
145
+ def close_pool(): beam_pool.close()
135
146
 
136
147
  min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
137
148
  if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
138
- if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
149
+ if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
139
150
 
140
151
  try:
141
152
  rawbufs = _ensure_buffer_alloc(rawbufs)
142
- var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
153
+ var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
143
154
  exiting, st = False, time.perf_counter()
144
155
  dev = Device[lin.opts.device]
145
156
  while not exiting:
146
- acted_lins: List[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
147
- timed_lins: List[Tuple[Kernel, float]] = []
157
+ acted_lins: list[Kernel] = flatten([get_kernel_actions(lin, include_0=False).values() for lin,_ in beam])
158
+ timed_lins: list[tuple[Kernel, float]] = []
148
159
  _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
149
160
  least_compute_ops = math.inf
150
161
  for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
@@ -152,59 +163,37 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
152
163
  p, lib, compile_et = proc
153
164
  if lib in seen_libs: continue
154
165
  # filter out kernels that use 1000x more compute than the smallest
155
- least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
166
+ least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops)
156
167
  if least_compute_ops*1000 < this_compute_ops: continue
157
- if len(CAPTURE_BEAM) > 0:
158
- with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
159
168
  seen_libs.add(lib)
160
169
  try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
161
- except RuntimeError as e:
162
- if len(CAPTURE_BEAM) > 0:
163
- with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
164
- continue # for runtime issues
170
+ except RuntimeError: continue # for runtime issues
165
171
  timed_lins.append((acted_lins[i], min(tms)))
166
- if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
167
- elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
172
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
173
+ elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
168
174
 
169
175
  # done
170
176
  opts = sorted(timed_lins, key=lambda x: x[1])
171
177
  exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
172
178
  if not exiting: beam = opts[:amt]
173
179
  elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
174
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
180
+ if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
175
181
  except KeyboardInterrupt as e:
176
182
  if beam_pool is not None: beam_pool.terminate()
177
183
  raise e
178
184
 
179
185
  if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
180
- if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
186
+ if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
181
187
  return beam[0][0]
182
188
 
183
- def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
189
+ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
184
190
  test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype).allocate(), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
185
191
  MAX_WORKGROUP = 1024
186
192
  local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
187
193
  local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
188
194
  def try_exec(local_size):
189
- try: return clprg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
195
+ try: return _prg(*[x._buf for x in test_rawbuffers], global_size=[g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size=local_size, wait=True) # noqa: E501
190
196
  except Exception: return float('inf')
191
197
  ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
192
198
  assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
193
199
  return ret[1]
194
-
195
- def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
196
- key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
197
- "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
198
- if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
199
-
200
- dev = Device[lin.opts.device]
201
- assert dev.compiler is not None
202
-
203
- rawbufs = _ensure_buffer_alloc(rawbufs)
204
- var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
205
- p = lin.to_program()
206
- tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
207
- max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
208
-
209
- if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
210
- return min(tms)
tinygrad/gradient.py ADDED
@@ -0,0 +1,73 @@
1
+ from typing import cast, Iterator
2
+ import math, functools, dataclasses
3
+ from tinygrad.dtype import dtypes, sum_acc_dtype
4
+ from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
5
+ from tinygrad.helpers import argsort
6
+
7
+ def reduce_gradient(ctx:UOp, ret:UOp):
8
+ if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
9
+ if ret.arg[0] == Ops.MAX:
10
+ max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
11
+ div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
12
+ return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
13
+ if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
14
+
15
+ # ctx is grad_output
16
+ pm_gradient = PatternMatcher([
17
+ (UPat(Ops.CAST, name="ret"), lambda ctx, ret: (ctx.cast(ret.src[0].dtype),)),
18
+ (UPat(Ops.RECIP, name="ret"), lambda ctx, ret: (-ctx * ret * ret,)),
19
+ (UPat(Ops.SIN, name="ret"), lambda ctx, ret: ((math.pi/2 - ret.src[0]).sin() * ctx,)),
20
+ (UPat(Ops.LOG2, name="ret"), lambda ctx, ret: (ctx / (ret.src[0] * math.log(2)),)),
21
+ (UPat(Ops.EXP2, name="ret"), lambda ctx, ret: (ret * ctx * math.log(2),)),
22
+ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
23
+ (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
24
+ (UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25
+ (UPat(Ops.POW, name="ret"), lambda ctx, ret:
26
+ (ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
27
+ ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
28
+ (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
29
+ (ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
30
+ (UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
31
+ (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
32
+ (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
33
+ (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
34
+ (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
35
+ (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
36
+ (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
37
+ (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
38
+ (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
39
+ (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
40
+ # TODO: this cast can be removed by putting the casts around the EXPAND
41
+ (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
42
+ (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
43
+ (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
44
+ # there's no gradient for bitcast
45
+ (UPat(Ops.BITCAST), lambda ctx: (None,)),
46
+ ])
47
+
48
+ # copied from tensor.py, get relevant toposort of gradients
49
+ def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
50
+ @functools.lru_cache(None)
51
+ def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
52
+ def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
53
+ visited.add(node)
54
+ if node.op is Ops.DETACH: return
55
+ if is_in_target_path(node):
56
+ for i in node.src:
57
+ if i not in visited: yield from _walk(i, visited)
58
+ yield node
59
+ return list(_walk(root, set()))
60
+
61
+ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
62
+ grads = {root: root_grad}
63
+ for t0 in reversed(_deepwalk(root, targets)):
64
+ if t0 not in grads: continue
65
+ lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
66
+ if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
67
+ assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
68
+ for k,v in zip(t0.src, lgrads):
69
+ if v is None: continue
70
+ if k in grads: grads[k] = grads[k] + v
71
+ else: grads[k] = v
72
+ if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
73
+ return grads
tinygrad/helpers.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
3
3
  import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
4
4
  from dataclasses import dataclass
5
- from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
5
+ from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
6
6
 
7
7
  T = TypeVar("T")
8
8
  U = TypeVar("U")
@@ -23,76 +23,79 @@ def argfix(*x):
23
23
  return tuple(x[0])
24
24
  return x
25
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
26
- def all_same(items:Union[Tuple[T, ...], List[T]]): return all(x == items[0] for x in items)
27
- def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
26
+ def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
27
+ def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
28
28
  def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
29
  def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
30
  def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
31
+ def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
31
32
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
32
33
  def ansilen(s:str): return len(ansistrip(s))
33
- def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
34
+ def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
34
35
  def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
35
36
  def fully_flatten(l):
36
37
  if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
38
+ if hasattr(l, "shape") and l.shape == (): return [l[()]]
37
39
  flattened = []
38
- if hasattr(l, "shape") and l.shape == (): flattened.append(l[()])
39
- else:
40
- for i in range(len(l)): flattened.extend(fully_flatten(l[i]))
40
+ for li in l: flattened.extend(fully_flatten(li))
41
41
  return flattened
42
42
  return [l]
43
43
  def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
44
44
  def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
45
- def ceildiv(num, amt):
46
- ret = -(num//-amt)
47
- return ret if not isinstance(ret, float) else int(ret)
45
+ def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
48
46
  def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
49
- def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
50
- def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
51
- def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
47
+ def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
48
+ def hi32(x:Any) -> Any: return x >> 32 # Any is sint
49
+ def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
50
+ def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
51
+ def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
52
+ def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
53
+ def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
52
54
  kvs = set([(k,v) for d in ds for k,v in d.items()])
53
55
  assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
54
56
  return {k:v for d in ds for k,v in d.items()}
55
- def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
56
- a:List[T] = []
57
- b:List[T] = []
58
- for s in itr: (a if fxn(s) else b).append(s)
59
- return a,b
57
+ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]]:
58
+ ret:tuple[list[T], list[T]] = ([], [])
59
+ for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
60
+ return ret
60
61
  def unwrap(x:Optional[T]) -> T:
61
62
  assert x is not None
62
63
  return x
64
+ def get_single_element(x:list[T]) -> T:
65
+ assert len(x) == 1, f"list {x} must only have 1 element"
66
+ return x[0]
63
67
  def get_child(obj, key):
64
68
  for k in key.split('.'):
65
69
  if k.isnumeric(): obj = obj[int(k)]
66
70
  elif isinstance(obj, dict): obj = obj[k]
67
71
  else: obj = getattr(obj, k)
68
72
  return obj
69
- def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
73
+ def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
70
74
 
71
75
  # for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
72
- def polyN(x:T, p:List[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
76
+ def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
73
77
 
74
78
  @functools.lru_cache(maxsize=None)
75
79
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
76
80
  @functools.lru_cache(maxsize=None)
77
81
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
78
- def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
82
+ def temp(x:str, append_user:bool=False) -> str:
83
+ return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
79
84
 
80
85
  class Context(contextlib.ContextDecorator):
81
- stack: ClassVar[List[dict[str, int]]] = [{}]
82
86
  def __init__(self, **kwargs): self.kwargs = kwargs
83
87
  def __enter__(self):
84
- Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
85
- for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
86
- Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
88
+ self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
89
+ for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
87
90
  def __exit__(self, *args):
88
- for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
91
+ for k,v in self.old_context.items(): ContextVar._cache[k].value = v
89
92
 
90
93
  class ContextVar:
91
- _cache: ClassVar[Dict[str, ContextVar]] = {}
94
+ _cache: ClassVar[dict[str, ContextVar]] = {}
92
95
  value: int
93
96
  key: str
94
97
  def __init__(self, key, default_value):
95
- assert key not in ContextVar._cache, f"attempt to recreate ContextVar {key}"
98
+ if key in ContextVar._cache: raise RuntimeError(f"attempt to recreate ContextVar {key}")
96
99
  ContextVar._cache[key] = self
97
100
  self.value, self.key = getenv(key, default_value), key
98
101
  def __bool__(self): return bool(self.value)
@@ -100,12 +103,16 @@ class ContextVar:
100
103
  def __gt__(self, x): return self.value > x
101
104
  def __lt__(self, x): return self.value < x
102
105
 
103
- DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
106
+ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
107
+ JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
104
108
  WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
105
- PROFILE, PROFILEPATH = ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
106
- USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
107
- FUSE_ARANGE, FUSE_CONV_BW, LAZYCACHE = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0), ContextVar("LAZYCACHE", 1)
109
+ USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
110
+ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
111
+ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
108
112
  SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
113
+ PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
114
+ CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
115
+ DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
109
116
 
110
117
  @dataclass(frozen=True)
111
118
  class Metadata:
@@ -160,11 +167,10 @@ class Profiling(contextlib.ContextDecorator):
160
167
 
161
168
  # *** universal database cache ***
162
169
 
163
- _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
164
- CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
165
- CACHELEVEL = getenv("CACHELEVEL", 2)
170
+ cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
171
+ CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
166
172
 
167
- VERSION = 16
173
+ VERSION = 19
168
174
  _db_connection = None
169
175
  def db_connection():
170
176
  global _db_connection
@@ -182,8 +188,8 @@ def diskcache_clear():
182
188
  drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
183
189
  cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
184
190
 
185
- def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
186
- if CACHELEVEL == 0: return None
191
+ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
192
+ if CACHELEVEL < 1: return None
187
193
  if isinstance(key, (str,int)): key = {"key": key}
188
194
  conn = db_connection()
189
195
  cur = conn.cursor()
@@ -195,8 +201,8 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
195
201
  return None
196
202
 
197
203
  _db_tables = set()
198
- def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
199
- if CACHELEVEL == 0: return val
204
+ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
205
+ if CACHELEVEL < 1: return val
200
206
  if isinstance(key, (str,int)): key = {"key": key}
201
207
  conn = db_connection()
202
208
  cur = conn.cursor()
@@ -205,7 +211,7 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
205
211
  ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
206
212
  cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
207
213
  _db_tables.add(table)
208
- cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
214
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
209
215
  conn.commit()
210
216
  cur.close()
211
217
  return val
@@ -217,6 +223,10 @@ def diskcache(func):
217
223
  return diskcache_put(table, key, func(*args, **kwargs))
218
224
  return wrapper
219
225
 
226
+ # *** process replay ***
227
+
228
+ CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
229
+
220
230
  # *** http support ***
221
231
 
222
232
  def _ensure_downloads_dir() -> pathlib.Path:
@@ -228,28 +238,26 @@ def _ensure_downloads_dir() -> pathlib.Path:
228
238
  subprocess.run(["sudo", "chown", "tiny:root", downloads_dir], check=True)
229
239
  subprocess.run(["sudo", "chmod", "775", downloads_dir], check=True)
230
240
  return downloads_dir
231
- return pathlib.Path(_cache_dir) / "tinygrad" / "downloads"
241
+ return pathlib.Path(cache_dir) / "downloads"
232
242
 
233
243
  def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
234
244
  allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
235
245
  if url.startswith(("/", ".")): return pathlib.Path(url)
236
246
  if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
237
- else:
238
- fp = _ensure_downloads_dir() / (subdir or "") / \
239
- ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
247
+ else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
240
248
  if not fp.is_file() or not allow_caching:
249
+ (_dir := fp.parent).mkdir(parents=True, exist_ok=True)
241
250
  with urllib.request.urlopen(url, timeout=10) as r:
242
- assert r.status == 200
251
+ assert r.status == 200, r.status
243
252
  length = int(r.headers.get('content-length', 0)) if not gunzip else None
244
- progress_bar = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
245
- (path := fp.parent).mkdir(parents=True, exist_ok=True)
246
253
  readfile = gzip.GzipFile(fileobj=r) if gunzip else r
247
- with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
254
+ progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
255
+ with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
248
256
  while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
249
257
  f.close()
250
- progress_bar.update(close=True)
251
- if length and (file_size:=os.stat(f.name).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
252
258
  pathlib.Path(f.name).rename(fp)
259
+ progress_bar.update(close=True)
260
+ if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
253
261
  return fp
254
262
 
255
263
  # *** Exec helpers
@@ -264,16 +272,28 @@ def cpu_objdump(lib, objdump_tool='objdump'):
264
272
  pathlib.Path(f.name).write_bytes(lib)
265
273
  print(subprocess.check_output([objdump_tool, '-d', f.name]).decode('utf-8'))
266
274
 
275
+ def capstone_flatdump(lib: bytes):
276
+ import capstone
277
+ match platform.machine():
278
+ case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
279
+ case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
280
+ case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
281
+ cs.skipdata = True
282
+ for instr in cs.disasm(lib, 0):
283
+ print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
284
+ sys.stdout.flush()
285
+
267
286
  # *** ctypes helpers
268
287
 
269
288
  # TODO: make this work with read only memoryviews (if possible)
270
289
  def from_mv(mv:memoryview, to_type=ctypes.c_char):
271
290
  return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
272
- def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
273
- def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
274
- def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
291
+ def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
292
+ def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
293
+ def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
294
+ return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
275
295
  @functools.lru_cache(maxsize=None)
276
- def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
296
+ def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
277
297
  class CStruct(ctypes.Structure):
278
298
  _pack_, _fields_ = 1, fields
279
299
  return CStruct
@@ -282,13 +302,15 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
282
302
 
283
303
  # *** tqdm
284
304
 
285
- class tqdm:
286
- def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
305
+ class tqdm(Generic[T]):
306
+ def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
307
+ unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
287
308
  self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
288
309
  self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
289
310
  self.set_description(desc)
290
311
  self.update(0)
291
- def __iter__(self):
312
+ def __iter__(self) -> Iterator[T]:
313
+ assert self.iterable is not None, "need an iterable to iterate"
292
314
  for item in self.iterable:
293
315
  yield item
294
316
  self.update(1)