tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
|
|
1
|
+
# type: ignore
|
2
|
+
import sys, pathlib
|
3
|
+
sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
|
4
|
+
try: import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
|
5
|
+
except ImportError as e: raise ImportError("torch frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e
|
tinygrad/gradient.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1
|
-
from typing import cast
|
2
|
-
import math,
|
3
|
-
from tinygrad.
|
4
|
-
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
1
|
+
from typing import cast
|
2
|
+
import math, dataclasses
|
3
|
+
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
5
4
|
from tinygrad.helpers import argsort
|
6
5
|
|
7
6
|
def reduce_gradient(ctx:UOp, ret:UOp):
|
8
|
-
|
7
|
+
def to_inp_shape(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
8
|
+
if ret.arg[0] == Ops.ADD: return (to_inp_shape(ctx),)
|
9
9
|
if ret.arg[0] == Ops.MAX:
|
10
|
-
max_is_1s = ret.src[0].
|
11
|
-
div = max_is_1s.r(Ops.ADD, ret.arg[1])
|
12
|
-
return ((max_is_1s/div) * ctx
|
13
|
-
if ret.arg[0] == Ops.MUL: return ((ctx * ret)
|
10
|
+
max_is_1s = ret.src[0].eq(to_inp_shape(ret)).cast(ctx.dtype)
|
11
|
+
div = to_inp_shape(max_is_1s.r(Ops.ADD, ret.arg[1]))
|
12
|
+
return ((max_is_1s/div) * to_inp_shape(ctx),)
|
13
|
+
if ret.arg[0] == Ops.MUL: return (to_inp_shape(ctx * ret) / ret.src[0],)
|
14
14
|
|
15
15
|
# ctx is grad_output
|
16
16
|
pm_gradient = PatternMatcher([
|
@@ -23,40 +23,32 @@ pm_gradient = PatternMatcher([
|
|
23
23
|
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
24
24
|
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
25
25
|
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
|
26
|
-
(ret.src[0].eq(0)
|
27
|
-
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)),
|
26
|
+
(ctx*(ret.src[0].eq(0) & ret.src[1].eq(0)).where(ret.src[1], ret.src[1]*ret.src[0].pow(ret.src[1]-1)),
|
27
|
+
ctx*ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ret*ret.src[0].log2()*math.log(2.0)))),
|
28
28
|
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
29
29
|
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
30
30
|
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
31
31
|
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
32
32
|
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
33
|
-
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
33
|
+
(UPat((Ops.CONTIGUOUS, Ops.FUSE)), lambda ctx: (ctx,)),
|
34
34
|
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
35
35
|
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
36
36
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
37
37
|
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
38
38
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
39
39
|
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
|
40
|
-
|
41
|
-
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
42
|
-
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
40
|
+
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)),)),
|
43
41
|
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
44
42
|
# there's no gradient for bitcast
|
45
43
|
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
46
44
|
])
|
47
45
|
|
48
|
-
# copied from tensor.py, get relevant toposort of gradients
|
49
46
|
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
if is_in_target_path(node):
|
56
|
-
for i in node.src:
|
57
|
-
if i not in visited: yield from _walk(i, visited)
|
58
|
-
yield node
|
59
|
-
return list(_walk(root, set()))
|
47
|
+
# compute the target path (top down)
|
48
|
+
in_target_path: dict[UOp, bool] = {}
|
49
|
+
for u in root.toposort(): in_target_path[u] = any(x in targets or in_target_path[x] for x in u.src)
|
50
|
+
# don't flow through DETACH/ASSIGN or anything not in target path
|
51
|
+
return list(root.toposort(lambda node: node.op not in {Ops.DETACH, Ops.ASSIGN} and in_target_path[node]))
|
60
52
|
|
61
53
|
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
62
54
|
grads = {root: root_grad}
|
@@ -69,5 +61,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
|
69
61
|
if v is None: continue
|
70
62
|
if k in grads: grads[k] = grads[k] + v
|
71
63
|
else: grads[k] = v
|
72
|
-
if (forward_metadata:=all_metadata.get(t0))
|
64
|
+
if len(forward_metadata:=all_metadata.get(t0, ())): all_metadata[v] = tuple(dataclasses.replace(x, backward=True) for x in forward_metadata)
|
73
65
|
return grads
|
tinygrad/helpers.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
|
3
|
-
import urllib.request, subprocess, shutil, math,
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import
|
3
|
+
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator
|
6
6
|
|
7
7
|
T = TypeVar("T")
|
8
8
|
U = TypeVar("U")
|
9
9
|
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
10
|
-
def prod(x:Iterable[T]) ->
|
10
|
+
def prod(x:Iterable[T]) -> T|int: return functools.reduce(operator.mul, x, 1)
|
11
11
|
|
12
12
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
13
13
|
OSX = platform.system() == "Darwin"
|
@@ -23,15 +23,14 @@ def argfix(*x):
|
|
23
23
|
return tuple(x[0])
|
24
24
|
return x
|
25
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
26
|
-
def all_same(items:
|
26
|
+
def all_same(items:tuple[T, ...]|list[T]): return all(x == items[0] for x in items)
|
27
27
|
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
28
|
-
def colored(st, color:
|
28
|
+
def colored(st, color:str|None, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
29
|
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
|
30
|
-
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
|
31
30
|
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
|
32
31
|
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
33
32
|
def ansilen(s:str): return len(ansistrip(s))
|
34
|
-
def make_tuple(x:
|
33
|
+
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
35
34
|
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
36
35
|
def fully_flatten(l):
|
37
36
|
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
@@ -44,12 +43,17 @@ def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
|
44
43
|
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
45
44
|
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
46
45
|
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
46
|
+
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
|
47
|
+
# cstyle div and mod
|
48
|
+
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
|
49
|
+
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
|
47
50
|
def lo32(x:Any) -> Any: return x & 0xFFFFFFFF # Any is sint
|
48
51
|
def hi32(x:Any) -> Any: return x >> 32 # Any is sint
|
49
52
|
def data64(data:Any) -> tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
50
53
|
def data64_le(data:Any) -> tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
51
|
-
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << end-start+1) - 1)
|
54
|
+
def getbits(value: int, start: int, end: int): return (value >> start) & ((1 << (end - start + 1)) - 1)
|
52
55
|
def i2u(bits: int, value: int): return value if value >= 0 else (1<<bits)+value
|
56
|
+
def is_numpy_ndarray(x) -> bool: return str(type(x)) == "<class 'numpy.ndarray'>"
|
53
57
|
def merge_dicts(ds:Iterable[dict[T,U]]) -> dict[T,U]:
|
54
58
|
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
55
59
|
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
@@ -58,11 +62,11 @@ def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> tuple[list[T], list[T]
|
|
58
62
|
ret:tuple[list[T], list[T]] = ([], [])
|
59
63
|
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
60
64
|
return ret
|
61
|
-
def unwrap(x:
|
65
|
+
def unwrap(x:T|None) -> T:
|
62
66
|
assert x is not None
|
63
67
|
return x
|
64
|
-
def get_single_element(x:
|
65
|
-
assert len(x) == 1, f"
|
68
|
+
def get_single_element(x:Sequence[T]) -> T:
|
69
|
+
assert len(x) == 1, f"{x} must only have 1 element"
|
66
70
|
return x[0]
|
67
71
|
def get_child(obj, key):
|
68
72
|
for k in key.split('.'):
|
@@ -70,14 +74,32 @@ def get_child(obj, key):
|
|
70
74
|
elif isinstance(obj, dict): obj = obj[k]
|
71
75
|
else: obj = getattr(obj, k)
|
72
76
|
return obj
|
73
|
-
def word_wrap(x, wrap=80):
|
77
|
+
def word_wrap(x, wrap=80):
|
78
|
+
if len(ansistrip(x)) <= wrap: return x
|
79
|
+
if len(lines:=x.splitlines()) > 1: return "\n".join(word_wrap(line, wrap) for line in lines)
|
80
|
+
i = 0
|
81
|
+
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
|
82
|
+
return x[:i] + "\n" + word_wrap(x[i:], wrap)
|
83
|
+
|
84
|
+
def suppress_finalizing(func):
|
85
|
+
def wrapper(*args, **kwargs):
|
86
|
+
try: return func(*args, **kwargs)
|
87
|
+
except (AttributeError, TypeError, ImportError):
|
88
|
+
if not getattr(sys, 'is_finalizing', lambda: True)(): raise # re-raise if not finalizing
|
89
|
+
return wrapper
|
90
|
+
|
91
|
+
def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
92
|
+
|
93
|
+
class LazySeq(Generic[T]): # NOTE: Mapping requires __iter__ and __len__, Sequence requires supporting __len__ and slicing in __getitem__
|
94
|
+
def __init__(self, gen:Callable[[int], T]): self.gen = gen
|
95
|
+
def __getitem__(self, idx:int) -> T: return self.gen(idx)
|
74
96
|
|
75
97
|
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
76
98
|
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
77
99
|
|
78
|
-
@functools.
|
100
|
+
@functools.cache
|
79
101
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
80
|
-
@functools.
|
102
|
+
@functools.cache
|
81
103
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
82
104
|
def temp(x:str, append_user:bool=False) -> str:
|
83
105
|
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
@@ -105,14 +127,19 @@ class ContextVar:
|
|
105
127
|
|
106
128
|
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
107
129
|
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
|
130
|
+
JIT_BATCH_SIZE = ContextVar("JIT_BATCH_SIZE", 32)
|
108
131
|
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
109
132
|
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
|
110
|
-
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
|
111
|
-
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE",
|
133
|
+
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE, NOLOCALS = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1), ContextVar("NOLOCALS", 0)
|
134
|
+
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 1), ContextVar("FUSE_CONV_BW", 0)
|
112
135
|
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
113
136
|
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
114
137
|
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
115
|
-
|
138
|
+
DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
|
139
|
+
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
140
|
+
QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
141
|
+
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
142
|
+
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
|
116
143
|
|
117
144
|
@dataclass(frozen=True)
|
118
145
|
class Metadata:
|
@@ -122,7 +149,6 @@ class Metadata:
|
|
122
149
|
def __hash__(self): return hash(self.name)
|
123
150
|
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
|
124
151
|
def __str__(self): return self.name + (" bw" if self.backward else "")
|
125
|
-
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
126
152
|
|
127
153
|
# **************** global state Counters ****************
|
128
154
|
|
@@ -165,12 +191,37 @@ class Profiling(contextlib.ContextDecorator):
|
|
165
191
|
colored(_format_fcn(fcn).ljust(50), "yellow"),
|
166
192
|
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')
|
167
193
|
|
194
|
+
|
195
|
+
@dataclass(frozen=True)
|
196
|
+
class TracingKey:
|
197
|
+
display_name:str # display name of this trace event
|
198
|
+
keys:tuple[str, ...]=() # optional keys to search for related traces
|
199
|
+
cat:str|None=None # optional category to color this by
|
200
|
+
ret:Any=None
|
201
|
+
|
202
|
+
class ProfileEvent: pass
|
203
|
+
|
204
|
+
@dataclass
|
205
|
+
class ProfileRangeEvent(ProfileEvent): device:str; name:str|TracingKey; st:decimal.Decimal; en:decimal.Decimal|None=None; is_copy:bool=False # noqa: E702
|
206
|
+
|
207
|
+
@dataclass(frozen=True)
|
208
|
+
class ProfilePointEvent(ProfileEvent): device:str; name:str; ts:decimal.Decimal; key:int; arg:dict=field(default_factory=dict) # noqa: E702
|
209
|
+
|
210
|
+
cpu_events:list[ProfileEvent] = []
|
211
|
+
@contextlib.contextmanager
|
212
|
+
def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
|
213
|
+
res = ProfileRangeEvent(device, name, decimal.Decimal(time.perf_counter_ns()) / 1000, is_copy=is_copy)
|
214
|
+
try: yield res
|
215
|
+
finally:
|
216
|
+
res.en = decimal.Decimal(time.perf_counter_ns()) / 1000
|
217
|
+
if PROFILE and display: cpu_events.append(res)
|
218
|
+
|
168
219
|
# *** universal database cache ***
|
169
220
|
|
170
221
|
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
171
222
|
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
172
223
|
|
173
|
-
VERSION =
|
224
|
+
VERSION = 22
|
174
225
|
_db_connection = None
|
175
226
|
def db_connection():
|
176
227
|
global _db_connection
|
@@ -180,7 +231,7 @@ def db_connection():
|
|
180
231
|
# another connection has set it already or is in the process of setting it
|
181
232
|
# that connection will lock the database
|
182
233
|
with contextlib.suppress(sqlite3.OperationalError): _db_connection.execute("PRAGMA journal_mode=WAL").fetchone()
|
183
|
-
if DEBUG >=
|
234
|
+
if DEBUG >= 8: _db_connection.set_trace_callback(print)
|
184
235
|
return _db_connection
|
185
236
|
|
186
237
|
def diskcache_clear():
|
@@ -188,11 +239,10 @@ def diskcache_clear():
|
|
188
239
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
189
240
|
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
190
241
|
|
191
|
-
def diskcache_get(table:str, key:
|
242
|
+
def diskcache_get(table:str, key:dict|str|int) -> Any:
|
192
243
|
if CACHELEVEL < 1: return None
|
193
244
|
if isinstance(key, (str,int)): key = {"key": key}
|
194
|
-
|
195
|
-
cur = conn.cursor()
|
245
|
+
cur = db_connection().cursor()
|
196
246
|
try:
|
197
247
|
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
198
248
|
except sqlite3.OperationalError:
|
@@ -201,7 +251,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
|
201
251
|
return None
|
202
252
|
|
203
253
|
_db_tables = set()
|
204
|
-
def diskcache_put(table:str, key:
|
254
|
+
def diskcache_put(table:str, key:dict|str|int, val:Any, prepickled=False):
|
205
255
|
if CACHELEVEL < 1: return val
|
206
256
|
if isinstance(key, (str,int)): key = {"key": key}
|
207
257
|
conn = db_connection()
|
@@ -211,22 +261,18 @@ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=Fals
|
|
211
261
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
212
262
|
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
213
263
|
_db_tables.add(table)
|
214
|
-
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key
|
264
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
|
215
265
|
conn.commit()
|
216
266
|
cur.close()
|
217
267
|
return val
|
218
268
|
|
219
|
-
def diskcache(func):
|
220
|
-
def wrapper(*args, **kwargs) ->
|
269
|
+
def diskcache(func:Callable[..., T]):
|
270
|
+
def wrapper(*args, **kwargs) -> T:
|
221
271
|
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
222
|
-
if (ret:=diskcache_get(table, key)): return ret
|
272
|
+
if (ret:=diskcache_get(table, key)) is not None: return ret
|
223
273
|
return diskcache_put(table, key, func(*args, **kwargs))
|
224
274
|
return wrapper
|
225
275
|
|
226
|
-
# *** process replay ***
|
227
|
-
|
228
|
-
CAPTURE_PROCESS_REPLAY = getenv("RUN_PROCESS_REPLAY") or getenv("CAPTURE_PROCESS_REPLAY")
|
229
|
-
|
230
276
|
# *** http support ***
|
231
277
|
|
232
278
|
def _ensure_downloads_dir() -> pathlib.Path:
|
@@ -240,14 +286,14 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
|
240
286
|
return downloads_dir
|
241
287
|
return pathlib.Path(cache_dir) / "downloads"
|
242
288
|
|
243
|
-
def fetch(url:str, name:
|
289
|
+
def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
|
244
290
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
245
291
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
246
292
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
247
293
|
else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
248
294
|
if not fp.is_file() or not allow_caching:
|
249
295
|
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
250
|
-
with urllib.request.urlopen(url, timeout=10) as r:
|
296
|
+
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.11.0"}), timeout=10) as r:
|
251
297
|
assert r.status == 200, r.status
|
252
298
|
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
253
299
|
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
@@ -262,11 +308,6 @@ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional
|
|
262
308
|
|
263
309
|
# *** Exec helpers
|
264
310
|
|
265
|
-
def cpu_time_execution(cb, enable):
|
266
|
-
if enable: st = time.perf_counter()
|
267
|
-
cb()
|
268
|
-
if enable: return time.perf_counter()-st
|
269
|
-
|
270
311
|
def cpu_objdump(lib, objdump_tool='objdump'):
|
271
312
|
with tempfile.NamedTemporaryFile(delete=True) as f:
|
272
313
|
pathlib.Path(f.name).write_bytes(lib)
|
@@ -283,17 +324,23 @@ def capstone_flatdump(lib: bytes):
|
|
283
324
|
print(f"{instr.address:#08x}: {instr.mnemonic}\t{instr.op_str}")
|
284
325
|
sys.stdout.flush()
|
285
326
|
|
327
|
+
def wait_cond(cb, value=True, timeout_ms=10000, msg="") -> bool:
|
328
|
+
start_time = int(time.perf_counter() * 1000)
|
329
|
+
while int(time.perf_counter() * 1000) - start_time < timeout_ms:
|
330
|
+
if (val:=cb()) == value: return val
|
331
|
+
raise TimeoutError(f"{msg}. Timed out after {timeout_ms} ms, condition not met: {val} != {value}")
|
332
|
+
|
286
333
|
# *** ctypes helpers
|
287
334
|
|
288
335
|
# TODO: make this work with read only memoryviews (if possible)
|
289
|
-
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
336
|
+
def from_mv(mv:memoryview, to_type:type[ctypes._SimpleCData]=ctypes.c_char) -> ctypes.Array:
|
290
337
|
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
291
|
-
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(
|
338
|
+
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview((ctypes.c_uint8 * sz).from_address(ptr)).cast("B")
|
292
339
|
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
293
340
|
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
|
294
341
|
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
295
|
-
@functools.
|
296
|
-
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
|
342
|
+
@functools.cache
|
343
|
+
def init_c_struct_t(fields: tuple[tuple[str, type[ctypes._SimpleCData]], ...]):
|
297
344
|
class CStruct(ctypes.Structure):
|
298
345
|
_pack_, _fields_ = 1, fields
|
299
346
|
return CStruct
|
@@ -304,7 +351,7 @@ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(m
|
|
304
351
|
|
305
352
|
class tqdm(Generic[T]):
|
306
353
|
def __init__(self, iterable:Iterable[T]|None=None, desc:str='', disable:bool=False,
|
307
|
-
unit:str='it', unit_scale=False, total:
|
354
|
+
unit:str='it', unit_scale=False, total:int|None=None, rate:int=100):
|
308
355
|
self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate
|
309
356
|
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
|
310
357
|
self.set_description(desc)
|
@@ -322,9 +369,10 @@ class tqdm(Generic[T]):
|
|
322
369
|
self.n, self.i = self.n+n, self.i+1
|
323
370
|
if self.disable or (not close and self.i % self.skip != 0): return
|
324
371
|
prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
325
|
-
if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
|
372
|
+
if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
|
326
373
|
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
|
327
|
-
def SI(x):
|
374
|
+
def SI(x):
|
375
|
+
return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
|
328
376
|
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
|
329
377
|
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
|
330
378
|
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
|
tinygrad/nn/__init__.py
CHANGED
@@ -10,7 +10,6 @@ class BatchNorm:
|
|
10
10
|
"""
|
11
11
|
Applies Batch Normalization over a 2D or 3D input.
|
12
12
|
|
13
|
-
- Described: https://paperswithcode.com/method/batch-normalization
|
14
13
|
- Paper: https://arxiv.org/abs/1502.03167v3
|
15
14
|
|
16
15
|
See: `Tensor.batchnorm`
|
@@ -182,7 +181,6 @@ class GroupNorm:
|
|
182
181
|
"""
|
183
182
|
Applies Group Normalization over a mini-batch of inputs.
|
184
183
|
|
185
|
-
- Described: https://paperswithcode.com/method/group-normalization
|
186
184
|
- Paper: https://arxiv.org/abs/1803.08494v3
|
187
185
|
|
188
186
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -213,7 +211,6 @@ class InstanceNorm:
|
|
213
211
|
"""
|
214
212
|
Applies Instance Normalization over a mini-batch of inputs.
|
215
213
|
|
216
|
-
- Described: https://paperswithcode.com/method/instance-normalization
|
217
214
|
- Paper: https://arxiv.org/abs/1607.08022v3
|
218
215
|
|
219
216
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -240,7 +237,6 @@ class LayerNorm:
|
|
240
237
|
"""
|
241
238
|
Applies Layer Normalization over a mini-batch of inputs.
|
242
239
|
|
243
|
-
- Described: https://paperswithcode.com/method/layer-normalization
|
244
240
|
- Paper: https://arxiv.org/abs/1607.06450v1
|
245
241
|
|
246
242
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -287,7 +283,6 @@ class RMSNorm:
|
|
287
283
|
"""
|
288
284
|
Applies Root Mean Square Normalization to input.
|
289
285
|
|
290
|
-
- Described: https://paperswithcode.com/method/rmsnorm
|
291
286
|
- Paper: https://arxiv.org/abs/1910.07467
|
292
287
|
|
293
288
|
```python exec="true" source="above" session="tensor" result="python"
|
@@ -299,11 +294,15 @@ class RMSNorm:
|
|
299
294
|
print(norm(t).numpy())
|
300
295
|
```
|
301
296
|
"""
|
302
|
-
def __init__(self, dim:int, eps=1e-6
|
297
|
+
def __init__(self, dim:int, eps=1e-6, elementwise_affine=True):
|
298
|
+
self.eps = eps
|
299
|
+
self.weight = Tensor.ones(dim) if elementwise_affine else None
|
303
300
|
|
304
301
|
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
305
302
|
|
306
|
-
def __call__(self, x:Tensor) -> Tensor:
|
303
|
+
def __call__(self, x:Tensor) -> Tensor:
|
304
|
+
x = self._norm(x.float()).cast(x.dtype)
|
305
|
+
return x if self.weight is None else x * self.weight
|
307
306
|
|
308
307
|
class Embedding:
|
309
308
|
"""
|
@@ -323,7 +322,7 @@ class Embedding:
|
|
323
322
|
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).unsqueeze(-1)
|
324
323
|
big_shp = idx.shape+(self.vocab_sz, self.embed_sz)
|
325
324
|
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1)).expand(big_shp), self.weight.expand(big_shp)
|
326
|
-
return (arange == idx).mul(vals).sum(-2,
|
325
|
+
return (arange == idx).mul(vals).sum(-2, dtype=vals.dtype)
|
327
326
|
|
328
327
|
class LSTMCell:
|
329
328
|
"""
|