tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/function.py CHANGED
@@ -31,7 +31,7 @@ class Neg(Function):
31
31
 
32
32
  class Reciprocal(Function):
33
33
  def forward(self, x:LazyBuffer) -> LazyBuffer:
34
- self.ret = x.const(1).e(BinaryOps.DIV, x)
34
+ self.ret = x.e(UnaryOps.RECIP)
35
35
  return self.ret
36
36
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
37
37
  return grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.ret).e(BinaryOps.MUL, self.ret)
@@ -42,7 +42,7 @@ class Sin(Function):
42
42
  return x.e(UnaryOps.SIN)
43
43
 
44
44
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
45
- return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
45
+ return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
46
46
 
47
47
  # NOTE: maximum(x, 0) behaves differently where x=0
48
48
  class Relu(Function):
@@ -58,7 +58,7 @@ class Log(Function):
58
58
  self.x = x
59
59
  return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
60
60
 
61
- def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x)
61
+ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
62
62
 
63
63
  class Exp(Function):
64
64
  def forward(self, x:LazyBuffer) -> LazyBuffer:
@@ -73,23 +73,23 @@ class Sqrt(Function):
73
73
  return self.ret
74
74
 
75
75
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
76
- return grad_output.e(BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2)))
76
+ return grad_output.e(BinaryOps.MUL, self.ret.e(BinaryOps.MUL, self.ret.const(2)).e(UnaryOps.RECIP))
77
77
 
78
78
  # NOTE: the implicit derivative of sigmoid is not stable
79
79
  # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
80
80
  # TODO: have the backend automatically find this
81
81
  class Sigmoid(Function):
82
82
  def forward(self, x:LazyBuffer) -> LazyBuffer:
83
- self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
83
+ self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
84
84
  return self.ret
85
85
 
86
86
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
87
- return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
87
+ return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.ADD, self.ret.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
88
88
 
89
89
  class Sign(Function):
90
90
  def forward(self, x:LazyBuffer) -> LazyBuffer:
91
- return x.e(BinaryOps.CMPEQ, x.const(0)).e(TernaryOps.WHERE, x.const(0),
92
- x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)))
91
+ return x.e(BinaryOps.CMPNE, x.const(0)).e(
92
+ TernaryOps.WHERE, x.e(BinaryOps.CMPLT, x.const(0)).e(TernaryOps.WHERE, x.const(-1), x.const(1)), x.const(0))
93
93
  # backward always return 0 to match torch
94
94
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0)
95
95
 
@@ -99,8 +99,8 @@ class Less(Function):
99
99
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y)
100
100
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
101
101
 
102
- class Eq(Function):
103
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y)
102
+ class Neq(Function):
103
+ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPNE, y)
104
104
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None
105
105
 
106
106
  class Xor(Function):
@@ -113,13 +113,6 @@ class Add(Function):
113
113
  return grad_output if self.needs_input_grad[0] else None, \
114
114
  grad_output if self.needs_input_grad[1] else None
115
115
 
116
- class Sub(Function):
117
- def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y)
118
-
119
- def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
120
- return grad_output if self.needs_input_grad[0] else None, \
121
- grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
122
-
123
116
  class Mul(Function):
124
117
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
125
118
  self.x, self.y = x, y
@@ -132,11 +125,11 @@ class Mul(Function):
132
125
  class Div(Function):
133
126
  def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
134
127
  self.x, self.y = x, y
135
- return x.e(BinaryOps.DIV, y)
128
+ return x.e(BinaryOps.MUL, y.e(UnaryOps.RECIP)) if not dtypes.is_int(x.dtype) else x.e(BinaryOps.IDIV, y)
136
129
 
137
130
  def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
138
- return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
139
- grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
131
+ return grad_output.e(BinaryOps.MUL, self.y.e(UnaryOps.RECIP)) if self.needs_input_grad[0] else None, \
132
+ grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.MUL, self.y.e(BinaryOps.MUL, self.y).e(UnaryOps.RECIP)) if self.needs_input_grad[1] else None # noqa: E501
140
133
 
141
134
  # ************* ternary ops *************
142
135
 
@@ -166,9 +159,10 @@ class Max(Function):
166
159
 
167
160
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
168
161
  # 1s in locations where the max was chosen (can be two locations)
