tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,48 @@
|
|
1
|
+
"""This is where the forwards and backwards passes live."""
|
1
2
|
import math
|
2
3
|
from typing import Tuple, Optional
|
3
|
-
from tinygrad.helpers import argsort
|
4
|
+
from tinygrad.helpers import argsort
|
5
|
+
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
4
6
|
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
5
7
|
from tinygrad.tensor import Function
|
6
8
|
from tinygrad.lazy import LazyBuffer
|
9
|
+
from tinygrad.shape.symbolic import sint
|
7
10
|
|
8
11
|
class Contiguous(Function):
|
9
12
|
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
10
13
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
|
11
14
|
|
15
|
+
class ContiguousBackward(Function):
|
16
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer: return x
|
17
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
|
18
|
+
|
12
19
|
class Cast(Function):
|
13
20
|
def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
|
14
21
|
self.input_dtype, self.bitcast = x.dtype, bitcast
|
15
|
-
return x.
|
22
|
+
return x.cast(dtype, bitcast)
|
16
23
|
|
17
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
18
|
-
return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast))
|
24
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
|
19
25
|
|
20
26
|
# ************* unary ops *************
|
21
27
|
|
28
|
+
class Neg(Function):
|
29
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
|
30
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
|
31
|
+
|
32
|
+
class Reciprocal(Function):
|
33
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
34
|
+
self.ret = x.const(1).e(BinaryOps.DIV, x)
|
35
|
+
return self.ret
|
36
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
37
|
+
return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
|
38
|
+
|
22
39
|
class Sin(Function):
|
23
40
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
24
41
|
self.x = x
|
25
42
|
return x.e(UnaryOps.SIN)
|
26
43
|
|
27
|
-
def backward(self,
|
28
|
-
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL,
|
44
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
45
|
+
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
29
46
|
|
30
47
|
# NOTE: maximum(x, 0) behaves differently where x=0
|
31
48
|
class Relu(Function):
|
@@ -34,23 +51,21 @@ class Relu(Function):
|
|
34
51
|
return self.ret
|
35
52
|
|
36
53
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
37
|
-
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
|
54
|
+
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
|
38
55
|
|
39
56
|
class Log(Function):
|
40
57
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
41
58
|
self.x = x
|
42
59
|
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
43
60
|
|
44
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
45
|
-
return grad_output.e(BinaryOps.DIV, self.x)
|
61
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
|
46
62
|
|
47
63
|
class Exp(Function):
|
48
64
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
49
65
|
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
50
66
|
return self.ret
|
51
67
|
|
52
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
53
|
-
return self.ret.e(BinaryOps.MUL, grad_output)
|
68
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
|
54
69
|
|
55
70
|
class Sqrt(Function):
|
56
71
|
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
@@ -71,48 +86,39 @@ class Sigmoid(Function):
|
|
71
86
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
72
87
|
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
|
73
88
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
82
|
-
return grad_output.expand(self.input_shape)
|
83
|
-
|
84
|
-
class Max(Function):
|
85
|
-
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
86
|
-
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
|
87
|
-
return self.ret
|
88
|
-
|
89
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
90
|
-
# 1s in locations where the max was chosen (can be two locations)
|
91
|
-
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
92
|
-
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
93
|
-
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
89
|
+
class Sign(Function):
|
90
|
+
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
91
|
+
return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
|
92
|
+
x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
|
93
|
+
# backward always return 0 to match torch
|
94
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
|
94
95
|
|
95
96
|
# ************* binary ops *************
|
96
97
|
|
97
98
|
class Less(Function):
|
98
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
99
|
-
|
99
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
|
100
|
+
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
101
|
+
|
102
|
+
class Eq(Function):
|
103
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
|
104
|
+
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
|
105
|
+
|
106
|
+
class Xor(Function):
|
107
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
|
100
108
|
|
101
109
|
class Add(Function):
|
102
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
103
|
-
return x.e(BinaryOps.ADD, y)
|
110
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
|
104
111
|
|
105
112
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
106
113
|
return grad_output if self.needs_input_grad[0] else None, \
|
107
114
|
grad_output if self.needs_input_grad[1] else None
|
108
115
|
|
109
116
|
class Sub(Function):
|
110
|
-
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
111
|
-
return x.e(BinaryOps.SUB, y)
|
117
|
+
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
|
112
118
|
|
113
119
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
114
120
|
return grad_output if self.needs_input_grad[0] else None, \
|
115
|
-
grad_output.
|
121
|
+
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
116
122
|
|
117
123
|
class Mul(Function):
|
118
124
|
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
@@ -130,67 +136,82 @@ class Div(Function):
|
|
130
136
|
|
131
137
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
132
138
|
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
133
|
-
grad_output.
|
139
|
+
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
|
134
140
|
|
135
141
|
# ************* ternary ops *************
|
136
142
|
|
137
143
|
class Where(Function):
|
138
144
|
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
139
145
|
self.x = x
|
140
|
-
return x.e(TernaryOps.WHERE, y, z)
|
146
|
+
return self.x.e(TernaryOps.WHERE, y, z)
|
141
147
|
|
142
148
|
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
143
149
|
return None, \
|
144
|
-
|
145
|
-
|
150
|
+
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
|
151
|
+
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
|
152
|
+
|
153
|
+
# ************* reduce ops *************
|
154
|
+
|
155
|
+
class Sum(Function):
|
156
|
+
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
157
|
+
self.input_shape = x.shape
|
158
|
+
return x.r(ReduceOps.SUM, axis)
|
159
|
+
|
160
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
161
|
+
|
162
|
+
class Max(Function):
|
163
|
+
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
164
|
+
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
165
|
+
return self.ret
|
166
|
+
|
167
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
168
|
+
# 1s in locations where the max was chosen (can be two locations)
|
169
|
+
max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
|
170
|
+
div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
|
171
|
+
return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
146
172
|
|
147
173
|
# ************* movement ops *************
|
148
174
|
|
149
175
|
# NOTE: this is sum in reverse
|
150
176
|
class Expand(Function):
|
151
177
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
152
|
-
self.
|
178
|
+
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
153
179
|
return x.expand(shape)
|
154
180
|
|
155
181
|
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
156
|
-
return grad_output.
|
182
|
+
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
|
157
183
|
|
158
184
|
class Reshape(Function):
|
159
185
|
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
160
186
|
self.input_shape = x.shape
|
161
187
|
return x.reshape(shape)
|
162
188
|
|
163
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
164
|
-
return grad_output.reshape(self.input_shape)
|
189
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
|
165
190
|
|
166
191
|
class Permute(Function):
|
167
192
|
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
168
193
|
self.input_order = order
|
169
194
|
return x.permute(order)
|
170
195
|
|
171
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
172
|
-
return grad_output.permute(argsort(self.input_order))
|
196
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
|
173
197
|
|
174
198
|
class Pad(Function):
|
175
199
|
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
176
200
|
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
177
201
|
return x.pad(arg)
|
178
202
|
|
179
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
180
|
-
return grad_output.shrink(self.narg)
|
203
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
|
181
204
|
|
182
205
|
class Shrink(Function):
|
183
|
-
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[
|
206
|
+
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
184
207
|
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
185
208
|
return x.shrink(arg)
|
186
209
|
|
187
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
188
|
-
return grad_output.pad(self.narg)
|
210
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
|
189
211
|
|
190
212
|
class Flip(Function):
|
191
213
|
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
192
214
|
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
193
215
|
return x.stride(self.arg)
|
194
216
|
|
195
|
-
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
196
|
-
return grad_output.stride(self.arg)
|
217
|
+
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
|
tinygrad/helpers.py
CHANGED
@@ -1,35 +1,75 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import os, functools, platform, time, re, contextlib
|
3
|
-
import
|
4
|
-
from
|
5
|
-
from
|
2
|
+
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
|
3
|
+
import itertools, urllib.request, subprocess
|
4
|
+
from tqdm import tqdm
|
5
|
+
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
6
|
+
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
7
|
+
from typing_extensions import TypeGuard
|
8
|
+
from tinygrad.shape.shapetracker import sint
|
9
|
+
|
10
|
+
T = TypeVar("T")
|
11
|
+
U = TypeVar("U")
|
12
|
+
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
13
|
+
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
|
6
14
|
|
7
15
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
8
16
|
OSX = platform.system() == "Darwin"
|
9
17
|
CI = os.getenv("CI", "") != ""
|
10
18
|
|
11
|
-
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
12
|
-
def argfix(*x):
|
19
|
+
def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
|
20
|
+
def argfix(*x):
|
21
|
+
if x and x[0].__class__ in (tuple, list):
|
22
|
+
if len(x) != 1: raise ValueError(f"bad arg {x}")
|
23
|
+
return tuple(x[0])
|
24
|
+
return x
|
13
25
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
14
|
-
def all_same(items): return all(x == items[0] for x in items)
|
15
|
-
def
|
16
|
-
def
|
26
|
+
def all_same(items:List[T]): return all(x == items[0] for x in items)
|
27
|
+
def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
28
|
+
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
|
29
|
+
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
30
|
+
def ansilen(s:str): return len(ansistrip(s))
|
17
31
|
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
18
|
-
def flatten(l:
|
19
|
-
def
|
32
|
+
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
33
|
+
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
|
20
34
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
21
|
-
def
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
35
|
+
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
36
|
+
def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
37
|
+
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
38
|
+
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
|
39
|
+
return {k:v for d in ds for k,v in d.items()}
|
40
|
+
def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
41
|
+
a:List[T] = []
|
42
|
+
b:List[T] = []
|
28
43
|
for s in lst: (a if fxn(s) else b).append(s)
|
29
44
|
return a,b
|
45
|
+
def unwrap(x:Optional[T]) -> T:
|
46
|
+
assert x is not None
|
47
|
+
return x
|
48
|
+
def unwrap2(x:Tuple[T,Any]) -> T:
|
49
|
+
ret, err = x
|
50
|
+
assert err is None, str(err)
|
51
|
+
return ret
|
52
|
+
def get_child(obj, key):
|
53
|
+
for k in key.split('.'):
|
54
|
+
if k.isnumeric(): obj = obj[int(k)]
|
55
|
+
elif isinstance(obj, dict): obj = obj[k]
|
56
|
+
else: obj = getattr(obj, k)
|
57
|
+
return obj
|
58
|
+
|
59
|
+
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
60
|
+
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
|
61
|
+
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
|
62
|
+
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
|
63
|
+
except ValueError: return None
|
64
|
+
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
|
30
65
|
|
31
66
|
@functools.lru_cache(maxsize=None)
|
32
|
-
def
|
67
|
+
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
68
|
+
@functools.lru_cache(maxsize=None)
|
69
|
+
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
70
|
+
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
|
71
|
+
|
72
|
+
class GraphException(Exception): pass
|
33
73
|
|
34
74
|
class Context(contextlib.ContextDecorator):
|
35
75
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
@@ -44,83 +84,23 @@ class Context(contextlib.ContextDecorator):
|
|
44
84
|
class ContextVar:
|
45
85
|
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
46
86
|
value: int
|
87
|
+
key: str
|
47
88
|
def __new__(cls, key, default_value):
|
48
89
|
if key in ContextVar._cache: return ContextVar._cache[key]
|
49
90
|
instance = ContextVar._cache[key] = super().__new__(cls)
|
50
|
-
instance.value = getenv(key, default_value)
|
91
|
+
instance.value, instance.key = getenv(key, default_value), key
|
51
92
|
return instance
|
52
93
|
def __bool__(self): return bool(self.value)
|
53
94
|
def __ge__(self, x): return self.value >= x
|
54
95
|
def __gt__(self, x): return self.value > x
|
55
96
|
def __lt__(self, x): return self.value < x
|
56
97
|
|
57
|
-
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
|
58
|
-
|
98
|
+
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
99
|
+
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
|
100
|
+
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
101
|
+
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
|
59
102
|
|
60
|
-
|
61
|
-
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
62
|
-
def __enter__(self): self.st = time.perf_counter_ns()
|
63
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
64
|
-
self.et = time.perf_counter_ns() - self.st
|
65
|
-
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
66
|
-
|
67
|
-
# **** tinygrad now supports dtypes! *****
|
68
|
-
|
69
|
-
class DType(NamedTuple):
|
70
|
-
priority: int # this determines when things get upcasted
|
71
|
-
itemsize: int
|
72
|
-
name: str
|
73
|
-
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
74
|
-
sz: int = 1
|
75
|
-
def __repr__(self): return f"dtypes.{self.name}"
|
76
|
-
|
77
|
-
# dependent typing?
|
78
|
-
class ImageDType(DType):
|
79
|
-
def __new__(cls, priority, itemsize, name, np, shape):
|
80
|
-
return super().__new__(cls, priority, itemsize, name, np)
|
81
|
-
def __init__(self, priority, itemsize, name, np, shape):
|
82
|
-
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
|
83
|
-
super().__init__()
|
84
|
-
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
85
|
-
|
86
|
-
class dtypes:
|
87
|
-
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
88
|
-
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
89
|
-
@staticmethod
|
90
|
-
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
|
91
|
-
@staticmethod
|
92
|
-
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
93
|
-
@staticmethod
|
94
|
-
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
95
|
-
@staticmethod
|
96
|
-
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
97
|
-
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
98
|
-
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
99
|
-
half = float16
|
100
|
-
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
101
|
-
float = float32
|
102
|
-
float64: Final[DType] = DType(0, 8, "double", np.float64)
|
103
|
-
double = float64
|
104
|
-
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
105
|
-
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
106
|
-
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
107
|
-
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
108
|
-
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
109
|
-
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
110
|
-
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
111
|
-
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
112
|
-
|
113
|
-
# NOTE: bfloat16 isn't supported in numpy
|
114
|
-
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
115
|
-
|
116
|
-
# NOTE: these are internal dtypes, should probably check for that
|
117
|
-
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
118
|
-
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
119
|
-
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
120
|
-
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
121
|
-
|
122
|
-
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
123
|
-
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
103
|
+
# **************** global state Counters ****************
|
124
104
|
|
125
105
|
class GlobalCounters:
|
126
106
|
global_ops: ClassVar[int] = 0
|
@@ -128,7 +108,134 @@ class GlobalCounters:
|
|
128
108
|
time_sum_s: ClassVar[float] = 0.0
|
129
109
|
kernel_count: ClassVar[int] = 0
|
130
110
|
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
131
|
-
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
132
|
-
cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]]
|
133
111
|
@staticmethod
|
134
|
-
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count
|
112
|
+
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
113
|
+
|
114
|
+
# **************** timer and profiler ****************
|
115
|
+
|
116
|
+
class Timing(contextlib.ContextDecorator):
|
117
|
+
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
118
|
+
def __enter__(self): self.st = time.perf_counter_ns()
|
119
|
+
def __exit__(self, *exc):
|
120
|
+
self.et = time.perf_counter_ns() - self.st
|
121
|
+
if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
122
|
+
|
123
|
+
def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
|
124
|
+
class Profiling(contextlib.ContextDecorator):
|
125
|
+
def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
|
126
|
+
self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
|
127
|
+
def __enter__(self):
|
128
|
+
self.pr = cProfile.Profile()
|
129
|
+
if self.enabled: self.pr.enable()
|
130
|
+
def __exit__(self, *exc):
|
131
|
+
if self.enabled:
|
132
|
+
self.pr.disable()
|
133
|
+
if self.fn: self.pr.dump_stats(self.fn)
|
134
|
+
stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
|
135
|
+
for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
|
136
|
+
(_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
|
137
|
+
scallers = sorted(callers.items(), key=lambda x: -x[1][2])
|
138
|
+
print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
|
139
|
+
colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
|
140
|
+
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
|
141
|
+
|
142
|
+
# *** universal database cache ***
|
143
|
+
|
144
|
+
_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
|
145
|
+
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
146
|
+
CACHELEVEL = getenv("CACHELEVEL", 2)
|
147
|
+
|
148
|
+
VERSION = 16
|
149
|
+
_db_connection = None
|
150
|
+
def db_connection():
|
151
|
+
global _db_connection
|
152
|
+
if _db_connection is None:
|
153
|
+
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
|
154
|
+
_db_connection = sqlite3.connect(CACHEDB)
|
155
|
+
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
156
|
+
return _db_connection
|
157
|
+
|
158
|
+
def diskcache_clear():
|
159
|
+
cur = db_connection().cursor()
|
160
|
+
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
161
|
+
cur.executescript("\n".join([s[0] for s in drop_tables]))
|
162
|
+
|
163
|
+
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
164
|
+
if CACHELEVEL == 0: return None
|
165
|
+
if isinstance(key, (str,int)): key = {"key": key}
|
166
|
+
conn = db_connection()
|
167
|
+
cur = conn.cursor()
|
168
|
+
try:
|
169
|
+
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
170
|
+
except sqlite3.OperationalError:
|
171
|
+
return None # table doesn't exist
|
172
|
+
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
173
|
+
return None
|
174
|
+
|
175
|
+
_db_tables = set()
|
176
|
+
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
177
|
+
if CACHELEVEL == 0: return val
|
178
|
+
if isinstance(key, (str,int)): key = {"key": key}
|
179
|
+
conn = db_connection()
|
180
|
+
cur = conn.cursor()
|
181
|
+
if table not in _db_tables:
|
182
|
+
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
183
|
+
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
184
|
+
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
185
|
+
_db_tables.add(table)
|
186
|
+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
|
187
|
+
conn.commit()
|
188
|
+
cur.close()
|
189
|
+
return val
|
190
|
+
|
191
|
+
def diskcache(func):
|
192
|
+
def wrapper(*args, **kwargs) -> bytes:
|
193
|
+
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
194
|
+
if (ret:=diskcache_get(table, key)): return ret
|
195
|
+
return diskcache_put(table, key, func(*args, **kwargs))
|
196
|
+
return wrapper
|
197
|
+
|
198
|
+
# *** http support ***
|
199
|
+
|
200
|
+
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
201
|
+
if url.startswith(("/", ".")): return pathlib.Path(url)
|
202
|
+
fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
|
203
|
+
if not fp.is_file() or not allow_caching:
|
204
|
+
with urllib.request.urlopen(url, timeout=10) as r:
|
205
|
+
assert r.status == 200
|
206
|
+
total_length = int(r.headers.get('content-length', 0))
|
207
|
+
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
|
208
|
+
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
209
|
+
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
210
|
+
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
211
|
+
f.close()
|
212
|
+
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
213
|
+
pathlib.Path(f.name).rename(fp)
|
214
|
+
return fp
|
215
|
+
|
216
|
+
# *** Exec helpers
|
217
|
+
|
218
|
+
def cpu_time_execution(cb, enable):
|
219
|
+
if enable: st = time.perf_counter()
|
220
|
+
cb()
|
221
|
+
if enable: return time.perf_counter()-st
|
222
|
+
|
223
|
+
def cpu_objdump(lib):
|
224
|
+
with tempfile.NamedTemporaryFile(delete=True) as f:
|
225
|
+
pathlib.Path(f.name).write_bytes(lib)
|
226
|
+
print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
|
227
|
+
|
228
|
+
# *** ctypes helpers
|
229
|
+
|
230
|
+
# TODO: make this work with read only memoryviews (if possible)
|
231
|
+
def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
232
|
+
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
233
|
+
def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
234
|
+
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
235
|
+
@functools.lru_cache(maxsize=None)
|
236
|
+
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
237
|
+
class CStruct(ctypes.Structure):
|
238
|
+
_pack_, _fields_ = 1, fields
|
239
|
+
return CStruct
|
240
|
+
def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
|
241
|
+
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
|