tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -2,9 +2,9 @@ from __future__ import annotations
2
2
  import multiprocessing
3
3
  from dataclasses import dataclass
4
4
  from collections import defaultdict
5
- from typing import List, Optional, Dict, Tuple, Any
6
- import importlib, inspect, functools, pathlib, os, ctypes
7
- from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
5
+ from typing import List, Optional, Dict, Tuple, Any, cast
6
+ import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
7
+ from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
8
8
  from tinygrad.dtype import DType, ImageDType
9
9
  from tinygrad.renderer import Renderer
10
10
 
@@ -19,15 +19,17 @@ class _Device:
19
19
  def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
20
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
21
  def __get_canonicalized_item(self, ix:str) -> Compiled:
22
- if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
23
- assert multiprocessing.current_process().name == "MainProcess" or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent"
22
+ assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
23
+ f"can only open device {ix} from parent, not {cpn}"
24
24
  x = ix.split(":")[0].upper()
25
- return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
25
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
26
+ if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
27
+ return ret
26
28
  @functools.cached_property
27
29
  def DEFAULT(self) -> str:
28
30
  device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
29
31
  if device_from_env: return device_from_env
30
- for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
32
+ for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
31
33
  try:
32
34
  if self[device]:
33
35
  os.environ[device] = "1" # we set this in environment for spawned children
@@ -171,13 +173,148 @@ class Compiler:
171
173
  def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
172
174
  def compile_cached(self, src:str) -> bytes:
173
175
  if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
174
- assert not getenv("ASSERT_COMPILE"), "tried to compile with ASSERT_COMPILE set"
176
+ assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
175
177
  lib = self.compile(src)
176
178
  if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
177
179
  return lib
178
180
 
179
181
  class Compiled:
180
182
  def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
181
- self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
182
- self.renderer = renderer if renderer else Renderer()
183
+ self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
184
+ self.renderer = renderer or Renderer()
183
185
  def synchronize(self): pass # override this in your device
186
+
187
+ # **************** for HCQ Compatible Devices ****************
188
+
189
+ @contextlib.contextmanager
190
+ def hcq_profile(dev, queue_type, enabled, desc):
191
+ st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
192
+ if enabled: queue_type().timestamp(st).submit(dev)
193
+ try: yield (st, en)
194
+ finally:
195
+ if enabled: queue_type().timestamp(en).submit(dev)
196
+ if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
197
+
198
+ class HCQCompatCompiled(Compiled):
199
+ def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
200
+ self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
201
+ self.timeline_value: int = 1
202
+ self.timeline_signal, self._shadow_timeline_signal = timeline_signals
203
+ self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
204
+ self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
205
+ if PROFILE: self._prof_setup()
206
+
207
+ from tinygrad.runtime.graph.hcq import HCQGraph
208
+ super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
209
+
210
+ @classmethod
211
+ def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
212
+
213
+ @classmethod
214
+ def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
215
+
216
+ @classmethod
217
+ def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
218
+
219
+ @classmethod
220
+ def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
221
+
222
+ @classmethod
223
+ def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
224
+
225
+ def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
226
+
227
+ def _prof_setup(self):
228
+ self.profile_logger = ProfileLogger()
229
+
230
+ def _sync_queue(q_t):
231
+ q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
232
+ self.timeline_value += 1
233
+ cpu_start_time = time.perf_counter_ns() / 1e3
234
+ self._wait_signal(self.timeline_signal, self.timeline_value - 1)
235
+ return cpu_start_time, self._read_timestamp(self.timeline_signal)
236
+ self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
237
+ self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
238
+
239
+ atexit.register(self._prof_finalize)
240
+
241
+ def _prof_process_events(self):
242
+ self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
243
+ for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
244
+ self.sig_prof_records = []
245
+
246
+ def _prof_finalize(self):
247
+ for st, en, name, is_cp in self.raw_prof_records:
248
+ self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
249
+ del self.profile_logger
250
+
251
+ def _wrap_timeline_signal(self):
252
+ self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
253
+ self._set_signal(self.timeline_signal, 0)
254
+ cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
255
+
256
+ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
257
+ def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
258
+ self.device = device
259
+ self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
260
+ self.b_timeline, self.b_next = [0] * len(self.b), 0
261
+ super().__init__()
262
+
263
+ def copyin(self, dest, src: memoryview):
264
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
265
+ for i in range(0, src.nbytes, self.b[0].size):
266
+ self.b_next = (self.b_next + 1) % len(self.b)
267
+ self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
268
+ ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
269
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
270
+ .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
271
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
272
+ self.b_timeline[self.b_next] = self.device.timeline_value
273
+ self.device.timeline_value += 1
274
+
275
+ def copy_from_disk(self, dest, src, size):
276
+ def _get_temp_buf():
277
+ # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
278
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
279
+ self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
280
+ return (self.b[self.b_next].va_addr, self.b_next)
281
+ return None
282
+
283
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
284
+ for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
285
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
286
+ .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
287
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
288
+ self.b_timeline[batch_info[1]] = self.device.timeline_value
289
+ self.device.timeline_value += 1
290
+
291
+ def copyout(self, dest:memoryview, src):
292
+ self.device.synchronize()
293
+
294
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
295
+ for i in range(0, dest.nbytes, self.b[0].size):
296
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
297
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
298
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
299
+ self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
300
+ self.device.timeline_value += 1
301
+
302
+ ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
303
+
304
+ def transfer(self, dest, src, sz: int, src_dev, dest_dev):
305
+ src_dev._gpu_map(dest)
306
+
307
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
308
+ src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
309
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
310
+ .copy(dest.va_addr, src.va_addr, sz) \
311
+ .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
312
+ src_dev.timeline_value += 1
313
+
314
+ if src_dev != dest_dev:
315
+ dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
316
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
317
+ .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
318
+ dest_dev.timeline_value += 1
319
+
320
+ def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
tinygrad/dtype.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
2
2
  from dataclasses import dataclass
