tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,41 +1,44 @@
1
1
  from __future__ import annotations
2
- import multiprocessing
3
- from dataclasses import dataclass
2
+ from dataclasses import dataclass, replace
4
3
  from collections import defaultdict
5
- from typing import List, Optional, Dict, Tuple, Any, cast
6
- import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib
7
- from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
8
- from tinygrad.dtype import DType, ImageDType
4
+ from typing import Optional, Dict, Tuple, Any, Iterator
5
+ import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
6
+ from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
7
+ from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
9
8
  from tinygrad.renderer import Renderer
10
9
 
11
10
  # **************** Device ****************
12
11
 
13
12
  class _Device:
14
- def __init__(self) -> None: self._devices: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] # noqa: E501
13
+ def __init__(self) -> None:
14
+ self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
15
15
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
16
- def _canonicalize(self, device:str) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # noqa: E501
16
+ def _canonicalize(self, device:str) -> str: return ((d:=device.split(":", 1)[0].upper()) + device[len(d):]).replace(":0", "")
17
17
  # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
18
18
  def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
19
19
  def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
20
20
  @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
21
21
  def __get_canonicalized_item(self, ix:str) -> Compiled:
22
- assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
23
- f"can only open device {ix} from parent, not {cpn}"
22
+ cpn = multiprocessing.current_process().name
23
+ assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
24
24
  x = ix.split(":")[0].upper()
25
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
25
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
26
+ if (cname.lower() == x.lower() + "device")][0](ix)
26
27
  if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
27
28
  return ret
29
+ @property
30
+ def default(self) -> Compiled: return self[self.DEFAULT]
31
+ def get_available_devices(self) -> Iterator[str]:
32
+ for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
33
+ with contextlib.suppress(Exception): yield self[device].dname
28
34
  @functools.cached_property
29
35
  def DEFAULT(self) -> str:
30
- device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
31
- if device_from_env: return device_from_env
32
- for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
33
- try:
34
- if self[device]:
35
- os.environ[device] = "1" # we set this in environment for spawned children
36
- return device
37
- except Exception: pass
38
- raise RuntimeError("no usable devices")
36
+ if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
37
+ try:
38
+ device = next(self.get_available_devices())
39
+ os.environ[device] = "1" # we set this in environment for spawned children
40
+ return device
41
+ except StopIteration as exc: raise RuntimeError("no usable devices") from exc
39
42
  Device = _Device()
40
43
 
41
44
  # **************** Buffer + Allocators ****************
@@ -47,12 +50,13 @@ class BufferOptions:
47
50
  cpu_access: bool = False
48
51
  host: bool = False
49
52
  nolru: bool = False
53
+ external_ptr: Optional[int] = None
50
54
 
51
55
  class Buffer:
52
56
  def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None,
53
57
  initial_value:Optional[bytes]=None, lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
54
- assert isinstance(dtype, DType)
55
58
  if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
59
+ else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
56
60
  self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
57
61
  if base is None:
58
62
  assert offset == 0, "base buffers can't have offset"
@@ -73,10 +77,12 @@ class Buffer:
73
77
  def lb_refcount(self): return self.base._lb_refcount
74
78
  def ref(self, cnt): self.base._lb_refcount += cnt
75
79
  def is_allocated(self) -> bool: return hasattr(self, '_buf')
76
- def ensure_allocated(self) -> Buffer: return self.allocate() if not hasattr(self, '_buf') else self
77
- def allocate(self, opaque=None) -> Buffer:
78
- assert not hasattr(self, '_buf'), "can't allocate already allocated buffer"
80
+ def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
81
+ def allocate(self, opaque=None, external_ptr=None) -> Buffer:
82
+ assert not self.is_allocated(), "can't allocate already allocated buffer"
79
83
  self.allocator = Device[self.device].allocator
84
+ if external_ptr is not None:
85
+ self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferOptions(external_ptr=external_ptr)
80
86
  if self._base is not None:
81
87
  self._base.ensure_allocated()
82
88
  assert hasattr(self.allocator, "offset"), "offset function required for view"