169
- max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(dtypes.float)
162
+ max_is_1s = self.x.const(1.0).cast(dtypes.float).e(BinaryOps.ADD, self.x.e(BinaryOps.CMPNE, \
163
+ self.ret.expand(self.x.shape)).cast(dtypes.float).e(UnaryOps.NEG))
170
164
  div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape)
171
- return max_is_1s.e(BinaryOps.DIV, div).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
165
+ return max_is_1s.e(BinaryOps.MUL, div.e(UnaryOps.RECIP)).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
172
166
 
173
167
  # ************* movement ops *************
174
168
 
@@ -211,7 +205,7 @@ class Shrink(Function):
211
205
 
212
206
  class Flip(Function):
213
207
  def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
214
- self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
208
+ self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
215
209
  return x.stride(self.arg)
216
210
 
217
211
  def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg)
tinygrad/helpers.py CHANGED
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
- import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
3
- import itertools, urllib.request, subprocess
4
- from tqdm import tqdm
2
+ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
3
+ import itertools, urllib.request, subprocess, shutil, math, json
5
4
  from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
6
5
  if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
7
6
  from typing_extensions import TypeGuard
@@ -56,6 +55,12 @@ def get_child(obj, key):
56
55
  else: obj = getattr(obj, k)
57
56
  return obj
58
57
 
58
+ def get_shape(x) -> Tuple[int, ...]:
59
+ if not isinstance(x, (list, tuple)): return ()
60
+ subs = [get_shape(xi) for xi in x]
61
+ if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
62
+ return (len(subs),) + (subs[0] if subs else ())
63
+
59
64
  # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
60
65
  def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
61
66
  acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
@@ -96,9 +101,9 @@ class ContextVar:
96
101
  def __lt__(self, x): return self.value < x
97
102
 
98
103
  DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
99
- WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
104
+ WINO, THREEFRY, CAPTURING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1)
100
105
  GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
101
- MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
106
+ MULTIOUTPUT, PROFILE = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0)
102
107
 
103
108
  # **************** global state Counters ****************
104
109
 
@@ -139,6 +144,34 @@ class Profiling(contextlib.ContextDecorator):
139
144
  colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
140
145
  colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
141
146
 
147
+ class ProfileLogger:
148
+ writers: int = 0
149
+ mjson: List[Dict] = []
150
+ actors: Dict[str, int] = {}
151
+ subactors: Dict[Tuple[str, str], int] = {}
152
+ path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
153
+
154
+ def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
155
+
156
+ def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
157
+
158
+ def __del__(self):
159
+ for name,st,et,actor_name,subactor_name in self.events:
160
+ if actor_name not in self.actors:
161
+ self.actors[actor_name] = (pid:=len(self.actors))
162
+ self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
163
+
164
+ if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
165
+ self.subactors[subactor_key] = (tid:=len(self.subactors))
166
+ self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
167
+
168
+ self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
169
+
170
+ ProfileLogger.writers -= 1
171
+ if ProfileLogger.writers == 0:
172
+ with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
173
+ print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
174
+
142
175
  # *** universal database cache ***
143
176
 
144
177
  _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
@@ -197,18 +230,21 @@ def diskcache(func):
197
230
 
198
231
  # *** http support ***
199
232
 
200
- def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
233
+ def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
234
+ allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
201
235
  if url.startswith(("/", ".")): return pathlib.Path(url)
202
- fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) # noqa: E501
236
+ if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
237
+ else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
203
238
  if not fp.is_file() or not allow_caching:
204
239
  with urllib.request.urlopen(url, timeout=10) as r:
205
240
  assert r.status == 200
206
241
  total_length = int(r.headers.get('content-length', 0))
207
- progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
242
+ progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}: ", disable=CI)
208
243
  (path := fp.parent).mkdir(parents=True, exist_ok=True)
209
244
  with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
210
245
  while chunk := r.read(16384): progress_bar.update(f.write(chunk))
211
246
  f.close()
247
+ progress_bar.update(close=True)
212
248
  if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
213
249
  pathlib.Path(f.name).rename(fp)
214
250
  return fp
@@ -231,6 +267,7 @@ def cpu_objdump(lib):
231
267
  def from_mv(mv:memoryview, to_type=ctypes.c_char):
232
268
  return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
233
269
  def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
270
+ def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
234
271
  def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
235
272
  @functools.lru_cache(maxsize=None)
236
273
  def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
@@ -239,3 +276,35 @@ def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
239
276
  return CStruct
240
277
  def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
241
278
  def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