3
- import numpy as np # TODO: remove numpy
4
3
  import functools
5
4
  from tinygrad.helpers import getenv
6
5
 
@@ -18,9 +17,6 @@ class DType:
18
17
  assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
19
18
  return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
20
19
  def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
21
- # TODO: someday this will be removed with the "remove numpy" project
22
- @property
23
- def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
24
20
 
25
21
  # dependent typing?
26
22
  @dataclass(frozen=True, repr=False)
@@ -47,9 +43,13 @@ class dtypes:
47
43
  @staticmethod
48
44
  def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
49
45
  @staticmethod
50
- def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
51
- @staticmethod # NOTE: isinstance(True, int) is True in python
52
- def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
46
+ def from_py(x) -> DType:
47
+ if x.__class__ is float: return dtypes.default_float
48
+ if x.__class__ is int: return dtypes.default_int
49
+ if x.__class__ is bool: return dtypes.bool
50
+ # put this in the last is faster because there are more items than lists/tuples to check
51
+ if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
52
+ raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
53
53
  @staticmethod
54
54
  def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
55
55
  @staticmethod
tinygrad/engine/graph.py CHANGED
@@ -1,15 +1,14 @@
1
- import os, atexit, functools
1
+ import os, atexit, functools, contextlib
2
2
  from collections import defaultdict
3
- from typing import List, Any, DefaultDict
3
+ from typing import List, Any, DefaultDict, Union
4
4
  from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
5
5
  from tinygrad.device import Device
6
6
  from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
7
- from tinygrad.codegen.linearizer import UOps, UOp
7
+ from tinygrad.codegen.uops import UOps, UOp, UPat
8
8
  from tinygrad.shape.symbolic import NumNode
9
9
  from tinygrad.lazy import LazyBuffer
10
10
 
11
- try: import networkx as nx
12
- except ImportError: pass
11
+ with contextlib.suppress(ImportError): import networkx as nx
13
12
 
14
13
  # **** debugging and graphing ****
15
14
 
@@ -61,7 +60,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
61
60
  for idx,x in enumerate(lb.srcs):
62
61
  if nm(x) not in G.nodes: log_lazybuffer(x)
63
62
  if x.base.realized is None and x.base.op is LoadOps.CONST:
64
- label_append.append(f"\nCONST{idx} {x.base.arg}")
63
+ label_append.append(f"\nCONST{idx} {x.base.arg:g}")
65
64
  else:
66
65
  G.add_edge(nm(x), nm(lb), color='#a0a0a0')
