tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,326 +1,183 @@
1
1
  from __future__ import annotations
2
- import numpy as np
2
+ import multiprocessing
3
+ from dataclasses import dataclass
3
4
  from collections import defaultdict
4
- from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
5
- import importlib, inspect, functools, pathlib, time, re, ctypes
6
- from tinygrad.dtype import DType, dtypes, ImageDType
7
- from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
8
- from tinygrad.shape.symbolic import Variable, sym_infer, sint
9
- from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, UnaryOps, Op, GlobalCounters
10
-
11
- if TYPE_CHECKING:
12
- from tinygrad.codegen.linearizer import Linearizer
13
- from tinygrad.codegen.kernel import LinearizerOptions
5
+ from typing import List, Optional, Dict, Tuple, Any
6
+ import importlib, inspect, functools, pathlib, os, ctypes
7
+ from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
8
+ from tinygrad.dtype import DType, ImageDType
9
+ from tinygrad.renderer import Renderer
14
10
 
15
11
  # **************** Device ****************
16
12
 
17
13
  class _Device:
18
14
  def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
19
- def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT # noqa: E501
20
- def __getitem__(self, ix:str) -> Union[Interpreted, Compiled]: return self.__get_canonicalized_item(self.canonicalize(ix))
21
15
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
22
- def __get_canonicalized_item(self, ix:str) -> Union[Interpreted, Compiled]:
16
+ def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
17
+ # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
+ def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
+ def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
+ @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
+ def __get_canonicalized_item(self, ix:str) -> Compiled:
22
+ if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
23
+ assert multiprocessing.current_process().name == "MainProcess" or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent"
23
24
  x = ix.split(":")[0].upper()
24
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0] # noqa: E501
25
- if isinstance(ret, type): ret = ret(ix)
26
- return ret
25
+ return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
27
26
  @functools.cached_property
28
27
  def DEFAULT(self) -> str:
29
28
  device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
30
29
  if device_from_env: return device_from_env
31
- for device in ["METAL", "CUDA", "HIP", "GPU"]:
30
+ for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
32
31
  try:
33
- if self[device]: return device
32
+ if self[device]:
33
+ os.environ[device] = "1" # we set this in environment for spawned children
34
+ return device
34
35
  except Exception: pass
35
- return "CPU"
36
+ raise RuntimeError("no usable devices")
36
37
  Device = _Device()
37
38
 
38
- # **************** base Runner + helpers ****************
39
-
40
- class JITRunner:
41
- def __init__(self):
42
- self.op_estimate, self.mem_estimate = 0, 0
43
- def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
44
- var_vals = var_vals if var_vals is not None else {}
45
- from tinygrad.jit import CacheCollector
46
- et = self(rawbufs, var_vals)
47
- CacheCollector.add(self, rawbufs, var_vals)
48
- return et
49
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
50
- raise NotImplementedError("override this")
51
-
52
- def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): # noqa: E501
53
- if var_vals is None: var_vals = {}
54
- op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
55
- GlobalCounters.kernel_count += num_kernels
56
- GlobalCounters.global_ops += op_estimate
57
- GlobalCounters.global_mem += mem_estimate
58
- if et is not None: GlobalCounters.time_sum_s += et
59
- if DEBUG >= 2:
60
- ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
61
- print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
62
- (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
39
+ # **************** Buffer + Allocators ****************
63
40
 
64
- # **************** Buffer / Allocator ****************
41
+ @dataclass(frozen=True, eq=True)
42
+ class BufferOptions:
43
+ image: Optional[ImageDType] = None
44
+ uncached: bool = False
45
+ cpu_access: bool = False
46
+ host: bool = False
47
+ nolru: bool = False
65
48
 
66
49
  class Buffer:
67
- def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None):
50
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
51
+ initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
68
52
  assert isinstance(dtype, DType)