279
+
280
+ class tqdm:
281
+ def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:int=-1, rate:int=100):
282
+ self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
283
+ self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, len(iterable) if total==-1 else total
284
+ self.update(0)
285
+ def __iter__(self):
286
+ try:
287
+ for item in self.iter:
288
+ yield item
289
+ self.update(1)
290
+ finally: self.update(close=True)
291
+ def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
292
+ def update(self, n:int=0, close:bool=False):
293
+ self.n, self.i = self.n+n, self.i+1
294
+ if (self.i % self.skip != 0 and not close) or self.dis: return
295
+ prog, dur, term = self.n/self.t if self.t else -1, time.perf_counter()-self.st, shutil.get_terminal_size().columns
296
+ if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1) if self.i else 1
297
+ def fmt(t): return ':'.join([f'{x:02d}' for x in divmod(int(t), 60)]) if t!=-1 else '?'
298
+ def scl(x): return x/1000**int(math.log(x,1000))
299
+ def fn(x): return (f"{scl(x):.{3-math.ceil(math.log10(scl(x)))}f}"[:4]+(f"{[' ','k','M','G','T','P'][int(math.log(x,1000))]}") if x else '0.00')
300
+ if self.t: unit_text = f"{fn(self.n)}/{fn(self.t)}" if self.unit_scale else f"{self.n}/{self.t}"
301
+ else: unit_text = f"{fn(self.n)}{self.unit}" if self.unit_scale else f"{self.n}{self.unit}"
302
+ it_text = f"{fn(self.n/dur)}" if self.n and self.unit_scale else f"{self.n/dur:5.2f}" if self.n else "?"
303
+ if self.t: suf = f'| {unit_text} [{fmt(dur)}<{fmt(dur/self.n*self.t-dur if self.n else -1)}, {it_text}{self.unit}/s]'
304
+ else: suf = f'{unit_text} [{fmt(dur)}, {it_text}{self.unit}/s]'
305
+ sz = max(term-5-len(suf)-len(self.desc), 1)
306
+ bar = f'\r{self.desc}{round(100*prog):3}%|{"█"*round(sz*prog)}{" "*(sz-round(sz*prog))}{suf}' if self.t else f'\r{self.desc}{suf}{" "*term}'
307
+ print(bar[:term+1],flush=True,end='\n'*close,file=sys.stderr)
308
+
309
+ class trange(tqdm):
310
+ def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)
tinygrad/lazy.py CHANGED
@@ -22,7 +22,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=
22
22
  if enable_cache: lazycache[cache_key] = ret
23
23
  return ret
24
24
 
25
- view_supported_devices = {"LLVM", "CLANG", "CUDA", "DISK"}
25
+ view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "DISK"}
26
26
  class LazyBuffer:
27
27
  def __init__(self, device:str, st:ShapeTracker, dtype:DType,
28
28
  op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
@@ -34,11 +34,9 @@ class LazyBuffer:
34
34
  self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
35
35
  assert self.op is not LoadOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized"
36
36
 
37
- if (self.op is LoadOps.CONTIGUOUS or self.op is UnaryOps.BITCAST) and srcs[0].st.consecutive and \
38
- not srcs[0].is_unrealized_const() and device.split(":")[0] in view_supported_devices:
37
+ if self.op is LoadOps.VIEW:
39
38
  # some LazyBuffers can be processed with only a view, no AST required
40
39
  self.buffer: Buffer = srcs[0].base.buffer.view(st.size, dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize)
41
- self.op = LoadOps.VIEW
42
40
  else:
43
41
  self.buffer = srcs[1].base.buffer if self.op is LoadOps.ASSIGN else Buffer(device, self.size, dtype)
44
42
  self.buffer.ref(1)
@@ -74,6 +72,7 @@ class LazyBuffer:
74
72
  return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
75
73
 
76
74
  def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
75
+ assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
77
76
  shape = self.shape if shape is None else shape
78
77
  return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
79
78
 
@@ -83,9 +82,11 @@ class LazyBuffer:
83
82
  assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}"
84
83
  return LazyBuffer.loadop(LoadOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(x, self.base))
85
84
 
86
- def contiguous(self):
85
+ def can_view(self): return self.st.consecutive and not self.is_unrealized_const() and self.device.split(":")[0] in view_supported_devices
86
+
87
+ def contiguous(self, allow_buffer_view=True):
87
88
  if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
