tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+ import ctypes, functools, subprocess, io, atexit, collections, json
3
+ from typing import Tuple, TypeVar, List, Dict, Any
4
+ import tinygrad.runtime.autogen.hsa as hsa
5
+ from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
6
+ from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
7
+ from tinygrad.renderer.cstyle import HIPRenderer
8
+ from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
9
+ from tinygrad.runtime.driver.hip_comgr import compile_hip
10
+ if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401
11
+
12
+ PROFILE = getenv("PROFILE", 0)
13
+
14
+ class HSAProfiler:
15
+ def __init__(self):
16
+ self.tracked_signals = collections.defaultdict(list)
17
+ self.collected_events: List[Tuple[Any, ...]] = []
18
+ self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
19
+ self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
20
+
21
+ def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
22
+ def process(self, device):
23
+ # Process all tracked signals, should be called before any of tracked signals are reused.
24
+ for sig,name,is_copy in self.tracked_signals[device]:
25
+ if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings := self.copy_timings)))
26
+ else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
27
+ self.collected_events.append((device.device_id, 1 if is_copy else 0, name, timings.start, timings.end))
28
+ self.tracked_signals.pop(device)
29
+
30
+ def save(self, path):
31
+ mjson = []
32
+ for i in range(len(HSADevice.devices)):
33
+ mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
34
+ mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
35
+ mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
36
+
37
+ for dev_id,queue_id,name,st,et in self.collected_events:
38
+ mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
39
+ mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
40
+ with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
41
+ print(f"Saved HSA profile to {path}")
42
+ Profiler = HSAProfiler()
43
+
44
+ class HSACompiler(Compiler):
45
+ def __init__(self, arch:str):
46
+ self.arch = arch
47
+ super().__init__(f"compile_hip_{self.arch}")
48
+ def compile(self, src:str) -> bytes:
49
+ try: return compile_hip(src, self.arch)
50
+ except RuntimeError as e: raise CompileError(e)
51
+
52
+ class HSAProgram:
53
+ def __init__(self, device:HSADevice, name:str, lib:bytes):
54
+ self.device, self.name, self.lib = device, name, lib
55
+
56
+ if DEBUG >= 6:
57
+ asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
58
+ print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
59
+
60
+ self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
61
+ self.code_reader = init_c_var(hsa.hsa_code_object_reader_t(),
62
+ lambda x: check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(x))))
63
+ check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, self.code_reader, None, None))
64
+ check(hsa.hsa_executable_freeze(self.exec, None))
65
+
66
+ self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
67
+ self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
68
+ self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
69
+ self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
70
+ self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
71
+
72
+ def __del__(self):
73
+ self.device.synchronize()
74
+ if hasattr(self, 'code_reader'): check(hsa.hsa_code_object_reader_destroy(self.code_reader))
75
+ if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
76
+
77
+ def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
78
+ if not hasattr(self, "args_struct_t"):
79
+ self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
80
+ [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
81
+ if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
82
+ raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
83
+
84
+ kernargs = None
85
+ if self.kernargs_segment_size > 0:
86
+ kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
87
+ args_st = self.args_struct_t.from_address(kernargs)
88
+ for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
89
+ for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
90
+ self.device.flush_hdp()
91
+
92
+ signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
93
+ self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
94
+ if PROFILE: Profiler.track(signal, self.device, self.name)
95
+ if wait:
96
+ hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
97
+ check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
98
+ return (timings.end - timings.start) * self.device.clocks_to_time
99
+
100
+ T = TypeVar("T")
101
+ CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
102
+ class HSAAllocator(LRUAllocator):
103
+ def __init__(self, device:HSADevice):
104
+ self.device = device
105
+ super().__init__()
106
+
107
+ def _alloc(self, size:int, options:BufferOptions):
108
+ if options.host:
109
+ check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
110
+ check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
111
+ return mem.value
112
+ else:
113
+ c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU])
114
+ check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
115
+ check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
116
+ return buf.value
117
+
118
+ def _free(self, opaque:T, options:BufferOptions):
119
+ HSADevice.synchronize_system()
120
+ check(hsa.hsa_amd_memory_pool_free(opaque))
121
+
122
+ def copyin(self, dest:T, src: memoryview):
123
+ # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
124
+ self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
125
+ mem = self._alloc(src.nbytes, BufferOptions(host=True))
126
+ ctypes.memmove(mem, from_mv(src), src.nbytes)
127
+ check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
128
+ copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
129
+ self.device.hw_queue.submit_barrier([copy_signal])
130
+ self.device.delayed_free.append(mem)
131
+ if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
132
+
133
+ def copy_from_fd(self, dest, fd, offset, size):
134
+ self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
135
+
136
+ if not hasattr(self, 'hb'):
137
+ self.hb = [self._alloc(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
138
+ self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
139
+ self.hb_polarity = 0
140
+ self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
141
+ for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
142
+
143
+ fo = io.FileIO(fd, "a+b", closefd=False)
144
+ fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
145
+
146
+ copies_called = 0
147
+ copied_in = 0
148
+ for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
149
+ local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
150
+ copy_size = min(local_size-minor_offset, size-copied_in)
151
+ if copy_size == 0: break
152
+
153
+ hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
154
+ self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
155
+ self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
156
+
157
+ fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
158
+ check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
159
+ copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
160
+ self.sdma[self.hb_polarity], True))
161
+ copied_in += copy_size
162
+ self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
163
+ minor_offset = 0 # only on the first
164
+ copies_called += 1
165
+
166
+ wait_signals = [self.hb_signals[self.hb_polarity - 1]]
167
+ if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
168
+ self.device.hw_queue.submit_barrier(wait_signals)
169
+
170
+ def copyout(self, dest:memoryview, src:T):
171
+ HSADevice.synchronize_system()
172
+ copy_signal = self.device.alloc_signal(reusable=True)
173
+ c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent)
174
+ check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
175
+ check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
176
+ hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
177
+ check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
178
+ if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
179
+
180
+ def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
181
+ src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
182
+ dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
183
+ c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
184
+ check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
185
+ copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
186
+ src_dev.hw_queue.submit_barrier([copy_signal])
187
+ dest_dev.hw_queue.submit_barrier([copy_signal])
188
+ if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
189
+
190
+ class HSADevice(Compiled):
191
+ devices: List[HSADevice] = []
192
+ agents: Dict[int, List[hsa.hsa_agent_t]] = {}
193
+ cpu_agent: hsa.hsa_agent_t
194
+ cpu_mempool: hsa.hsa_amd_memory_pool_t
195
+ def __init__(self, device:str=""):
196
+ if not HSADevice.agents:
197
+ check(hsa.hsa_init())
198
+ atexit.register(hsa_terminate)
199
+ HSADevice.agents = scan_agents()
200
+ HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
201
+ HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
202
+ if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
203
+
204
+ self.device_id = int(device.split(":")[1]) if ":" in device else 0
205
+ self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
206
+ self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
207
+ self.hw_queue = AQLQueue(self)
208
+ HSADevice.devices.append(self)
209
+
210
+ check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
211
+ self.arch = ctypes.string_at(agent_name_buf).decode()
212
+
213
+ check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
214
+ self.clocks_to_time: float = 1 / gpu_freq.value
215
+
216
+ check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AMD_AGENT_INFO_HDP_FLUSH, ctypes.byref(hdp_flush := hsa.hsa_amd_hdp_flush_t())))
217
+ self.hdp_flush = hdp_flush
218
+
219
+ self.delayed_free: List[int] = []
220
+ self.reusable_signals: List[hsa.hsa_signal_t] = []
221
+
222
+ from tinygrad.runtime.graph.hsa import HSAGraph
223
+ super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
224
+
225
+ # Finish init: preallocate some signals + space for kernargs
226
+ self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]
227
+ self._new_kernargs_region(16 << 20) # initial region size is 16mb
228
+
229
+ def synchronize(self):
230
+ self.hw_queue.wait()
231
+
232
+ for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
233
+ self.signal_pool.extend(self.reusable_signals)
234
+ self.reusable_signals.clear()
235
+
236
+ for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
237
+ self.delayed_free.clear()
238
+
239
+ self.kernarg_next_addr = self.kernarg_start_addr
240
+ Profiler.process(self)
241
+
242
+ @staticmethod
243
+ def synchronize_system():
244
+ for d in HSADevice.devices: d.synchronize()
245
+
246
+ def alloc_signal(self, reusable=False):
247
+ if len(self.signal_pool): signal = self.signal_pool.pop()
248
+ else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
249
+
250
+ # reusable means a signal could be reused after synchronize for the device it's allocated from is called.
251
+ if reusable: self.reusable_signals.append(signal)
252
+ return signal
253
+
254
+ def alloc_kernargs(self, sz):
255
+ if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2))
256
+ result = self.kernarg_next_addr
257
+ self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16)
258
+ return result
259
+
260
+ def _new_kernargs_region(self, sz:int):
261
+ if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
262
+ self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferOptions())
263
+ self.kernarg_next_addr = self.kernarg_start_addr
264
+ self.kernarg_pool_sz: int = sz
265
+
266
+ def flush_hdp(self): self.hdp_flush.HDP_MEM_FLUSH_CNTL[0] = 1
267
+
268
+ def hsa_terminate():
269
+ # Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
270
+ for dev in HSADevice.devices:
271
+ Profiler.process(dev)
272
+ del dev.hw_queue
273
+
274
+ # hsa_shut_down cleans up all hsa-related resources.
275
+ hsa.hsa_shut_down()
276
+ HSADevice.synchronize = lambda: None #type:ignore
277
+ HSAProgram.__del__ = lambda _: None #type:ignore
278
+ if Profiler.collected_events: Profiler.save("/tmp/profile.json")
@@ -1,66 +1,46 @@
1
- import ctypes
2
- from typing import ClassVar, Tuple
3
- from tinygrad.device import Compiled, MallocAllocator
4
- from tinygrad.helpers import getenv, DEBUG, cpu_time_execution
5
- from ctypes import CFUNCTYPE
6
- from tinygrad.codegen.kernel import LinearizerOptions
7
- from tinygrad.renderer.llvmir import uops_to_llvm_ir
8
-
1
+ from __future__ import annotations
2
+ import ctypes, functools
3
+ from typing import Tuple
4
+ from tinygrad.device import Compiled, Compiler, MallocAllocator
5
+ from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump
6
+ from tinygrad.renderer.llvmir import LLVMRenderer
9
7
  import llvmlite.binding as llvm