69
- self.device, self.size, self.dtype = device, size, dtype
53
+ if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
54
+ self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
55
+ if base is None:
56
+ assert offset == 0, "base buffers can't have offset"
57
+ self._base = None
58
+ self._lb_refcount = lb_refcount
59
+ if opaque is not None: self.allocate(opaque)
60
+ if initial_value is not None:
61
+ self.allocate()
62
+ self.copyin(memoryview(initial_value))
63
+ else:
64
+ assert base._base is None, "base can't have a base"
65
+ assert device == base.device, "base must have the same device"
66
+ self._base = base
67
+ if preallocate: self.allocate()
68
+ @property
69
+ def base(self) -> Buffer: return self._base if self._base is not None else self
70
+ @property
71
+ def lb_refcount(self): return self.base._lb_refcount
72
+ def ref(self, cnt): self.base._lb_refcount += cnt
73
+ def is_allocated(self) -> bool: return hasattr(self, '_buf')
74
+ def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
75
+ def allocate(self, opaque=None) -> Buffer:
76
+ assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
70
77
  self.allocator = Device[self.device].allocator
71
- # TODO: image hack shouldn't be here. where should it be?
72
- self._buf = opaque if opaque is not None else self.allocator.alloc(dtype if isinstance(dtype, ImageDType) else size * dtype.itemsize)
73
- # TODO: mem_used for all devices
74
- if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.size * self.dtype.itemsize
78
+ if self._base is not None:
79
+ self._base.ensure_allocated()
80
+ assert hasattr(self.allocator, "offset"), "offset function required for view"
81
+ self._buf: Any = self.allocator.offset(self.base._buf, self.nbytes, self.offset)
82
+ else:
83
+ self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
84
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
85
+ return self
86
+ def __reduce__(self):
87
+ buf = None
88
+ if self._base is not None:
89
+ return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
90
+ if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
91
+ if self.is_allocated():
92
+ buf = bytearray(self.nbytes)
93
+ self.copyout(memoryview(buf))
94
+ return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
95
+ @property
96
+ def nbytes(self): return self.size*self.dtype.itemsize
75
97
  def __del__(self):
76
- if not hasattr(self, '_buf'): return # happens when __init__ has raised exception
77
- if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.size * self.dtype.itemsize
78
- if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
79
- else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
80
- def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
98
+ if not hasattr(self, '_buf'): return
99
+ if self._base is None:
100
+ if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
101
+ self.allocator.free(self._buf, self.nbytes, self.options)
102
+ def __repr__(self):
103
+ return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
104
+ (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
105
+ (">" if self.options is None else f" {self.options=}>")
106
+ def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
107
+ # zero copy with as_buffer (disabled by default due to use after free)
108
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
109
+ assert not force_zero_copy, "force zero copy was passed, but copy is required"
110
+ return self.copyout(memoryview(bytearray(self.nbytes)))
81
111
  def copyin(self, mv:memoryview):
82
112
  mv = flat_mv(mv)
83
- assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
113
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
114
+ assert self.is_allocated(), "can't copyin to unallocated buffer"
84
115
  self.allocator.copyin(self._buf, mv)
85
116
  return self
86
- @staticmethod
87
- def fromCPU(device:str, x:np.ndarray): return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
88
- def toCPU(self) -> np.ndarray:
89
- # zero copy with as_buffer
90
- if hasattr(self.allocator, 'as_buffer'):
91
- return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
92
- ret = np.empty(self.size, self.dtype.np)
93
- if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
94
- return ret
95
-
96
- def _internal_buffer_copy(dest:Buffer, src:Buffer):
97
- if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): # noqa: E721
98
- # fast path, used on HIP between GPUs
99
- # NOTE: it's important we use the dest device here to ensure the transfer is ready
100
- Device[src.device].synchronize() # TODO: async this
101
- dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
102
- return
103
- if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
104
- # fast path, used on Metal in OS X Sonoma
105
- # NOTE: this is *only* faster if the pages from disk are already loaded into memory
106
- fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
107
- if fb:
108
- dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
109
- return
110
- if hasattr(dest.allocator, 'copy_from_fd') and src.device.startswith("DISK") and src.size*src.dtype.itemsize >= 4096 and src._buf.ud.fd is not None:
111
- dest.allocator.copy_from_fd(dest._buf, src._buf.ud.fd, src._buf.offset, src.size*src.dtype.itemsize)
112
- elif hasattr(dest.allocator, 'as_buffer'):
113
- # fast(ish) path, uses readinto in diskbuffers
114
- src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
115
- elif hasattr(src.allocator, 'as_buffer'):
116
- dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
117
- else:
118
- # slow path, allocates a CPU buffer
119
- dest.copyin(src.toCPU().data)
120
-
121
- class _BufferCopy(JITRunner):
122
- # TODO: make wait work
123
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
124
- dest, src = rawbufs
125
- assert dest.size == src.size, f"buffer copy size mismatch, {dest.size} != {src.size}"
126
- assert dest.dtype == src.dtype, f"buffer copy dtype mismatch, {dest.dtype} != {src.dtype}"
127
- st = time.perf_counter()
128
- _internal_buffer_copy(dest, src)
129
- et = None
130
- if wait or DEBUG >= 2:
131
- Device[dest.device].synchronize()
132
- et = time.perf_counter() - st
133
- update_stats(colored(f"copy {dest.size:8d}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow"), 0, dest.size*dest.dtype.itemsize, {}, et, 2, jit, device=dest.device) # noqa: E501
134
- BufferCopy = _BufferCopy()
117
+ def copyout(self, mv:memoryview) -> memoryview:
118
+ mv = flat_mv(mv)
119
+ assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
120
+ assert self.is_allocated(), "can't copyout unallocated buffer"
121
+ self.allocator.copyout(mv, self._buf)
122
+ return mv
123
+ def view(self, size:int, dtype:DType, offset:int) -> Buffer:
124
+ assert offset < self.nbytes, "offset must be less than nbytes"
125
+ if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
126
+ return Buffer(self.device, size, dtype, base=self, offset=offset)
135
127
 