@@ -88,7 +94,7 @@ class Buffer:
88
94
  def __reduce__(self):
89
95
  buf = None
90
96
  if self._base is not None:
91
- return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, hasattr(self, '_buf'))
97
+ return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, self.is_allocated())
92
98
  if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
93
99
  if self.is_allocated():
94
100
  buf = bytearray(self.nbytes)
@@ -97,17 +103,17 @@ class Buffer:
97
103
  @property
98
104
  def nbytes(self): return self.size*self.dtype.itemsize
99
105
  def __del__(self):
100
- if not hasattr(self, '_buf'): return
101
- if self._base is None:
106
+ if not self.is_allocated(): return
107
+ if self._base is None and (self.options is None or self.options.external_ptr is None):
102
108
  if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
103
109
  self.allocator.free(self._buf, self.nbytes, self.options)
104
110
  def __repr__(self):
105
- return f"<buf real:{hasattr(self, '_buf')} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
106
- (f" offset:{self.offset}" if hasattr(self, "base") else "") + \
107
- (">" if self.options is None else f" {self.options=}>")
111
+ return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
112
+ (f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
108
113
  def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
109
114
  # zero copy with as_buffer (disabled by default due to use after free)
110
- if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
115
+ if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
116
+ return self.allocator.as_buffer(self._buf)
111
117
  assert not force_zero_copy, "force zero copy was passed, but copy is required"
112
118
  return self.copyout(memoryview(bytearray(self.nbytes)))
113
119
  def copyin(self, mv:memoryview):
@@ -133,13 +139,16 @@ class Allocator:
133
139
  assert not isinstance(size, int) or size > 0, f"alloc size must be positve, getting {size}"
134
140
  return self._alloc(size, options if options is not None else BufferOptions())
135
141
  def _alloc(self, size:int, options:BufferOptions): raise NotImplementedError("need alloc")
136
- def free(self, opaque, size:int, options:Optional[BufferOptions]=None):
137
- self._free(opaque, options if options is not None else BufferOptions())
142
+ def free(self, opaque, size:int, options:Optional[BufferOptions]=None): self._free(opaque, options if options is not None else BufferOptions())
138
143
  def _free(self, opaque, options:BufferOptions): pass # if opaque is a Python object, you don't need a free
139
144
  def copyin(self, dest, src:memoryview): raise NotImplementedError("need copyin")
140
145
  def copyout(self, dest:memoryview, src): raise NotImplementedError("need copyout")
141
146
 
142
147
  class LRUAllocator(Allocator): # pylint: disable=abstract-method
148
+ """
149
+ The LRU Allocator is responsible for caching buffers.
150
+ It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
151
+ """
143
152
  def __init__(self): self.cache: Dict[Tuple[int, Optional[BufferOptions]], Any] = defaultdict(list)
144
153
  def alloc(self, size:int, options:Optional[BufferOptions]=None):
145
154
  if len(c := self.cache[(size, options)]): return c.pop()
@@ -156,7 +165,8 @@ class LRUAllocator(Allocator): # pylint: disable=abstract-method
156
165
  else: super().free(opaque, size, options)
157
166
 
158
167
  class _MallocAllocator(LRUAllocator):
159
- def _alloc(self, size:int, options:BufferOptions): return (ctypes.c_uint8 * size)()
168
+ def _alloc(self, size:int, options:BufferOptions):
169
+ return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else (ctypes.c_uint8 * size)()
160
170
  def as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
161
171
  def copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
162
172
  def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
@@ -170,151 +180,42 @@ class CompileError(Exception): pass
170
180
 
171
181
  class Compiler:
172
182
  def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
173
- def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
183
+ def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
174
184
  def compile_cached(self, src:str) -> bytes:
175
185
  if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
176
186
  assert not getenv("ASSERT_COMPILE"), f"tried to compile with ASSERT_COMPILE set\n{src}"
177
187
  lib = self.compile(src)
178
188
  if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
179
189
  return lib
190
+ def disassemble(self, lib:bytes): pass
180
191
 
181
192
  class Compiled:
182
193
  def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
183
194
  self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
184
195
  self.renderer = renderer or Renderer()
185
- def synchronize(self): pass # override this in your device
186
-
187
- # **************** for HCQ Compatible Devices ****************
188
-
189
- @contextlib.contextmanager
190
- def hcq_profile(dev, queue_type, enabled, desc):
191
- st, en = (dev._get_signal(), dev._get_signal()) if enabled else (None, None)
192
- if enabled: queue_type().timestamp(st).submit(dev)
193
- try: yield (st, en)
194
- finally:
195
- if enabled: queue_type().timestamp(en).submit(dev)
196
- if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
197
-
198
- class HCQCompatCompiled(Compiled):
199
- def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, comp_queue_t, copy_queue_t, timeline_signals):
200
- self.hw_compute_queue_t, self.hw_copy_queue_t = comp_queue_t, copy_queue_t
201
- self.timeline_value: int = 1
202
- self.timeline_signal, self._shadow_timeline_signal = timeline_signals
203
- self.sig_prof_records: List[Tuple[Any, Any, str, bool]] = []
204
- self.raw_prof_records: List[Tuple[int, int, str, bool]] = []
205
- if PROFILE: self._prof_setup()
206
-
207
- from tinygrad.runtime.graph.hcq import HCQGraph
208
- super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
209
-
210
- @classmethod
211
- def _read_signal(self, sig): raise NotImplementedError("need _read_signal") # reads a value for a signal
212
-
213
- @classmethod
214
- def _read_timestamp(self, sig): raise NotImplementedError("need _read_timestamp") # reads a timestamp for a signal
215
-
216
- @classmethod
217
- def _set_signal(self, sig, value): raise NotImplementedError("need _set_signal") # sets a value for a signal
218
-
219
- @classmethod
220
- def _get_signal(self, value=0, **kwargs): raise NotImplementedError("need _get_signal") # allocates a new signal
221
-
222
- @classmethod
223
- def _wait_signal(self, signal, value=0, timeout=10000): raise NotImplementedError("need _wait_signal") # waits for a signal value
224
-
225
- def _gpu2cpu_time(self, gpu_time, is_copy): raise NotImplementedError("need _gpu2cpu_time")
226
-
227
- def _prof_setup(self):
228
- self.profile_logger = ProfileLogger()
229
-
230
- def _sync_queue(q_t):
231
- q_t().timestamp(self.timeline_signal).signal(self.timeline_signal, self.timeline_value).submit(self)
232
- self.timeline_value += 1
233
- cpu_start_time = time.perf_counter_ns() / 1e3
234
- self._wait_signal(self.timeline_signal, self.timeline_value - 1)
235
- return cpu_start_time, self._read_timestamp(self.timeline_signal)
236
- self.cpu_start_time, self.gpu_start_time = _sync_queue(self.hw_compute_queue_t)
237
- self.copy_cpu_start_time, self.copy_gpu_start_time = _sync_queue(self.hw_copy_queue_t)
238
-
239
- atexit.register(self._prof_finalize)
240
-
241
- def _prof_process_events(self):
242
- self.raw_prof_records += [(self._read_timestamp(st), self._read_timestamp(en), name, is_cp) for st, en, name, is_cp in self.sig_prof_records]
243
- for st, en, _, _ in self.sig_prof_records: self.signals_pool += [st, en] # type: ignore
244
- self.sig_prof_records = []
245
-
246
- def _prof_finalize(self):
247
- for st, en, name, is_cp in self.raw_prof_records:
248
- self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, ["COMPUTE", "DMA"][is_cp])]
249
- del self.profile_logger
250
-
251
- def _wrap_timeline_signal(self):
252
- self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
253
- self._set_signal(self.timeline_signal, 0)
254
- cast(HCQCompatAllocator, self.allocator).b_timeline = [0] * len(cast(HCQCompatAllocator, self.allocator).b)
255
-
256
- class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
257
- def __init__(self, device, batch_size=(2 << 20), batch_cnt=32):
258
- self.device = device
259
- self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
260
- self.b_timeline, self.b_next = [0] * len(self.b), 0
261
- super().__init__()
262
-
263
- def copyin(self, dest, src: memoryview):
264
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
265
- for i in range(0, src.nbytes, self.b[0].size):
266
- self.b_next = (self.b_next + 1) % len(self.b)
267
- self.device._wait_signal(self.device.timeline_signal, self.b_timeline[self.b_next])
268
- ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
269
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
270
- .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
271
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
272
- self.b_timeline[self.b_next] = self.device.timeline_value
273
- self.device.timeline_value += 1
274
-
275
- def copy_from_disk(self, dest, src, size):
276
- def _get_temp_buf():
277
- # Check if the next buffer is safe to be used (its signal has passed) and reserve it.
278
- if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device._read_signal(self.device.timeline_signal):
279
- self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
280
- return (self.b[self.b_next].va_addr, self.b_next)
281
- return None
282
-
283
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
284
- for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
285
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
286
- .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
287
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
288
- self.b_timeline[batch_info[1]] = self.device.timeline_value
289
- self.device.timeline_value += 1
290
-
291
- def copyout(self, dest:memoryview, src):
292
- self.device.synchronize()
293
-
294
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
295
- for i in range(0, dest.nbytes, self.b[0].size):
296
- self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
297
- .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
298
- .signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
299
- self.device._wait_signal(self.device.timeline_signal, self.device.timeline_value)
300
- self.device.timeline_value += 1
301
-
302
- ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
303
-
304
- def transfer(self, dest, src, sz: int, src_dev, dest_dev):
305
- src_dev._gpu_map(dest)
306
-
307
- with hcq_profile(self.device, self.device.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
308
- src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
309
- .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
310
- .copy(dest.va_addr, src.va_addr, sz) \
311
- .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
312
- src_dev.timeline_value += 1
313
-
314
- if src_dev != dest_dev:
315
- dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
316
- .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
317
- .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
318
- dest_dev.timeline_value += 1
319
-
320
- def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
196
+ def synchronize(self):
197
+ """
198
+ Synchronize all pending operations on the device.
199
+
200
+ This method ensures that all previously queued operations on the device have been completed before proceeding.
201
+ """
202
+ # override this in your device implementation
203
+
204
+ # TODO: move this to each Device
205
+ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
206
+ if device is None: device = Device.DEFAULT
207
+ if dtype == dtypes.bfloat16:
208
+ # NOTE: this requires bf16 buffer support
209
+ return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
210
+ if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
211
+ # for CI GPU and OSX, cl_khr_fp16 isn't supported
212
+ # for CI LLVM, it segfaults because it can't link to the casting function
213
+ # CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
214
+ # PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
215
+ if dtype == dtypes.half:
216
+ if device == "GPU": return not CI and not OSX
217
+ if device in ["CUDA", "NV"]: return not CI
218
+ if device == "LLVM": return OSX
219
+ if device == "PYTHON": return sys.version_info >= (3, 12)
220
+ if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
221
+ return True
tinygrad/dtype.py CHANGED
@@ -1,47 +1,79 @@
1
- from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
2
- from dataclasses import dataclass
3
- import functools
1
+ from __future__ import annotations
2
+ from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
3
+ import math, struct, ctypes, functools
4
+ from dataclasses import dataclass, fields
4
5
  from tinygrad.helpers import getenv
5
6
 
6
7
  ConstType = Union[float, int, bool]
7
8
 
8
- @dataclass(frozen=True, order=True)
9
- class DType:
9
+ # all DTypes should only be created once
10
+ class DTypeMetaClass(type):
11
+ dcache: Dict[Tuple, DType] = {}
12
+ def __call__(cls, *args, **kwargs):
13
+ if (ret:=DTypeMetaClass.dcache.get(args, None)) is not None: return ret
14
+ DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
15
+ return ret
16
+
17
+ @dataclass(frozen=True, eq=False)
18
+ class DType(metaclass=DTypeMetaClass):
10
19
  priority: int # this determines when things get upcasted
11
20
  itemsize: int
12
21
  name: str
13
22
  fmt: Optional[str]
14
23
  count: int
15
- def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
16
- def vec(self, sz:int):
17
- assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
18
- return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
19
- def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
20
-
21
- # dependent typing?
22
- @dataclass(frozen=True, repr=False)
23
- class ImageDType(DType):
24
- shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
25
- base: DType
26
- def scalar(self): return self.base
27
- def vec(self, sz:int): return self.base.vec(sz)
28
- def __repr__(self): return f"dtypes.{self.name}({self.shape})"
29
-
30
- # @dataclass(frozen=True, init=False, repr=False, eq=False)
24
+ _scalar: Optional[DType]
25
+ @staticmethod
26
+ def new(priority:int, itemsize:int, name:str, fmt:Optional[str]): return DType(priority, itemsize, name, fmt, 1, None)
27
+ def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
28
+ def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count > 1 else "")
29
+ def __lt__(self, o:DType): return (self.priority, self.itemsize, self.name, self.fmt, self.count) < (o.priority, o.itemsize, o.name, o.fmt, o.count)
30
+ @property
31
+ def base(self): return self
32
+ @property
33
+ def vcount(self): return self.count
34
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
35
+ def vec(self, sz:int) -> DType:
36
+ assert self.count == 1, f"can't vectorize {self} with size {sz}"
37
+ if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
38
+ return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
39
+ def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
40
+ def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
41
+
42
+ @dataclass(frozen=True, eq=False)
31
43
  class PtrDType(DType):
