tinygrad 0.8.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/__init__.py +0 -0
  3. tinygrad/codegen/kernel.py +253 -225
  4. tinygrad/codegen/linearizer.py +398 -436
  5. tinygrad/codegen/uops.py +451 -0
  6. tinygrad/device.py +268 -274
  7. tinygrad/dtype.py +56 -40
  8. tinygrad/engine/__init__.py +0 -0
  9. tinygrad/engine/graph.py +100 -0
  10. tinygrad/engine/jit.py +198 -0
  11. tinygrad/engine/realize.py +192 -0
  12. tinygrad/engine/schedule.py +370 -0
  13. tinygrad/engine/search.py +199 -0
  14. tinygrad/{mlops.py → function.py} +40 -32
  15. tinygrad/helpers.py +144 -46
  16. tinygrad/lazy.py +143 -242
  17. tinygrad/multi.py +173 -0
  18. tinygrad/nn/__init__.py +180 -9
  19. tinygrad/nn/datasets.py +8 -0
  20. tinygrad/nn/optim.py +106 -28
  21. tinygrad/nn/state.py +87 -19
  22. tinygrad/ops.py +104 -45
  23. tinygrad/renderer/__init__.py +65 -0
  24. tinygrad/renderer/assembly.py +269 -0
  25. tinygrad/renderer/cstyle.py +308 -210
  26. tinygrad/renderer/llvmir.py +119 -124
  27. tinygrad/runtime/__init__.py +0 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +13403 -0
  29. tinygrad/runtime/autogen/comgr.py +891 -0
  30. tinygrad/runtime/autogen/cuda.py +5923 -0
  31. tinygrad/runtime/autogen/hip.py +5909 -0
  32. tinygrad/runtime/autogen/hsa.py +5893 -0
  33. tinygrad/runtime/autogen/io_uring.py +1486 -0
  34. tinygrad/runtime/autogen/kfd.py +812 -0
  35. tinygrad/runtime/autogen/nv_gpu.py +33597 -0
  36. tinygrad/runtime/autogen/opencl.py +1795 -0
  37. tinygrad/runtime/driver/__init__.py +0 -0
  38. tinygrad/runtime/driver/hip_comgr.py +56 -0
  39. tinygrad/runtime/graph/__init__.py +0 -0
  40. tinygrad/runtime/graph/clang.py +39 -0
  41. tinygrad/runtime/graph/cuda.py +59 -54
  42. tinygrad/runtime/graph/hcq.py +187 -0
  43. tinygrad/runtime/graph/metal.py +37 -41
  44. tinygrad/runtime/ops_amd.py +550 -0
  45. tinygrad/runtime/ops_clang.py +16 -14
  46. tinygrad/runtime/ops_cuda.py +129 -37
  47. tinygrad/runtime/ops_disk.py +111 -43
  48. tinygrad/runtime/ops_gpu.py +52 -50
  49. tinygrad/runtime/ops_llvm.py +36 -56
  50. tinygrad/runtime/ops_metal.py +41 -24
  51. tinygrad/runtime/ops_npy.py +9 -0
  52. tinygrad/runtime/ops_nv.py +625 -0
  53. tinygrad/runtime/ops_python.py +208 -0
  54. tinygrad/shape/__init__.py +0 -0
  55. tinygrad/shape/shapetracker.py +46 -107
  56. tinygrad/shape/symbolic.py +99 -98
  57. tinygrad/shape/view.py +162 -45
  58. tinygrad/tensor.py +2492 -483
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +1 -1
  60. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +31 -13
  61. tinygrad-0.9.1.dist-info/RECORD +63 -0
  62. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  63. tinygrad/features/image.py +0 -93
  64. tinygrad/features/multi.py +0 -103
  65. tinygrad/features/search.py +0 -160
  66. tinygrad/graph.py +0 -106
  67. tinygrad/jit.py +0 -152
  68. tinygrad/realize.py +0 -50
  69. tinygrad/runtime/graph/hip.py +0 -24
  70. tinygrad/runtime/ops_cpu.py +0 -45
  71. tinygrad/runtime/ops_hip.py +0 -97
  72. tinygrad/runtime/ops_torch.py +0 -49
  73. tinygrad-0.8.0.dist-info/RECORD +0 -41
  74. {tinygrad-0.8.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,326 +1,320 @@
1
1
  from __future__ import annotations
2
- import numpy as np
2
+ import multiprocessing
3
+ from dataclasses import dataclass
3
4
  from collections import defaultdict
4
- from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
5
- import importlib, inspect, functools, pathlib, time, re, ctypes
6
- from tinygrad.dtype import DType, dtypes, ImageDType
7
- from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
8
- from tinygrad.shape.symbolic import Variable, sym_infer, sint
9
- from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
10
-
11
- if TYPE_CHECKING:
12
- from tinygrad.codegen.linearizer import Linearizer
13
- from tinygrad.codegen.kernel import LinearizerOptions
5
+ from typing import List, Optional, Dict, Tuple, Any, cast
6
+ import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
7
+ from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
8
+ from tinygrad.dtype import DType, ImageDType
9
+ from tinygrad.renderer import Renderer
14
10
 
15
11
  # **************** Device ****************
16
12
 
17
13
  class _Device:
18
14
  def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
19
- def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT # noqa: E501
20
- def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix))
21
15
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
22
- def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
16
+ def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
17
+ # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
+ def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
+ def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
+ @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
+ def __get_canonicalized_item(self, ix:str) -> Compiled:
22
+ assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
23
+ f"can only open device {ix} from parent, not {cpn}"
23
24
  x = ix.split(":")[0].upper()
