tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +11 -6
- tinygrad/codegen/kernel.py +308 -175
- tinygrad/codegen/linearize.py +95 -0
- tinygrad/codegen/lowerer.py +143 -0
- tinygrad/codegen/transcendental.py +257 -0
- tinygrad/codegen/uopgraph.py +506 -0
- tinygrad/device.py +72 -171
- tinygrad/dtype.py +122 -47
- tinygrad/engine/jit.py +184 -87
- tinygrad/{lazy.py → engine/lazy.py} +74 -66
- tinygrad/engine/memory.py +51 -0
- tinygrad/engine/realize.py +86 -61
- tinygrad/engine/schedule.py +366 -317
- tinygrad/engine/search.py +58 -47
- tinygrad/function.py +59 -58
- tinygrad/helpers.py +120 -102
- tinygrad/multi.py +82 -78
- tinygrad/nn/__init__.py +116 -67
- tinygrad/nn/datasets.py +12 -5
- tinygrad/nn/optim.py +1 -1
- tinygrad/nn/state.py +91 -6
- tinygrad/ops.py +1126 -143
- tinygrad/renderer/__init__.py +47 -23
- tinygrad/renderer/cstyle.py +338 -265
- tinygrad/renderer/llvmir.py +125 -143
- tinygrad/renderer/ptx.py +225 -0
- tinygrad/runtime/autogen/adreno.py +17904 -0
- tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
- tinygrad/runtime/autogen/cuda.py +6 -162
- tinygrad/runtime/autogen/io_uring.py +97 -63
- tinygrad/runtime/autogen/kfd.py +60 -47
- tinygrad/runtime/autogen/kgsl.py +1386 -0
- tinygrad/runtime/autogen/libc.py +5462 -0
- tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
- tinygrad/runtime/autogen/nvrtc.py +579 -0
- tinygrad/runtime/autogen/opencl.py +11 -11
- tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
- tinygrad/runtime/graph/clang.py +3 -3
- tinygrad/runtime/graph/cuda.py +11 -15
- tinygrad/runtime/graph/hcq.py +120 -107
- tinygrad/runtime/graph/metal.py +71 -43
- tinygrad/runtime/ops_amd.py +244 -323
- tinygrad/runtime/ops_clang.py +12 -5
- tinygrad/runtime/ops_cloud.py +220 -0
- tinygrad/runtime/ops_cuda.py +42 -99
- tinygrad/runtime/ops_disk.py +25 -26
- tinygrad/runtime/ops_dsp.py +181 -0
- tinygrad/runtime/ops_gpu.py +29 -16
- tinygrad/runtime/ops_hip.py +68 -0
- tinygrad/runtime/ops_llvm.py +15 -10
- tinygrad/runtime/ops_metal.py +147 -64
- tinygrad/runtime/ops_nv.py +356 -397
- tinygrad/runtime/ops_python.py +78 -79
- tinygrad/runtime/ops_qcom.py +405 -0
- tinygrad/runtime/support/__init__.py +0 -0
- tinygrad/runtime/support/compiler_cuda.py +77 -0
- tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
- tinygrad/runtime/support/elf.py +38 -0
- tinygrad/runtime/support/hcq.py +539 -0
- tinygrad/shape/shapetracker.py +40 -50
- tinygrad/shape/view.py +102 -63
- tinygrad/tensor.py +1109 -365
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
- tinygrad-0.10.0.dist-info/RECORD +77 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/linearizer.py +0 -528
- tinygrad/codegen/uops.py +0 -451
- tinygrad/engine/graph.py +0 -100
- tinygrad/renderer/assembly.py +0 -269
- tinygrad/shape/symbolic.py +0 -327
- tinygrad-0.9.1.dist-info/RECORD +0 -63
- /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
- {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/ops_metal.py
CHANGED
@@ -1,105 +1,188 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
import os, subprocess, pathlib, ctypes, tempfile, functools
|
3
|
-
import
|
4
|
-
from
|
5
|
-
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
3
|
+
from typing import List, Any, Tuple, Optional, cast
|
4
|
+
from tinygrad.helpers import prod, getenv, T
|
6
5
|
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
7
6
|
from tinygrad.renderer.cstyle import MetalRenderer
|
8
7
|
|
8
|
+
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
9
|
+
def __hash__(self): return hash(self.value)
|
10
|
+
def __eq__(self, other): return self.value == other.value
|
11
|
+
|
12
|
+
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
|
13
|
+
def __del__(self): msg(self, "release")
|
14
|
+
|
15
|
+
@functools.lru_cache(None)
|
16
|
+
def sel(name: str): return libobjc.sel_registerName(name.encode())
|
17
|
+
|
18
|
+
class MTLResourceOptions:
|
19
|
+
MTLResourceCPUCacheModeDefaultCache = 0
|
20
|
+
MTLResourceStorageModeShared = 0 << 4
|
21
|
+
|
22
|
+
class MTLPipelineOption:
|
23
|
+
MTLPipelineOptionNone = 0
|
24
|
+
|
25
|
+
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
26
|
+
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
27
|
+
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
28
|
+
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
29
|
+
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
30
|
+
libobjc.objc_getClass.restype = objc_id
|
31
|
+
libobjc.sel_registerName.restype = objc_id
|
32
|
+
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
33
|
+
libdispatch.dispatch_data_create.restype = objc_instance
|
34
|
+
|
35
|
+
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
36
|
+
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
|
37
|
+
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
38
|
+
sender.restype = restype
|
39
|
+
return sender(ptr, sel(selector), *args)
|
40
|
+
|
41
|
+
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
|
42
|
+
|
43
|
+
def to_struct(*t: int, _type: type = ctypes.c_ulong):
|
44
|
+
class Struct(ctypes.Structure): pass
|
45
|
+
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
|
46
|
+
return Struct(*t)
|
47
|
+
|
9
48
|
def wait_check(cbuf: Any):
|
10
|
-
cbuf
|
11
|
-
|
12
|
-
|
49
|
+
msg(cbuf, "waitUntilCompleted")
|
50
|
+
error_check(msg(cbuf, "error", restype=objc_instance))
|
51
|
+
|
52
|
+
def elapsed_time(cbuf: objc_id):
|
53
|
+
return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
|
54
|
+
|
55
|
+
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
|
56
|
+
if error.value is None: return None
|
57
|
+
raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
|
58
|
+
|
59
|
+
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
|
60
|
+
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
|
61
|
+
msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
|
62
|
+
library = msg(device.device, "newLibraryWithSource:options:error:", to_ns_str(src), options,
|
63
|
+
ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
|
64
|
+
error_check(compileError, CompileError)
|
65
|
+
return library
|
13
66
|
|
14
67
|
class MetalCompiler(Compiler):
|
15
|
-
def __init__(self, device:Optional[MetalDevice]):
|
68
|
+
def __init__(self, device:Optional[MetalDevice]=None):
|
16
69
|
self.device = device
|
17
|
-
super().__init__("compile_metal")
|
70
|
+
super().__init__("compile_metal_xcode" if self.device is None else "compile_metal")
|
18
71
|
def compile(self, src:str) -> bytes:
|
19
72
|
if self.device is None:
|
20
73
|
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
21
74
|
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
75
|
+
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
76
|
+
else:
|
77
|
+
library = metal_src_to_library(self.device, src)
|
78
|
+
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
|
79
|
+
lib = ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
|
80
|
+
assert lib[:4] == b"MTLB", "Invalid Metal library. Using conda? Corrupt XCode?"
|
81
|
+
return lib
|
82
|
+
def disassemble(self, lib:bytes):
|
83
|
+
with tempfile.NamedTemporaryFile(delete=True) as shader:
|
84
|
+
shader.write(lib)
|
85
|
+
shader.flush()
|
86
|
+
ret = os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
|
87
|
+
if ret: print("Disassembler Error: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
|
28
88
|
|
29
89
|
class MetalProgram:
|
30
90
|
def __init__(self, device:MetalDevice, name:str, lib:bytes):
|
31
91
|
self.device, self.name, self.lib = device, name, lib
|
32
|
-
if
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
92
|
+
if lib[:4] == b"MTLB":
|
93
|
+
# binary metal library
|
94
|
+
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
95
|
+
error_library_creation = objc_instance()
|
96
|
+
self.library = msg(self.device.device, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
|
97
|
+
error_check(error_library_creation)
|
98
|
+
else:
|
99
|
+
# metal source. rely on OS caching
|
100
|
+
try: self.library = metal_src_to_library(self.device, lib.decode())
|
101
|
+
except CompileError as e: raise RuntimeError from e
|
102
|
+
self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
|
103
|
+
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
|
104
|
+
msg(descriptor, "setComputeFunction:", self.fxn)
|
105
|
+
msg(descriptor, "setSupportIndirectCommandBuffers:", True)
|
106
|
+
self.pipeline_state = msg(self.device.device, "newComputePipelineStateWithDescriptor:options:reflection:error:",
|
107
|
+
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
|
108
|
+
error_check(error_pipeline_creation)
|
42
109
|
|
43
110
|
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
encoder
|
51
|
-
encoder.
|
52
|
-
|
111
|
+
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
|
112
|
+
if prod(local_size) > cast(int, max_total_threads):
|
113
|
+
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
|
114
|
+
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
|
115
|
+
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
|
116
|
+
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
|
117
|
+
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
118
|
+
msg(encoder, "setComputePipelineState:", self.pipeline_state)
|
119
|
+
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
|
120
|
+
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
121
|
+
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
122
|
+
msg(encoder, "endEncoding")
|
123
|
+
msg(command_buffer, "commit")
|
53
124
|
if wait:
|
54
125
|
wait_check(command_buffer)
|
55
|
-
return
|
126
|
+
return elapsed_time(command_buffer)
|
56
127
|
self.device.mtl_buffers_in_flight.append(command_buffer)
|
57
128
|
|
129
|
+
class MetalBuffer:
|
130
|
+
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
|
131
|
+
|
58
132
|
class MetalAllocator(LRUAllocator):
|
59
133
|
def __init__(self, device:MetalDevice):
|
60
134
|
self.device:MetalDevice = device
|
61
|
-
self.track_cross_device: Set[MetalDevice] = set()
|
62
135
|
super().__init__()
|
63
|
-
def
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
return
|
68
|
-
def
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
encoder
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
136
|
+
def _alloc(self, size:int, options) -> MetalBuffer:
|
137
|
+
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
138
|
+
ret = msg(self.device.device, "newBufferWithLength:options:", size, MTLResourceOptions.MTLResourceStorageModeShared, restype=objc_id)
|
139
|
+
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
140
|
+
return MetalBuffer(ret, size)
|
141
|
+
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
|
142
|
+
def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
143
|
+
dest_dev.synchronize()
|
144
|
+
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
145
|
+
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
|
146
|
+
msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, ctypes.c_ulong(src.offset),
|
147
|
+
dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
|
148
|
+
msg(encoder, "endEncoding")
|
149
|
+
if src_dev != dest_dev:
|
150
|
+
msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
|
151
|
+
dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
152
|
+
msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
|
153
|
+
msg(dest_command_buffer, "commit")
|
154
|
+
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
|
155
|
+
src_dev.timeline_value += 1
|
156
|
+
msg(src_command_buffer, "commit")
|
157
|
+
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
80
158
|
def from_buffer(self, src:memoryview) -> Optional[Any]:
|
81
|
-
|
159
|
+
ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
|
160
|
+
ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
|
82
161
|
if ret: self.device.mv_in_metal.append(src)
|
83
|
-
return ret
|
84
|
-
def
|
85
|
-
def as_buffer(self, src:Any) -> memoryview:
|
162
|
+
return MetalBuffer(ret, src.nbytes)
|
163
|
+
def as_buffer(self, src:MetalBuffer) -> memoryview:
|
86
164
|
self.device.synchronize()
|
87
|
-
|
88
|
-
|
89
|
-
|
165
|
+
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
|
166
|
+
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
167
|
+
return memoryview(array).cast("B")[src.offset:]
|
168
|
+
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
|
169
|
+
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
|
170
|
+
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
90
171
|
|
91
172
|
class MetalDevice(Compiled):
|
92
173
|
def __init__(self, device:str):
|
93
|
-
self.device =
|
94
|
-
self.mtl_queue = self.device
|
174
|
+
self.device = libmetal.MTLCreateSystemDefaultDevice()
|
175
|
+
self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
|
176
|
+
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
95
177
|
self.mtl_buffers_in_flight: List[Any] = []
|
96
178
|
self.mv_in_metal: List[memoryview] = []
|
97
|
-
self.
|
179
|
+
self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
|
180
|
+
self.timeline_value = 0
|
181
|
+
|
98
182
|
from tinygrad.runtime.graph.metal import MetalGraph
|
99
|
-
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(
|
183
|
+
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_XCODE") else Compiler(),
|
100
184
|
functools.partial(MetalProgram, self), MetalGraph)
|
101
185
|
def synchronize(self):
|
102
186
|
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
103
187
|
self.mv_in_metal.clear()
|
104
188
|
self.mtl_buffers_in_flight.clear()
|
105
|
-
self.track_cross_buffer.clear()
|