67
66
  label = '"' + \
@@ -75,18 +74,19 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
75
74
  # realized but unseen?
76
75
  G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
77
76
 
78
- def _tree(lazyop:LazyOp, cycles, cnt, prefix=""):
77
+ def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
79
78
  cnt[0] += 1
80
- if len(lazyop.src) == 0: return [f"━━ {prefix}{lazyop.op.name} {lazyop.arg if lazyop.arg else ''}"]
81
- if (lid := id(lazyop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
82
- return [f"━⬆︎ goto {cycles[id(lazyop)][0]}: {lazyop.op.name}"]
79
+ src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
80
+ if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
81
+ if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
82
+ return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
83
83
  cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
84
- lines = [f"━┳ {prefix}{lazyop.op.name} {lazyop.arg if lazyop.arg else ''}"]
85
- childs = [_tree(c, cycles, cnt) for c in lazyop.src[:]]
84
+ lines = [f"━┳ {dag.op} {dag.arg}"]
85
+ childs = [_tree(c, cycles, cnt) for c in src]
86
86
  for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
87
87
  return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
88
88
 
89
- def print_tree(lazyop:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazyop, {}, [-1]))]))
89
+ def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
90
90
 
91
91
  def graph_uops(uops:List[UOp]):
92
92
  colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
@@ -94,7 +94,7 @@ def graph_uops(uops:List[UOp]):
94
94
  UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
95
95
  G = nx.DiGraph()
96
96
  for u in uops:
97
- if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
98
- G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
99
- for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
97
+ if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
98
+ G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
99
+ for v in u.src: G.add_edge(uops.index(v), uops.index(u))
100
100
  save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
tinygrad/engine/jit.py CHANGED
@@ -3,7 +3,7 @@ from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, O
3
3
  import functools, itertools, collections
4
4
  from tinygrad.tensor import Tensor
5
5
  from tinygrad.lazy import LazyBuffer
6
- from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
6
+ from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, ContextVar, GRAPH, BEAM, getenv, all_int, GraphException, colored, JIT
7
7
  from tinygrad.device import Buffer, Compiled, Device
8
8
  from tinygrad.dtype import DType
9
9
  from tinygrad.shape.shapetracker import ShapeTracker
@@ -41,7 +41,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
41
41
  if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
42
42
  ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
43
43
  if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
44
- elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
44
+ elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
45
45
  ji_graph_dev = Device[ji.bufs[0].device]
46
46
 
47
47
  graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
@@ -82,7 +82,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
82
82
  if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
83
83
  if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
84
84
  self.jc_idx_with_updatable_launch_dims.append(j)
85
- self.vars = list(var_vals.keys())
85
+ self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
86
86
  super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)
87
87
 
88
88
  class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
@@ -97,15 +97,16 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
97
97
  wait_nodes = []
98
98
 
99
99
  for rawbuf in read + write:
100
- if id(rawbuf._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf._buf)])
100
+ if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
101
101
  for rawbuf in write:
102
- if id(rawbuf._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf._buf)))
102
+ if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
103
103
 
104
- for rawbuf in read: self.r_dependency_map[id(rawbuf._buf)].append(new_dependency)
105
- for rawbuf in write: self.w_dependency_map[id(rawbuf._buf)] = new_dependency
104
+ for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
105
+ for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
106
106
  return list({id(x):x for x in wait_nodes}.values())
107
107
 
108
108
  ReturnType = TypeVar('ReturnType')
109
+ IN_JIT = ContextVar('IN_JIT', 0)
109
110
  class TinyJit(Generic[ReturnType]):
110
111
  def __init__(self, fxn:Callable[..., ReturnType]):
111
112
  self.fxn = fxn
@@ -134,25 +135,26 @@ class TinyJit(Generic[ReturnType]):
134
135
 
135
136
  def __call__(self, *args, **kwargs) -> ReturnType:
136
137
  input_tensors: List[Tuple[Union[int, str], Tensor]] = \
