tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +141 -201
  4. tinygrad/codegen/linearize.py +223 -84
  5. tinygrad/codegen/lowerer.py +60 -42
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +22 -13
  8. tinygrad/device.py +187 -47
  9. tinygrad/dtype.py +39 -28
  10. tinygrad/engine/jit.py +83 -65
  11. tinygrad/engine/memory.py +4 -5
  12. tinygrad/engine/multi.py +161 -0
  13. tinygrad/engine/realize.py +62 -108
  14. tinygrad/engine/schedule.py +396 -357
  15. tinygrad/engine/search.py +55 -66
  16. tinygrad/gradient.py +73 -0
  17. tinygrad/helpers.py +81 -59
  18. tinygrad/nn/__init__.py +30 -32
  19. tinygrad/nn/datasets.py +1 -2
  20. tinygrad/nn/optim.py +22 -26
  21. tinygrad/nn/state.py +91 -66
  22. tinygrad/ops.py +492 -641
  23. tinygrad/renderer/__init__.py +95 -36
  24. tinygrad/renderer/cstyle.py +99 -92
  25. tinygrad/renderer/llvmir.py +83 -34
  26. tinygrad/renderer/ptx.py +83 -99
  27. tinygrad/renderer/wgsl.py +95 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  29. tinygrad/runtime/autogen/comgr.py +2 -0
  30. tinygrad/runtime/autogen/kfd.py +4 -3
  31. tinygrad/runtime/autogen/kgsl.py +1 -1
  32. tinygrad/runtime/autogen/libc.py +404 -71
  33. tinygrad/runtime/autogen/llvm.py +11379 -0
  34. tinygrad/runtime/autogen/pci.py +1333 -0
  35. tinygrad/runtime/autogen/vfio.py +891 -0
  36. tinygrad/runtime/autogen/webgpu.py +6985 -0
  37. tinygrad/runtime/graph/cuda.py +8 -9
  38. tinygrad/runtime/graph/hcq.py +84 -79
  39. tinygrad/runtime/graph/metal.py +40 -43
  40. tinygrad/runtime/ops_amd.py +498 -334
  41. tinygrad/runtime/ops_cloud.py +34 -34
  42. tinygrad/runtime/ops_cpu.py +24 -0
  43. tinygrad/runtime/ops_cuda.py +30 -27
  44. tinygrad/runtime/ops_disk.py +62 -63
  45. tinygrad/runtime/ops_dsp.py +159 -42
  46. tinygrad/runtime/ops_gpu.py +30 -30
  47. tinygrad/runtime/ops_hip.py +29 -31
  48. tinygrad/runtime/ops_llvm.py +48 -41
  49. tinygrad/runtime/ops_metal.py +149 -113
  50. tinygrad/runtime/ops_npy.py +2 -2
  51. tinygrad/runtime/ops_nv.py +238 -273
  52. tinygrad/runtime/ops_python.py +55 -50
  53. tinygrad/runtime/ops_qcom.py +129 -157
  54. tinygrad/runtime/ops_webgpu.py +225 -0
  55. tinygrad/runtime/support/allocator.py +94 -0
  56. tinygrad/runtime/support/am/__init__.py +0 -0
  57. tinygrad/runtime/support/am/amdev.py +396 -0
  58. tinygrad/runtime/support/am/ip.py +463 -0
  59. tinygrad/runtime/support/compiler_cuda.py +4 -2
  60. tinygrad/runtime/support/elf.py +28 -4
  61. tinygrad/runtime/support/hcq.py +256 -324
  62. tinygrad/runtime/support/llvm.py +26 -0
  63. tinygrad/shape/shapetracker.py +85 -53
  64. tinygrad/shape/view.py +104 -140
  65. tinygrad/spec.py +155 -0
  66. tinygrad/tensor.py +835 -527
  67. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  68. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  69. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  70. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  71. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  72. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  73. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  74. tinygrad/viz/index.html +544 -0
  75. tinygrad/viz/perfetto.html +178 -0
  76. tinygrad/viz/serve.py +205 -0
  77. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
  78. tinygrad-0.10.2.dist-info/RECORD +99 -0
  79. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
  80. tinygrad/codegen/uopgraph.py +0 -506
  81. tinygrad/engine/lazy.py +0 -228
  82. tinygrad/function.py +0 -212
  83. tinygrad/multi.py +0 -177
  84. tinygrad/runtime/graph/clang.py +0 -39
  85. tinygrad/runtime/ops_clang.py +0 -35
  86. tinygrad-0.10.0.dist-info/RECORD +0 -77
  87. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  88. {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
- from __future__ import annotations
2
- import os, subprocess, pathlib, ctypes, tempfile, functools
3
- from typing import List, Any, Tuple, Optional, cast
4
- from tinygrad.helpers import prod, getenv, T
5
- from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
1
+ import os, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
2
+ from typing import Any, Union, cast
3
+ from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
4
+ from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
6
5
  from tinygrad.renderer.cstyle import MetalRenderer
7
6
 
8
7
  class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
@@ -10,10 +9,11 @@ class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response
10
9
  def __eq__(self, other): return self.value == other.value
11
10
 
12
11
  class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
13
- def __del__(self): msg(self, "release")
14
-
15
- @functools.lru_cache(None)
16
- def sel(name: str): return libobjc.sel_registerName(name.encode())
12
+ def __del__(self):
13
+ # CPython doesn't make any guarantees about order in which globals (like `msg` or `libobjc`) are destroyed when the interpreter shuts down
14
+ # https://github.com/tinygrad/tinygrad/pull/8949 triggered the unlucky ordering which lead to a bunch of errors at exit
15
+ # TODO: Why isn't `sys.is_finalizing` working?
16
+ if msg is not None and libobjc is not None: msg("release")(self)
17
17
 
18
18
  class MTLResourceOptions:
19
19
  MTLResourceCPUCacheModeDefaultCache = 0
@@ -22,6 +22,9 @@ class MTLResourceOptions:
22
22
  class MTLPipelineOption:
23
23
  MTLPipelineOptionNone = 0
24
24
 
25
+ # 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
26
+ REQUEST_TYPE_COMPILE = 13
27
+
25
28
  libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
26
29
  libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
27
30
  # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
@@ -32,53 +35,107 @@ libobjc.sel_registerName.restype = objc_id
32
35
  libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
33
36
  libdispatch.dispatch_data_create.restype = objc_instance
34
37
 
35
- # Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
36
- def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
38
+ @functools.lru_cache(None)
39
+ def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
40
+ resname = libobjc.sel_registerName(selector.encode())
37
41
  sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
38
42
  sender.restype = restype
39
- return sender(ptr, sel(selector), *args)
43
+ def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
44
+ return _msg
40
45
 
41
- def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
46
+ @functools.lru_cache(None)
47
+ def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
48
+ def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
42
49
 
43
- def to_struct(*t: int, _type: type = ctypes.c_ulong):
44
- class Struct(ctypes.Structure): pass
45
- Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
46
- return Struct(*t)
50
+ def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
47
51
 
48
52
  def wait_check(cbuf: Any):
49
- msg(cbuf, "waitUntilCompleted")
50
- error_check(msg(cbuf, "error", restype=objc_instance))
53
+ msg("waitUntilCompleted")(cbuf)
54
+ error_check(msg("error", objc_instance)(cbuf))
51
55
 
52
- def elapsed_time(cbuf: objc_id):
53
- return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
56
+ def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg("label", objc_id)(cbuf)).value is not None else None
57
+ def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg("GPUStartTime", ctypes.c_double)(cbuf))
58
+ def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg("GPUEndTime", ctypes.c_double)(cbuf))
54
59
 