24
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0] # noqa: E501
25
- if isinstance(ret, type): ret = ret(ix)
25
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
26
+ if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
26
27
  return ret
27
28
  @functools.cached_property
28
29
  def DEFAULT(self) -> str:
29
30
  device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
30
31
  if device_from_env: return device_from_env
31
- for device in ["METAL", "CUDA", "HIP", "GPU"]:
32
+ for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
32
33
  try:
33
- if self[device]: return device
34
+ if self[device]:
35
+ os.environ[device] = "1" # we set this in environment for spawned children
36
+ return device
34
37
  except Exception: pass
35
- return "CPU"
38
+ raise RuntimeError("no usable devices")
36
39
  Device = _Device()
37
40
 
38
- # **************** base Runner + helpers ****************
39
-
40
- class JITRunner:
41
- def __init__(self):
42
- self.op_estimate, self.mem_estimate = 0, 0
43
- def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
44
- var_vals = var_vals if var_vals is not None else {}
45
- from tinygrad.jit import CacheCollector
46
- et = self(rawbufs, var_vals)
47
- CacheCollector.add(self, rawbufs, var_vals)
48
- return et
49
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
50
- raise NotImplementedError("override this")
51
-
52
- def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): # noqa: E501
53
- if var_vals is None: var_vals = {}
54
- op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
55
- GlobalCounters.kernel_count += num_kernels
56
- GlobalCounters.global_ops += op_estimate
57
- GlobalCounters.global_mem += mem_estimate
58
- if et is not None: GlobalCounters.time_sum_s += et
59
- if DEBUG >= 2:
60
- ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
61
- print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
62
- (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
63
-
64
- # **************** Buffer / Allocator ****************
41
+ # **************** Buffer + Allocators ****************
42
+
43
+ @dataclass(frozen=True, eq=True)
44
+ class BufferOptions:
45
+ image: Optional[ImageDType] = None
46
+ uncached: bool = False
47
+ cpu_access: bool = False
48
+ host: bool = False
49
+ nolru: bool = False
65
50
 
66
51
  class Buffer:
67
- def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None):
52
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
53
+ initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
68
54
  assert isinstance(dtype, DType)
69
- self.device, self.size, self.dtype = device, size, dtype
55
+ if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
56
+ self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
57
+ if base is None:
58
+ assert offset == 0, "base buffers can't have offset"
59
+ self._base = None
60
+ self._lb_refcount = lb_refcount
61
+ if opaque is not None: self.allocate(opaque)
62
+ if initial_value is not None:
63
+ self.allocate()
64
+ self.copyin(memoryview(initial_value))
65
+ else:
66
+ assert base._base is None, "base can't have a base"
67
+ assert device == base.device, "base must have the same device"
68
+ self._base = base
69
+ if preallocate: self.allocate()
70
+ @property
71
+ def base(self) -> Buffer: return self._base if self._base is not None else self
72
+ @property
73
+ def lb_refcount(self): return self.base._lb_refcount
74
+ def ref(self, cnt): self.base._lb_refcount += cnt
75
+ def is_allocated(self) -> bool: return hasattr(self, '_buf')
76
+ def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
77
+ def allocate(self, opaque=None) -> Buffer:
78
+ assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
70
79
  self.allocator = Device[self.device].allocator
