tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,31 +1,48 @@
1
+ """This is where the forwards and backwards passes live."""
1
2
  import math
2
3
  from typing import Tuple, Optional
3
- from tinygrad.helpers import argsort, DType
4
+ from tinygrad.helpers import argsort
5
+ from tinygrad.dtype import dtypes, DType, sum_acc_dtype
4
6
  from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
5
7
  from tinygrad.tensor import Function
6
8
  from tinygrad.lazy import LazyBuffer
9
+ from tinygrad.shape.symbolic import sint
7
10
 
8
11
  class Contiguous(Function):
9
12
  def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
10
13
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output
11
14
 
15
+ class ContiguousBackward(Function):
16
+ def forward(self, x:LazyBuffer) -> LazyBuffer: return x
17
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.contiguous()
18
+
12
19
  class Cast(Function):
13
20
  def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer:
14
21
  self.input_dtype, self.bitcast = x.dtype, bitcast
15
- return x.e(UnaryOps.CAST, arg=(dtype, bitcast))
22
+ return x.cast(dtype, bitcast)
16
23
 
17
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
18
- return grad_output.e(UnaryOps.CAST, arg=(self.input_dtype, self.bitcast))
24
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast)
19
25
 
20
26
  # ************* unary ops *************
21
27
 
28
+ class Neg(Function):
29
+ def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG)
30
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG)
31
+
32
+ class Reciprocal(Function):
33
+ def forward(self, x:LazyBuffer) -> LazyBuffer:
34
+ self.ret = x.const(1).e(BinaryOps.DIV, x)
35
+ return self.ret
36
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
37
+ return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
38
+
22
39
  class Sin(Function):
23
40
  def forward(self, x:LazyBuffer) -> LazyBuffer:
24
41
  self.x = x
25
42
  return x.e(UnaryOps.SIN)
26
43
 
27
- def backward(self, grad:LazyBuffer) -> LazyBuffer:
28
- return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
44
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
45
+ return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
29
46
 
30
47
  # NOTE: maximum(x, 0) behaves differently where x=0
31
48
  class Relu(Function):
@@ -34,23 +51,21 @@ class Relu(Function):
34
51
  return self.ret
35
52
 
36
53
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
37
- return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
54
+ return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
38
55
 
39
56
  class Log(Function):
40
57
  def forward(self, x:LazyBuffer) -> LazyBuffer:
41
58
  self.x = x
42
59
  return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
43
60
 
44
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
45
- return grad_output.e(BinaryOps.DIV, self.x)
61
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
46
62
 
47
63
  class Exp(Function):
48
64
  def forward(self, x:LazyBuffer) -> LazyBuffer:
49
65
  self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
50
66
  return self.ret
51
67
 
52
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
53
- return self.ret.e(BinaryOps.MUL, grad_output)
68
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
54
69
 
55
70
  class Sqrt(Function):
56
71
  def forward(self, x:LazyBuffer) -> LazyBuffer:
@@ -71,48 +86,39 @@ class Sigmoid(Function):
71
86
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
72
87
  return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
73
88
 
74
- # ************* reduce ops *************
75
-
76
- class Sum(Function):
77
- def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
78
- self.input_shape = x.shape
79
- return x.reduce_op(ReduceOps.SUM, new_shape)
80
-
81
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
82
- return grad_output.expand(self.input_shape)
83
-
84
- class Max(Function):
85
- def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
86
- self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
87
- return self.ret
88
-
89
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
90
- # 1s in locations where the max was chosen (can be two locations)
91
- max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
92
- div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
93
- return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
89
+ class Sign(Function):
90
+ def forward(self, x:LazyBuffer) -> LazyBuffer:
91
+ return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
92
+ x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
93
+ # backward always return 0 to match torch
94
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
94
95
 
95
96
  # ************* binary ops *************
96
97
 
97
98
  class Less(Function):
98
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
99
- return x.e(BinaryOps.CMPLT, y)
99
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
100
+ def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
101
+
102
+ class Eq(Function):
103
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
104
+ def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
105
+
106
+ class Xor(Function):
107
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y)
100
108
 
101
109
  class Add(Function):
102
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
103
- return x.e(BinaryOps.ADD, y)
110
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y)
104
111
 
105
112
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
106
113
  return grad_output if self.needs_input_grad[0] else None, \
107
114
  grad_output if self.needs_input_grad[1] else None
108
115
 
109
116
  class Sub(Function):
110
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
111
- return x.e(BinaryOps.SUB, y)
117
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
112
118
 
113
119
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
114
120
  return grad_output if self.needs_input_grad[0] else None, \