10
8
 
11
- LLVMOPT = bool(getenv("LLVMOPT"))
9
+ class LLVMCompiler(Compiler):
10
+ def __init__(self, device:LLVMDevice):
11
+ self.device = device
12
+ super().__init__("compile_llvm")
13
+ def compile(self, src:str) -> bytes:
14
+ mod = llvm.parse_assembly(src)
15
+ mod.verify()
16
+ self.device.optimizer.run(mod)
17
+ if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod))
18
+ return self.device.target_machine.emit_object(mod)
12
19
 
13
- class LLVM:
14
- target_machine: ClassVar[llvm.targets.TargetMachine] = None
15
- engine: ClassVar[llvm.executionengine.ExecutionEngine] = None
16
- optimizer: ClassVar[llvm.passmanagers.ModulePassManager] = None
20
+ class LLVMProgram:
21
+ def __init__(self, device:LLVMDevice, name:str, lib:bytes):
22
+ if DEBUG >= 6: cpu_objdump(lib)
23
+ self.name, self.lib = name, lib
24
+ device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
25
+ self.fxn = device.engine.get_function_address(name)
17
26
 
18
- def __init__(self):
19
- if LLVM.engine is not None: return
27
+ def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
28
+ if not hasattr(self, 'cfunc'):
29
+ self.cfunc = ctypes.CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
30
+ return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
31
+
32
+ class LLVMDevice(Compiled):
33
+ def __init__(self, device:str):
20
34
  llvm.initialize()