71
- # TODO: image hack shouldn't be here. where should it be?
72
- self._buf = opaque if opaque is not None else self.allocator.alloc(dtype if isinstance(dtype, ImageDType) else size * dtype.itemsize)
73
- # TODO: mem_used for all devices
74
- if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.size * self.dtype.itemsize
80
+ if self._base is not None:
81
+ self._base.ensure_allocated()
82
+ assert hasattr(self.allocator, "offset"), "offset function required for view"
83
+ self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
84
+ else:
85
+ self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
86
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
87
+ return self
88
+ def __reduce__(self):
89
+ buf = None
90
+ if self._base is not None:
91
+ return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
92
+ if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
93
+ if self.is_allocated():
94
+ buf = bytearray(self.nbytes)
95
+ self.copyout(memoryview(buf))
96
+ return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
97
+ @property
98
+ def nbytes(self): return self.size*self.dtype.itemsize
75
99
  def __del__(self):
76
- if not hasattr(self, '_buf'): return # happens when __init__ has raised exception
77
- if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.size * self.dtype.itemsize
78
- if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
79
- else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
80
- def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
100
+ if not hasattr(self, '_buf'): return
101
+ if self._base is None:
102
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
103
+ self.allocator.free(self._buf, self.nbytes, self.options)
104
+ def __repr__(self):
105
+ return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
106
+ (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
107
+ (">" if self.options is None else f" {self.options=}>")
108
+ def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
109
+ # zero copy with as_buffer (disabled by default due to use after free)
110
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
111
+ assert not force_zero_copy, "force zero copy was passed, but copy is required"
112
+ return self.copyout(memoryview(bytearray(self.nbytes)))
81
113
  def copyin(self, mv:memoryview):
82
114
  mv = flat_mv(mv)
83
- assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
115
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
116
+ assert self.is_allocated(), "can't copyin to unallocated buffer"
84
117
  self.allocator.copyin(self._buf, mv)
85
118
  return self
86
- @staticmethod
87
- def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
88
- def toCPU(self) -> np.ndarray:
89
- # zero copy with as_buffer
90
- if hasattr(self.allocator, 'as_buffer'):
91
- return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
92
- ret = np.empty(self.size, self.dtype.np)
93
- if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
94
- return ret
95
-
96
- def _internal_buffer_copy(dest:Buffer, src:Buffer):
97
- if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
98
- # fast path, used on HIP between GPUs
99
- # NOTE: it's important we use the dest device here to ensure the transfer is ready
100
- Device[src.device].synchronize() # TODO: async this
101
- dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
102
- return
103
- if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
104
- # fast path, used on Metal in OS X Sonoma
105
- # NOTE: this is *only* faster if the pages from disk are already loaded into memory
106
- fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
107
- if fb:
108
- dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
109
- return
110
- if hasattr(dest.allocator, 'copy_from_fd') and src.device.startswith("DISK") and src.size*src.dtype.itemsize >= 4096 and src._buf.ud.fd is not None:
111
- dest.allocator.copy_from_fd(dest._buf, src._buf.ud.fd, src._buf.offset, src.size*src.dtype.itemsize)
112
- elif hasattr(dest.allocator, 'as_buffer'):
113
- # fast(ish) path, uses readinto in diskbuffers
114
- src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
115
- elif hasattr(src.allocator, 'as_buffer'):
116
- dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
117
- else:
118
- # slow path, allocates a CPU buffer
119
- dest.copyin(src.toCPU().data)
120
-
121
- class _BufferCopy(JITRunner):
122
- # TODO: make wait work
123
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
124
- dest, src = rawbufs
125
- assert dest.size == src.size, f"buffer copy size mismatch, {dest.size} != {src.size}"
126
- assert dest.dtype == src.dtype, f"buffer copy dtype mismatch, {dest.dtype} != {src.dtype}"
127
- st = time.perf_counter()
128
- _internal_buffer_copy(dest, src)
129
- et = None
130
- if wait or DEBUG >= 2:
131
- Device[dest.device].synchronize()
132
- et = time.perf_counter() - st
133
- update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, device=dest.device) # noqa: E501
134
- BufferCopy = _BufferCopy()
119
+ def copyout(self, mv:memoryview) -> memoryview:
120
+ mv = flat_mv(mv)
121
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
122
+ assert self.is_allocated(), "can't copyout unallocated buffer"
123
+ self.allocator.copyout(mv, self._buf)
124
+ return mv
125
+ def view(self, size:int, dtype:DType, offset:int) -> Buffer:
126
+ assert offset < self.nbytes, "offset must be less than nbytes"
127
+ if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
128
+ return Buffer(self.device, size, dtype, base=self, offset=offset)
135
129
 