88
- ret = self.e(LoadOps.CONTIGUOUS)
89
+ ret = self.e(LoadOps.VIEW) if allow_buffer_view and self.can_view() else self.e(LoadOps.CONTIGUOUS)
89
90
  if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti
90
91
  return ret
91
92
  self.base.forced_realize = True
@@ -96,9 +97,6 @@ class LazyBuffer:
96
97
  if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
97
98
  if self.is_unrealized_unmasked_const() and not bitcast:
98
99
  return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
99
- # TODO: applying this makes gpt2 slower
100
- if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
101
- return self.base.cast(dtype, bitcast)._view(self.st)
102
100
  new_shape = self.shape
103
101
  if bitcast and self.dtype.itemsize != dtype.itemsize:
104
102
  if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
@@ -106,7 +104,10 @@ class LazyBuffer:
106
104
  # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
107
105
  if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
108
106
  new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
109
- cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
107
+ elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
108
+ # TODO: applying this makes gpt2 slower
109
+ return self.base.cast(dtype, bitcast)._view(self.st)
110
+ cast_op: Union[LoadOps, UnaryOps] = (LoadOps.VIEW if self.can_view() else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST
110
111
  return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
111
112
 
112
113
  def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
@@ -145,24 +146,22 @@ class LazyBuffer:
145
146
  if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
146
147
  if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
147
148
 
148
- out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
149
+ out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
149
150
 
150
151
  # const folding
151
152
  if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
152
153
  return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
153
- if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0]
154
- if op in BinaryOps: x, y = self, in_srcs[0]
155
- if op is BinaryOps.ADD:
156
- if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment
157
- if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment
158
- if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
159
- if op is BinaryOps.MUL:
160
- if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
161
- return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
162
- if y.is_unrealized_unmasked_const() and (val := float(y.base.arg)) in (1, 0, -1):
163
- return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
164
- if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
165
- return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
154
+ if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
155
+ if op in BinaryOps:
156
+ x, y = self, in_srcs[0]
157
+ if op is BinaryOps.ADD:
158
+ if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
159
+ if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
160
+ if op is BinaryOps.MUL:
161
+ if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
162
+ return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
163
+ if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
164
+ return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
166
165
 
167
166
  return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
168
167
 
@@ -170,7 +169,7 @@ class LazyBuffer:
170
169
 
171
170
  def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
172
171
  assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
173
- axis = tuple(x for x in axis if self.shape[x] != 1)
172
+ axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
174
173
  if len(axis) == 0: return self
175
174
  new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
176
175
  return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
@@ -181,7 +180,8 @@ class LazyBuffer:
181
180
  if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
182
181
 
183
182
  # const folding
184
- if self.is_unrealized_unmasked_const():
183
+ # TODO: fold this for symbolic?
184
+ if self.is_unrealized_unmasked_const() and all_int(self.shape):
185
185
  return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
186
186
 
187
187
  # TODO: can we split symbolic shape if the reduce axis is not symbolic?
tinygrad/multi.py CHANGED
@@ -45,7 +45,7 @@ def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:
45
45
  def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
46
46
  if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
47
47
  sz = round_up(lbs[0].shape[axis], len(lbs)) // len(lbs)
