tinygrad 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/devectorizer.py +247 -0
- tinygrad/codegen/expander.py +121 -0
- tinygrad/codegen/kernel.py +141 -201
- tinygrad/codegen/linearize.py +223 -84
- tinygrad/codegen/lowerer.py +60 -42
- tinygrad/codegen/symbolic.py +476 -0
- tinygrad/codegen/transcendental.py +22 -13
- tinygrad/device.py +187 -47
- tinygrad/dtype.py +39 -28
- tinygrad/engine/jit.py +83 -65
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +161 -0
- tinygrad/engine/realize.py +62 -108
- tinygrad/engine/schedule.py +396 -357
- tinygrad/engine/search.py +55 -66
- tinygrad/gradient.py +73 -0
- tinygrad/helpers.py +81 -59
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +91 -66
- tinygrad/ops.py +492 -641
- tinygrad/renderer/__init__.py +95 -36
- tinygrad/renderer/cstyle.py +99 -92
- tinygrad/renderer/llvmir.py +83 -34
- tinygrad/renderer/ptx.py +83 -99
- tinygrad/renderer/wgsl.py +95 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libc.py +404 -71
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/pci.py +1333 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/autogen/webgpu.py +6985 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +40 -43
- tinygrad/runtime/ops_amd.py +498 -334
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cpu.py +24 -0
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +159 -42
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +48 -41
- tinygrad/runtime/ops_metal.py +149 -113
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +238 -273
- tinygrad/runtime/ops_python.py +55 -50
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +225 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +396 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +28 -4
- tinygrad/runtime/support/hcq.py +256 -324
- tinygrad/runtime/support/llvm.py +26 -0
- tinygrad/shape/shapetracker.py +85 -53
- tinygrad/shape/view.py +104 -140
- tinygrad/spec.py +155 -0
- tinygrad/tensor.py +835 -527
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
- tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
- tinygrad/viz/index.html +544 -0
- tinygrad/viz/perfetto.html +178 -0
- tinygrad/viz/serve.py +205 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/METADATA +48 -25
- tinygrad-0.10.2.dist-info/RECORD +99 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +1 -1
- tinygrad/codegen/uopgraph.py +0 -506
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad/runtime/ops_clang.py +0 -35
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_metal.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
1
|
+
import os, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
|
2
|
+
from typing import Any, Union, cast
|
3
|
+
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
|
4
|
+
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
|
6
5
|
from tinygrad.renderer.cstyle import MetalRenderer
|
7
6
|
|
8
7
|
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
@@ -10,10 +9,11 @@ class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response
|
|
10
9
|
def __eq__(self, other): return self.value == other.value
|
11
10
|
|
12
11
|
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
|
13
|
-
def __del__(self):
|
14
|
-
|
15
|
-
|
16
|
-
|
12
|
+
def __del__(self):
|
13
|
+
# CPython doesn't make any guarantees about order in which globals (like `msg` or `libobjc`) are destroyed when the interpreter shuts down
|
14
|
+
# https://github.com/tinygrad/tinygrad/pull/8949 triggered the unlucky ordering which lead to a bunch of errors at exit
|
15
|
+
# TODO: Why isn't `sys.is_finalizing` working?
|
16
|
+
if msg is not None and libobjc is not None: msg("release")(self)
|
17
17
|
|
18
18
|
class MTLResourceOptions:
|
19
19
|
MTLResourceCPUCacheModeDefaultCache = 0
|
@@ -22,6 +22,9 @@ class MTLResourceOptions:
|
|
22
22
|
class MTLPipelineOption:
|
23
23
|
MTLPipelineOptionNone = 0
|
24
24
|
|
25
|
+
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
26
|
+
REQUEST_TYPE_COMPILE = 13
|
27
|
+
|
25
28
|
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
26
29
|
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
27
30
|
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
@@ -32,53 +35,107 @@ libobjc.sel_registerName.restype = objc_id
|
|
32
35
|
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
33
36
|
libdispatch.dispatch_data_create.restype = objc_instance
|
34
37
|
|
35
|
-
|
36
|
-
def msg(
|
38
|
+
@functools.lru_cache(None)
|
39
|
+
def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
|
40
|
+
resname = libobjc.sel_registerName(selector.encode())
|
37
41
|
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
38
42
|
sender.restype = restype
|
39
|
-
return sender(ptr,
|
43
|
+
def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
|
44
|
+
return _msg
|
40
45
|
|
41
|
-
|
46
|
+
@functools.lru_cache(None)
|
47
|
+
def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
|
48
|
+
def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
|
42
49
|
|
43
|
-
def to_struct(*t: int, _type: type = ctypes.c_ulong):
|
44
|
-
class Struct(ctypes.Structure): pass
|
45
|
-
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
|
46
|
-
return Struct(*t)
|
50
|
+
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
|
47
51
|
|
48
52
|
def wait_check(cbuf: Any):
|
49
|
-
msg(
|
50
|
-
error_check(msg(
|
53
|
+
msg("waitUntilCompleted")(cbuf)
|
54
|
+
error_check(msg("error", objc_instance)(cbuf))
|
51
55
|
|
52
|
-
def
|
53
|
-
|
56
|
+
def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg("label", objc_id)(cbuf)).value is not None else None
|
57
|
+
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg("GPUStartTime", ctypes.c_double)(cbuf))
|
58
|
+
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg("GPUEndTime", ctypes.c_double)(cbuf))
|
54
59
|
|
55
60
|
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
|
56
61
|
if error.value is None: return None
|
57
|
-
raise error_constructor(
|
62
|
+
raise error_constructor(from_ns_str(msg("localizedDescription", objc_instance)(error)))
|
63
|
+
|
64
|
+
class MetalDevice(Compiled):
|
65
|
+
def __init__(self, device:str):
|
66
|
+
self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
|
67
|
+
self.mtl_queue = msg("newCommandQueueWithMaxCommandBufferCount:", objc_instance)(self.sysdevice, 1024)
|
68
|
+
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
69
|
+
self.mtl_buffers_in_flight: list[Any] = []
|
70
|
+
self.timeline_signal = msg("newSharedEvent", objc_instance)(self.sysdevice)
|
71
|
+
self.timeline_value = 0
|
72
|
+
|
73
|
+
Compiled.profile_events += [ProfileDeviceEvent(device)]
|
74
|
+
|
75
|
+
from tinygrad.runtime.graph.metal import MetalGraph
|
76
|
+
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
|
77
|
+
functools.partial(MetalProgram, self), MetalGraph)
|
78
|
+
|
79
|
+
def synchronize(self):
|
80
|
+
for cbuf in self.mtl_buffers_in_flight:
|
81
|
+
wait_check(cbuf)
|
82
|
+
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
|
83
|
+
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
|
84
|
+
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
|
85
|
+
self.mtl_buffers_in_flight.clear()
|
58
86
|
|
59
87
|
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
60
|
-
options = msg(libobjc.objc_getClass(b"MTLCompileOptions")
|
61
|
-
msg(
|
62
|
-
library = msg(
|
63
|
-
|
88
|
+
options = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLCompileOptions"))
|
89
|
+
msg("setFastMathEnabled:")(options, getenv("METAL_FAST_MATH"))
|
90
|
+
library = msg("newLibraryWithSource:options:error:", objc_instance)(device.sysdevice, to_ns_str(src),
|
91
|
+
options, ctypes.byref(compileError:=objc_instance()))
|
64
92
|
error_check(compileError, CompileError)
|
65
93
|
return library
|
66
94
|
|
67
95
|
class MetalCompiler(Compiler):
|
68
|
-
|
69
|
-
|
70
|
-
|
96
|
+
# Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
|
97
|
+
# This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
|
98
|
+
# library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
|
99
|
+
# doesn't seem to be anything we can do.
|
100
|
+
with contextlib.suppress(FileNotFoundError):
|
101
|
+
import tinygrad.runtime.autogen.llvm # noqa: F401
|
102
|
+
support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
|
103
|
+
support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
|
104
|
+
|
105
|
+
def __init__(self):
|
106
|
+
self.cgs = ctypes.c_void_p(MetalCompiler.support.MTLCodeGenServiceCreate(b"tinygrad"))
|
107
|
+
super().__init__("compile_metal_direct")
|
108
|
+
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
|
71
109
|
def compile(self, src:str) -> bytes:
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
110
|
+
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
|
111
|
+
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
|
112
|
+
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
|
113
|
+
nonlocal ret
|
114
|
+
if error == 0:
|
115
|
+
reply = bytes(to_mv(dataPtr, dataLen))
|
116
|
+
# offset from beginning to data = header size + warning size
|
117
|
+
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
|
118
|
+
else:
|
119
|
+
ret = CompileError(errorMessage.decode())
|
120
|
+
|
121
|
+
# no changes for compute in 2.0 - 2.4 specs, use 2.0 as default for old versions.
|
122
|
+
macos_major = int(platform.mac_ver()[0].split('.')[0])
|
123
|
+
metal_version = "metal3.1" if macos_major >= 14 else "metal3.0" if macos_major >= 13 else "macos-metal2.0"
|
124
|
+
|
125
|
+
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
|
126
|
+
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
|
127
|
+
params = f'-fno-fast-math -std={metal_version} --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
|
128
|
+
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
|
129
|
+
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
|
130
|
+
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
|
131
|
+
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
|
132
|
+
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
|
133
|
+
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
|
134
|
+
# argument and pretend it's a normal callback
|
135
|
+
MetalCompiler.support.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
|
136
|
+
if isinstance(ret, Exception): raise ret
|
137
|
+
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
|
138
|
+
return ret
|
82
139
|
def disassemble(self, lib:bytes):
|
83
140
|
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
84
141
|
shader.write(lib)
|
@@ -87,102 +144,81 @@ class MetalCompiler(Compiler):
|
|
87
144
|
if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
|
88
145
|
|
89
146
|
class MetalProgram:
|
90
|
-
def __init__(self,
|
91
|
-
self.
|
147
|
+
def __init__(self, dev:MetalDevice, name:str, lib:bytes):
|
148
|
+
self.dev, self.name, self.lib = dev, name, lib
|
92
149
|
if lib[:4] == b"MTLB":
|
93
150
|
# binary metal library
|
94
151
|
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
95
|
-
|
96
|
-
|
97
|
-
error_check(error_library_creation)
|
152
|
+
self.library = msg("newLibraryWithData:error:", objc_instance)(self.dev.sysdevice, data, ctypes.byref(error_lib:=objc_instance()))
|
153
|
+
error_check(error_lib)
|
98
154
|
else:
|
99
155
|
# metal source. rely on OS caching
|
100
|
-
try: self.library = metal_src_to_library(self.
|
156
|
+
try: self.library = metal_src_to_library(self.dev, lib.decode())
|
101
157
|
except CompileError as e: raise RuntimeError from e
|
102
|
-
self.fxn = msg(
|
103
|
-
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor")
|
104
|
-
msg(
|
105
|
-
msg(
|
106
|
-
self.pipeline_state = msg(
|
107
|
-
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance())
|
158
|
+
self.fxn = msg("newFunctionWithName:", objc_instance)(self.library, to_ns_str(name))
|
159
|
+
descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"))
|
160
|
+
msg("setComputeFunction:")(descriptor, self.fxn)
|
161
|
+
msg("setSupportIndirectCommandBuffers:")(descriptor, True)
|
162
|
+
self.pipeline_state = msg("newComputePipelineStateWithDescriptor:options:reflection:error:", objc_instance)(self.dev.sysdevice,
|
163
|
+
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()))
|
108
164
|
error_check(error_pipeline_creation)
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
msg(
|
119
|
-
|
120
|
-
for i,a in enumerate(
|
121
|
-
|
122
|
-
msg(encoder,
|
123
|
-
msg(
|
165
|
+
# cache these msg calls
|
166
|
+
self.max_total_threads: int = cast(int, msg("maxTotalThreadsPerThreadgroup", ctypes.c_ulong)(self.pipeline_state))
|
167
|
+
|
168
|
+
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
169
|
+
if prod(local_size) > self.max_total_threads:
|
170
|
+
exec_width = msg("threadExecutionWidth", ctypes.c_ulong)(self.pipeline_state)
|
171
|
+
memory_length = msg("staticThreadgroupMemoryLength", ctypes.c_ulong)(self.pipeline_state)
|
172
|
+
raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
|
173
|
+
command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
|
174
|
+
encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
|
175
|
+
msg("setComputePipelineState:")(encoder, self.pipeline_state)
|
176
|
+
for i,a in enumerate(bufs): msg("setBuffer:offset:atIndex:")(encoder, a.buf, a.offset, i)
|
177
|
+
for i,a in enumerate(vals, start=len(bufs)): msg("setBytes:length:atIndex:")(encoder, bytes(ctypes.c_int(a)), 4, i)
|
178
|
+
msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(*global_size), to_struct(*local_size))
|
179
|
+
msg("endEncoding")(encoder)
|
180
|
+
msg("setLabel:")(command_buffer, to_ns_str(self.name)) # TODO: is this always needed?
|
181
|
+
msg("commit")(command_buffer)
|
182
|
+
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
124
183
|
if wait:
|
125
184
|
wait_check(command_buffer)
|
126
|
-
return
|
127
|
-
self.device.mtl_buffers_in_flight.append(command_buffer)
|
185
|
+
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
128
186
|
|
129
187
|
class MetalBuffer:
|
130
188
|
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
|
131
189
|
|
132
190
|
class MetalAllocator(LRUAllocator):
|
133
|
-
def __init__(self,
|
134
|
-
self.
|
191
|
+
def __init__(self, dev:MetalDevice):
|
192
|
+
self.dev:MetalDevice = dev
|
135
193
|
super().__init__()
|
136
194
|
def _alloc(self, size:int, options) -> MetalBuffer:
|
137
195
|
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
138
|
-
ret = msg(
|
196
|
+
ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
|
139
197
|
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
140
198
|
return MetalBuffer(ret, size)
|
141
|
-
def _free(self, opaque:MetalBuffer, options): msg(
|
142
|
-
def
|
199
|
+
def _free(self, opaque:MetalBuffer, options): msg("release")(opaque.buf)
|
200
|
+
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
143
201
|
dest_dev.synchronize()
|
144
|
-
src_command_buffer = msg(
|
145
|
-
encoder = msg(
|
146
|
-
msg(
|
202
|
+
src_command_buffer = msg("commandBuffer", objc_instance)(src_dev.mtl_queue)
|
203
|
+
encoder = msg("blitCommandEncoder", objc_instance)(src_command_buffer)
|
204
|
+
msg("copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:")(encoder, src.buf, ctypes.c_ulong(src.offset),
|
147
205
|
dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
|
148
|
-
msg(
|
206
|
+
msg("endEncoding")(encoder)
|
149
207
|
if src_dev != dest_dev:
|
150
|
-
msg(
|
151
|
-
dest_command_buffer = msg(
|
152
|
-
msg(
|
153
|
-
msg(
|
208
|
+
msg("encodeSignalEvent:value:")(src_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
|
209
|
+
dest_command_buffer = msg("commandBuffer", objc_instance)(dest_dev.mtl_queue)
|
210
|
+
msg("encodeWaitForEvent:value:")(dest_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
|
211
|
+
msg("commit")(dest_command_buffer)
|
154
212
|
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
|
155
213
|
src_dev.timeline_value += 1
|
156
|
-
msg(src_command_buffer, "
|
214
|
+
msg("setLabel:")(src_command_buffer, to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
|
215
|
+
msg("commit")(src_command_buffer)
|
157
216
|
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
158
|
-
def
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
return
|
163
|
-
def
|
164
|
-
|
165
|
-
|
166
|
-
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
167
|
-
return memoryview(array).cast("B")[src.offset:]
|
168
|
-
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
|
169
|
-
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
|
170
|
-
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
171
|
-
|
172
|
-
class MetalDevice(Compiled):
|
173
|
-
def __init__(self, device:str):
|
174
|
-
self.device = libmetal.MTLCreateSystemDefaultDevice()
|
175
|
-
self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
|
176
|
-
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
177
|
-
self.mtl_buffers_in_flight: List[Any] = []
|
178
|
-
self.mv_in_metal: List[memoryview] = []
|
179
|
-
self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
|
180
|
-
self.timeline_value = 0
|
181
|
-
|
182
|
-
from tinygrad.runtime.graph.metal import MetalGraph
|
183
|
-
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
|
184
|
-
functools.partial(MetalProgram, self), MetalGraph)
|
185
|
-
def synchronize(self):
|
186
|
-
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
187
|
-
self.mv_in_metal.clear()
|
188
|
-
self.mtl_buffers_in_flight.clear()
|
217
|
+
def _cp_mv(self, dst, src, prof_desc):
|
218
|
+
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
|
219
|
+
def _as_buffer(self, src:MetalBuffer) -> memoryview:
|
220
|
+
self.dev.synchronize()
|
221
|
+
return to_mv(cast(int, msg("contents", objc_id)(src.buf).value), src.size + src.offset)[src.offset:]
|
222
|
+
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
|
223
|
+
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
|
224
|
+
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
tinygrad/runtime/ops_npy.py
CHANGED
@@ -2,8 +2,8 @@ import numpy as np
|
|
2
2
|
from tinygrad.helpers import flat_mv
|
3
3
|
from tinygrad.device import Compiled, Allocator
|
4
4
|
|
5
|
-
class NpyAllocator(Allocator):
|
6
|
-
def
|
5
|
+
class NpyAllocator(Allocator):
|
6
|
+
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
7
7
|
|
8
8
|
class NpyDevice(Compiled):
|
9
9
|
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)
|