tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/engine/search.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from typing import cast, Optional, Callable
2
- import itertools, functools, random, math, time, multiprocessing, traceback, signal
2
+ import itertools, functools, random, math, time, multiprocessing, traceback, signal, atexit
3
3
  from collections import defaultdict
4
4
  from dataclasses import replace
5
5
  from tinygrad.ops import UOp, Ops, Variable, sym_infer
6
6
  from tinygrad.device import Device, Buffer, Compiler
7
- from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
8
- from tinygrad.helpers import IGNORE_BEAM_CACHE
7
+ from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, time_to_str
8
+ from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
9
9
  from tinygrad.dtype import ImageDType, PtrDType
10
10
  from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
11
11
  from tinygrad.tensor import Tensor
@@ -103,9 +103,17 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
103
103
  # get dictionary of all possible actions
104
104
  def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
105
105
  acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
106
- for i,a in enumerate(actions):
106
+ kernel_actions = actions.copy()
107
+
108
+ if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
109
+ for i, action in enumerate(kernel_actions):
110
+ if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
111
+ # replace every tc_action with default tc with one tc_action for each available tc
112
+ kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]
113
+
114
+ for i,a in enumerate(kernel_actions):
107
115
  if a.axis is not None and a.op is not OptOps.TC:
108
- if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
116
+ if ((ax:=lin.real_axis(a)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
109
117
  lin2 = lin.copy()
110
118
  try:
111
119
  lin2.apply_opt(a)
@@ -133,10 +141,12 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
133
141
  default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV", "METAL"} else 0
134
142
  if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
135
143
  beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
144
+ @atexit.register
145
+ def close_pool(): beam_pool.close()
136
146
 
137
147
  min_progress = getenv("BEAM_MIN_PROGRESS", 0.01)/1e6
138
148
  if BEAM_DEBUG: print(f"BEAM_SEARCH:\n{lin.ast}")
139
- if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
149
+ if DEBUG >= 2: print(f" 0.00s: from 1 -> 1 actions {lin.colored_shape()}")
140
150
 
141
151
  try:
142
152
  rawbufs = _ensure_buffer_alloc(rawbufs)
@@ -159,21 +169,21 @@ def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True,
159
169
  try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
160
170
  except RuntimeError: continue # for runtime issues
161
171
  timed_lins.append((acted_lins[i], min(tms)))
162
- if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
163
- elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
172
+ if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(list, p.uops)):5d} uops {time_to_str(compile_et, w=12)} compile/{time_to_str(timed_lins[-1][1], w=12)} run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
173
+ elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {time_to_str(timed_lins[-1][1], w=12)} {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
164
174
 
165
175
  # done
166
176
  opts = sorted(timed_lins, key=lambda x: x[1])
167
177
  exiting = len(opts) == 0 or (opts[0][1] < min_progress) or (len(beam) > 0 and ((beam[0][1]-opts[0][1]) < min_progress))
168
178
  if not exiting: beam = opts[:amt]
169
179
  elif len(opts) > 0 and opts[0][1] < beam[0][1]: beam = opts[:1]
170
- if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
180
+ if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(time_to_str(beam[0][1], w=12), "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) # noqa: E501
171
181
  except KeyboardInterrupt as e:
172
182
  if beam_pool is not None: beam_pool.terminate()
173
183
  raise e
174
184
 
175
185
  if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
176
- if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={beam[0][1]*1e6:0.2f} us, applied_opts={beam[0][0].applied_opts}")
186
+ if BEAM_DEBUG: print(f"BEAM_SEARCH: final tm={time_to_str(beam[0][1], w=0)}, applied_opts={beam[0][0].applied_opts}")
177
187
  return beam[0][0]
178
188
 
179
189
  def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffer]) -> list[int]:
@@ -187,20 +197,3 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
187
197
  ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
188
198
  assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
189
199
  return ret[1]
190
-
191
- def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
192
- key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
193
- "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
194
- if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
195
-
196
- dev = Device[lin.opts.device]
197
- assert dev.compiler is not None
198
-
199
- rawbufs = _ensure_buffer_alloc(rawbufs)
200
- var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
201
- p = lin.to_program()
202
- tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
203
- max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
204
-
205
- if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
206
- return min(tms)
tinygrad/gradient.py CHANGED
@@ -22,6 +22,9 @@ pm_gradient = PatternMatcher([
22
22
  (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
23
23
  (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
24
24
  (UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25
+ (UPat(Ops.POW, name="ret"), lambda ctx, ret:
26
+ (ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
27
+ ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
25
28
  (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
26
29
  (ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
27
30
  (UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
tinygrad/helpers.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
3
3
  import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
4
4
  from dataclasses import dataclass
5
5
  from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
@@ -28,6 +28,7 @@ def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstan
28
28
  def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
29
  def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
30
  def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
31
+ def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
31
32
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
32
33
  def ansilen(s:str): return len(ansistrip(s))
33
34
  def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
@@ -79,7 +80,7 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
79
80
  @functools.lru_cache(maxsize=None)
80
81
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
81
82
  def temp(x:str, append_user:bool=False) -> str:
82
- return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{os.getenv('USERNAME', os.getlogin())}" if append_user else x)).as_posix()
83
+ return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
83
84
 
84
85
  class Context(contextlib.ContextDecorator):
85
86
  def __init__(self, **kwargs): self.kwargs = kwargs
@@ -106,11 +107,12 @@ DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), Cont
106
107
  JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
107
108
  WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
108
109
  USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
109
- TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
110
+ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
110
111
  FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
111
112
  SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
112
113
  PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
113
- CACHELEVEL, IGNORE_BEAM_CACHE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0)
114
+ CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
115
+ DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
114
116
 
115
117
  @dataclass(frozen=True)
116
118
  class Metadata:
@@ -276,6 +278,7 @@ def capstone_flatdump(lib: bytes):
276
278
  case 'x86_64' | 'AMD64': cs = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
277
279
  case 'aarch64' | 'arm64': cs = capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
278
280
  case machine: raise NotImplementedError(f"Capstone disassembly isn't supported for {machine}")
281
+ cs.skipdata = True
279
282
  for instr in cs.disasm(lib, 0):
280
283
  print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
281
284
  sys.stdout.flush()
tinygrad/nn/state.py CHANGED
@@ -195,8 +195,8 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
195
195
  if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
196
196
  intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
197
197
  assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
198
- if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
199
- assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
198
+ if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
199
+ assert storage[1] != dtypes.bfloat16, "can't permute BF16"
200
200
  # TODO: find a nice way to support all shapetracker on disktensors
201
201
  ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
202
202