136
128
  # TODO: size, dest, src are the same type. can we enforce this?
137
- sz_type = Union[ImageDType, int]
138
129
  class Allocator:
139
- def alloc(self, size:sz_type):
130
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
140
131
  assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
141
- return self._alloc_image(size) if isinstance(size, ImageDType) else self._alloc(size)
142
- def _alloc(self, size:int): raise NotImplementedError("need alloc")
143
- def _alloc_image(self, dtype:ImageDType): raise RuntimeError("need alloc image")
144
- def free(self, opaque, size:sz_type): self._free(opaque) # if you are returning a Python object, you don't need a free
145
- def _free(self, opaque): pass
132
+ return self._alloc(size, options if options is not None else BufferOptions())
133
+ def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
134
+ def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
135
+ self._free(opaque, options if options is not None else BufferOptions())
136
+ def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
146
137
  def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
147
138
  def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
148
139
 
149
140
  class LRUAllocator(Allocator): # pylint: disable=abstract-method
150
- def __init__(self): self.cache: Dict[sz_type, Any] = defaultdict(list)
151
- def alloc(self, size:sz_type):
152
- if len(c := self.cache[size]): return c.pop()
153
- try:
154
- return super().alloc(size)
155
- except MemoryError:
141
+ def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
142
+ def alloc(self, size:int, options:Optional[BufferOptions]=None):
143
+ if len(c := self.cache[(size, options)]): return c.pop()
144
+ try: return super().alloc(size, options)
145
+ except (RuntimeError, MemoryError):
156
146
  self.free_cache()
157
- return super().alloc(size)
147
+ return super().alloc(size, options)
158
148
  def free_cache(self):
159
- for opaques in self.cache.values():
160
- for opaque in opaques: self._free(opaque)
149
+ for (sz,options),opaques in self.cache.items():
150
+ for opaque in opaques: super().free(opaque, sz, options)
161
151
  opaques.clear()
162
- def free(self, opaque:Any, size:sz_type):
163
- if getenv("LRU", 1): self.cache[size].append(opaque)
164
- else: self._free(opaque)
152
+ def free(self, opaque:Any, size:int, options:Optional[BufferOptions]=None):
153
+ if getenv("LRU", 1) and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
154
+ else: super().free(opaque, size, options)
165
155
 
166
156
  class _MallocAllocator(LRUAllocator):
