tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+ import sys, pathlib
3
+ sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
4
+ try: import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
5
+ except ImportError as e: raise ImportError("torch frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e
tinygrad/gradient.py CHANGED
@@ -1,16 +1,16 @@
1
- from typing import cast, Iterator
2
- import math, functools, dataclasses
3
- from tinygrad.dtype import dtypes, sum_acc_dtype
4
- from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
1
+ from typing import cast
2
+ import math, dataclasses
3
+ from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
5
4
  from tinygrad.helpers import argsort
6
5
 
7
6
  def reduce_gradient(ctx:UOp, ret:UOp):
8
- if ret.arg[0] == Ops.ADD: return (ctx.expand(ret.src[0].shape),)
7
+ def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
8
+ if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
9
9
  if ret.arg[0] == Ops.MAX:
10
- max_is_1s = ret.src[0].ne(ret.expand(ret.src[0].shape)).ne(ret.src[0].const_like(1).cast(dtypes.bool)).cast(ctx.dtype)
11
- div = max_is_1s.r(Ops.ADD, ret.arg[1]).expand(ret.src[0].shape)
12
- return ((max_is_1s/div) * ctx.expand(ret.src[0].shape),)
13
- if ret.arg[0] == Ops.MUL: return ((ctx * ret).expand(ret.src[0].shape) / ret.src[0],)
10
+ max_is_1s = ret.src[0].eq(to_inp_shape(ret)).cast(ctx.dtype)
11
+ div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
12
+ return ((max_is_1s/div) * to_inp_shape(ctx),)
13
+ if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
14
14
 
15
15
  # ctx is grad_output
16
16
  pm_gradient = PatternMatcher([
@@ -23,40 +23,32 @@ pm_gradient = PatternMatcher([
23
23
  (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
24
24
  (UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25
25
  (UPat(Ops.POW, name="ret"), lambda ctx, ret:
26
- (ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
27
- ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
26
+ (ctx*(ret.src[0].eq(0) & ret.src[1].eq(0)).where(ret.src[1], ret.src[1]*ret.src[0].pow(ret.src[1]-1)),
27
+ ctx*ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ret*ret.src[0].log2()*math.log(2.0)))),
28
28
  (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
29
29
  (ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
30
30
  (UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
31
31
  (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
32
32
  (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
33
- (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
33
+ (UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
34
34
  (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
35
35
  (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
36
36
  (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
37
37
  (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
38
38
  (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
39
39
  (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
40
- # TODO: this cast can be removed by putting the casts around the EXPAND
41
- (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
42
- (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
40
+ (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
43
41
  (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
44
42
  # there's no gradient for bitcast
45
43
  (UPat(Ops.BITCAST), lambda ctx: (None,)),
46
44
  ])
47
45
 
48
- # copied from tensor.py, get relevant toposort of gradients
49
46
  def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
50
- @functools.lru_cache(None)
51
- def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
52
- def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
53
- visited.add(node)
54
- if node.op is Ops.DETACH: return
55
- if is_in_target_path(node):
56
- for i in node.src:
57
- if i not in visited: yield from _walk(i, visited)
58
- yield node
59
- return list(_walk(root, set()))
47
+ # compute the target path (top down)
48
+ in_target_path: dict[UOp, bool] = {}
49
+ for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
50
+ # don't flow through DETACH/ASSIGN or anything not in target path
51
+ return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
60
52
 
61
53
  def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
62
54
  grads = {root: root_grad}
@@ -69,5 +61,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
69
61
  if v is None: continue
70
62
  if k in grads: grads[k] = grads[k] + v
71
63
  else: grads[k] = v
72
- if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
64
+ if len(forward_metadata:=all_metadata.get(t0, ())): all_metadata[v] = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
73
65
  return grads
tinygrad/helpers.py CHANGED
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
  import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
3
- import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
4
- from dataclasses import dataclass
5
- from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
3
+ import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal
4
+ from dataclasses import dataclass, field
5
+ from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator
6
6
 
7
7
  T = TypeVar("T")
8
8
  U = TypeVar("U")
9
9
  # NOTE: it returns int 1 if x is empty regardless of the type of x
10
- def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
10
+ def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
11
11
 
12
12
  # NOTE: helpers is not allowed to import from anything else in tinygrad
13
13
  OSX = platform.system() == "Darwin"
@@ -23,15 +23,14 @@ def argfix(*x):
23
23
  return tuple(x[0])
24
24
  return x
25
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
26
- def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for x in items)
26
+ def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
27
27
  def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
28
- def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
28
+ def colored(st, color:str|None, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
29
  def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30
- def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
31
30
  def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
32
31
  def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
33
32
  def ansilen(s:str): return len(ansistrip(s))
34
- def make_tuple(x:Union[int, Sequence[int]], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
33
+ def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
35
34
  def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
36
35
  def fully_flatten(l):
37
36
  if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
@@ -44,12 +43,17 @@ def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
44
43
  def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
45
44
  def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
46
45
  def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
46
+ def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
47
+ # cstyle div and mod
48
+ def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
49
+ def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
47
50
  def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
48
51
  def hi32(x:Any) -> Any: return x >> 32 # Any is sint
49
52
  def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
50
53
  def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
51
- def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
54
+ def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << (end - start + 1)) - 1)
52
55
  def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
56
+ def is_numpy_ndarray(x) -> bool: return str(type(x)) == "<class 'numpy.ndarray'>"
53
57
  def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
54
58
  kvs = set([(k,v) for d in ds for k,v in d.items()])
55
59
  assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
@@ -58,11 +62,11 @@ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]
58
62
  ret:tuple[list[T], list[T]] = ([], [])
59
63
  for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
60
64
  return ret
61
- def unwrap(x:Optional[T]) -> T:
65
+ def unwrap(x:T|None) -> T:
62
66
  assert x is not None
63
67
  return x
64
- def get_single_element(x:list[T]) -> T:
65
- assert len(x) == 1, f"list {x} must only have 1 element"
68
+ def get_single_element(x:Sequence[T]) -> T:
69
+ assert len(x) == 1, f"{x} must only have 1 element"
66
70
  return x[0]
67
71
  def get_child(obj, key):
68
72
  for k in key.split('.'):
@@ -70,14 +74,32 @@ def get_child(obj, key):
70
74
  elif isinstance(obj, dict): obj = obj[k]
71
75
  else: obj = getattr(obj, k)
72
76
  return obj
73
- def word_wrap(x, wrap=80): return x if len(x) <= wrap or '\n' in x[0:wrap] else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
77
+ def word_wrap(x, wrap=80):
78
+ if len(ansistrip(x)) <= wrap: return x
79
+ if len(lines:=x.splitlines()) > 1: return "\n".join(word_wrap(line, wrap) for line in lines)
80
+ i = 0
81
+ while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
82
+ return x[:i] + "\n" + word_wrap(x[i:], wrap)
83
+
84
+ def suppress_finalizing(func):
85
+ def wrapper(*args, **kwargs):
86
+ try: return func(*args, **kwargs)
87
+ except (AttributeError, TypeError, ImportError):
88
+ if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
89
+ return wrapper
90
+
91
+ def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
92
+
93
+ class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
94
+ def __init__(self, gen:Callable[[int], T]): self.gen = gen
95
+ def __getitem__(self, idx:int) -> T: return self.gen(idx)
74
96
 
75
97
  # for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
76
98
  def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
77
99
 
78
- @functools.lru_cache(maxsize=None)
100
+ @functools.cache
79
101
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
80
- @functools.lru_cache(maxsize=None)
102
+ @functools.cache
81
103
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
82
104
  def temp(x:str, append_user:bool=False) -> str:
83
105
  return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
@@ -105,14 +127,19 @@ class ContextVar:
105
127
 
106
128
  DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
107
129
  JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
130
+ JIT_BATCH_SIZE = ContextVar("JIT_BATCH_SIZE", 32)
108
131
  WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
109
132
  USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
110
- TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
111
- FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
133
+ TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1), ContextVar("NOLOCALS", 0)
134
+ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 1), ContextVar("FUSE_CONV_BW", 0)
112
135
  SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
113
136
  PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
114
137
  CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
115
- DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
138
+ DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
139
+ DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
140
+ QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
141
+ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
142
+ ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
116
143
 
117
144
  @dataclass(frozen=True)
118
145
  class Metadata:
@@ -122,7 +149,6 @@ class Metadata:
122
149
  def __hash__(self): return hash(self.name)
123
150
  def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
124
151
  def __str__(self): return self.name + (" bw" if self.backward else "")
125
- _METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
126
152
 
127
153
  # **************** global state Counters ****************
128
154
 
@@ -165,12 +191,37 @@ class Profiling(contextlib.ContextDecorator):
165
191
  colored(_format_fcn(fcn).ljust(50), "yellow"),
166
192
  colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
167
193
 
194
+
195
+ @dataclass(frozen=True)
196
+ class TracingKey:
197
+ display_name:str # display name of this trace event
198
+ keys:tuple[str, ...]=() # optional keys to search for related traces
199
+ cat:str|None=None # optional category to color this by
200
+ ret:Any=None
201
+
202
+ class ProfileEvent: pass
203
+
204
+ @dataclass
205
+ class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
206
+
207
+ @dataclass(frozen=True)
208
+ class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:int; arg:dict=field(default_factory=dict) # noqa: E702
209
+
210
+ cpu_events:list[ProfileEvent] = []
211
+ @contextlib.contextmanager
212
+ def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
213
+ res = ProfileRangeEvent(device, name, decimal.Decimal(time.perf_counter_ns()) / 1000, is_copy=is_copy)
214
+ try: yield res
215
+ finally:
216
+ res.en = decimal.Decimal(time.perf_counter_ns()) / 1000
217
+ if PROFILE and display: cpu_events.append(res)
218
+
168
219
  # *** universal database cache ***
169
220
 
170
221
  cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
171
222
  CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
172
223
 
173
- VERSION = 19
224
+ VERSION = 22
174
225
  _db_connection = None
175
226
  def db_connection():
176
227
  global _db_connection
@@ -180,7 +231,7 @@ def db_connection():
180
231
  # another connection has set it already or is in the process of setting it
181
232
  # that connection will lock the database
182
233
  with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
183
- if DEBUG >= 7: _db_connection.set_trace_callback(print)
234
+ if DEBUG >= 8: _db_connection.set_trace_callback(print)
184
235
  return _db_connection
185
236
 
186
237
  def diskcache_clear():
@@ -188,11 +239,10 @@ def diskcache_clear():
188
239
  drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
189
240
  cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
190
241
 
191
- def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
242
+ def diskcache_get(table:str, key:dict|str|int) -> Any:
192
243
  if CACHELEVEL < 1: return None
193
244
  if isinstance(key, (str,int)): key = {"key": key}
194
- conn = db_connection()
195
- cur = conn.cursor()
245
+ cur = db_connection().cursor()
196
246
  try:
197
247
  res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
198
248
  except sqlite3.OperationalError:
@@ -201,7 +251,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
201
251
  return None
202
252
 
203
253
  _db_tables = set()
204
- def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
254
+ def diskcache_put(table:str, key:dict|str|int, val:Any, prepickled=False):
205
255
  if CACHELEVEL < 1: return val
206
256
  if isinstance(key, (str,int)): key = {"key": key}
207
257
  conn = db_connection()
@@ -211,22 +261,18 @@ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=Fals
211
261
  ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
212
262
  cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
213
263
  _db_tables.add(table)
214
- cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
264
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
215
265
  conn.commit()
216
266
  cur.close()
217
267
  return val
218
268
 
219
- def diskcache(func):
220
- def wrapper(*args, **kwargs) -> bytes:
269
+ def diskcache(func:Callable[..., T]):
270
+ def wrapper(*args, **kwargs) -> T:
221
271
  table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
222
- if (ret:=diskcache_get(table, key)): return ret
272
+ if (ret:=diskcache_get(table, key)) is not None: return ret
223
273
  return diskcache_put(table, key, func(*args, **kwargs))
224
274
  return wrapper
225
275
 
226
- # *** process replay ***
227
-
228
- CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
229
-
230
276
  # *** http support ***
231
277
 
232
278
  def _ensure_downloads_dir() -> pathlib.Path:
@@ -240,14 +286,14 @@ def _ensure_downloads_dir() -> pathlib.Path:
240
286
  return downloads_dir
241
287
  return pathlib.Path(cache_dir) / "downloads"
242
288
 
243
- def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
289
+ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
244
290
  allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
245
291
  if url.startswith(("/", ".")): return pathlib.Path(url)
246
292
  if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
247
293
  else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
248
294
  if not fp.is_file() or not allow_caching:
249
295
  (_dir := fp.parent).mkdir(parents=True, exist_ok=True)
250
- with urllib.request.urlopen(url, timeout=10) as r:
296
+ with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.11.0"}), timeout=10) as r:
251
297
  assert r.status == 200, r.status
252
298
  length = int(r.headers.get('content-length', 0)) if not gunzip else None
253
299
  readfile = gzip.GzipFile(fileobj=r) if gunzip else r
@@ -262,11 +308,6 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
262
308
 
263
309
  # *** Exec helpers
264
310
 
265
- def cpu_time_execution(cb, enable):
266
- if enable: st = time.perf_counter()
267
- cb()
268
- if enable: return time.perf_counter()-st
269
-
270
311
  def cpu_objdump(lib, objdump_tool='objdump'):
271
312
  with tempfile.NamedTemporaryFile(delete=True) as f:
272
313
  pathlib.Path(f.name).write_bytes(lib)
@@ -283,17 +324,23 @@ def capstone_flatdump(lib: bytes):
283
324
  print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
284
325
  sys.stdout.flush()
285
326
 
327
+ def wait_cond(cb, value=True, timeout_ms=10000, msg="") -> bool:
328
+ start_time = int(time.perf_counter() * 1000)
329
+ while int(time.perf_counter() * 1000) - start_time < timeout_ms:
330
+ if (val:=cb()) == value: return val
331
+ raise TimeoutError(f"{msg}. Timed out after {timeout_ms} ms, condition not met: {val} != {value}")
332
+
286
333
  # *** ctypes helpers
287
334
 
288
335
  # TODO: make this work with read only memoryviews (if possible)
289
- def from_mv(mv:memoryview, to_type=ctypes.c_char):
336
+ def from_mv(mv:memoryview, to_type:type[ctypes._SimpleCData]=ctypes.c_char) -> ctypes.Array:
290
337
  return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
291
- def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
338
+ def to_mv(ptr:int, sz:int) -> memoryview: return memoryview((ctypes.c_uint8 * sz).from_address(ptr)).cast("B")
292
339
  def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
293
340
  def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
294
341
  return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
295
- @functools.lru_cache(maxsize=None)
296
- def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
342
+ @functools.cache
343
+ def init_c_struct_t(fields: tuple[tuple[str, type[ctypes._SimpleCData]], ...]):
297
344
  class CStruct(ctypes.Structure):
298
345
  _pack_, _fields_ = 1, fields
299
346
  return CStruct
@@ -304,7 +351,7 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
304
351
 
305
352
  class tqdm(Generic[T]):
306
353
  def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
307
- unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
354
+ unit:str='it', unit_scale=False, total:int|None=None, rate:int=100):
308
355
  self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
309
356
  self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
310
357
  self.set_description(desc)
@@ -322,9 +369,10 @@ class tqdm(Generic[T]):
322
369
  self.n, self.i = self.n+n, self.i+1
323
370
  if self.disable or (not close and self.i % self.skip != 0): return
324
371
  prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
325
- if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
372
+ if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
326
373
  def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
327
- def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
374
+ def SI(x):
375
+ return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
328
376
  prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
329
377
  est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
330
378
  it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
tinygrad/nn/__init__.py CHANGED
@@ -10,7 +10,6 @@ class BatchNorm:
10
10
  """
11
11
  Applies Batch Normalization over a 2D or 3D input.
12
12
 
13
- - Described: https://paperswithcode.com/method/batch-normalization
14
13
  - Paper: https://arxiv.org/abs/1502.03167v3
15
14
 
16
15
  See: `Tensor.batchnorm`
@@ -182,7 +181,6 @@ class GroupNorm:
182
181
  """
183
182
  Applies Group Normalization over a mini-batch of inputs.
184
183
 
185
- - Described: https://paperswithcode.com/method/group-normalization
186
184
  - Paper: https://arxiv.org/abs/1803.08494v3
187
185
 
188
186
  ```python exec="true" source="above" session="tensor" result="python"
@@ -213,7 +211,6 @@ class InstanceNorm:
213
211
  """
214
212
  Applies Instance Normalization over a mini-batch of inputs.
215
213
 
216
- - Described: https://paperswithcode.com/method/instance-normalization
217
214
  - Paper: https://arxiv.org/abs/1607.08022v3
218
215
 
219
216
  ```python exec="true" source="above" session="tensor" result="python"
@@ -240,7 +237,6 @@ class LayerNorm:
240
237
  """
241
238
  Applies Layer Normalization over a mini-batch of inputs.
242
239
 
243
- - Described: https://paperswithcode.com/method/layer-normalization
244
240
  - Paper: https://arxiv.org/abs/1607.06450v1
245
241
 
246
242
  ```python exec="true" source="above" session="tensor" result="python"
@@ -287,7 +283,6 @@ class RMSNorm:
287
283
  """
288
284
  Applies Root Mean Square Normalization to input.
289
285
 
290
- - Described: https://paperswithcode.com/method/rmsnorm
291
286
  - Paper: https://arxiv.org/abs/1910.07467
292
287
 
293
288
  ```python exec="true" source="above" session="tensor" result="python"
@@ -299,11 +294,15 @@ class RMSNorm:
299
294
  print(norm(t).numpy())
300
295
  ```
301
296
  """
302
- def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
297
+ def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
298
+ self.eps = eps
299
+ self.weight = Tensor.ones(dim) if elementwise_affine else None
303
300
 
304
301
  def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
305
302
 
306
- def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
303
+ def __call__(self, x:Tensor) -> Tensor:
304
+ x = self._norm(x.float()).cast(x.dtype)
305
+ return x if self.weight is None else x * self.weight
307
306
 
308
307
  class Embedding:
309
308
  """
@@ -323,7 +322,7 @@ class Embedding:
323
322
  if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
324
323
  big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
325
324
  arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
326
- return (arange == idx).mul(vals).sum(-2, acc_dtype=vals.dtype)
325
+ return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
327
326
 
328
327
  class LSTMCell:
329
328
  """