32
- def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
33
- def __repr__(self): return f"ptr.{super().__repr__()}"
34
- def __hash__(self): return super().__hash__()
35
- def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
36
- def __ne__(self, dt): return not (self == dt)
44
+ _base: DType
45
+ local: bool
46
+ v: int
47
+ @property
48
+ def base(self): return self._base
49
+ @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
50
+ def vec(self, sz:int) -> DType:
51
+ assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
52
+ if sz == 1: return self # sz=1 is a scalar
53
+ return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
54
+ def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
55
+ @property
56
+ def vcount(self): return self.v
57
+ def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
58
+
59
+ @dataclass(frozen=True, eq=False)
60
+ class ImageDType(PtrDType):
61
+ shape: Tuple[int, ...] = () # shape of the Image
62
+ def ptr(self, local=False) -> PtrDType:
63
+ assert not local, "images can't be local"
64
+ return self
65
+ def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
37
66
 
38
67
  class dtypes:
39
68
  @staticmethod
40
- def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
69
+ @functools.lru_cache(None)
70
+ def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
41
71
  @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
42
- def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
72
+ @functools.lru_cache(None)
73
+ def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
43
74
  @staticmethod
44
- def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
75
+ @functools.lru_cache(None)
76
+ def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
45
77
  @staticmethod