137
- [(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
138
- if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
139
- lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
140
- expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
141
- input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
142
- assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT"
143
- var_vals: Dict[Variable, int] = merge_dicts([x[1] for x in expected_sts_var_dtype_device] + \
144
- [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
145
-
146
- expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
147
- if self.cnt == 0:
138
+ [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
139
+ if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
140
+ names: List[Union[int, str]] = [name for name,_ in input_tensors]
141
+ lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
142
+ st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
143
+ input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
144
+ assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
145
+ var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
146
+ [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
147
+ st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
148
+ if not JIT or self.cnt == 0:
149
+ if IN_JIT: raise RuntimeError("having TinyJit inside another TinyJit is not supported")
148
150
  # jit ignore
149
- with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
151
+ with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value, IN_JIT=1):
150
152
  self.ret = self.fxn(*args, **kwargs)
151
153
  if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
152
154
  elif self.cnt == 1:
153
155
  # jit capture
154
- self.expected_names: List[Union[int, str]] = expected_names
155
- self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
156
+ self.expected_names: List[Union[int, str]] = names
157
+ self.expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = st_vars_dtype_device
156
158
  with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
157
159
  capturing.append(self)
158
160
  self.ret = self.fxn(*args, **kwargs)
@@ -160,31 +162,32 @@ class TinyJit(Generic[ReturnType]):
160
162
  capturing.clear()
161
163
  del self.buffer_replace
162
164
  assert len(self.jit_cache), "didn't JIT anything!"
163
- if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
165
+ if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_buffers)} inputs")
164
166
 
165
167
  # track inputs that are views of buffers
166
- for ji in self.jit_cache:
167
- for b in ji.bufs:
168
- if b is not None and b._base is not None and b._base in input_rawbuffers:
169
- input_rawbuffers.append(b)
170
- self.extra_view_inputs.append((input_rawbuffers.index(b.base), b.offset, b.device, b.size, b.dtype))
168
+ for item in self.jit_cache:
169
+ for b in item.bufs:
170
+ if b is not None and b._base is not None and b._base in input_buffers:
171
+ input_buffers.append(b)
172
+ self.extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
171
173
 
172
174
  # memory planning (optional)
173
- assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
174
- self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
175
+ assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in self.jit_cache], debug_prefix="JIT ")
176
+ self.jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in self.jit_cache]
175
177
 
176
178
  # Condense the items into a graph executor.
177
- if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
179
+ if JIT < 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals)
178
180
 
179
- self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
180
- if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
181
+ self.input_replace = get_input_replace(self.jit_cache, input_buffers)
182
+ if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
181
183
  elif self.cnt >= 2:
182
184
  # jit exec
183
- assert self.expected_names == expected_names, f"args mismatch in JIT: {self.expected_names=} != {expected_names}"
184
- assert self.expected_lbs == expected_lbs, f"args mismatch in JIT: {self.expected_lbs=} != {expected_lbs=}"
185
+ assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
186
+ assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
187
+ f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
185
188
  for idx, offset, device, size, dtype in self.extra_view_inputs:
186
- input_rawbuffers.append(Buffer(device, size, dtype, base=input_rawbuffers[idx], offset=offset).ensure_allocated())
187
- for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
189
+ input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
190
+ for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
188
191
  if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
189
192
  for ei in self.jit_cache: ei.run(var_vals, jit=True)
190
193
 
@@ -1,7 +1,7 @@
1
1
  from typing import List, Dict, Optional, cast, Generator, Tuple
2
2
  import time
3
3
  from dataclasses import dataclass, replace
4
- from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int
4
+ from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING
5
5
  from tinygrad.ops import BufferOps, LoadOps, LazyOp
6
6
  from tinygrad.device import Device, Buffer
7
7
  from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -38,7 +38,7 @@ def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
38
38
  if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
39
39
  # TODO: check the correctness inline once compare_linearizer is in core
40
40
  if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
41
- if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
41
+ if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
42
42
  return k
43
43
 
44
44
  # **************** Runners ****************
@@ -101,8 +101,9 @@ class BufferCopy(Runner):
101
101
  else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}"
102
102
  super().__init__(colored(name, "yellow"), dest_device, 0, total_sz)
103
103
  def copy(self, dest, src):
104
- if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
105
- dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
104
+ disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.device, 'io_uring') and hasattr(src.allocator.device, 'fd')
105
+ if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_disk') and disk_supports_fast_copyout and src.nbytes >= 4096:
106
+ dest.allocator.copy_from_disk(dest._buf, src._buf, src.nbytes)
106
107
  elif src.device.startswith("DISK") and hasattr(dest.allocator, 'as_buffer'):
