tinygrad 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. tinygrad/codegen/kernel.py +114 -172
  2. tinygrad/codegen/linearize.py +211 -81
  3. tinygrad/codegen/lowerer.py +30 -35
  4. tinygrad/codegen/{uopgraph.py → rewriter.py} +69 -59
  5. tinygrad/codegen/transcendental.py +12 -13
  6. tinygrad/device.py +170 -47
  7. tinygrad/dtype.py +28 -26
  8. tinygrad/engine/jit.py +80 -63
  9. tinygrad/engine/memory.py +4 -5
  10. tinygrad/engine/multi.py +162 -0
  11. tinygrad/engine/realize.py +58 -107
  12. tinygrad/engine/schedule.py +381 -314
  13. tinygrad/engine/search.py +40 -44
  14. tinygrad/gradient.py +70 -0
  15. tinygrad/helpers.py +77 -58
  16. tinygrad/nn/__init__.py +30 -32
  17. tinygrad/nn/datasets.py +1 -2
  18. tinygrad/nn/optim.py +22 -26
  19. tinygrad/nn/state.py +89 -64
  20. tinygrad/ops.py +562 -446
  21. tinygrad/renderer/__init__.py +79 -36
  22. tinygrad/renderer/cstyle.py +70 -84
  23. tinygrad/renderer/llvmir.py +32 -20
  24. tinygrad/renderer/ptx.py +79 -99
  25. tinygrad/renderer/wgsl.py +87 -0
  26. tinygrad/runtime/autogen/amd_gpu.py +39507 -12
  27. tinygrad/runtime/autogen/comgr.py +2 -0
  28. tinygrad/runtime/autogen/kfd.py +4 -3
  29. tinygrad/runtime/autogen/kgsl.py +1 -1
  30. tinygrad/runtime/autogen/libpciaccess.py +2023 -0
  31. tinygrad/runtime/autogen/llvm.py +11379 -0
  32. tinygrad/runtime/autogen/vfio.py +891 -0
  33. tinygrad/runtime/graph/cuda.py +8 -9
  34. tinygrad/runtime/graph/hcq.py +84 -79
  35. tinygrad/runtime/graph/metal.py +19 -21
  36. tinygrad/runtime/ops_amd.py +488 -327
  37. tinygrad/runtime/ops_clang.py +15 -28
  38. tinygrad/runtime/ops_cloud.py +34 -34
  39. tinygrad/runtime/ops_cuda.py +30 -27
  40. tinygrad/runtime/ops_disk.py +62 -63
  41. tinygrad/runtime/ops_dsp.py +129 -38
  42. tinygrad/runtime/ops_gpu.py +30 -30
  43. tinygrad/runtime/ops_hip.py +29 -31
  44. tinygrad/runtime/ops_llvm.py +45 -40
  45. tinygrad/runtime/ops_metal.py +93 -73
  46. tinygrad/runtime/ops_npy.py +2 -2
  47. tinygrad/runtime/ops_nv.py +232 -270
  48. tinygrad/runtime/ops_python.py +51 -46
  49. tinygrad/runtime/ops_qcom.py +129 -157
  50. tinygrad/runtime/ops_webgpu.py +63 -0
  51. tinygrad/runtime/support/allocator.py +94 -0
  52. tinygrad/runtime/support/am/__init__.py +0 -0
  53. tinygrad/runtime/support/am/amdev.py +384 -0
  54. tinygrad/runtime/support/am/ip.py +463 -0
  55. tinygrad/runtime/support/compiler_cuda.py +4 -2
  56. tinygrad/runtime/support/elf.py +26 -4
  57. tinygrad/runtime/support/hcq.py +254 -324
  58. tinygrad/runtime/support/llvm.py +32 -0
  59. tinygrad/shape/shapetracker.py +84 -53
  60. tinygrad/shape/view.py +103 -138
  61. tinygrad/spec.py +154 -0
  62. tinygrad/tensor.py +744 -496
  63. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/METADATA +32 -21
  64. tinygrad-0.10.1.dist-info/RECORD +86 -0
  65. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/WHEEL +1 -1
  66. tinygrad/engine/lazy.py +0 -228
  67. tinygrad/function.py +0 -212
  68. tinygrad/multi.py +0 -177
  69. tinygrad/runtime/graph/clang.py +0 -39
  70. tinygrad-0.10.0.dist-info/RECORD +0 -77
  71. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/LICENSE +0 -0
  72. {tinygrad-0.10.0.dist-info → tinygrad-0.10.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
- from __future__ import annotations
2
- import os, subprocess, pathlib, ctypes, tempfile, functools
3
- from typing import List, Any, Tuple, Optional, cast
4
- from tinygrad.helpers import prod, getenv, T
5
- from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
1
+ import os, pathlib, struct, ctypes, tempfile, functools, decimal
2
+ from typing import Any, Union, cast
3
+ from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
4
+ from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
6
5
  from tinygrad.renderer.cstyle import MetalRenderer
7
6
 
8
7
  class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
@@ -22,14 +21,19 @@ class MTLResourceOptions:
22
21
  class MTLPipelineOption:
23
22
  MTLPipelineOptionNone = 0
24
23
 
24
+ # 13 is requestType that metal uses to compile source code into MTLB, there aren't any docs or symbols.
25
+ REQUEST_TYPE_COMPILE = 13
26
+
25
27
  libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
26
28
  libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
29
+ compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
27
30
  # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
28
31
  ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
29
32
  libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
30
33
  libobjc.objc_getClass.restype = objc_id
31
34
  libobjc.sel_registerName.restype = objc_id
32
35
  libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
36
+ compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
33
37
  libdispatch.dispatch_data_create.restype = objc_instance
34
38
 
35
39
  # Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
@@ -39,46 +43,83 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
39
43
  return sender(ptr, sel(selector), *args)
40
44
 
41
45
  def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
46
+ def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
42
47
 
43
- def to_struct(*t: int, _type: type = ctypes.c_ulong):
44
- class Struct(ctypes.Structure): pass
45
- Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
46
- return Struct(*t)
48
+ def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
47
49
 
48
50
  def wait_check(cbuf: Any):
49
51
  msg(cbuf, "waitUntilCompleted")
50
52
  error_check(msg(cbuf, "error", restype=objc_instance))
51
53
 
52
- def elapsed_time(cbuf: objc_id):
53
- return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
54
+ def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg(cbuf, "label", restype=objc_id)).value is not None else None
55
+ def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
56
+ def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double))
54
57
 