21
35
  llvm.initialize_native_target()
22
36
  llvm.initialize_native_asmprinter()
23
37
  llvm.initialize_native_asmparser()
24
- target = llvm.Target.from_triple(llvm.get_process_triple())
25
- LLVM.optimizer = llvm.create_module_pass_manager()
26
- LLVM.target_machine = target.create_target_machine(opt=2) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
27
- LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
28
-
29
- # TODO: this makes compile times so much faster
30
- if LLVMOPT:
31
- llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
32
- if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize')
33
- #llvm.set_option(str(), '--debug')
34
-
35
- # does this do anything?
36
- builder = llvm.create_pass_manager_builder()
37
- builder.opt_level = 3
38
- builder.size_level = 0
39
- builder.loop_vectorize = True
40
- builder.slp_vectorize = True
41
- builder.populate(LLVM.optimizer)
42
-
43
- LLVM.target_machine.set_asm_verbosity(True)
38
+ self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
39
+ # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
40
+ self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2)
41
+ self.target_machine.add_analysis_passes(self.optimizer)
42
+ self.target_machine.set_asm_verbosity(True)
44
43
  backing_mod = llvm.parse_assembly(str())
45
44
  backing_mod.triple = llvm.get_process_triple()
46
- LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
47
-
48
- def compile_llvm(prg) -> bytes:
49
- mod = llvm.parse_assembly(prg)
50
- mod.verify()
51
- LLVM().optimizer.run(mod)
52
- if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(mod))
53
- return LLVM.target_machine.emit_object(mod)
54
-
55
- class LLVMProgram:
56
- def __init__(self, name:str, lib:bytes):
57
- self.name, self.lib = name, lib
58
- LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
59
- self.fxn = LLVM.engine.get_function_address(name)
60
-
61
- def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
62
- self.cfunc = CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
63
- return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
64
-
65
- LLVMDevice = Compiled(MallocAllocator, LinearizerOptions(supports_float4=False, has_local=False, has_shared=False),
66
- uops_to_llvm_ir, compile_llvm, LLVMProgram)
45
+ self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
46
+ super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
@@ -1,21 +1,31 @@
1
1
  from __future__ import annotations
