tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
+ """This is where the forwards and backwards passes live."""
1
2
  import math
2
3
  from typing import Tuple, Optional
3
4
  from tinygrad.helpers import argsort
4
- from tinygrad.dtype import DType
5
+ from tinygrad.dtype import dtypes, DType, sum_acc_dtype
5
6
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
6
7
  from tinygrad.tensor import Function
7
8
  from tinygrad.lazy import LazyBuffer
@@ -24,21 +25,24 @@ class Cast(Function):
24
25
 
25
26
  # ************* unary ops *************
26
27
 
27
- class Zero(Function):
28
- def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0)
29
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
30
-
31
28
  class Neg(Function):
32
29
  def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
33
30
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
34
31
 
32
+ class Reciprocal(Function):
33
+ def forward(self, x:LazyBuffer) -> LazyBuffer:
34
+ self.ret = x.e(UnaryOps.RECIP)
35
+ return self.ret
36
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
37
+ return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
38
+
35
39
  class Sin(Function):
36
40
  def forward(self, x:LazyBuffer) -> LazyBuffer:
37
41
  self.x = x
38
42
  return x.e(UnaryOps.SIN)
39
43
 
40
44
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
41
- return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
45
+ return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
42
46
 
43
47
  # NOTE: maximum(x, 0) behaves differently where x=0
44
48
  class Relu(Function):
@@ -54,7 +58,7 @@ class Log(Function):
54
58
  self.x = x
55
59
  return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
56
60
 
57
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
61
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
58
62
 
59
63
  class Exp(Function):
60
64
  def forward(self, x:LazyBuffer) -> LazyBuffer:
@@ -69,26 +73,35 @@ class Sqrt(Function):
69
73
  return self.ret
70
74
 
71
75
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
72
- return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)))
76
+ return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
73
77
 
74
78
  # NOTE: the implicit derivative of sigmoid is not stable
75
79
  # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
76
80
  # TODO: have the backend automatically find this
77
81
  class Sigmoid(Function):
78
82
  def forward(self, x:LazyBuffer) -> LazyBuffer:
79
- self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
83
+ self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
80
84
  return self.ret
81
85
 
82
86
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
83
- return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
87
+ return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
88
+
89
+ class Sign(Function):
90
+ def forward(self, x:LazyBuffer) -> LazyBuffer:
91
+ return x.e(BinaryOps.CMPNE, x.const(0)).e(
92
+ TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
93
+ # backward always return 0 to match torch
94
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
84
95
 
85
96
  # ************* binary ops *************
86
97
 
87
98
  class Less(Function):
88
99
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
100
+ def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
89
101
 
90
- class Eq(Function):
91
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
102
+ class Neq(Function):
103
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
104
+ def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
92
105
 
93
106
  class Xor(Function):
94
107
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
@@ -100,13 +113,6 @@ class Add(Function):
100
113
  return grad_output if self.needs_input_grad[0] else None, \
101
114
  grad_output if self.needs_input_grad[1] else None
102
115
 
103
- class Sub(Function):
104
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
105
-
106
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
107
- return grad_output if self.needs_input_grad[0] else None, \
108
- grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
109
-
110
116
  class Mul(Function):
111
117
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
112
118
  self.x, self.y = x, y
@@ -119,11 +125,11 @@ class Mul(Function):
119
125
  class Div(Function):
120
126
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
121
127
  self.x, self.y = x, y
122
- return x.e(BinaryOps.DIV, y)
128
+ return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
123
129
 
124
130
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
125
- return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
126
- grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
131
+ return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
132
+ grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
127
133
 
128
134
  # ************* ternary ops *************
129
135
 
@@ -140,32 +146,34 @@ class Where(Function):
140
146
  # ************* reduce ops *************
141
147
 
142
148
  class Sum(Function):
143
- def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
149
+ def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
144
150
  self.input_shape = x.shape
145
- return x.r(ReduceOps.SUM, new_shape)
151
+ return x.r(ReduceOps.SUM, axis)
146
152
 
147
153
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
148
154
 
149
155
  class Max(Function):
150
- def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
151
- self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
156
+ def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
157
+ self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
152
158
  return self.ret
153
159
 
154
160
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
155
161
  # 1s in locations where the max was chosen (can be two locations)
156
- max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype)
157
- div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
158
- return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
162
+ max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
163
+ self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
164
+ div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
165
+ return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
159
166
 