55
58
  def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
56
59
  if error.value is None: return None
57
- raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
60
+ raise error_constructor(from_ns_str(msg(error, "localizedDescription", restype=objc_instance)))
61
+
62
+ class MetalDevice(Compiled):
63
+ def __init__(self, device:str):
64
+ self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
65
+ self.mtl_queue = msg(self.sysdevice, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
66
+ if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
67
+ self.mtl_buffers_in_flight: list[Any] = []
68
+ self.timeline_signal = msg(self.sysdevice, "newSharedEvent", restype=objc_instance)
69
+ self.timeline_value = 0
70
+
71
+ Compiled.profile_events += [ProfileDeviceEvent(device)]
72
+
73
+ from tinygrad.runtime.graph.metal import MetalGraph
74
+ super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
75
+ functools.partial(MetalProgram, self), MetalGraph)
76
+
77
+ def synchronize(self):
78
+ for cbuf in self.mtl_buffers_in_flight:
79
+ wait_check(cbuf)
80
+ st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
81
+ if PROFILE and (lb:=cmdbuf_label(cbuf)) is not None:
82
+ Compiled.profile_events += [ProfileRangeEvent(self.device, lb, st, en, is_copy=lb.startswith("COPY"))]
83
+ self.mtl_buffers_in_flight.clear()
58
84
 