115
- grad_output.const(0).e(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
121
+ grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
116
122
 
117
123
  class Mul(Function):
118
124
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
@@ -130,67 +136,82 @@ class Div(Function):
130
136
 
131
137
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
132
138
  return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
133
- grad_output.const(0).e(BinaryOps.SUB, grad_output).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
139
+ grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
134
140
 
135
141
  # ************* ternary ops *************
136
142
 
137
143
  class Where(Function):
138
144
  def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
139
145
  self.x = x
140
- return x.e(TernaryOps.WHERE, y, z)
146
+ return self.x.e(TernaryOps.WHERE, y, z)
141
147
 
142
148
  def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
143
149
  return None, \
144
- self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
145
- self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
150
+ self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
151
+ self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
152
+
153
+ # ************* reduce ops *************
154
+
155
+ class Sum(Function):
156
+ def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
157
+ self.input_shape = x.shape
158
+ return x.r(ReduceOps.SUM, axis)
159
+
160
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
161
+
162
+ class Max(Function):
163
+ def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
164
+ self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
165
+ return self.ret
166
+
167
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
168
+ # 1s in locations where the max was chosen (can be two locations)
169
+ max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
170
+ div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
171
+ return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
146
172
 
147
173
  # ************* movement ops *************
148
174
 
149
175
  # NOTE: this is sum in reverse
150
176
  class Expand(Function):
151
177
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
152
- self.input_shape = x.shape
178
+ self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
153
179
  return x.expand(shape)
154
180
 
155
181
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
156
- return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
182
+ return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(ReduceOps.SUM, self.expanded_axis).cast(grad_output.dtype)
157
183
 
158
184
  class Reshape(Function):
159
185
  def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
160
186
  self.input_shape = x.shape
161
187
  return x.reshape(shape)
162
188
 
163
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
164
- return grad_output.reshape(self.input_shape)
189
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape)
165
190
 
166
191
  class Permute(Function):
167
192
  def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
168
193
  self.input_order = order
169
194
  return x.permute(order)
170
195
 
171
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
172
- return grad_output.permute(argsort(self.input_order))
196
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order))
173
197
 
174
198
  class Pad(Function):
175
199
  def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
176
200
  self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
177
201
  return x.pad(arg)
178
202
 
179
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
180
- return grad_output.shrink(self.narg)
203
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg)
181
204
 
182
205
  class Shrink(Function):
183
- def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
206
+ def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
184
207
  self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
185
208
  return x.shrink(arg)
186
209
 
187
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
188
- return grad_output.pad(self.narg)
210
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg)
189
211
 
190
212
  class Flip(Function):
191
213
  def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
192
214
  self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
193
215
  return x.stride(self.arg)
194
216
 
195
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
196
- return grad_output.stride(self.arg)
217
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
tinygrad/helpers.py CHANGED
@@ -1,35 +1,75 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib
3
- import numpy as np
4
- from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
5
- from math import prod # noqa: F401 # pylint:disable=unused-import
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
3
+ import itertools, urllib.request, subprocess
4
+ from tqdm import tqdm
5
+ from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
6
+ if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
7
+ from typing_extensions import TypeGuard
8
+ from tinygrad.shape.shapetracker import sint
9
+
10
+ T = TypeVar("T")
11
+ U = TypeVar("U")
12
+ # NOTE: it returns int 1 if x is empty regardless of the type of x
13
+ def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
6
14
 
7
15
  # NOTE: helpers is not allowed to import from anything else in tinygrad
8
16
  OSX = platform.system() == "Darwin"
9
17
  CI = os.getenv("CI", "") != ""
10
18
 
11
- def dedup(x): return list(dict.fromkeys(x)) # retains list order
12
- def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
19
+ def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
20
+ def argfix(*x):
21
+ if x and x[0].__class__ in (tuple, list):
22
+ if len(x) != 1: raise ValueError(f"bad arg {x}")
23
+ return tuple(x[0])
24
+ return x
13
25
  def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
14
- def all_same(items): return all(x == items[0] for x in items)
15
- def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
16
- def ansilen(s): return len(re.sub('\x1b\\[(K|.*?m)', '', s))
26
+ def all_same(items:List[T]): return all(x == items[0] for x in items)
27
+ def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
28
+ def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
29
+ def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
30
+ def ansilen(s:str): return len(ansistrip(s))
17
31
  def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
18
- def flatten(l:Iterator): return [item for sublist in l for item in sublist]
19
- def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
32
+ def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
33
+ def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
20
34
  def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
21
- def merge_dicts(ds:Iterable[Dict]) -> Dict:
22
- kvs = set([(k,v) for d in ds for k,v in d.items()])
23
- assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
24
- return {k:v for k,v in kvs}
25
- def partition(lst, fxn):
26
- a: list[Any] = []
27
- b: list[Any] = []
35
+ def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
36
+ def round_up(num, amt:int): return (num+amt-1)//amt * amt
37
+ def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
38
+ assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
39
+ return {k:v for d in ds for k,v in d.items()}
40
+ def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
41
+ a:List[T] = []
42
+ b:List[T] = []
28
43
  for s in lst: (a if fxn(s) else b).append(s)