160
167
  # ************* movement ops *************
161
168
 
162
169
  # NOTE: this is sum in reverse
163
170
  class Expand(Function):
164
171
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
165
- self.input_shape = x.shape
172
+ self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
166
173
  return x.expand(shape)
167
174
 
168
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.input_shape)
175
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
176
+ return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
169
177
 
170
178
  class Reshape(Function):
171
179
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
@@ -197,7 +205,7 @@ class Shrink(Function):
197
205
 
198
206
  class Flip(Function):
199
207
  def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
200
- self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
208
+ self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
201
209
  return x.stride(self.arg)
202
210
 
203
211
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
tinygrad/helpers.py CHANGED
@@ -1,22 +1,26 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
3
- from urllib import request # NOTE: this has to be imported specifically
4
- from tqdm import tqdm
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
3
+ import itertools, urllib.request, subprocess, shutil, math, json
5
4
  from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
6
5
  if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
7
6
  from typing_extensions import TypeGuard
7
+ from tinygrad.shape.shapetracker import sint
8
8
 
9
9
  T = TypeVar("T")
10
10
  U = TypeVar("U")
11
11
  # NOTE: it returns int 1 if x is empty regardless of the type of x
12
- def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
12
+ def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
13
13
 
14
14
  # NOTE: helpers is not allowed to import from anything else in tinygrad
15
15
  OSX = platform.system() == "Darwin"
16
16
  CI = os.getenv("CI", "") != ""
17
17
 
18
18
  def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
19
- def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
19
+ def argfix(*x):
20
+ if x and x[0].__class__ in (tuple, list):
21
+ if len(x) != 1: raise ValueError(f"bad arg {x}")
22
+ return tuple(x[0])
23
+ return x
20
24
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
21
25
  def all_same(items:List[T]): return all(x == items[0] for x in items)
22
26
  def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
@@ -51,12 +55,27 @@ def get_child(obj, key):
51
55
  else: obj = getattr(obj, k)
52
56
  return obj
53
57
 
58
+ def get_shape(x) -> Tuple[int, ...]:
59
+ if not isinstance(x, (list, tuple)): return ()
60
+ subs = [get_shape(xi) for xi in x]
61
+ if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
62
+ return (len(subs),) + (subs[0] if subs else ())
63
+
64
+ # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
65
+ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
66
+ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
67
+ try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
68
+ except ValueError: return None
69
+ return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
70
+
54
71
  @functools.lru_cache(maxsize=None)
55
72
  def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
56
73
  @functools.lru_cache(maxsize=None)
57
74
  def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
58
75
  def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
59
76
 
77
+ class GraphException(Exception): pass
78
+
60
79
  class Context(contextlib.ContextDecorator):
61
80
  stack: ClassVar[List[dict[str, int]]] = [{}]
62
81
  def __init__(self, **kwargs): self.kwargs = kwargs
@@ -70,18 +89,34 @@ class Context(contextlib.ContextDecorator):
70
89
  class ContextVar:
71
90
  _cache: ClassVar[Dict[str, ContextVar]] = {}
72
91
  value: int
92
+ key: str
73
93
  def __new__(cls, key, default_value):
74
94
  if key in ContextVar._cache: return ContextVar._cache[key]
75
95
  instance = ContextVar._cache[key] = super().__new__(cls)
76
- instance.value = getenv(key, default_value)
96
+ instance.value, instance.key = getenv(key, default_value), key
77
97
  return instance
78
98
  def __bool__(self): return bool(self.value)
79
99
  def __ge__(self, x): return self.value >= x
80
100
  def __gt__(self, x): return self.value > x
81
101
  def __lt__(self, x): return self.value < x
82
102
 
83
- DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
84
- GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
103
+ DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
104
+ WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
105
+ GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
106
+ MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
107
+
108
+ # **************** global state Counters ****************
109
+
110
+ class GlobalCounters:
111
+ global_ops: ClassVar[int] = 0
112
+ global_mem: ClassVar[int] = 0
113
+ time_sum_s: ClassVar[float] = 0.0
114
+ kernel_count: ClassVar[int] = 0
115
+ mem_used: ClassVar[int] = 0 # NOTE: this is not reset
116
+ @staticmethod
117
+ def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
118
+
119
+ # **************** timer and profiler ****************
85
120
 