136
130
  # TODO: size, dest, src are the same type. can we enforce this?
137
- sz_type = Union[ImageDType, int]
138
131
  class Allocator:
139
- def alloc(self, size:sz_type):
132
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
140
133
  assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
141
- return self._alloc_image(size) if isinstance(size, ImageDType) else self._alloc(size)
142
- def _alloc(self, size:int): raise NotImplementedError("need alloc")
143
- def _alloc_image(self, dtype:ImageDType): raise RuntimeError("need alloc image")
144
- def free(self, opaque, size:sz_type): self._free(opaque) # if you are returning a Python object, you don't need a free
145
- def _free(self, opaque): pass
134
+ return self._alloc(size, options if options is not None else BufferOptions())
135
+ def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
136
+ def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
137
+ self._free(opaque, options if options is not None else BufferOptions())
138
+ def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
146
139
  def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
147
140
  def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
148
141
 
149
142
  class LRUAllocator(Allocator): # pylint: disable=abstract-method
150
- def __init__(self): self.cache: Dict[sz_type, Any] = defaultdict(list)
151
- def alloc(self, size:sz_type):
152
- if len(c := self.cache[size]): return c.pop()
153
- try:
154
- return super().alloc(size)
155
- except MemoryError:
143
+ def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
144
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
145
+ if len(c := self.cache[(size, options)]): return c.pop()
146
+ try: return super().alloc(size, options)
147
+ except (RuntimeError, MemoryError):
156
148
  self.free_cache()
157
- return super().alloc(size)
149
+ return super().alloc(size, options)
158
150
  def free_cache(self):
159
- for opaques in self.cache.values():
160
- for opaque in opaques: self._free(opaque)
151
+ for (sz,options),opaques in self.cache.items():
152
+ for opaque in opaques: super().free(opaque, sz, options)
161
153
  opaques.clear()
162
- def free(self, opaque:Any, size:sz_type):
163
- if getenv("LRU", 1): self.cache[size].append(opaque)
164
- else: self._free(opaque)
154
+ def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
155
+ if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
156
+ else: super().free(opaque, size, options)
165
157
 
166
158
  class _MallocAllocator(LRUAllocator):
167
- def _alloc(self, size:int): return (ctypes.c_uint8 * size)()
159
+ def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
168
160
  def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
169
161
  def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
170
162
  def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
171
- MallocAllocator = _MallocAllocator()
172
-
173
- # **************** for Interpreted Devices ****************
174
-
175
- class InterpretedASTRunner(JITRunner):
176
- def __init__(self, ast:LazyOp, fxn:Callable):
177
- super().__init__()
178
- self.fxn = fxn
179
- info = get_lazyop_info(ast)
180
- self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
181
-
182
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
183
- st = time.perf_counter()
184
- rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
185
- et = time.perf_counter() - st
186
- update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
187
- return et
188
-
189
- class Interpreted:
190
- def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
191
- self.allocator, self.fxn_for_op = allocator, fxn_for_op
192
- self.synchronize, self.codegen, self.graph = lambda: None, None, None
193
-
194
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
195
- def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast)
196
-
197
- def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
198
- if DEBUG >= 3:
199
- from tinygrad.graph import print_tree
200
- print_tree(ast)
201
- tglob: Dict[str, Any] = {"Variable": Variable}
202
-
203
- @functools.lru_cache(None)
204
- def gstr(x:Any, nm=None) -> str:
205
- if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
206
- str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
207
- # TODO: (Variable - Variable) might create NumNode. can we remove it?
208
- return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
209
- ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
210
- tglob[ret] = x
211
- return ret
212
-
213
- lines: List[str] = []
214
- @functools.lru_cache(None)
215
- def _interpret_ast(ast:LazyOp) -> str:
216
- # TODO: shortcutted store won't work with strides
217
- if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
218
- if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
219
- if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
220
- if (castop:=ast.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
221
- # MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
222
- ast = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), ast.arg)
223
-
224
- if ast.op in BufferOps:
225
- if ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
226
- else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
227
- for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
228
- else:
229
- tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
230
-
231
- ret = f"a{len(lines)}"
232
- lines.append(f" {ret} = {tmp}")
233
- return ret
163
+ def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
234
164
 