46
78
  def from_py(x) -> DType:
47
79
  if x.__class__ is float: return dtypes.default_float
@@ -51,23 +83,44 @@ class dtypes:
51
83
  if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
52
84
  raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
53
85
  @staticmethod
54
- def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
86
+ def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
87
+ if isinstance(val, tuple):
88
+ assert len(val) == dtype.count, f"mismatch {val} {dtype}"
89
+ return tuple(dtypes.as_const(x, dtype) for x in val)
90
+ # TODO: should truncate here
91
+ return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
92
+ @staticmethod
93
+ @functools.lru_cache(None)
94
+ def min(dtype:DType):
95
+ if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
96
+ return -float("inf") if dtypes.is_float(dtype) else False
97
+ @staticmethod
98
+ @functools.lru_cache(None)
99
+ def max(dtype:DType):
100
+ if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
101
+ return float("inf") if dtypes.is_float(dtype) else True
102
+ @staticmethod
103
+ def finfo(dtype:DType) -> Tuple[int, int]:
104
+ """(exponent, mantissa)"""
105
+ if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type")
106
+ return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
55
107
  @staticmethod
56
108
  def fields() -> Dict[str, DType]: return DTYPES_DICT
57
- bool: Final[DType] = DType(0, 1, "bool", '?', 1)
58
- int8: Final[DType] = DType(1, 1, "char", 'b', 1)
59
- uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
60
- int16: Final[DType] = DType(3, 2, "short", 'h', 1)
61
- uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
62
- int32: Final[DType] = DType(5, 4, "int", 'i', 1)
63
- uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
64
- int64: Final[DType] = DType(7, 8, "long", 'l', 1)
65
- uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
66
- float16: Final[DType] = DType(9, 2, "half", 'e', 1)
109
+ void: Final[DType] = DType.new(-1, 0, "void", None)
110
+ bool: Final[DType] = DType.new(0, 1, "bool", '?')
111
+ int8: Final[DType] = DType.new(1, 1, "char", 'b')
112
+ uint8: Final[DType] = DType.new(2, 1, "unsigned char", 'B')
113
+ int16: Final[DType] = DType.new(3, 2, "short", 'h')
114
+ uint16: Final[DType] = DType.new(4, 2, "unsigned short", 'H')
115
+ int32: Final[DType] = DType.new(5, 4, "int", 'i')
116
+ uint32: Final[DType] = DType.new(6, 4, "unsigned int", 'I')
117
+ int64: Final[DType] = DType.new(7, 8, "long", 'q')
118
+ uint64: Final[DType] = DType.new(8, 8, "unsigned long", 'Q')
119
+ float16: Final[DType] = DType.new(9, 2, "half", 'e')
67
120
  # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
