tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/__init__.py +0 -0
- tinygrad/codegen/kernel.py +253 -225
- tinygrad/codegen/linearizer.py +398 -436
- tinygrad/codegen/uops.py +451 -0
- tinygrad/device.py +268 -274
- tinygrad/dtype.py +56 -40
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +198 -0
- tinygrad/engine/realize.py +192 -0
- tinygrad/engine/schedule.py +370 -0
- tinygrad/engine/search.py +199 -0
- tinygrad/{mlops.py → function.py} +40 -32
- tinygrad/helpers.py +144 -46
- tinygrad/lazy.py +143 -242
- tinygrad/multi.py +173 -0
- tinygrad/nn/__init__.py +180 -9
- tinygrad/nn/datasets.py +8 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +87 -19
- tinygrad/ops.py +104 -45
- tinygrad/renderer/__init__.py +65 -0
- tinygrad/renderer/assembly.py +269 -0
- tinygrad/renderer/cstyle.py +308 -210
- tinygrad/renderer/llvmir.py +119 -124
- tinygrad/runtime/__init__.py +0 -0
- tinygrad/runtime/autogen/amd_gpu.py +13403 -0
- tinygrad/runtime/autogen/comgr.py +891 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5893 -0
- tinygrad/runtime/autogen/io_uring.py +1486 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33597 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/__init__.py +0 -0
- tinygrad/runtime/driver/hip_comgr.py +56 -0
- tinygrad/runtime/graph/__init__.py +0 -0
- tinygrad/runtime/graph/clang.py +39 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +187 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +550 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +129 -37
- tinygrad/runtime/ops_disk.py +111 -43
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +41 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +625 -0
- tinygrad/runtime/ops_python.py +208 -0
- tinygrad/shape/__init__.py +0 -0
- tinygrad/shape/shapetracker.py +46 -107
- tinygrad/shape/symbolic.py +99 -98
- tinygrad/shape/view.py +162 -45
- tinygrad/tensor.py +2492 -483
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
- tinygrad-0.9.1.dist-info/RECORD +63 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
|
|
1
|
+
"""This is where the forwards and backwards passes live."""
|
1
2
|
import math
|
2
3
|
from typing import Tuple, Optional
|
3
4
|
from tinygrad.helpers import argsort
|
4
|
-
from tinygrad.dtype import DType
|
5
|
+
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
5
6
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
6
7
|
from tinygrad.tensor import Function
|
7
8
|
from tinygrad.lazy import LazyBuffer
|
@@ -24,21 +25,24 @@ class Cast(Function):
|
|
24
25
|
|
25
26
|
# ************* unary ops *************
|
26
27
|
|
27
|
-
class Zero(Function):
|
28
|
-
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
|
29
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
30
|
-
|
31
28
|
class Neg(Function):
|
32
29
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
33
30
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
|
34
31
|
|
32
|
+
class Reciprocal(Function):
|
33
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
34
|
+
self.ret = x.e(UnaryOps.RECIP)
|
35
|
+
return self.ret
|
36
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
37
|
+
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
|
38
|
+
|
35
39
|
class Sin(Function):
|
36
40
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
37
41
|
self.x = x
|
38
42
|
return x.e(UnaryOps.SIN)
|
39
43
|
|
40
44
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
41
|
-
return self.x.const(math.pi / 2).e(BinaryOps.
|
45
|
+
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
42
46
|
|
43
47
|
# NOTE: maximum(x, 0) behaves differently where x=0
|
44
48
|
class Relu(Function):
|
@@ -54,7 +58,7 @@ class Log(Function):
|
|
54
58
|
self.x = x
|
55
59
|
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
56
60
|
|
57
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.
|
61
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
58
62
|
|
59
63
|
class Exp(Function):
|
60
64
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
@@ -69,26 +73,35 @@ class Sqrt(Function):
|
|
69
73
|
return self.ret
|
70
74
|
|
71
75
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
72
|
-
return grad_output.e(BinaryOps.
|
76
|
+
return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
|
73
77
|
|
74
78
|
# NOTE: the implicit derivative of sigmoid is not stable
|
75
79
|
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
|
76
80
|
# TODO: have the backend automatically find this
|
77
81
|
class Sigmoid(Function):
|
78
82
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
79
|
-
self.ret = x.const(1).e(BinaryOps.
|
83
|
+
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
80
84
|
return self.ret
|
81
85
|
|
82
86
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
83
|
-
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.
|
87
|
+
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
|
88
|
+
|
89
|
+
class Sign(Function):
|
90
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
91
|
+
return x.e(BinaryOps.CMPNE, x.const(0)).e(
|
92
|
+
TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
|
93
|
+
# backward always return 0 to match torch
|
94
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
84
95
|
|
85
96
|
# ************* binary ops *************
|
86
97
|
|
87
98
|
class Less(Function):
|
88
99
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
100
|
+
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
89
101
|
|
90
|
-
class
|
91
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.
|
102
|
+
class Neq(Function):
|
103
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
|
104
|
+
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
92
105
|
|
93
106
|
class Xor(Function):
|
94
107
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
@@ -100,13 +113,6 @@ class Add(Function):
|
|
100
113
|
return grad_output if self.needs_input_grad[0] else None, \
|
101
114
|
grad_output if self.needs_input_grad[1] else None
|
102
115
|
|
103
|
-
class Sub(Function):
|
104
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
|
105
|
-
|
106
|
-
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
107
|
-
return grad_output if self.needs_input_grad[0] else None, \
|
108
|
-
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
109
|
-
|
110
116
|
class Mul(Function):
|
111
117
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
112
118
|
self.x, self.y = x, y
|
@@ -119,11 +125,11 @@ class Mul(Function):
|
|
119
125
|
class Div(Function):
|
120
126
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
121
127
|
self.x, self.y = x, y
|
122
|
-
return x.e(BinaryOps.
|
128
|
+
return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
|
123
129
|
|
124
130
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
125
|
-
return grad_output.e(BinaryOps.
|
126
|
-
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.
|
131
|
+
return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
|
132
|
+
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
|
127
133
|
|
128
134
|
# ************* ternary ops *************
|
129
135
|
|
@@ -140,32 +146,34 @@ class Where(Function):
|
|
140
146
|
# ************* reduce ops *************
|
141
147
|
|
142
148
|
class Sum(Function):
|
143
|
-
def forward(self, x:LazyBuffer,
|
149
|
+
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
144
150
|
self.input_shape = x.shape
|
145
|
-
return x.r(ReduceOps.SUM,
|
151
|
+
return x.r(ReduceOps.SUM, axis)
|
146
152
|
|
147
153
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
148
154
|
|
149
155
|
class Max(Function):
|
150
|
-
def forward(self, x:LazyBuffer,
|
151
|
-
self.x, self.ret = x, x.r(ReduceOps.MAX,
|
156
|
+
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
157
|
+
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
152
158
|
return self.ret
|
153
159
|
|
154
160
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
155
161
|
# 1s in locations where the max was chosen (can be two locations)
|
156
|
-
max_is_1s = self.x.e(BinaryOps.
|
157
|
-
|
158
|
-
|
162
|
+
max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
|
163
|
+
self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
|
164
|
+
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
165
|
+
return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
159
166
|
|
160
167
|
# ************* movement ops *************
|
161
168
|
|
162
169
|
# NOTE: this is sum in reverse
|
163
170
|
class Expand(Function):
|
164
171
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
165
|
-
self.
|
172
|
+
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
166
173
|
return x.expand(shape)
|
167
174
|
|
168
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
175
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
176
|
+
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
|
169
177
|
|
170
178
|
class Reshape(Function):
|
171
179
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
@@ -197,7 +205,7 @@ class Shrink(Function):
|
|
197
205
|
|
198
206
|
class Flip(Function):
|
199
207
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
200
|
-
self.arg = tuple([-1 if i in
|
208
|
+
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
201
209
|
return x.stride(self.arg)
|
202
210
|
|
203
211
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|
tinygrad/helpers.py
CHANGED
@@ -1,22 +1,26 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
|
3
|
-
|
4
|
-
from tqdm import tqdm
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
|
3
|
+
import itertools, urllib.request, subprocess, shutil, math, json
|
5
4
|
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
6
5
|
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
7
6
|
from typing_extensions import TypeGuard
|
7
|
+
from tinygrad.shape.shapetracker import sint
|
8
8
|
|
9
9
|
T = TypeVar("T")
|
10
10
|
U = TypeVar("U")
|
11
11
|
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
12
|
-
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.
|
12
|
+
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
|
13
13
|
|
14
14
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
15
15
|
OSX = platform.system() == "Darwin"
|
16
16
|
CI = os.getenv("CI", "") != ""
|
17
17
|
|
18
18
|
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
19
|
-
def argfix(*x):
|
19
|
+
def argfix(*x):
|
20
|
+
if x and x[0].__class__ in (tuple, list):
|
21
|
+
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
22
|
+
return tuple(x[0])
|
23
|
+
return x
|
20
24
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
21
25
|
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
22
26
|
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
@@ -51,12 +55,27 @@ def get_child(obj, key):
|
|
51
55
|
else: obj = getattr(obj, k)
|
52
56
|
return obj
|
53
57
|
|
58
|
+
def get_shape(x) -> Tuple[int, ...]:
|
59
|
+
if not isinstance(x, (list, tuple)): return ()
|
60
|
+
subs = [get_shape(xi) for xi in x]
|
61
|
+
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
|
62
|
+
return (len(subs),) + (subs[0] if subs else ())
|
63
|
+
|
64
|
+
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
65
|
+
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
66
|
+
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
67
|
+
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
68
|
+
except ValueError: return None
|
69
|
+
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
70
|
+
|
54
71
|
@functools.lru_cache(maxsize=None)
|
55
72
|
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
56
73
|
@functools.lru_cache(maxsize=None)
|
57
74
|
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
58
75
|
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
59
76
|
|
77
|
+
class GraphException(Exception): pass
|
78
|
+
|
60
79
|
class Context(contextlib.ContextDecorator):
|
61
80
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
62
81
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
@@ -70,18 +89,34 @@ class Context(contextlib.ContextDecorator):
|
|
70
89
|
class ContextVar:
|
71
90
|
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
72
91
|
value: int
|
92
|
+
key: str
|
73
93
|
def __new__(cls, key, default_value):
|
74
94
|
if key in ContextVar._cache: return ContextVar._cache[key]
|
75
95
|
instance = ContextVar._cache[key] = super().__new__(cls)
|
76
|
-
instance.value = getenv(key, default_value)
|
96
|
+
instance.value, instance.key = getenv(key, default_value), key
|
77
97
|
return instance
|
78
98
|
def __bool__(self): return bool(self.value)
|
79
99
|
def __ge__(self, x): return self.value >= x
|
80
100
|
def __gt__(self, x): return self.value > x
|
81
101
|
def __lt__(self, x): return self.value < x
|
82
102
|
|
83
|
-
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
84
|
-
|
103
|
+
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
104
|
+
WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
|
105
|
+
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
106
|
+
MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
|
107
|
+
|
108
|
+
# **************** global state Counters ****************
|
109
|
+
|
110
|
+
class GlobalCounters:
|
111
|
+
global_ops: ClassVar[int] = 0
|
112
|
+
global_mem: ClassVar[int] = 0
|
113
|
+
time_sum_s: ClassVar[float] = 0.0
|
114
|
+
kernel_count: ClassVar[int] = 0
|
115
|
+
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
116
|
+
@staticmethod
|
117
|
+
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
118
|
+
|
119
|
+
# **************** timer and profiler ****************
|
85
120
|
|
86
121
|
class Timing(contextlib.ContextDecorator):
|
87
122
|
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
@@ -90,16 +125,52 @@ class Timing(contextlib.ContextDecorator):
|
|
90
125
|
self.et = time.perf_counter_ns() - self.st
|
91
126
|
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
92
127
|
|
128
|
+
def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
|
93
129
|
class Profiling(contextlib.ContextDecorator):
|
94
|
-
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None
|
130
|
+
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
|
131
|
+
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
|
95
132
|
def __enter__(self):
|
96
|
-
self.pr = cProfile.Profile(
|
133
|
+
self.pr = cProfile.Profile()
|
97
134
|
if self.enabled: self.pr.enable()
|
98
135
|
def __exit__(self, *exc):
|
99
136
|
if self.enabled:
|
100
137
|
self.pr.disable()
|
101
138
|
if self.fn: self.pr.dump_stats(self.fn)
|
102
|
-
pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
139
|
+
stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
140
|
+
for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
|
141
|
+
(_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
|
142
|
+
scallers = sorted(callers.items(), key=lambda x: -x[1][2])
|
143
|
+
print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
|
144
|
+
colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
|
145
|
+
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
|
146
|
+
|
147
|
+
class ProfileLogger:
|
148
|
+
writers: int = 0
|
149
|
+
mjson: List[Dict] = []
|
150
|
+
actors: Dict[str, int] = {}
|
151
|
+
subactors: Dict[Tuple[str, str], int] = {}
|
152
|
+
path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
|
153
|
+
|
154
|
+
def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
|
155
|
+
|
156
|
+
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
|
157
|
+
|
158
|
+
def __del__(self):
|
159
|
+
for name,st,et,actor_name,subactor_name in self.events:
|
160
|
+
if actor_name not in self.actors:
|
161
|
+
self.actors[actor_name] = (pid:=len(self.actors))
|
162
|
+
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
|
163
|
+
|
164
|
+
if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
|
165
|
+
self.subactors[subactor_key] = (tid:=len(self.subactors))
|
166
|
+
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
|
167
|
+
|
168
|
+
self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
|
169
|
+
|
170
|
+
ProfileLogger.writers -= 1
|
171
|
+
if ProfileLogger.writers == 0:
|
172
|
+
with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
|
173
|
+
print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
|
103
174
|
|
104
175
|
# *** universal database cache ***
|
105
176
|
|
@@ -107,7 +178,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
|
107
178
|
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
108
179
|
CACHELEVEL = getenv("CACHELEVEL", 2)
|
109
180
|
|
110
|
-
VERSION =
|
181
|
+
VERSION = 16
|
111
182
|
_db_connection = None
|
112
183
|
def db_connection():
|
113
184
|
global _db_connection
|
@@ -117,13 +188,18 @@ def db_connection():
|
|
117
188
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
118
189
|
return _db_connection
|
119
190
|
|
191
|
+
def diskcache_clear():
|
192
|
+
cur = db_connection().cursor()
|
193
|
+
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
194
|
+
cur.executescript("\n".join([s[0] for s in drop_tables]))
|
195
|
+
|
120
196
|
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
121
197
|
if CACHELEVEL == 0: return None
|
122
198
|
if isinstance(key, (str,int)): key = {"key": key}
|
123
199
|
conn = db_connection()
|
124
200
|
cur = conn.cursor()
|
125
201
|
try:
|
126
|
-
res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
202
|
+
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
127
203
|
except sqlite3.OperationalError:
|
128
204
|
return None # table doesn't exist
|
129
205
|
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
@@ -138,27 +214,37 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|
138
214
|
if table not in _db_tables:
|
139
215
|
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
140
216
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
141
|
-
cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
217
|
+
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
142
218
|
_db_tables.add(table)
|
143
|
-
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
219
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
144
220
|
conn.commit()
|
145
221
|
cur.close()
|
146
222
|
return val
|
147
223
|
|
224
|
+
def diskcache(func):
|
225
|
+
def wrapper(*args, **kwargs) -> bytes:
|
226
|
+
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
227
|
+
if (ret:=diskcache_get(table, key)): return ret
|
228
|
+
return diskcache_put(table, key, func(*args, **kwargs))
|
229
|
+
return wrapper
|
230
|
+
|
148
231
|
# *** http support ***
|
149
232
|
|
150
|
-
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None,
|
233
|
+
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
|
234
|
+
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
151
235
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
152
|
-
|
236
|
+
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
237
|
+
else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
|
153
238
|
if not fp.is_file() or not allow_caching:
|
154
|
-
with request.urlopen(url, timeout=10) as r:
|
239
|
+
with urllib.request.urlopen(url, timeout=10) as r:
|
155
240
|
assert r.status == 200
|
156
241
|
total_length = int(r.headers.get('content-length', 0))
|
157
|
-
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
|
242
|
+
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
|
158
243
|
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
159
244
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
160
245
|
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
161
246
|
f.close()
|
247
|
+
progress_bar.update(close=True)
|
162
248
|
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
163
249
|
pathlib.Path(f.name).rename(fp)
|
164
250
|
return fp
|
@@ -170,11 +256,18 @@ def cpu_time_execution(cb, enable):
|
|
170
256
|
cb()
|
171
257
|
if enable: return time.perf_counter()-st
|
172
258
|
|
259
|
+
def cpu_objdump(lib):
|
260
|
+
with tempfile.NamedTemporaryFile(delete=True) as f:
|
261
|
+
pathlib.Path(f.name).write_bytes(lib)
|
262
|
+
print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
|
263
|
+
|
173
264
|
# *** ctypes helpers
|
174
265
|
|
175
266
|
# TODO: make this work with read only memoryviews (if possible)
|
176
|
-
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
267
|
+
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
268
|
+
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
177
269
|
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
270
|
+
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
178
271
|
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
179
272
|
@functools.lru_cache(maxsize=None)
|
180
273
|
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
@@ -182,31 +275,36 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
|
182
275
|
_pack_, _fields_ = 1, fields
|
183
276
|
return CStruct
|
184
277
|
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
185
|
-
def
|
186
|
-
def flat_mv(mv:memoryview):
|
187
|
-
if len(mv) == 0: return mv
|
188
|
-
return mv.cast("B", shape=(mv.nbytes,))
|
189
|
-
|
190
|
-
# *** Helpers for CUDA-like APIs.
|
191
|
-
|
192
|
-
def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
|
193
|
-
check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
|
194
|
-
status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
|
278
|
+
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|
195
279
|
|
196
|
-
|
197
|
-
|
280
|
+
class tqdm:
|
281
|
+
def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
|
282
|
+
self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
|
283
|
+
self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
|
284
|
+
self.update(0)
|
285
|
+
def __iter__(self):
|
286
|
+
try:
|
287
|
+
for item in self.iter:
|
288
|
+
yield item
|
289
|
+
self.update(1)
|
290
|
+
finally: self.update(close=True)
|
291
|
+
def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
|
292
|
+
def update(self, n:int=0, close:bool=False):
|
293
|
+
self.n, self.i = self.n+n, self.i+1
|
294
|
+
if (self.i % self.skip != 0 and not close) or self.dis: return
|
295
|
+
prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
|
296
|
+
if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
|
297
|
+
def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
|
298
|
+
def scl(x): return x/1000**int(math.log(x,1000))
|
299
|
+
def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
|
300
|
+
if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
|
301
|
+
else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
|
302
|
+
it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
|
303
|
+
if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
|
304
|
+
else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
|
305
|
+
sz = max(term-5-len(suf)-len(self.desc), 1)
|
306
|
+
bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
|
307
|
+
print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
|
198
308
|
|
199
|
-
|
200
|
-
|
201
|
-
return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
|
202
|
-
|
203
|
-
def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
|
204
|
-
if not enable: return cb()
|
205
|
-
evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
|
206
|
-
evrecord(evs[0], None)
|
207
|
-
cb()
|
208
|
-
evrecord(evs[1], None)
|
209
|
-
evsync(evs[1])
|
210
|
-
evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
|
211
|
-
for ev in evs: evdestroy(ev)
|
212
|
-
return ret.value * 1e-3
|
309
|
+
class trange(tqdm):
|
310
|
+
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
|