235
- ret = _interpret_ast(ast)
236
- src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {ret}"])
237
- if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
238
- exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
239
- return InterpretedASTRunner(ast, tglob['run'])
165
+ MallocAllocator = _MallocAllocator()
240
166
 
241
167
  # **************** for Compiled Devices ****************
242
168
 
243
- class CompiledASTRunner(JITRunner):
244
- def __init__(self, ast:Optional[LazyOp], name:str, prg:str, lib:bytes, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None): # noqa: E501
245
- super().__init__()
246
- if DEBUG >= 4: print(prg)
247
- if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
248
- if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
249
- self.name, self.display_name, self.prg, self.lib, self.global_size, self.local_size, self.first_run = \
250
- to_function_name(name), name, prg, lib, global_size, local_size, True
251
- self.vars: List[Variable] = []
252
- if ast:
253
- info = get_lazyop_info(ast)
254
- self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
255
- self.vars = ast.vars()
256
- assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
257
-
258
- def build(self, runtime):
259
- self.clprg = runtime(self.name, self.lib)
260
- return self
169
+ class CompileError(Exception): pass
261
170
 
262
- def launch_dims(self, var_vals):
263
- global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
264
- local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
265
- return global_size, local_size
266
-
267
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]:
268
- global_size, local_size = self.launch_dims(var_vals)
269
- if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
270
- # TODO: this is copied from get_program
271
- from tinygrad.features.search import optimize_local_size
272
- local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
273
- global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
274
- lra = {}
275
- if global_size: lra['global_size'] = global_size
276
- if local_size: lra['local_size'] = local_size
277
- et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
278
- if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) # noqa: E501
279
- self.first_run = False
280
- return et
171
+ class Compiler:
172
+ def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
173
+ def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
174
+ def compile_cached(self, src:str) -> bytes:
175
+ if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
176
+ assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
177
+ lib = self.compile(src)
178
+ if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
179
+ return lib
281
180
 
282
181
  class Compiled:
283
- def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None):
284
- self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph # noqa: E501
182
+ def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
183
+ self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
184
+ self.renderer = renderer or Renderer()
285
185
  def synchronize(self): pass # override this in your device
286
186
 