68
- bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
69
- float32: Final[DType] = DType(11, 4, "float", 'f', 1)
70
- float64: Final[DType] = DType(12, 8, "double", 'd', 1)
121
+ bfloat16: Final[DType] = DType.new(10, 2, "__bf16", None)
122
+ float32: Final[DType] = DType.new(11, 4, "float", 'f')
123
+ float64: Final[DType] = DType.new(12, 8, "double", 'd')
71
124
 
72
125
  # dtype aliases
73
126
  half = float16; float = float32; double = float64 # noqa: E702
@@ -76,17 +129,25 @@ class dtypes:
76
129
 
77
130
  # NOTE: these are image dtypes
78
131
  @staticmethod
79
- def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
132
+ def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
80
133
  @staticmethod
81
- def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)
134
+ def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
82
135
 
83
136
  default_float: ClassVar[DType] = float32
84
137
  default_int: ClassVar[DType] = int32
85
138
 
139
+ floats = (float16, bfloat16, float32, float64)
140
+ uints = (uint8, uint16, uint32, uint64)
141
+ sints = (int8, int16, int32, int64)
142
+ ints = uints + sints
143
+
86
144
  if (env_default_float := getenv("DEFAULT_FLOAT", "")):
87
145
  dtypes.default_float = getattr(dtypes, env_default_float.lower())