167
- def _alloc(self, size:int): return (ctypes.c_uint8 * size)()
157
+ def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
168
158
  def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
169
159
  def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
170
160
  def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
171
- MallocAllocator = _MallocAllocator()
172
-
173
- # **************** for Interpreted Devices ****************
174
-
175
- class InterpretedASTRunner(JITRunner):
176
- def __init__(self, ast:LazyOp, fxn:Callable):
177
- super().__init__()
178
- self.fxn = fxn
179
- info = get_lazyop_info(ast)
180
- self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
181
-
182
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
183
- st = time.perf_counter()
184
- rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
185
- et = time.perf_counter() - st
186
- update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
187
- return et
188
-
189
- class Interpreted:
190
- def __init__(self, allocator: Allocator, fxn_for_op:Dict[Op, Callable]):
191
- self.allocator, self.fxn_for_op = allocator, fxn_for_op
192
- self.synchronize, self.codegen, self.graph = lambda: None, None, None
161
+ def offset(self, buf, size:int, offset:int): return from_mv(self.as_buffer(buf)[offset:offset+size])
193
162
 
194
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
195
- def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast)
196
-
197
- def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
198
- if DEBUG >= 3:
199
- from tinygrad.graph import print_tree
200
- print_tree(ast)
201
- tglob: Dict[str, Any] = {"Variable": Variable}
202
-
203
- @functools.lru_cache(None)
204
- def gstr(x:Any, nm=None) -> str:
205
- if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
206
- str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
207
- # TODO: (Variable - Variable) might create NumNode. can we remove it?
208
- return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
209
- ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
210
- tglob[ret] = x
211
- return ret
212
-
213
- lines: List[str] = []
214
- @functools.lru_cache(None)
215
- def _interpret_ast(ast:LazyOp) -> str:
216
- # TODO: shortcutted store won't work with strides
217
- if ast.op == BufferOps.STORE: return _interpret_ast(ast.src[0])
218
- if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM:
219
- if ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
220
- if (castop:=ast.src[0]).op == UnaryOps.CAST and (mulop:=castop.src[0]).op == BinaryOps.MUL:
221
- # MULACC with acc cast rewrite: MUL -> CAST -> SUM => CAST -> MULACC
222
- ast = LazyOp(TernaryOps.MULACC, tuple(LazyOp(UnaryOps.CAST, (s, ), castop.arg) for s in mulop.src), ast.arg)
223
-
224
- if ast.op in BufferOps:
225
- if ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
226
- else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
227
- for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
228
- else:
229
- tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
230
-
231
- ret = f"a{len(lines)}"
232
- lines.append(f" {ret} = {tmp}")
233
- return ret
234
-
235
- ret = _interpret_ast(ast)
236
- src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {ret}"])
237
- if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
238
- exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
239
- return InterpretedASTRunner(ast, tglob['run'])
163
+ MallocAllocator = _MallocAllocator()
240
164
 
241
165
  # **************** for Compiled Devices ****************
242
166
 
243
- class CompiledASTRunner(JITRunner):
244
- def __init__(self, ast:Optional[LazyOp], name:str, prg:str, lib:bytes, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None): # noqa: E501
245
- super().__init__()
246
- if DEBUG >= 4: print(prg)
247
- if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
248
- if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
249
- self.name, self.display_name, self.prg, self.lib, self.global_size, self.local_size, self.first_run = \
250
- to_function_name(name), name, prg, lib, global_size, local_size, True
251
- self.vars: List[Variable] = []
252
- if ast:
253
- info = get_lazyop_info(ast)
254
- self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
255
- self.vars = ast.vars()
256
- assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
167
+ class CompileError(Exception): pass
257
168
 