287
- def to_program(self, k:Linearizer) -> CompiledASTRunner:
288
- assert self.compiler is not None, f"compiler is None, can't build {k.ast}"
289
- k.linearize()
290
- src = self.renderer(to_function_name(k.name), k.uops)
291
- if getenv("DISABLE_COMPILER_CACHE") or '<' in self.compiler.__name__:
292
- lib = self.compiler(src)
293
- else:
294
- lib = diskcache_get(self.compiler.__name__, src)
295
- if lib is None:
296
- lib = self.compiler(src)
297
- diskcache_put(self.compiler.__name__, src, lib)
298
- return CompiledASTRunner(k.ast, k.name, src, lib, k.global_size, k.local_size).build(self.runtime)
299
-
300
- def get_linearizer(self, ast:LazyOp) -> Linearizer:
301
- if DEBUG >= 3:
302
- from tinygrad.graph import print_tree
303
- print_tree(ast)
304
- from tinygrad.codegen.linearizer import Linearizer
305
- k = Linearizer(ast, self.linearizer_opts)
306
- k.required_optimizations()
307
- if not NOOPT:
308
- if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
309
- if BEAM >= 1:
310
- lins = [(("tc" if used_tensor_cores else "hc"), k)]
311
- if used_tensor_cores:
312
- lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
313
- lins[-1][1].hand_coded_optimizations()
314
- kb = Linearizer(ast, self.linearizer_opts)
315
- kb.required_optimizations()
316
- from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
317
- # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
318
- test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
319
- lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
320
- timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
321
- if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
322
- k = timed[0][1]
323
- return k
324
-
325
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
326
- def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast))
187
+ # **************** for HCQ Compatible Devices ****************
188
+
189
+ @contextlib.contextmanager
190
+ def hcq_profile(dev, queue_type, enabled, desc):
191
+ st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
192
+ if enabled: queue_type().timestamp(st).submit(dev)
193
+ try: yield (st, en)
194
+ finally:
195
+ if enabled: queue_type().timestamp(en).submit(dev)
196
+ if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
197
+
198
+ class HCQCompatCompiled(Compiled):
199
+ def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
200
+ self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
201
+ self.timeline_value: int = 1
202
+ self.timeline_signal, self._shadow_timeline_signal = timeline_signals
203
+ self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
204
+ self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
205
+ if PROFILE: self._prof_setup()
206
+
207
+ from tinygrad.runtime.graph.hcq import HCQGraph
208
+ super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
209
+
210
+ @classmethod
211
+ def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
212
+
213
+ @classmethod
214
+ def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
215
+
216
+ @classmethod
217
+ def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
218
+
219
+ @classmethod
220
+ def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
221
+
222
+ @classmethod
223
+ def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
224
+
225
+ def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
226
+
227
+ def _prof_setup(self):
228
+ self.profile_logger = ProfileLogger()
229
+
230
+ def _sync_queue(q_t):
231
+ q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
232
+ self.timeline_value += 1
233
+ cpu_start_time = time.perf_counter_ns() / 1e3
234
+ self._wait_signal(self.timeline_signal, self.timeline_value - 1)
235
+ return cpu_start_time, self._read_timestamp(self.timeline_signal)
236
+ self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
237
+ self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
238
+
239
+ atexit.register(self._prof_finalize)
240
+
241
+ def _prof_process_events(self):
242
+ self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
243
+ for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
244
+ self.sig_prof_records = []
245
+
246
+ def _prof_finalize(self):
247
+ for st, en, name, is_cp in self.raw_prof_records:
248
+ self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
249
+ del self.profile_logger
250
+
251
+ def _wrap_timeline_signal(self):
252
+ self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
253
+ self._set_signal(self.timeline_signal, 0)
254
+ cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
255
+
256
+ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
257
+ def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
258
+ self.device = device
259
+ self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
260
+ self.b_timeline, self.b_next = [0] * len(self.b), 0
261
+ super().__init__()
262
+
263
+ def copyin(self, dest, src: memoryview):
264
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
265
+ for i in range(0, src.nbytes, self.b[0].size):
266
+ self.b_next = (self.b_next + 1) % len(self.b)
267
+ self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
268
+ ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
269
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
270
+ .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
271
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
272
+ self.b_timeline[self.b_next] = self.device.timeline_value
273
+ self.device.timeline_value += 1
274
+
275
+ def copy_from_disk(self, dest, src, size):
276
+ def _get_temp_buf():
277
+ # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
278
+ if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
279
+ self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
280
+ return (self.b[self.b_next].va_addr, self.b_next)
281
+ return None
282
+
283
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
284
+ for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
285
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
286
+ .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
287
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
288
+ self.b_timeline[batch_info[1]] = self.device.timeline_value
289
+ self.device.timeline_value += 1
290
+
291
+ def copyout(self, dest:memoryview, src):
292
+ self.device.synchronize()
293
+
294
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
295
+ for i in range(0, dest.nbytes, self.b[0].size):
296
+ self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
297
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
298
+ .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
299
+ self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
300
+ self.device.timeline_value += 1
301
+
302
+ ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
303
+
304
+ def transfer(self, dest, src, sz: int, src_dev, dest_dev):
305
+ src_dev._gpu_map(dest)
306
+
307
+ with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
308
+ src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
309
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
310
+ .copy(dest.va_addr, src.va_addr, sz) \
311
+ .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
312
+ src_dev.timeline_value += 1
313
+
314
+ if src_dev != dest_dev:
315
+ dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
316
+ .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
317
+ .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
318
+ dest_dev.timeline_value += 1
319
+
320
+ def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)