55
60
  def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
56
61
  if error.value is None: return None
57
- raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
62
+ raise error_constructor(from_ns_str(msg("localizedDescription", objc_instance)(error)))
63
+
64
+ class MetalDevice(Compiled):
65
+ def __init__(self, device:str):
66
+ self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
67
+ self.mtl_queue = msg("newCommandQueueWithMaxCommandBufferCount:", objc_instance)(self.sysdevice, 1024)
68
+ if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
69
+ self.mtl_buffers_in_flight: list[Any] = []
70
+ self.timeline_signal = msg("newSharedEvent", objc_instance)(self.sysdevice)
71
+ self.timeline_value = 0
72
+
73
+ Compiled.profile_events += [ProfileDeviceEvent(device)]
74
+
75
+ from tinygrad.runtime.graph.metal import MetalGraph
76
+ super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
77
+ functools.partial(MetalProgram, self), MetalGraph)
78
+
79
+ def synchronize(self):
80
+ for cbuf in self.mtl_buffers_in_flight:
81
+ wait_check(cbuf)
82
+ st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
83
+ if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
84
+ Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
85
+ self.mtl_buffers_in_flight.clear()
58
86
 
59
87
  def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
60
- options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
61
- msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
62
- library = msg(device.device, "newLibraryWithSource:options:error:", to_ns_str(src), options,
63
- ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
88
+ options = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLCompileOptions"))
89
+ msg("setFastMathEnabled:")(options, getenv("METAL_FAST_MATH"))
90
+ library = msg("newLibraryWithSource:options:error:", objc_instance)(device.sysdevice, to_ns_str(src),
91
+ options, ctypes.byref(compileError:=objc_instance()))
64
92
  error_check(compileError, CompileError)
65
93
  return library
66
94
 
67
95
  class MetalCompiler(Compiler):
68
- def __init__(self, device:Optional[MetalDevice]=None):
69
- self.device = device
70
- super().__init__("compile_metal_xcode" if self.device is None else "compile_metal")
96
+ # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
97
+ # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
98
+ # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
99
+ # doesn't seem to be anything we can do.
100
+ with contextlib.suppress(FileNotFoundError):
101
+ import tinygrad.runtime.autogen.llvm # noqa: F401
102
+ support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
103
+ support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
104
+
105
+ def __init__(self):
106
+ self.cgs = ctypes.c_void_p(MetalCompiler.support.MTLCodeGenServiceCreate(b"tinygrad"))
107
+ super().__init__("compile_metal_direct")
108
+ def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
71
109
  def compile(self, src:str) -> bytes:
72
- if self.device is None:
73
- # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
74
- air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
75
- lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
76
- else:
77
- library = metal_src_to_library(self.device, src)
78
- library_contents = msg(library, "libraryDataContents", restype=objc_instance)
79
- lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
80
- assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
81
- return lib
110
+ ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
111
+ @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
112
+ def callback(blockptr, error, dataPtr, dataLen, errorMessage):
113
+ nonlocal ret
114
+ if error == 0:
115
+ reply = bytes(to_mv(dataPtr, dataLen))
116
+ # offset from beginning to data = header size + warning size
117
+ ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
118
+ else:
119
+ ret = CompileError(errorMessage.decode())
120
+
121
+ # no changes for compute in 2.0 - 2.4 specs, use 2.0 as default for old versions.
122
+ macos_major = int(platform.mac_ver()[0].split('.')[0])
123
+ metal_version = "metal3.1" if macos_major >= 14 else "metal3.0" if macos_major >= 13 else "macos-metal2.0"
124
+
125
+ # llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
126
+ # note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
127
+ params = f'-fno-fast-math -std={metal_version} --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
128
+ # source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
129
+ src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
130
+ request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
131
+ # The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
132
+ # See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
133
+ # Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
134
+ # argument and pretend it's a normal callback
135
+ MetalCompiler.support.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
136
+ if isinstance(ret, Exception): raise ret
137
+ assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
138
+ return ret
82
139
  def disassemble(self, lib:bytes):
83
140
  with tempfile.NamedTemporaryFile(delete=True) as shader:
84
141
  shader.write(lib)
@@ -87,102 +144,81 @@ class MetalCompiler(Compiler):
87
144
  if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
88
145
 
89
146
  class MetalProgram:
90
- def __init__(self, device:MetalDevice, name:str, lib:bytes):
91
- self.device, self.name, self.lib = device, name, lib
147
+ def __init__(self, dev:MetalDevice, name:str, lib:bytes):
148
+ self.dev, self.name, self.lib = dev, name, lib
92
149
  if lib[:4] == b"MTLB":
93
150
  # binary metal library
94
151
  data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
95
- error_library_creation = objc_instance()
96
- self.library = msg(self.device.device, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
97
- error_check(error_library_creation)
152
+ self.library = msg("newLibraryWithData:error:", objc_instance)(self.dev.sysdevice, data, ctypes.byref(error_lib:=objc_instance()))
153
+ error_check(error_lib)
98
154
  else:
99
155
  # metal source. rely on OS caching
100
- try: self.library = metal_src_to_library(self.device, lib.decode())
156
+ try: self.library = metal_src_to_library(self.dev, lib.decode())
101
157
  except CompileError as e: raise RuntimeError from e
102
- self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
103
- descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
104
- msg(descriptor, "setComputeFunction:", self.fxn)
105
- msg(descriptor, "setSupportIndirectCommandBuffers:", True)
106
- self.pipeline_state = msg(self.device.device, "newComputePipelineStateWithDescriptor:options:reflection:error:",
107
- descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
158
+ self.fxn = msg("newFunctionWithName:", objc_instance)(self.library, to_ns_str(name))
159
+ descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"))
160
+ msg("setComputeFunction:")(descriptor, self.fxn)
161
+ msg("setSupportIndirectCommandBuffers:")(descriptor, True)
162
+ self.pipeline_state = msg("newComputePipelineStateWithDescriptor:options:reflection:error:", objc_instance)(self.dev.sysdevice,
163
+ descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()))
108
164
  error_check(error_pipeline_creation)
109
-
110
- def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
111
- max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
112
- if prod(local_size) > cast(int, max_total_threads):
113
- exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
114
- memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
115
- raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
116
- command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
117
- encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
118
- msg(encoder, "setComputePipelineState:", self.pipeline_state)
119
- for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
120
- for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
121
- msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
122
- msg(encoder, "endEncoding")
123
- msg(command_buffer, "commit")
165
+ # cache these msg calls
166
+ self.max_total_threads: int = cast(int, msg("maxTotalThreadsPerThreadgroup", ctypes.c_ulong)(self.pipeline_state))
167
+
168
+ def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
169
+ if prod(local_size) > self.max_total_threads:
170
+ exec_width = msg("threadExecutionWidth", ctypes.c_ulong)(self.pipeline_state)
171
+ memory_length = msg("staticThreadgroupMemoryLength", ctypes.c_ulong)(self.pipeline_state)
172
+ raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
173
+ command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
174
+ encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
175
+ msg("setComputePipelineState:")(encoder, self.pipeline_state)
176
+ for i,a in enumerate(bufs): msg("setBuffer:offset:atIndex:")(encoder, a.buf, a.offset, i)
177
+ for i,a in enumerate(vals, start=len(bufs)): msg("setBytes:length:atIndex:")(encoder, bytes(ctypes.c_int(a)), 4, i)
178
+ msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(*global_size), to_struct(*local_size))
179
+ msg("endEncoding")(encoder)
180
+ msg("setLabel:")(command_buffer, to_ns_str(self.name)) # TODO: is this always needed?
181
+ msg("commit")(command_buffer)
182
+ self.dev.mtl_buffers_in_flight.append(command_buffer)
124
183
  if wait:
125
184
  wait_check(command_buffer)
126
- return elapsed_time(command_buffer)
127
- self.device.mtl_buffers_in_flight.append(command_buffer)
185
+ return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
128
186
 
129
187
  class MetalBuffer:
130
188
  def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
131
189
 
132
190
  class MetalAllocator(LRUAllocator):
133
- def __init__(self, device:MetalDevice):
134
- self.device:MetalDevice = device
191
+ def __init__(self, dev:MetalDevice):
192
+ self.dev:MetalDevice = dev
135
193
  super().__init__()
136
194
  def _alloc(self, size:int, options) -> MetalBuffer:
137
195
  # Buffer is explicitly released in _free() rather than garbage collected via reference count
138
- ret = msg(self.device.device, "newBufferWithLength:options:", size, MTLResourceOptions.MTLResourceStorageModeShared, restype=objc_id)
196
+ ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
139
197
  if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
140
198
  return MetalBuffer(ret, size)
141
- def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
142
- def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
199
+ def _free(self, opaque:MetalBuffer, options): msg("release")(opaque.buf)
200
+ def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
143
201
  dest_dev.synchronize()
144
- src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
145
- encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
146
- msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, ctypes.c_ulong(src.offset),
202
+ src_command_buffer = msg("commandBuffer", objc_instance)(src_dev.mtl_queue)
203
+ encoder = msg("blitCommandEncoder", objc_instance)(src_command_buffer)
204
+ msg("copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:")(encoder, src.buf, ctypes.c_ulong(src.offset),
147
205
  dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
148
- msg(encoder, "endEncoding")
206
+ msg("endEncoding")(encoder)
149
207
  if src_dev != dest_dev:
150
- msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
151
- dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
152
- msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
153
- msg(dest_command_buffer, "commit")
208
+ msg("encodeSignalEvent:value:")(src_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
209
+ dest_command_buffer = msg("commandBuffer", objc_instance)(dest_dev.mtl_queue)
210
+ msg("encodeWaitForEvent:value:")(dest_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
211
+ msg("commit")(dest_command_buffer)
154
212
  dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
155
213
  src_dev.timeline_value += 1
156
- msg(src_command_buffer, "commit")
214
+ msg("setLabel:")(src_command_buffer, to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
215
+ msg("commit")(src_command_buffer)
157
216
  src_dev.mtl_buffers_in_flight.append(src_command_buffer)
158
- def from_buffer(self, src:memoryview) -> Optional[Any]:
159
- ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
160
- ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
161
- if ret: self.device.mv_in_metal.append(src)
162
- return MetalBuffer(ret, src.nbytes)
163
- def as_buffer(self, src:MetalBuffer) -> memoryview:
164
- self.device.synchronize()
165
- ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
166
- array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
167
- return memoryview(array).cast("B")[src.offset:]
168
- def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
169
- def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
170
- def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
171
-
172
- class MetalDevice(Compiled):
173
- def __init__(self, device:str):
174
- self.device = libmetal.MTLCreateSystemDefaultDevice()
175
- self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
176
- if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
177
- self.mtl_buffers_in_flight: List[Any] = []
178
- self.mv_in_metal: List[memoryview] = []
179
- self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
180
- self.timeline_value = 0
181
-
182
- from tinygrad.runtime.graph.metal import MetalGraph
183
- super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
184
- functools.partial(MetalProgram, self), MetalGraph)
185
- def synchronize(self):
186
- for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
187
- self.mv_in_metal.clear()
188
- self.mtl_buffers_in_flight.clear()
217
+ def _cp_mv(self, dst, src, prof_desc):
218
+ with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
219
+ def _as_buffer(self, src:MetalBuffer) -> memoryview:
220
+ self.dev.synchronize()
221
+ return to_mv(cast(int, msg("contents", objc_id)(src.buf).value), src.size + src.offset)[src.offset:]
222
+ def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
223
+ def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
224
+ def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
@@ -2,8 +2,8 @@ import numpy as np
2
2
  from tinygrad.helpers import flat_mv
3
3
  from tinygrad.device import Compiled, Allocator
4
4
 
5
- class NpyAllocator(Allocator): # pylint: disable=abstract-method
6
- def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
5
+ class NpyAllocator(Allocator):
6
+ def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
7
7
 
8
8
  class NpyDevice(Compiled):
9
9
  def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)