29
44
  return a,b
45
+ def unwrap(x:Optional[T]) -> T:
46
+ assert x is not None
47
+ return x
48
+ def unwrap2(x:Tuple[T,Any]) -> T:
49
+ ret, err = x
50
+ assert err is None, str(err)
51
+ return ret
52
+ def get_child(obj, key):
53
+ for k in key.split('.'):
54
+ if k.isnumeric(): obj = obj[int(k)]
55
+ elif isinstance(obj, dict): obj = obj[k]
56
+ else: obj = getattr(obj, k)
57
+ return obj
58
+
59
+ # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
60
+ def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
61
+ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
62
+ try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
63
+ except ValueError: return None
64
+ return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
30
65
 
31
66
  @functools.lru_cache(maxsize=None)
32
- def getenv(key, default=0): return type(default)(os.getenv(key, default))
67
+ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
68
+ @functools.lru_cache(maxsize=None)
69
+ def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
70
+ def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
71
+
72
+ class GraphException(Exception): pass
33
73
 
34
74
  class Context(contextlib.ContextDecorator):
35
75
  stack: ClassVar[List[dict[str, int]]] = [{}]
@@ -44,83 +84,23 @@ class Context(contextlib.ContextDecorator):
44
84
  class ContextVar:
45
85
  _cache: ClassVar[Dict[str, ContextVar]] = {}
46
86
  value: int
87
+ key: str
47
88
  def __new__(cls, key, default_value):
48
89
  if key in ContextVar._cache: return ContextVar._cache[key]
49
90
  instance = ContextVar._cache[key] = super().__new__(cls)
50
- instance.value = getenv(key, default_value)
91
+ instance.value, instance.key = getenv(key, default_value), key
51
92
  return instance
52
93
  def __bool__(self): return bool(self.value)
53
94
  def __ge__(self, x): return self.value >= x
54
95
  def __gt__(self, x): return self.value > x
55
96
  def __lt__(self, x): return self.value < x
56
97
 
57
- DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
58
- GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
98
+ DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
99
+ WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
100
+ GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
101
+ MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
59
102
 
60
- class Timing(contextlib.ContextDecorator):
61
- def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
62
- def __enter__(self): self.st = time.perf_counter_ns()
63
- def __exit__(self, exc_type, exc_val, exc_tb):
64
- self.et = time.perf_counter_ns() - self.st
65
- if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
66
-
67
- # **** tinygrad now supports dtypes! *****
68
-
69
- class DType(NamedTuple):
70
- priority: int # this determines when things get upcasted
71
- itemsize: int
72
- name: str
73
- np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
74
- sz: int = 1
75
- def __repr__(self): return f"dtypes.{self.name}"
76
-
77
- # dependent typing?
78
- class ImageDType(DType):
79
- def __new__(cls, priority, itemsize, name, np, shape):
80
- return super().__new__(cls, priority, itemsize, name, np)
81
- def __init__(self, priority, itemsize, name, np, shape):
82
- self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
83
- super().__init__()
84
- def __repr__(self): return f"dtypes.{self.name}({self.shape})"
85
-
86
- class dtypes:
87
- @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
88
- def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
89
- @staticmethod
90
- def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
91
- @staticmethod
92
- def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
93
- @staticmethod
94
- def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
95
- @staticmethod
96
- def fields() -> Dict[str, DType]: return DTYPES_DICT
97
- bool: Final[DType] = DType(0, 1, "bool", np.bool_)
98
- float16: Final[DType] = DType(0, 2, "half", np.float16)
99
- half = float16
100
- float32: Final[DType] = DType(4, 4, "float", np.float32)
101
- float = float32
102
- float64: Final[DType] = DType(0, 8, "double", np.float64)
103
- double = float64
104
- int8: Final[DType] = DType(0, 1, "char", np.int8)
105
- int16: Final[DType] = DType(1, 2, "short", np.int16)
106
- int32: Final[DType] = DType(2, 4, "int", np.int32)
107
- int64: Final[DType] = DType(3, 8, "long", np.int64)
108
- uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
109
- uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
110
- uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
111
- uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
112
-
113
- # NOTE: bfloat16 isn't supported in numpy
114
- bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
115
-
116
- # NOTE: these are internal dtypes, should probably check for that
117
- _half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
118
- _float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
119
- _float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
120
- _arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
121
-
122
- # HACK: staticmethods are not callable in 3.8 so we have to compare the class
123
- DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
103
+ # **************** global state Counters ****************
124
104
 
125
105
  class GlobalCounters:
126
106
  global_ops: ClassVar[int] = 0
@@ -128,7 +108,134 @@ class GlobalCounters:
128
108
  time_sum_s: ClassVar[float] = 0.0