86
121
  class Timing(contextlib.ContextDecorator):
87
122
  def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
@@ -90,16 +125,52 @@ class Timing(contextlib.ContextDecorator):
90
125
  self.et = time.perf_counter_ns() - self.st
91
126
  if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
92
127
 
128
+ def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
93
129
  class Profiling(contextlib.ContextDecorator):
94
- def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None): self.enabled, self.sort, self.frac, self.fn = enabled, sort, frac, fn
130
+ def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
131
+ self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
95
132
  def __enter__(self):
96
- self.pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
133
+ self.pr = cProfile.Profile()
97
134
  if self.enabled: self.pr.enable()
98
135
  def __exit__(self, *exc):
99
136
  if self.enabled:
100
137
  self.pr.disable()
101
138
  if self.fn: self.pr.dump_stats(self.fn)
102
- pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(self.frac)
139
+ stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
140
+ for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
141
+ (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
142
+ scallers = sorted(callers.items(), key=lambda x: -x[1][2])
143
+ print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
144
+ colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
145
+ colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
146
+
147
+ class ProfileLogger:
148
+ writers: int = 0
149
+ mjson: List[Dict] = []
150
+ actors: Dict[str, int] = {}
151
+ subactors: Dict[Tuple[str, str], int] = {}
152
+ path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
153
+
154
+ def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
155
+
156
+ def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
157
+
158
+ def __del__(self):
159
+ for name,st,et,actor_name,subactor_name in self.events:
160
+ if actor_name not in self.actors:
161
+ self.actors[actor_name] = (pid:=len(self.actors))
162
+ self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
163
+
164
+ if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
165
+ self.subactors[subactor_key] = (tid:=len(self.subactors))
166
+ self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
167
+
168
+ self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
169
+
170
+ ProfileLogger.writers -= 1
171
+ if ProfileLogger.writers == 0:
172
+ with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
173
+ print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
103
174
 
104
175
  # *** universal database cache ***
105
176
 
@@ -107,7 +178,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
107
178
  CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
108
179
  CACHELEVEL = getenv("CACHELEVEL", 2)
109
180
 
110
- VERSION = 10
181
+ VERSION = 16
111
182
  _db_connection = None
112
183
  def db_connection():
113
184
  global _db_connection
@@ -117,13 +188,18 @@ def db_connection():
117
188
  if DEBUG >= 7: _db_connection.set_trace_callback(print)
118
189
  return _db_connection
119
190
 
191
+ def diskcache_clear():
192
+ cur = db_connection().cursor()
193
+ drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
194
+ cur.executescript("\n".join([s[0] for s in drop_tables]))
195
+
120
196
  def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
121
197
  if CACHELEVEL == 0: return None
122
198
  if isinstance(key, (str,int)): key = {"key": key}
123
199
  conn = db_connection()
124
200
  cur = conn.cursor()
125
201
  try:
126
- res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
202
+ res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
127
203
  except sqlite3.OperationalError:
128
204
  return None # table doesn't exist
129
205
  if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
@@ -138,27 +214,37 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
138
214
  if table not in _db_tables:
139
215
  TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
140
216
  ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
141
- cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
217
+ cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
142
218
  _db_tables.add(table)
143
- cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
219
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
144
220
  conn.commit()
145
221
  cur.close()
146
222
  return val
147
223
 
224
+ def diskcache(func):
225
+ def wrapper(*args, **kwargs) -> bytes:
226
+ table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
227
+ if (ret:=diskcache_get(table, key)): return ret
228
+ return diskcache_put(table, key, func(*args, **kwargs))
229
+ return wrapper
230
+
148
231
  # *** http support ***
149
232
 
150
- def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
233
+ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
234
+ allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
151
235
  if url.startswith(("/", ".")): return pathlib.Path(url)
152
- fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
236
+ if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
237
+ else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
153
238
  if not fp.is_file() or not allow_caching:
154
- with request.urlopen(url, timeout=10) as r:
239
+ with urllib.request.urlopen(url, timeout=10) as r:
155
240
  assert r.status == 200
156
241
  total_length = int(r.headers.get('content-length', 0))
157
- progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
242
+ progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
158
243
  (path := fp.parent).mkdir(parents=True, exist_ok=True)
159
244
  with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
160
245
  while chunk := r.read(16384): progress_bar.update(f.write(chunk))
161
246
  f.close()
247
+ progress_bar.update(close=True)
162
248
  if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
163
249
  pathlib.Path(f.name).rename(fp)
164
250
  return fp
@@ -170,11 +256,18 @@ def cpu_time_execution(cb, enable):
170
256
  cb()
171
257
  if enable: return time.perf_counter()-st
172
258
 
259
+ def cpu_objdump(lib):
260
+ with tempfile.NamedTemporaryFile(delete=True) as f:
261
+ pathlib.Path(f.name).write_bytes(lib)
262
+ print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
263
+
173
264
  # *** ctypes helpers
174
265
 
175
266
  # TODO: make this work with read only memoryviews (if possible)
176
- def from_mv(mv:memoryview, to_type=ctypes.c_char): return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type))
267
+ def from_mv(mv:memoryview, to_type=ctypes.c_char):
268
+ return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
177
269
  def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