2
2
  import os, subprocess, pathlib, ctypes, tempfile, functools
3
3
  import Metal, libdispatch
4
- from typing import List, Any, Tuple, Optional
5
- from tinygrad.codegen.kernel import LinearizerOptions
4
+ from typing import List, Set, Any, Tuple, Optional
6
5
  from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
7
- from tinygrad.device import Compiled, LRUAllocator
6
+ from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
8
7
  from tinygrad.renderer.cstyle import MetalRenderer
9
8
 
10
- def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
11
- assert MetalDevice.compiler_device, "metal device creation is required for metal compile"
12
- if use_xcode:
13
- # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
14
- air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
15
- return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
16
- options = Metal.MTLCompileOptions.new()
17
- library = unwrap2(MetalDevice.compiler_device.newLibraryWithSource_options_error_(prg, options, None))
18
- return library.libraryDataContents().bytes().tobytes()
9
+ def wait_check(cbuf: Any):
10
+ cbuf.waitUntilCompleted()
11
+ if (error := cbuf.error()) is not None:
12
+ raise RuntimeError(error)
13
+
14
+ class MetalCompiler(Compiler):
15
+ def __init__(self, device:Optional[MetalDevice]):
16
+ self.device = device
17
+ super().__init__("compile_metal")
18
+ def compile(self, src:str) -> bytes:
19
+ if self.device is None:
20
+ # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
21
+ air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
22
+ return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
23
+ else:
24
+ options = Metal.MTLCompileOptions.new()
25
+ options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
26
+ try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
27
+ except AssertionError as e: raise CompileError(e)
28
+ return library.libraryDataContents().bytes().tobytes()
19
29
 
20
30
  class MetalProgram:
21
31
  def __init__(self, device:MetalDevice, name:str, lib:bytes):
@@ -24,14 +34,15 @@ class MetalProgram:
24
34
  with tempfile.NamedTemporaryFile(delete=True) as shader:
25
35
  shader.write(lib)
26
36
  shader.flush()
