tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -10,7 +10,7 @@ from tinygrad.renderer import Renderer
10
10
 
11
11
  # **************** Device ****************
12
12
 
13
- ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"]
13
+ ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"]
14
14
  class _Device:
15
15
  def __init__(self) -> None:
16
16
  self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
@@ -44,6 +44,7 @@ class _Device:
44
44
  return device
45
45
  except StopIteration as exc: raise RuntimeError("no usable devices") from exc
46
46
  Device = _Device()
47
+ atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices])
47
48
 
48
49
  # **************** Profile ****************
49
50
 
@@ -207,7 +208,10 @@ class LRUAllocator(Allocator):
207
208
 
208
209
  class _MallocAllocator(LRUAllocator):
209
210
  def _alloc(self, size:int, options:BufferSpec):
210
- return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 16)
211
+ # must be aligned to 0x20 for 256-bit ymm registers
212
+ # TODO: investigate if this is the cause of nondeterminism in speed
213
+ alignment = 0x1000 if size >= 0x1000 else 0x20
214
+ return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
211
215
  def _alloc_aligned(self, size:int, alignment:int):
212
216
  buffer = (ctypes.c_uint8 * (size + alignment))()
213
217
  offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
@@ -224,31 +228,36 @@ MAP_JIT = 0x0800
224
228
 
225
229
  # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
226
230
  class CPUProgram:
227
- helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
231
+ rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
232
+ atomic_lib = ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
233
+
228
234
  def __init__(self, name:str, lib:bytes):
229
235
  if sys.platform == "win32":
230
236
  PAGE_EXECUTE_READWRITE = 0x40
231
237
  MEM_COMMIT = 0x1000
232
238
  MEM_RESERVE = 0x2000