258
- def build(self, runtime):
259
- self.clprg = runtime(self.name, self.lib)
260
- return self
261
-
262
- def launch_dims(self, var_vals):
263
- global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
264
- local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
265
- return global_size, local_size
266
-
267
- def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]:
268
- global_size, local_size = self.launch_dims(var_vals)
269
- if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
270
- # TODO: this is copied from get_program
271
- from tinygrad.features.search import optimize_local_size
272
- local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
273
- global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
274
- lra = {}
275
- if global_size: lra['global_size'] = global_size
276
- if local_size: lra['local_size'] = local_size
277
- et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
278
- if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) # noqa: E501
279
- self.first_run = False
280
- return et
169
+ class Compiler:
170
+ def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
171
+ def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
172
+ def compile_cached(self, src:str) -> bytes:
173
+ if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
174
+ assert not getenv("ASSERT_COMPILE"), "tried to compile with ASSERT_COMPILE set"
175
+ lib = self.compile(src)
176
+ if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
177
+ return lib
281
178
 
282
179
  class Compiled:
283
- def __init__(self, allocator:Allocator, linearizer_opts:LinearizerOptions, renderer, compiler, runtime, graph=None):
284
- self.allocator, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.graph = allocator, linearizer_opts, renderer, compiler, runtime, graph # noqa: E501
180
+ def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
181
+ self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
182
+ self.renderer = renderer if renderer else Renderer()
285
183
  def synchronize(self): pass # override this in your device
286
-
287
- def to_program(self, k:Linearizer) -> CompiledASTRunner:
288
- assert self.compiler is not None, f"compiler is None, can't build {k.ast}"
289
- k.linearize()
290
- src = self.renderer(to_function_name(k.name), k.uops)
291
- if getenv("DISABLE_COMPILER_CACHE") or '<' in self.compiler.__name__:
292
- lib = self.compiler(src)
293
- else:
294
- lib = diskcache_get(self.compiler.__name__, src)
295
- if lib is None:
296
- lib = self.compiler(src)
297
- diskcache_put(self.compiler.__name__, src, lib)
298
- return CompiledASTRunner(k.ast, k.name, src, lib, k.global_size, k.local_size).build(self.runtime)
299
-
300
- def get_linearizer(self, ast:LazyOp) -> Linearizer:
301
- if DEBUG >= 3:
302
- from tinygrad.graph import print_tree
303
- print_tree(ast)
304
- from tinygrad.codegen.linearizer import Linearizer
305
- k = Linearizer(ast, self.linearizer_opts)
306
- k.required_optimizations()
307
- if not NOOPT:
308
- if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
309
- if BEAM >= 1:
310
- lins = [(("tc" if used_tensor_cores else "hc"), k)]
311
- if used_tensor_cores:
312
- lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
313
- lins[-1][1].hand_coded_optimizations()
314
- kb = Linearizer(ast, self.linearizer_opts)
315
- kb.required_optimizations()
316
- from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
317
- # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
318
- test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
319
- lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
320
- timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
321
- if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
322
- k = timed[0][1]
323
- return k
324
-
325
- @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
326
- def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast))
tinygrad/dtype.py CHANGED
@@ -1,39 +1,43 @@
1
- from typing import NamedTuple, Final, Optional, ClassVar, Set, Tuple, Dict
1
+ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
2
+ from dataclasses import dataclass
2
3
  import numpy as np # TODO: remove numpy
3
4
  import functools
5
+ from tinygrad.helpers import getenv
4
6
 
5
- # TODO: migrate this from NamedTuple -> dataclass
6
- class DType(NamedTuple):
7
+ ConstType = Union[float, int, bool]
8
+
9
+ @dataclass(frozen=True, order=True)
10
+ class DType:
7
11
  priority: int # this determines when things get upcasted
8
12
  itemsize: int
9
13
  name: str
10
- np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
11
- sz: int = 1
12
- def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
14
+ fmt: Optional[str]
15
+ count: int
16
+ def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
13
17
  def vec(self, sz:int):
14
- assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
15
- return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{sz}", None, sz)
16
- def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
18
+ assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
19
+ return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
20
+ def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
21
+ # TODO: someday this will be removed with the "remove numpy" project
22
+ @property
23
+ def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
17
24
 
18
25
  # dependent typing?
26
+ @dataclass(frozen=True, repr=False)
19
27
  class ImageDType(DType):