129
109
  kernel_count: ClassVar[int] = 0
130
110
  mem_used: ClassVar[int] = 0 # NOTE: this is not reset
131
- mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
132
- cache: ClassVar[Optional[List[Tuple[Callable, Any, Dict[Any, int]]]]] = None # List[Tuple[Callable, List[RawBuffer], Dict[Variable, int]]]
133
111
  @staticmethod
134
- def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
112
+ def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
113
+
114
+ # **************** timer and profiler ****************
115
+
116
+ class Timing(contextlib.ContextDecorator):
117
+ def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
118
+ def __enter__(self): self.st = time.perf_counter_ns()
119
+ def __exit__(self, *exc):
120
+ self.et = time.perf_counter_ns() - self.st
121
+ if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
122
+
123
+ def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
124
+ class Profiling(contextlib.ContextDecorator):
125
+ def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
126
+ self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
127
+ def __enter__(self):
128
+ self.pr = cProfile.Profile()
129
+ if self.enabled: self.pr.enable()
130
+ def __exit__(self, *exc):
131
+ if self.enabled:
132
+ self.pr.disable()
133
+ if self.fn: self.pr.dump_stats(self.fn)
134
+ stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
135
+ for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
136
+ (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
137
+ scallers = sorted(callers.items(), key=lambda x: -x[1][2])
138
+ print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
139
+ colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
140
+ colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
141
+
142
+ # *** universal database cache ***
143
+
144
+ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
145
+ CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
146
+ CACHELEVEL = getenv("CACHELEVEL", 2)
147
+
148
+ VERSION = 16
149
+ _db_connection = None
150
+ def db_connection():
151
+ global _db_connection
152
+ if _db_connection is None:
153
+ os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
154
+ _db_connection = sqlite3.connect(CACHEDB)
155
+ if DEBUG >= 7: _db_connection.set_trace_callback(print)
156
+ return _db_connection
157
+
158
+ def diskcache_clear():
159
+ cur = db_connection().cursor()
160
+ drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
161
+ cur.executescript("\n".join([s[0] for s in drop_tables]))
162
+
163
+ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
164
+ if CACHELEVEL == 0: return None
165
+ if isinstance(key, (str,int)): key = {"key": key}
166
+ conn = db_connection()
167
+ cur = conn.cursor()
168
+ try:
169
+ res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
170
+ except sqlite3.OperationalError:
171
+ return None # table doesn't exist
172
+ if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
173
+ return None
174
+
175
+ _db_tables = set()
176
+ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
177
+ if CACHELEVEL == 0: return val
178
+ if isinstance(key, (str,int)): key = {"key": key}
179
+ conn = db_connection()
180
+ cur = conn.cursor()
181
+ if table not in _db_tables:
182
+ TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
183
+ ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
184
+ cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
185
+ _db_tables.add(table)
186
+ cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
187
+ conn.commit()
188
+ cur.close()
189
+ return val
190
+
191
+ def diskcache(func):
192
+ def wrapper(*args, **kwargs) -> bytes:
193
+ table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
194
+ if (ret:=diskcache_get(table, key)): return ret
195
+ return diskcache_put(table, key, func(*args, **kwargs))
196
+ return wrapper
197
+
198
+ # *** http support ***
199
+
200
+ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
201
+ if url.startswith(("/", ".")): return pathlib.Path(url)
202
+ fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
203
+ if not fp.is_file() or not allow_caching:
204
+ with urllib.request.urlopen(url, timeout=10) as r:
205
+ assert r.status == 200
206
+ total_length = int(r.headers.get('content-length', 0))
207
+ progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
208
+ (path := fp.parent).mkdir(parents=True, exist_ok=True)
209
+ with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
210
+ while chunk := r.read(16384): progress_bar.update(f.write(chunk))
211
+ f.close()
212
+ if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
213
+ pathlib.Path(f.name).rename(fp)
214
+ return fp
215
+
216
+ # *** Exec helpers
217
+
218
+ def cpu_time_execution(cb, enable):
219
+ if enable: st = time.perf_counter()
220
+ cb()
221
+ if enable: return time.perf_counter()-st
222
+
223
+ def cpu_objdump(lib):
224
+ with tempfile.NamedTemporaryFile(delete=True) as f:
225
+ pathlib.Path(f.name).write_bytes(lib)
226
+ print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
227
+
228
+ # *** ctypes helpers
229
+
230
+ # TODO: make this work with read only memoryviews (if possible)
231
+ def from_mv(mv:memoryview, to_type=ctypes.c_char):
232
+ return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
233
+ def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
234
+ def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
235
+ @functools.lru_cache(maxsize=None)
236
+ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
237
+ class CStruct(ctypes.Structure):
238
+ _pack_, _fields_ = 1, fields
239
+ return CStruct
240
+ def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
241
+ def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))