27
- os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
37
+ os.system(f"cd {pathlib.Path(__file__).parents[2]}/extra/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
38
+ assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
28
39
  data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
29
40
  self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
30
41
  self.fxn = self.library.newFunctionWithName_(name)
31
42
  self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
32
43
 
33
- def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
34
- assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501
44
+ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
45
+ if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
35
46
  command_buffer = self.device.mtl_queue.commandBuffer()
36
47
  encoder = command_buffer.computeCommandEncoder()
37
48
  encoder.setComputePipelineState_(self.pipeline_state)
@@ -41,19 +52,26 @@ class MetalProgram:
41
52
  encoder.endEncoding()
42
53
  command_buffer.commit()
43
54
  if wait:
44
- command_buffer.waitUntilCompleted()
55
+ wait_check(command_buffer)
45
56
  return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
46
57
  self.device.mtl_buffers_in_flight.append(command_buffer)
47
58
 
48
59
  class MetalAllocator(LRUAllocator):
49
60
  def __init__(self, device:MetalDevice):
50
61
  self.device:MetalDevice = device
62
+ self.track_cross_device: Set[MetalDevice] = set()
51
63
  super().__init__()
52
- def _alloc(self, size:int) -> Any:
64
+ def free_cache(self):
65
+ self.device.synchronize()
66
+ for x in self.track_cross_device: x.synchronize()
67
+ self.track_cross_device.clear()
68
+ return super().free_cache()
69
+ def _alloc(self, size:int, options) -> Any:
53
70
  ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
54
71
  if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
55
72
  return ret
56
- def transfer(self, dest:Any, src:Any, sz:int):
73
+ def transfer(self, dest:Any, src:Any, sz:int, src_dev: MetalDevice, **kwargs):
74
+ src_dev.synchronize()
57
75
  command_buffer = self.device.mtl_queue.commandBuffer()
58
76
  encoder = command_buffer.blitCommandEncoder()
59
77
  encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)
@@ -64,7 +82,7 @@ class MetalAllocator(LRUAllocator):
64
82
  ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, len(src), Metal.MTLResourceStorageModeShared, None)
65
83
  if ret: self.device.mv_in_metal.append(src)
66
84
  return ret
67
- def _free(self, opaque:Any): opaque.release()
85
+ def _free(self, opaque:Any, options): opaque.release()
68
86
  def as_buffer(self, src:Any) -> memoryview:
69
87
  self.device.synchronize()
70
88
  return src.contents().as_buffer(src.length())
@@ -72,17 +90,17 @@ class MetalAllocator(LRUAllocator):
72
90
  def copyout(self, dest:memoryview, src:Any): dest[:] = self.as_buffer(src)
73
91
 
74
92
  class MetalDevice(Compiled):
75
- compiler_device = None
76
93
  def __init__(self, device:str):
77
94
  self.device = Metal.MTLCreateSystemDefaultDevice()
78
- if MetalDevice.compiler_device is None: MetalDevice.compiler_device = self.device
79
95
  self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
80
96
  self.mtl_buffers_in_flight: List[Any] = []
81
97
  self.mv_in_metal: List[memoryview] = []
98
+ self.track_cross_buffer: List[Any] = []
82
99
  from tinygrad.runtime.graph.metal import MetalGraph
83
- super().__init__(MetalAllocator(self), LinearizerOptions(device="METAL"), MetalRenderer,
84
- compile_metal, functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
100
+ super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self),
101
+ functools.partial(MetalProgram, self), MetalGraph)
85
102
  def synchronize(self):
86
- for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
103
+ for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
87
104
  self.mv_in_metal.clear()
88
105
  self.mtl_buffers_in_flight.clear()
106
+ self.track_cross_buffer.clear()
@@ -0,0 +1,9 @@
1
+ import numpy as np
2
+ from tinygrad.helpers import flat_mv
3
+ from tinygrad.device import Compiled, Allocator
4
+
5
+ class NpyAllocator(Allocator):
6
+ def copyout(self, dest:memoryview, src:np.ndarray): dest[:] = flat_mv(np.require(src, requirements='C').data)
7
+
8
+ class NpyDevice(Compiled):
9
+ def __init__(self, device:str): super().__init__(device, NpyAllocator(), None, None, None)