88
146
  assert dtypes.is_float(dtypes.default_float), f"{env_default_float} is not a float dtype"
89
147
 
148
+ DTypeLike = Union[str, DType]
149
+ def to_dtype(dtype:DTypeLike) -> DType: return dtype if isinstance(dtype, DType) else getattr(dtypes, dtype)
150
+
90
151
  # https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
91
152
  # we don't support weak type and complex type
92
153
  promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
@@ -103,11 +164,25 @@ def least_upper_dtype(*ds:DType) -> DType:
103
164
  def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
104
165
 
105
166
  # HACK: staticmethods are not callable in 3.8 so we have to compare the class
106
- DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
167
+ DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default', 'void'))
168
+ or v.__class__ is staticmethod or isinstance(v, tuple))}
107
169
  INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
170
+ INVERSE_DTYPES_DICT['void'] = 'void'
108
171
 
109
172
  def sum_acc_dtype(dt:DType):
110
173
  # default acc dtype for sum
111
174
  if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
112
175
  if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
113
- return least_upper_dtype(dt, dtypes.float)
176
+ return least_upper_dtype(dt, dtypes.float)
177
+
178
+ def truncate_fp16(x):
179
+ try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
180
+ except OverflowError: return math.copysign(math.inf, x)
181
+
182
+ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
183
+ # TODO: bfloat16
184
+ dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
185
+ dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
186
+ dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
187
+ dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
188
+ dtypes.int64: lambda x: ctypes.c_int64(x).value}