20
- def __new__(cls, priority, itemsize, name, np, shape, base):
21
- return super().__new__(cls, priority, itemsize, name, np)
22
- def __init__(self, priority, itemsize, name, np, shape, base):
23
- self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
24
- self.base: DType = base
25
- super().__init__()
28
+ shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
29
+ base: DType
26
30
  def scalar(self): return self.base
27
31
  def vec(self, sz:int): return self.base.vec(sz)
28
32
  def __repr__(self): return f"dtypes.{self.name}({self.shape})"
29
- # TODO: fix this to not need these
30
- def __hash__(self): return hash((super().__hash__(), self.shape))
31
- def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
32
- def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
33
33
 
34
+ # @dataclass(frozen=True, init=False, repr=False, eq=False)
34
35
  class PtrDType(DType):
35
- def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
36
+ def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
36
37
  def __repr__(self): return f"ptr.{super().__repr__()}"
38
+ def __hash__(self): return super().__hash__()
39
+ def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
40
+ def __ne__(self, dt): return not (self == dt)
37
41
 
38
42
  class dtypes:
39
43
  @staticmethod
@@ -43,25 +47,27 @@ class dtypes:
43
47
  @staticmethod
44
48
  def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
45
49
  @staticmethod
46
- def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
50
+ def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
47
51
  @staticmethod # NOTE: isinstance(True, int) is True in python
48
52
  def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
49
53
  @staticmethod
54
+ def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
55
+ @staticmethod
50
56
  def fields() -> Dict[str, DType]: return DTYPES_DICT
51
- bool: Final[DType] = DType(0, 1, "bool", np.bool_)
52
- int8: Final[DType] = DType(1, 1, "char", np.int8)
53
- uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
54
- int16: Final[DType] = DType(3, 2, "short", np.int16)
55
- uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
56
- int32: Final[DType] = DType(5, 4, "int", np.int32)
57
- uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
58
- int64: Final[DType] = DType(7, 8, "long", np.int64)
59
- uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
60
- float16: Final[DType] = DType(9, 2, "half", np.float16)
57
+ bool: Final[DType] = DType(0, 1, "bool", '?', 1)
58
+ int8: Final[DType] = DType(1, 1, "char", 'b', 1)
59
+ uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
60
+ int16: Final[DType] = DType(3, 2, "short", 'h', 1)
61
+ uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
62
+ int32: Final[DType] = DType(5, 4, "int", 'i', 1)
63
+ uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
64
+ int64: Final[DType] = DType(7, 8, "long", 'l', 1)
65
+ uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
66
+ float16: Final[DType] = DType(9, 2, "half", 'e', 1)
61
67
  # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
62
- bfloat16: Final[DType] = DType(10, 2, "__bf16", None)
63
- float32: Final[DType] = DType(11, 4, "float", np.float32)
64
- float64: Final[DType] = DType(12, 8, "double", np.float64)
68
+ bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
69
+ float32: Final[DType] = DType(11, 4, "float", 'f', 1)
70
+ float64: Final[DType] = DType(12, 8, "double", 'd', 1)
65
71
 
66
72
  # dtype aliases
67
73
  half = float16; float = float32; double = float64 # noqa: E702
@@ -70,13 +76,17 @@ class dtypes:
70
76
 
71
77
  # NOTE: these are image dtypes
72
78
  @staticmethod
73
- def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp, dtypes.float32)
79
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
74
80
  @staticmethod
75
- def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
81
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
76
82
 
77
83
  default_float: ClassVar[DType] = float32
78
84
  default_int: ClassVar[DType] = int32
79
85
 
86
+ if (env_default_float := getenv("DEFAULT_FLOAT", "")):
87
+ dtypes.default_float = getattr(dtypes, env_default_float.lower())
88
+ assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
89
+
80
90
  # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
81
91
  # we don't support weak type and complex type
82
92
  promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
@@ -94,4 +104,10 @@ def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else
94
104
 
95
105
  # HACK: staticmethods are not callable in 3.8 so we have to compare the class
96
106
  DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
97
- INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
107
+ INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
108
+
109
+ def sum_acc_dtype(dt:DType):
110
+ # default acc dtype for sum
111
+ if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
112
+ if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
113
+ return least_upper_dtype(dt, dtypes.float)
File without changes