233
- ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_uint64
234
- ptr = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_int(0), ctypes.c_int(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
235
- ctypes.memmove(ptr, lib, len(lib))
236
- self.fxn = ctypes.CFUNCTYPE(None)(ptr)
239
+ ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
240
+ self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
241
+ ctypes.memmove(self.mem, lib, len(lib))
242
+ ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
243
+ proc = ctypes.windll.kernel32.GetCurrentProcess()
244
+ ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
245
+ self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
237
246
  else:
238
247
  from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
239
248
  # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
240
249
  # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
241
250
  self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
242
251
 
243
- if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(False)
252
+ if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
244
253
  self.mem.write(lib)
245
- if OSX: CPUProgram.helper_handle.pthread_jit_write_protect_np(True)
254
+ if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
246
255
 
247
256
  # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
248
257
  # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
249
258
  # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
250
259
  # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
251
- CPUProgram.helper_handle["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
260
+ CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
252
261
 
253
262
  self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
254
263
 
@@ -262,6 +271,9 @@ class CPUProgram:
262
271
  if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
263
272
  return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
264
273
 
274
+ def __del__(self):
275
+ if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
276
+
265
277
  # **************** for Compiled Devices ****************
266
278
 
267
279
  class CompileError(Exception): pass
@@ -295,6 +307,11 @@ class Compiled:
295
307
  Called at the end of profiling to allow the device to finalize any profiling.
296
308
  """
297
309
  # override this in your device implementation
310
+ def finalize(self):
311
+ """
312
+ Called at the end of process lifetime to allow the device to finalize.
313
+ """
314
+ # override this in your device implementation
298
315
 
299
316
  # TODO: move this to each Device
300
317
  def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
@@ -303,7 +320,7 @@ def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
303
320
  # NOTE: this requires bf16 buffer support
304
321
  return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
305
322
  if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
306
- dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32]
323
+ dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
307
324
  # for CI GPU and OSX, cl_khr_fp16 isn't supported
308
325
  # for CI LLVM, it segfaults because it can't link to the casting function
309
326
  # CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
tinygrad/dtype.py CHANGED
@@ -54,7 +54,7 @@ class PtrDType(DType):
54
54
  def vec(self, sz:int) -> DType:
55
55
  assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
56
56
  if sz == 1: return self # sz=1 is a scalar
57
- return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
57
+ return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
58
58
  def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
59
59
  @property
60
60
  def vcount(self): return self.v
@@ -80,6 +80,8 @@ class dtypes:
80
80
  @functools.lru_cache(None)
81
81
  def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
82
82
  @staticmethod
83
+ def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
84
+ @staticmethod
83
85
  def from_py(x) -> DType:
84
86
  if x.__class__ is float: return dtypes.default_float
85
87
  if x.__class__ is int: return dtypes.default_int
@@ -181,9 +183,16 @@ def truncate_fp16(x):
181
183
  try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
182
184
  except OverflowError: return math.copysign(math.inf, x)
183
185
 
186
+ def truncate_bf16(x):
187
+ max_bf16 = struct.unpack('f', struct.pack('I', 0x7f7f0000))[0]
188
+ if x > max_bf16 or x < -max_bf16: return math.copysign(math.inf, x)
189
+ f32_int = struct.unpack('I', struct.pack('f', x))[0]
190
+ bf = struct.unpack('f', struct.pack('I', f32_int & 0xFFFF0000))[0]
191
+ return bf
192
+
184
193
  truncate: dict[DType, Callable] = {dtypes.bool: bool,
185
- # TODO: bfloat16
186
- dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
194
+ dtypes.float16: truncate_fp16, dtypes.bfloat16: truncate_bf16,
195
+ dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
187
196
  dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
188
197
  dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
189
198
  dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value,
tinygrad/engine/jit.py CHANGED
@@ -24,7 +24,8 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
24
24
  def flush_batch():
25
25
  nonlocal current_batch, current_device, max_batch_size
26
26
  try:
27
- if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
27
+ if current_device is None: raise GraphException("no device for graph")
28
+ if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
28
29
  graph_runner = current_device.graph(current_batch, input_rawbuffers, var_vals)
29
30
  # clear jit inputs to allow their memory to be freed/reused
30
31
  for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
@@ -193,7 +194,7 @@ class CapturedJit(Generic[ReturnType]):
193
194
  def _prepare_jit_inputs(args, kwargs):
194
195
  input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
195
196
  names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
196
- if tensors: Tensor.realize(*tensors)
197
+ if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
197
198
  # TODO: should we be unpacking multi here?
198
199
  lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
199
200
  input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
tinygrad/engine/multi.py CHANGED
@@ -1,4 +1,3 @@
1
- from __future__ import annotations
2
1
  import functools, itertools, operator
3
2
  from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
4
3
  from tinygrad.ops import Ops, UOp, sint
@@ -2,6 +2,7 @@ from typing import Optional, cast, Generator
2
2
  import time, pprint
3
3
  from dataclasses import dataclass, replace
4
4
  from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
5
+ from tinygrad.helpers import DEVECTORIZE, time_to_str
5
6
  from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
6
7
  from tinygrad.device import Device, Buffer
7
8
  from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -99,11 +100,13 @@ class BufferXfer(BufferCopy):
99
100
 
100
101
  # **************** method cache ****************
101
102
 
102
- method_cache: dict[tuple[str, bytes, int, int, bool], CompiledRunner] = {}
103
+ method_cache: dict[tuple[str, bytes, tuple[int, ...], bool], CompiledRunner] = {}
103
104
  def get_runner(device:str, ast:UOp) -> CompiledRunner:
104
- ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
105
+ # TODO: this should be all context relevant to rendering
106
+ context = (BEAM.value, NOOPT.value, DEVECTORIZE.value)
107
+ ckey = (device, ast.key, context, False)
105
108
  if cret:=method_cache.get(ckey): return cret
106
- bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
109
+ bkey = (device.split(":")[0], ast.key, context, True)
107
110
  if bret:=method_cache.get(bkey):
108
111
  method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
109
112
  else:
@@ -130,7 +133,7 @@ class ExecItem:
130
133
  if DEBUG >= 2:
131
134
  lds_est = sym_infer(self.prg.estimates.lds, var_vals)
132
135
  mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
133
- ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
136
+ ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
134
137
  print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
135
138
  (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
136
139
  f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))