107
108
  # fast(ish) path, uses readinto in diskbuffers
108
109
  src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
@@ -187,5 +188,5 @@ capturing: List = [] # put classes with an add method in here
187
188
 
188
189
  def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None, do_update_stats=True):
189
190
  for ei in lower_schedule(schedule):
190
- if len(capturing): capturing[0].add(ei)
191
+ if len(capturing) and CAPTURING: capturing[0].add(ei)
191
192
  ei.run(var_vals, do_update_stats=do_update_stats)
@@ -1,12 +1,12 @@
1
1
  import sys, pickle, atexit
2
2
  from collections import defaultdict, deque
3
3
  from dataclasses import dataclass
4
- from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
4
+ from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
5
5
  from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
6
6
  from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
7
- from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
7
+ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
8
8
  from tinygrad.shape.symbolic import Variable
9
- from tinygrad.dtype import ImageDType, dtypes, DType
9
+ from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
10
10
  from tinygrad.lazy import LazyBuffer
11
11
  from tinygrad.shape.shapetracker import ShapeTracker
12
12
  from tinygrad.device import Buffer
@@ -56,8 +56,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
56
56
  if buf.op is LoadOps.CONST:
57
57
  unbound_st, st_var_vals = st.simplify().unbind()
58
58
  var_vals.update(st_var_vals)
59
- if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
60
- return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
59
+ if isinstance(buf.arg, Variable):
60
+ val, var_val = buf.arg.unbind()
61
+ var_vals.__setitem__(val, var_val)
62
+ else:
63
+ assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
64
+ val = buf.arg
65
+ return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
61
66
 
62
67
  # if we aren't fusing it, it's a load and we add it to the inputs
63
68
  if buf.realized is not None or (buf in realizes and buf not in outputs):
@@ -69,7 +74,8 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
69
74
  # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
70
75
  if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
71
76
  ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
72
- raise RuntimeError(f"must be contiguous for assign {unbound_st}")
77
+ raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
78
+ +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
73
79
  return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
74
80
  if buf not in inputs: inputs.append(buf)
75
81
  return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
@@ -138,6 +144,8 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
138
144
  pass # don't realize image to image casts. this is part of a larger problem
139
145
  else:
140
146
  realizes[buf.base] = None
147
+ # check all other pads for safe fusion
148
+ elif any(v.mask is not None for v in buf.st.views): simple_pads.add(buf.base)
141
149
  return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
142
150
  # base
143
151
  allbufs[buf] = None
@@ -308,7 +316,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
308
316
  if SAVE_SCHEDULE:
309
317
  def _save():
310
318
  print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
311
- pickle.dump(SCHEDULES, open(fp, "wb"))
319
+ with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
312
320
  if len(SCHEDULES) == 0: atexit.register(_save)
313
321
  SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
314
322
  # confirm everything was scheduled correctly
tinygrad/engine/search.py CHANGED
@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals, rawbufs, early_stop=None, max_
45
45
  input_bufs = [rawbufs[i] for i,_ in car.p.globals]
46
46
  for _ in range(cnt):
47
47
  if clear_l2:
48
- with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
48
+ with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
49
49
  tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
50
50
  if early_stop is not None and early_stop < tms[-1]: break
51
51
  return tms
@@ -70,6 +70,9 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
70
70
  ret = None
71
71
  except TimeoutException:
72
72
  ret = None
73
+ except Exception as e:
74
+ if getenv("BEAM_STRICT_MODE"): raise e
75
+ ret = None
73
76
  finally:
74
77
  signal.alarm(0)
75
78
  return x[0], ret
@@ -115,7 +118,7 @@ beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
115
118
  def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=True) -> Linearizer:
116
119
  global beam_pool
117
120
  key = {"ast": lin.ast[0].key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
118
- if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
121
+ if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None:
119
122
  ret = lin.copy()
120
123
  for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
121
124
  return ret
@@ -123,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
123
126
  beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
124
127
  seen_libs = set()
125
128
 
126
- default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
129
+ default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
127
130
  if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
128
131
  beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
129
132