48
- return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
48
+ return [lb.shrink(tuple((0,s) if a != axis else (min(s,sz*i),min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
49
49
 
50
50
  class MultiLazyBuffer:
51
51
  def __init__(self, lbs:List[LazyBuffer], axis:Optional[int], real:Optional[List[bool]]=None):
@@ -54,7 +54,7 @@ class MultiLazyBuffer:
54
54
  self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
55
55
  if axis is not None:
56
56
  splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
57
- self.bounds = [(st,ed) for st,ed in zip(splits, splits[1:])]
57
+ self.bounds = list(zip(splits, splits[1:]))
58
58
 
59
59
  @property
60
60
  def shape(self):
@@ -73,15 +73,19 @@ class MultiLazyBuffer:
73
73
  def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
74
74
  lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
75
75
  sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
76
- return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
76
+ return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
77
77
 
78
78
  def copy_to_device(self, device:str) -> LazyBuffer:
79
- if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)
80
- sz = self.lbs[0].shape[self.axis]
81
- llbs = []
82
- for i,lb in enumerate([lb.copy_to_device(device) for lb in self.real_lbs]):
83
- pad_arg = tuple((0,0) if a != self.axis else (sz*i, max(0, self.shape[self.axis]-sz*(i+1))) for a in range(len(lb.shape)))
84
- llbs.append(lb.pad(pad_arg))
79
+ if self.axis is None:
80
+ # if we already have a copy on the device, return that
81
+ for lb in self.real_lbs:
82
+ if lb.device == device: return lb
83
+ return self.lbs[self.real.index(True)].copy_to_device(device)
84
+ llbs:List[LazyBuffer] = []
85
+ for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
86
+ if not real: continue
87
+ pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
88
+ llbs.append(lb.copy_to_device(device).pad(pad_arg))
85
89
  return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
86
90
 
87
91
  # passthroughs
tinygrad/nn/__init__.py CHANGED
@@ -2,7 +2,7 @@ import math
2
2
  from typing import Optional, Union, Tuple, cast
3
3
  from tinygrad.tensor import Tensor
4
4
  from tinygrad.helpers import prod
5
- from tinygrad.nn import optim, state # noqa: F401
5
+ from tinygrad.nn import optim, state, datasets # noqa: F401
6
6
 
7
7
  class BatchNorm2d:
8
8
  """
tinygrad/nn/datasets.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import gzip
2
- from tinygrad import Tensor, fetch
2
+ from tinygrad.tensor import Tensor
3
+ from tinygrad.helpers import fetch
3
4
 
4
5
  def _fetch_mnist(file, offset): return Tensor(gzip.open(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file)).read()[offset:])
5
6
  def mnist():
tinygrad/nn/state.py CHANGED
@@ -1,9 +1,8 @@
1
1
  import os, json, pathlib, zipfile, pickle, tarfile, struct
2
- from tqdm import tqdm
3
2
  from typing import Dict, Union, List, Optional, Any, Tuple
4
3
  from tinygrad.tensor import Tensor
5
4
  from tinygrad.dtype import dtypes
6
- from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters
5
+ from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
7
6
  from tinygrad.shape.view import strides_for_shape
8
7
  from tinygrad.multi import MultiLazyBuffer
9
8
 
@@ -41,7 +40,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any
41
40
  Saves a state_dict to disk in a .safetensor file with optional metadata.
42
41
 
43
42
  ```python
44
- t = nn.Tensor([1, 2, 3])
43
+ t = Tensor([1, 2, 3])
45
44
  nn.state.safe_save({'t':t}, "test.safetensor")
46
45
  ```
47
46
  """
@@ -120,7 +119,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
120
119
  if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
121
120
  print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
122
121
  for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
123
- t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}")
122
+ t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
124
123
  if k not in state_dict and not strict:
125
124
  if DEBUG >= 1: print(f"WARNING: not loading {k}")
126
125
  continue
tinygrad/ops.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from typing import Union, Tuple, Any, List, Dict, Callable
3
- import functools, hashlib, math, operator, ctypes
3
+ import functools, hashlib, math, operator, ctypes, struct
4
4
  from enum import Enum, auto
5
5
  from dataclasses import dataclass
6
6
  from tinygrad.helpers import prod, dedup
@@ -14,10 +14,11 @@ from tinygrad.shape.shapetracker import ShapeTracker
14
14
  # NOTE: many GPUs don't have DIV, but UnaryOps.RECIP doesn't work for integer division
15
15
  class UnaryOps(Enum):
16
16
  """A -> A (elementwise)"""
17
- EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702
17
+ EXP2 = auto(); LOG2 = auto(); CAST = auto(); BITCAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto(); RECIP = auto() # noqa: E702
18
18
  class BinaryOps(Enum):
19
19
  """A + A -> A (elementwise)"""
20
- ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702
20
+ ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
21
+ SHR = auto(); SHL = auto() # noqa: E702
21
22
  class TernaryOps(Enum):
22
23
  """A + A + A -> A (elementwise)"""
23
24
  WHERE = auto(); MULACC = auto() # noqa: E702
@@ -30,7 +31,7 @@ class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS =
30
31
  Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
31
32
 
32
33
  # do not preserve f(0) = 0
33
- UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
34
+ UNSAFE_PAD_OPS = {UnaryOps.RECIP, UnaryOps.LOG2, UnaryOps.EXP2, BinaryOps.IDIV}
34
35
 
35
36
  @dataclass(frozen=True)
36
37
  class MemBuffer:
@@ -40,7 +41,7 @@ class MemBuffer:
40
41
 
41
42
  @dataclass(frozen=True)
42
43
  class ConstBuffer:
43
- val: ConstType
44
+ val: ConstType | Variable
44
45
  dtype: DType
45
46
  st: ShapeTracker
46
47
 
@@ -61,7 +62,7 @@ class LazyOp:
61
62
  def dtype(self) -> DType:
62
63
  if self.op in BufferOps: return self.arg.dtype
63
64
  if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg
64
- return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype
65
+ return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPNE} else self.src[-1].dtype
65
66
 
66
67
  @functools.cached_property
67
68
  def key(self) -> bytes:
@@ -73,8 +74,8 @@ class LazyOp:
73
74
  def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
74
75
  def vars(self) -> List[Variable]:
75
76
  extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
76
- const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
77
- return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
77
+ const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
78
+ return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
78
79
 
79
80
  # **************** independent FlopCounter ****************
80
81
 
@@ -116,21 +117,53 @@ def hook_overflow(dv, fxn):
116
117
 
117
118
  python_alu = {
118
119
  UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
119
- UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
120
- UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
120
+ UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x),
121
+ UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
122
+ UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
123
+ UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
121
124
  UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
122
- BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
123
- BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt,
124
- BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0],
125
- BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
125
+ BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
126
+ BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
127
+ BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
128
+ BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
126
129
  TernaryOps.WHERE: lambda x,y,z: y if x else z}
