tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/codegen/kernel.py +114 -172
- tinygrad/codegen/linearize.py +211 -81
- tinygrad/codegen/lowerer.py +30 -35
- tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
- tinygrad/codegen/transcendental.py +12 -13
- tinygrad/device.py +170 -47
- tinygrad/dtype.py +28 -26
- tinygrad/engine/jit.py +80 -63
- tinygrad/engine/memory.py +4 -5
- tinygrad/engine/multi.py +162 -0
- tinygrad/engine/realize.py +58 -107
- tinygrad/engine/schedule.py +381 -314
- tinygrad/engine/search.py +40 -44
- tinygrad/gradient.py +70 -0
- tinygrad/helpers.py +77 -58
- tinygrad/nn/__init__.py +30 -32
- tinygrad/nn/datasets.py +1 -2
- tinygrad/nn/optim.py +22 -26
- tinygrad/nn/state.py +89 -64
- tinygrad/ops.py +562 -446
- tinygrad/renderer/__init__.py +79 -36
- tinygrad/renderer/cstyle.py +70 -84
- tinygrad/renderer/llvmir.py +32 -20
- tinygrad/renderer/ptx.py +79 -99
- tinygrad/renderer/wgsl.py +87 -0
- tinygrad/runtime/autogen/amd_gpu.py +39507 -12
- tinygrad/runtime/autogen/comgr.py +2 -0
- tinygrad/runtime/autogen/kfd.py +4 -3
- tinygrad/runtime/autogen/kgsl.py +1 -1
- tinygrad/runtime/autogen/libpciaccess.py +2023 -0
- tinygrad/runtime/autogen/llvm.py +11379 -0
- tinygrad/runtime/autogen/vfio.py +891 -0
- tinygrad/runtime/graph/cuda.py +8 -9
- tinygrad/runtime/graph/hcq.py +84 -79
- tinygrad/runtime/graph/metal.py +19 -21
- tinygrad/runtime/ops_amd.py +488 -327
- tinygrad/runtime/ops_clang.py +15 -28
- tinygrad/runtime/ops_cloud.py +34 -34
- tinygrad/runtime/ops_cuda.py +30 -27
- tinygrad/runtime/ops_disk.py +62 -63
- tinygrad/runtime/ops_dsp.py +129 -38
- tinygrad/runtime/ops_gpu.py +30 -30
- tinygrad/runtime/ops_hip.py +29 -31
- tinygrad/runtime/ops_llvm.py +45 -40
- tinygrad/runtime/ops_metal.py +93 -73
- tinygrad/runtime/ops_npy.py +2 -2
- tinygrad/runtime/ops_nv.py +232 -270
- tinygrad/runtime/ops_python.py +51 -46
- tinygrad/runtime/ops_qcom.py +129 -157
- tinygrad/runtime/ops_webgpu.py +63 -0
- tinygrad/runtime/support/allocator.py +94 -0
- tinygrad/runtime/support/am/__init__.py +0 -0
- tinygrad/runtime/support/am/amdev.py +384 -0
- tinygrad/runtime/support/am/ip.py +463 -0
- tinygrad/runtime/support/compiler_cuda.py +4 -2
- tinygrad/runtime/support/elf.py +26 -4
- tinygrad/runtime/support/hcq.py +254 -324
- tinygrad/runtime/support/llvm.py +32 -0
- tinygrad/shape/shapetracker.py +84 -53
- tinygrad/shape/view.py +103 -138
- tinygrad/spec.py +154 -0
- tinygrad/tensor.py +744 -496
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
- tinygrad-0.10.1.dist-info/RECORD +86 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
- tinygrad/engine/lazy.py +0 -228
- tinygrad/function.py +0 -212
- tinygrad/multi.py +0 -177
- tinygrad/runtime/graph/clang.py +0 -39
- tinygrad-0.10.0.dist-info/RECORD +0 -77
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
- {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_metal.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
from
|
4
|
-
from tinygrad.
|
5
|
-
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
1
|
+
import os, pathlib, struct, ctypes, tempfile, functools, decimal
|
2
|
+
from typing import Any, Union, cast
|
3
|
+
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
|
4
|
+
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
|
6
5
|
from tinygrad.renderer.cstyle import MetalRenderer
|
7
6
|
|
8
7
|
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
@@ -22,14 +21,19 @@ class MTLResourceOptions:
|
|
22
21
|
class MTLPipelineOption:
|
23
22
|
MTLPipelineOptionNone = 0
|
24
23
|
|
24
|
+
# 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
|
25
|
+
REQUEST_TYPE_COMPILE = 13
|
26
|
+
|
25
27
|
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
26
28
|
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
29
|
+
compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
|
27
30
|
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
28
31
|
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
29
32
|
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
30
33
|
libobjc.objc_getClass.restype = objc_id
|
31
34
|
libobjc.sel_registerName.restype = objc_id
|
32
35
|
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
36
|
+
compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
|
33
37
|
libdispatch.dispatch_data_create.restype = objc_instance
|
34
38
|
|
35
39
|
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
@@ -39,46 +43,83 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
|
|
39
43
|
return sender(ptr, sel(selector), *args)
|
40
44
|
|
41
45
|
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
|
46
|
+
def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
|
42
47
|
|
43
|
-
def to_struct(*t: int, _type: type = ctypes.c_ulong):
|
44
|
-
class Struct(ctypes.Structure): pass
|
45
|
-
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
|
46
|
-
return Struct(*t)
|
48
|
+
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
|
47
49
|
|
48
50
|
def wait_check(cbuf: Any):
|
49
51
|
msg(cbuf, "waitUntilCompleted")
|
50
52
|
error_check(msg(cbuf, "error", restype=objc_instance))
|
51
53
|
|
52
|
-
def
|
53
|
-
|
54
|
+
def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg(cbuf, "label", restype=objc_id)).value is not None else None
|
55
|
+
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
|
56
|
+
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double))
|
54
57
|
|
55
58
|
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
|
56
59
|
if error.value is None: return None
|
57
|
-
raise error_constructor(
|
60
|
+
raise error_constructor(from_ns_str(msg(error, "localizedDescription", restype=objc_instance)))
|
61
|
+
|
62
|
+
class MetalDevice(Compiled):
|
63
|
+
def __init__(self, device:str):
|
64
|
+
self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
|
65
|
+
self.mtl_queue = msg(self.sysdevice, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
|
66
|
+
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
67
|
+
self.mtl_buffers_in_flight: list[Any] = []
|
68
|
+
self.timeline_signal = msg(self.sysdevice, "newSharedEvent", restype=objc_instance)
|
69
|
+
self.timeline_value = 0
|
70
|
+
|
71
|
+
Compiled.profile_events += [ProfileDeviceEvent(device)]
|
72
|
+
|
73
|
+
from tinygrad.runtime.graph.metal import MetalGraph
|
74
|
+
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
|
75
|
+
functools.partial(MetalProgram, self), MetalGraph)
|
76
|
+
|
77
|
+
def synchronize(self):
|
78
|
+
for cbuf in self.mtl_buffers_in_flight:
|
79
|
+
wait_check(cbuf)
|
80
|
+
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
|
81
|
+
if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
|
82
|
+
Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
|
83
|
+
self.mtl_buffers_in_flight.clear()
|
58
84
|
|
59
85
|
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
60
86
|
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
|
61
87
|
msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
|
62
|
-
library = msg(device.
|
88
|
+
library = msg(device.sysdevice, "newLibraryWithSource:options:error:", to_ns_str(src), options,
|
63
89
|
ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
|
64
90
|
error_check(compileError, CompileError)
|
65
91
|
return library
|
66
92
|
|
67
93
|
class MetalCompiler(Compiler):
|
68
|
-
def __init__(self
|
69
|
-
self.
|
70
|
-
super().__init__("
|
94
|
+
def __init__(self):
|
95
|
+
self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
|
96
|
+
super().__init__("compile_metal_direct")
|
97
|
+
def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
|
71
98
|
def compile(self, src:str) -> bytes:
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
99
|
+
ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
|
100
|
+
@ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
|
101
|
+
def callback(blockptr, error, dataPtr, dataLen, errorMessage):
|
102
|
+
nonlocal ret
|
103
|
+
if error == 0:
|
104
|
+
reply = bytes(to_mv(dataPtr, dataLen))
|
105
|
+
# offset from beginning to data = header size + warning size
|
106
|
+
ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
|
107
|
+
else:
|
108
|
+
ret = CompileError(errorMessage.decode())
|
109
|
+
# llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
|
110
|
+
# note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
|
111
|
+
params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
|
112
|
+
# source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
|
113
|
+
src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
|
114
|
+
request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
|
115
|
+
# The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
|
116
|
+
# See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
|
117
|
+
# Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
|
118
|
+
# argument and pretend it's a normal callback
|
119
|
+
compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
|
120
|
+
if isinstance(ret, Exception): raise ret
|
121
|
+
assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
|
122
|
+
return ret
|
82
123
|
def disassemble(self, lib:bytes):
|
83
124
|
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
84
125
|
shader.write(lib)
|
@@ -87,59 +128,60 @@ class MetalCompiler(Compiler):
|
|
87
128
|
if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
|
88
129
|
|
89
130
|
class MetalProgram:
|
90
|
-
def __init__(self,
|
91
|
-
self.
|
131
|
+
def __init__(self, dev:MetalDevice, name:str, lib:bytes):
|
132
|
+
self.dev, self.name, self.lib = dev, name, lib
|
92
133
|
if lib[:4] == b"MTLB":
|
93
134
|
# binary metal library
|
94
135
|
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
95
|
-
|
96
|
-
|
97
|
-
error_check(error_library_creation)
|
136
|
+
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
|
137
|
+
error_check(error_lib)
|
98
138
|
else:
|
99
139
|
# metal source. rely on OS caching
|
100
|
-
try: self.library = metal_src_to_library(self.
|
140
|
+
try: self.library = metal_src_to_library(self.dev, lib.decode())
|
101
141
|
except CompileError as e: raise RuntimeError from e
|
102
142
|
self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
|
103
143
|
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
|
104
144
|
msg(descriptor, "setComputeFunction:", self.fxn)
|
105
145
|
msg(descriptor, "setSupportIndirectCommandBuffers:", True)
|
106
|
-
self.pipeline_state = msg(self.
|
146
|
+
self.pipeline_state = msg(self.dev.sysdevice, "newComputePipelineStateWithDescriptor:options:reflection:error:",
|
107
147
|
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
|
108
148
|
error_check(error_pipeline_creation)
|
109
149
|
|
110
|
-
def __call__(self, *bufs, global_size:
|
150
|
+
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
|
111
151
|
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
|
112
152
|
if prod(local_size) > cast(int, max_total_threads):
|
113
153
|
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
|
114
154
|
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
|
115
155
|
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
|
116
|
-
command_buffer = msg(self.
|
156
|
+
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
117
157
|
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
118
158
|
msg(encoder, "setComputePipelineState:", self.pipeline_state)
|
119
159
|
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
|
120
|
-
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
160
|
+
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
121
161
|
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
122
162
|
msg(encoder, "endEncoding")
|
163
|
+
msg(command_buffer, "setLabel:", to_ns_str(self.name))
|
123
164
|
msg(command_buffer, "commit")
|
165
|
+
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
124
166
|
if wait:
|
125
167
|
wait_check(command_buffer)
|
126
|
-
return
|
127
|
-
self.device.mtl_buffers_in_flight.append(command_buffer)
|
168
|
+
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
|
128
169
|
|
129
170
|
class MetalBuffer:
|
130
171
|
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
|
131
172
|
|
132
173
|
class MetalAllocator(LRUAllocator):
|
133
|
-
def __init__(self,
|
134
|
-
self.
|
174
|
+
def __init__(self, dev:MetalDevice):
|
175
|
+
self.dev:MetalDevice = dev
|
135
176
|
super().__init__()
|
136
177
|
def _alloc(self, size:int, options) -> MetalBuffer:
|
137
178
|
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
138
|
-
ret = msg(self.
|
179
|
+
ret = msg(self.dev.sysdevice, "newBufferWithLength:options:", ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared,
|
180
|
+
restype=objc_id)
|
139
181
|
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
140
182
|
return MetalBuffer(ret, size)
|
141
183
|
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
|
142
|
-
def
|
184
|
+
def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
143
185
|
dest_dev.synchronize()
|
144
186
|
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
145
187
|
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
|
@@ -153,36 +195,14 @@ class MetalAllocator(LRUAllocator):
|
|
153
195
|
msg(dest_command_buffer, "commit")
|
154
196
|
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
|
155
197
|
src_dev.timeline_value += 1
|
198
|
+
msg(src_command_buffer, "setLabel:", to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
|
156
199
|
msg(src_command_buffer, "commit")
|
157
200
|
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
158
|
-
def
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
return
|
163
|
-
def
|
164
|
-
|
165
|
-
|
166
|
-
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
167
|
-
return memoryview(array).cast("B")[src.offset:]
|
168
|
-
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
|
169
|
-
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
|
170
|
-
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
171
|
-
|
172
|
-
class MetalDevice(Compiled):
|
173
|
-
def __init__(self, device:str):
|
174
|
-
self.device = libmetal.MTLCreateSystemDefaultDevice()
|
175
|
-
self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
|
176
|
-
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
177
|
-
self.mtl_buffers_in_flight: List[Any] = []
|
178
|
-
self.mv_in_metal: List[memoryview] = []
|
179
|
-
self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
|
180
|
-
self.timeline_value = 0
|
181
|
-
|
182
|
-
from tinygrad.runtime.graph.metal import MetalGraph
|
183
|
-
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
|
184
|
-
functools.partial(MetalProgram, self), MetalGraph)
|
185
|
-
def synchronize(self):
|
186
|
-
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
187
|
-
self.mv_in_metal.clear()
|
188
|
-
self.mtl_buffers_in_flight.clear()
|
201
|
+
def _cp_mv(self, dst, src, prof_desc):
|
202
|
+
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
|
203
|
+
def _as_buffer(self, src:MetalBuffer) -> memoryview:
|
204
|
+
self.dev.synchronize()
|
205
|
+
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
|
206
|
+
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
|
207
|
+
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
|
208
|
+
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
tinygrad/runtime/ops_npy.py
CHANGED
@@ -2,8 +2,8 @@ import numpy as np
|
|
2
2
|
from tinygrad.helpers import flat_mv
|
3
3
|
from tinygrad.device import Compiled, Allocator
|
4
4
|
|
5
|
-
class NpyAllocator(Allocator):
|
6
|
-
def
|
5
|
+
class NpyAllocator(Allocator):
|
6
|
+
def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
|
7
7
|
|
8
8
|
class NpyDevice(Compiled):
|
9
9
|
def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)
|