59
85
  def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
60
86
  options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
61
87
  msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
62
- library = msg(device.device, "newLibraryWithSource:options:error:", to_ns_str(src), options,
88
+ library = msg(device.sysdevice, "newLibraryWithSource:options:error:", to_ns_str(src), options,
63
89
  ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
64
90
  error_check(compileError, CompileError)
65
91
  return library
66
92
 
67
93
  class MetalCompiler(Compiler):
68
- def __init__(self, device:Optional[MetalDevice]=None):
69
- self.device = device
70
- super().__init__("compile_metal_xcode" if self.device is None else "compile_metal")
94
+ def __init__(self):
95
+ self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
96
+ super().__init__("compile_metal_direct")
97
+ def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
71
98
  def compile(self, src:str) -> bytes:
72
- if self.device is None:
73
- # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
74
- air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
75
- lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
76
- else:
77
- library = metal_src_to_library(self.device, src)
78
- library_contents = msg(library, "libraryDataContents", restype=objc_instance)
79
- lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
80
- assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
81
- return lib
99
+ ret: Union[Exception, bytes] = CompileError("MTLCodeGenServiceBuildRequest returned without calling the callback")
100
+ @ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_int32, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_char_p)
101
+ def callback(blockptr, error, dataPtr, dataLen, errorMessage):
102
+ nonlocal ret
103
+ if error == 0:
104
+ reply = bytes(to_mv(dataPtr, dataLen))
105
+ # offset from beginning to data = header size + warning size
106
+ ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
107
+ else:
108
+ ret = CompileError(errorMessage.decode())
109
+ # llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
110
+ # note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
111
+ params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
112
+ # source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
113
+ src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
114
+ request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
115
+ # The callback is actually not a callback but a block which is apple's non-standard extension to add closures to C.
116
+ # See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
117
+ # Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
118
+ # argument and pretend it's a normal callback
119
+ compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
120
+ if isinstance(ret, Exception): raise ret
121
+ assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
122
+ return ret
82
123
  def disassemble(self, lib:bytes):
83
124
  with tempfile.NamedTemporaryFile(delete=True) as shader:
84
125
  shader.write(lib)
@@ -87,59 +128,60 @@ class MetalCompiler(Compiler):
87
128
  if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
88
129
 
89
130
  class MetalProgram:
90
- def __init__(self, device:MetalDevice, name:str, lib:bytes):
91
- self.device, self.name, self.lib = device, name, lib
131
+ def __init__(self, dev:MetalDevice, name:str, lib:bytes):
132
+ self.dev, self.name, self.lib = dev, name, lib
92
133
  if lib[:4] == b"MTLB":
93
134
  # binary metal library
94
135
  data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
95
- error_library_creation = objc_instance()
96
- self.library = msg(self.device.device, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
97
- error_check(error_library_creation)
136
+ self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
137
+ error_check(error_lib)
98
138
  else:
99
139
  # metal source. rely on OS caching
100
- try: self.library = metal_src_to_library(self.device, lib.decode())
140
+ try: self.library = metal_src_to_library(self.dev, lib.decode())
101
141
  except CompileError as e: raise RuntimeError from e
102
142
  self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
103
143
  descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
104
144
  msg(descriptor, "setComputeFunction:", self.fxn)
105
145
  msg(descriptor, "setSupportIndirectCommandBuffers:", True)
106
- self.pipeline_state = msg(self.device.device, "newComputePipelineStateWithDescriptor:options:reflection:error:",
146
+ self.pipeline_state = msg(self.dev.sysdevice, "newComputePipelineStateWithDescriptor:options:reflection:error:",
107
147
  descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
108
148
  error_check(error_pipeline_creation)
109
149
 
110
- def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
150
+ def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
111
151
  max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
112
152
  if prod(local_size) > cast(int, max_total_threads):
113
153
  exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
114
154
  memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
115
155
  raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
116
- command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
156
+ command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
117
157
  encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
118
158
  msg(encoder, "setComputePipelineState:", self.pipeline_state)
119
159
  for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
120
- for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
160
+ for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
121
161
  msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
122
162
  msg(encoder, "endEncoding")
163
+ msg(command_buffer, "setLabel:", to_ns_str(self.name))
123
164
  msg(command_buffer, "commit")
165
+ self.dev.mtl_buffers_in_flight.append(command_buffer)
124
166
  if wait:
125
167
  wait_check(command_buffer)
126
- return elapsed_time(command_buffer)
127
- self.device.mtl_buffers_in_flight.append(command_buffer)
168
+ return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
128
169
 
129
170
  class MetalBuffer:
130
171
  def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
131
172
 
132
173
  class MetalAllocator(LRUAllocator):
133
- def __init__(self, device:MetalDevice):
134
- self.device:MetalDevice = device
174
+ def __init__(self, dev:MetalDevice):
175
+ self.dev:MetalDevice = dev
135
176
  super().__init__()
136
177
  def _alloc(self, size:int, options) -> MetalBuffer:
137
178
  # Buffer is explicitly released in _free() rather than garbage collected via reference count
138
- ret = msg(self.device.device, "newBufferWithLength:options:", size, MTLResourceOptions.MTLResourceStorageModeShared, restype=objc_id)
179
+ ret = msg(self.dev.sysdevice, "newBufferWithLength:options:", ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared,
180
+ restype=objc_id)
139
181
  if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
140
182
  return MetalBuffer(ret, size)
141
183
  def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
142
- def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
184
+ def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
143
185
  dest_dev.synchronize()
144
186
  src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
145
187
  encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
@@ -153,36 +195,14 @@ class MetalAllocator(LRUAllocator):
153
195
  msg(dest_command_buffer, "commit")
154
196
  dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
155
197
  src_dev.timeline_value += 1
198
+ msg(src_command_buffer, "setLabel:", to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
156
199
  msg(src_command_buffer, "commit")
157
200
  src_dev.mtl_buffers_in_flight.append(src_command_buffer)
158
- def from_buffer(self, src:memoryview) -> Optional[Any]:
159
- ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
160
- ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
161
- if ret: self.device.mv_in_metal.append(src)
162
- return MetalBuffer(ret, src.nbytes)
163
- def as_buffer(self, src:MetalBuffer) -> memoryview:
164
- self.device.synchronize()
165
- ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
166
- array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
167
- return memoryview(array).cast("B")[src.offset:]
168
- def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
169
- def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
170
- def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
171
-
172
- class MetalDevice(Compiled):
173
- def __init__(self, device:str):
174
- self.device = libmetal.MTLCreateSystemDefaultDevice()
175
- self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
176
- if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
177
- self.mtl_buffers_in_flight: List[Any] = []
178
- self.mv_in_metal: List[memoryview] = []
179
- self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
180
- self.timeline_value = 0
181
-
182
- from tinygrad.runtime.graph.metal import MetalGraph
183
- super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
184
- functools.partial(MetalProgram, self), MetalGraph)
185
- def synchronize(self):
186
- for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
187
- self.mv_in_metal.clear()
188
- self.mtl_buffers_in_flight.clear()
201
+ def _cp_mv(self, dst, src, prof_desc):
202
+ with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
203
+ def _as_buffer(self, src:MetalBuffer) -> memoryview:
204
+ self.dev.synchronize()
205
+ return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
206
+ def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
207
+ def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
208
+ def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
@@ -2,8 +2,8 @@ import numpy as np
2
2
  from tinygrad.helpers import flat_mv
3
3
  from tinygrad.device import Compiled, Allocator
4
4
 
5
- class NpyAllocator(Allocator): # pylint: disable=abstract-method
6
- def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
5
+ class NpyAllocator(Allocator):
6
+ def _copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
7
7
 
8
8
  class NpyDevice(Compiled):
9
9
  def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)