127
130
 
131
+ def truncate_fp16(x):
132
+ try:
133
+ x = float(x)
134
+ struct.pack("@e", x)
135
+ return x
136
+ except OverflowError: return math.copysign(math.inf, x)
137
+
128
138
  truncate: Dict[DType, Callable] = {dtypes.bool: bool,
129
- # TODO: float16 and bfloat16?
130
- dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
139
+ # TODO: bfloat16
140
+ dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
131
141
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
132
142
  dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
133
143
  dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
134
144
  dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
135
145
 
136
146
  def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
147
+
148
+ # the living definition of LazyOps
149
+ def verify_lazyop(*ast:LazyOp):
150
+ sts: Dict[LazyOp, ShapeTracker] = {}
151
+ def dfs(op:LazyOp, st:ShapeTracker):
152
+ if op in sts: return
153
+ for x in op.src: dfs(x, st)
154
+ # only reduceop is allowed to change shape, limited to turning n to 1
155
+ if op.op in ReduceOps:
156
+ expected_shape = tuple(1 if i in op.arg else s for i,s in enumerate(sts[op.src[0]].shape))
157
+ assert st.shape == expected_shape, f"unexpected reduceop shape {st.shape} != {expected_shape}"
158
+ st = ShapeTracker.from_shape(expected_shape)
159
+ else:
160
+ # movementops are pushed to the edges with LOAD
161
+ if op.op in BufferOps: st = op.arg.st
162
+ else: st = sts[op.src[0]]
163
+ for x in op.src: assert sts[x].shape == st.shape, f"found implicit movement op {x.op} {sts[x].shape} != {op.op} {st.shape}"
164
+ sts[op] = st
165
+ for i, out in enumerate(ast):
166
+ assert out.arg.idx == i, f"unexpected output buffer idx {out.arg.idx} != {i}"
167
+ assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
168
+ assert out.arg.st.size == ast[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
169
+ dfs(out, out.arg.st)
@@ -1,7 +1,7 @@
1
1
  from typing import Optional, List, Tuple, Dict
2
2
  import functools
3
3
  from dataclasses import dataclass
4
- from tinygrad.helpers import to_function_name
4
+ from tinygrad.helpers import getenv, to_function_name
5
5
  from tinygrad.codegen.uops import UOpGraph
6
6
  from tinygrad.shape.symbolic import sym_infer, sint, Variable
7
7
  from tinygrad.dtype import DType
@@ -52,10 +52,14 @@ class Renderer:
52
52
  supports_float4: bool = True
53
53
  has_local: bool = True
54
54
  has_shared: bool = True
55
- # NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
56
- global_max: Optional[List[int]] = None
57
- local_max: Optional[List[int]] = None
55
+ # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
56
+ global_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
57
+ local_max: Optional[Tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
58
58
  shared_max: int = 32768
59
59
  tensor_cores: List[TensorCore] = []
60
+ @functools.cached_property
61
+ def tc_opt(self): return getenv("TC_OPT")
62
+ @functools.cached_property
63
+ def tc(self): return getenv("TC", 1)
60
64
 
61
65
  def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")