270
+ def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
178
271
  def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
179
272
  @functools.lru_cache(maxsize=None)
180
273
  def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
@@ -182,31 +275,36 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
182
275
  _pack_, _fields_ = 1, fields
183
276
  return CStruct
184
277
  def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
185
- def get_bytes(arg, get_sz, get_str, check) -> bytes: return (sz := init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))), ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value))[1] # noqa: E501
186
- def flat_mv(mv:memoryview):
187
- if len(mv) == 0: return mv
188
- return mv.cast("B", shape=(mv.nbytes,))
189
-
190
- # *** Helpers for CUDA-like APIs.
191
-
192
- def compile_cuda_style(prg, compile_options, prog_t, create_prog, compile_prog, get_code, get_code_size, get_log, get_log_size, check) -> bytes:
193
- check(create_prog(ctypes.byref(prog := prog_t()), prg.encode(), "<null>".encode(), 0, None, None))
194
- status = compile_prog(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
278
+ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
195
279
 
196
- if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}")
197
- return get_bytes(prog, get_code_size, get_code, check)
280
+ class tqdm:
281
+ def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
282
+ self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
283
+ self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
284
+ self.update(0)
285
+ def __iter__(self):
286
+ try:
287
+ for item in self.iter:
288
+ yield item
289
+ self.update(1)
290
+ finally: self.update(close=True)
291
+ def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
292
+ def update(self, n:int=0, close:bool=False):
293
+ self.n, self.i = self.n+n, self.i+1
294
+ if (self.i % self.skip != 0 and not close) or self.dis: return
295
+ prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
296
+ if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
297
+ def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
298
+ def scl(x): return x/1000**int(math.log(x,1000))
299
+ def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
300
+ if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
301
+ else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
302
+ it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
303
+ if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
304
+ else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
305
+ sz = max(term-5-len(suf)-len(self.desc), 1)
306
+ bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
307
+ print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
198
308
 
199
- def encode_args_cuda_style(bufs, vals, device_ptr_t, marks) -> Tuple[ctypes.Array, ctypes.Structure]:
200
- c_args = init_c_struct_t(tuple([(f'f{i}', device_ptr_t) for i in range(len(bufs))] + [(f'f{i}', ctypes.c_int) for i in range(len(bufs), len(bufs)+len(vals))]))(*bufs, *vals) # noqa: E501
201
- return (ctypes.c_void_p * 5)(ctypes.c_void_p(marks[0]), ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p), ctypes.c_void_p(marks[1]), ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p), ctypes.c_void_p(marks[2])), c_args # noqa: E501
202
-
203
- def time_execution_cuda_style(cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False) -> Optional[float]:
204
- if not enable: return cb()
205
- evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
206
- evrecord(evs[0], None)
207
- cb()
208
- evrecord(evs[1], None)
209
- evsync(evs[1])
210
- evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
211
- for ev in evs: evdestroy(ev)
212
- return ret.value * 1e-3
309
+ class trange(tqdm):
310
+ def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)