tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -6
- tinygrad/codegen/kernel.py +230 -190
- tinygrad/codegen/linearizer.py +278 -384
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +132 -275
- tinygrad/dtype.py +53 -37
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +28 -14
- tinygrad/helpers.py +72 -43
- tinygrad/lazy.py +141 -240
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +179 -8
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +106 -28
- tinygrad/nn/state.py +86 -17
- tinygrad/ops.py +70 -44
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +299 -206
- tinygrad/renderer/llvmir.py +118 -123
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +59 -54
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +37 -41
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +16 -14
- tinygrad/runtime/ops_cuda.py +130 -38
- tinygrad/runtime/ops_disk.py +45 -42
- tinygrad/runtime/ops_gpu.py +52 -50
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +36 -56
- tinygrad/runtime/ops_metal.py +42 -24
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +41 -105
- tinygrad/shape/symbolic.py +98 -95
- tinygrad/shape/view.py +137 -35
- tinygrad/tensor.py +2367 -442
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/features/image.py +0 -93
- tinygrad/features/multi.py +0 -103
- tinygrad/features/search.py +0 -160
- tinygrad/graph.py +0 -106
- tinygrad/jit.py +0 -152
- tinygrad/realize.py +0 -50
- tinygrad/runtime/graph/hip.py +0 -24
- tinygrad/runtime/ops_cpu.py +0 -45
- tinygrad/runtime/ops_hip.py +0 -97
- tinygrad/runtime/ops_torch.py +0 -49
- tinygrad-0.8.0.dist-info/RECORD +0 -41
- {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
import ctypes
|
2
|
+
import tinygrad.runtime.autogen.comgr as comgr
|
3
|
+
|
4
|
+
def check(status):
|
5
|
+
if status != 0:
|
6
|
+
comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
7
|
+
raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
|
8
|
+
|
9
|
+
def _get_comgr_data(data_set, data_type):
|
10
|
+
check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
|
11
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
|
12
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
|
13
|
+
check(comgr.amd_comgr_release_data(data_exec))
|
14
|
+
return bytes(dat)
|
15
|
+
|
16
|
+
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
|
17
|
+
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
18
|
+
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
|
19
|
+
check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
|
20
|
+
check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
|
21
|
+
check(comgr.amd_comgr_action_info_set_logging(action_info, True))
|
22
|
+
|
23
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
|
24
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
|
25
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
|
26
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
|
27
|
+
|
28
|
+
check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
|
29
|
+
check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
|
30
|
+
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
31
|
+
|
32
|
+
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
33
|
+
# -include hiprtc_runtime.h was removed
|
34
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
|
35
|
+
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
36
|
+
if status != 0:
|
37
|
+
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
38
|
+
raise RuntimeError("compile failed")
|
39
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
|
40
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
|
41
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b""))
|
42
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
|
43
|
+
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
|
44
|
+
check(comgr.amd_comgr_release_data(data_src))
|
45
|
+
for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
|
46
|
+
check(comgr.amd_comgr_destroy_action_info(action_info))
|
47
|
+
return ret
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import ctypes, collections
|
2
|
+
import tinygrad.runtime.autogen.hsa as hsa
|
3
|
+
from tinygrad.helpers import init_c_var
|
4
|
+
|
5
|
+
def check(status):
|
6
|
+
if status != 0:
|
7
|
+
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
8
|
+
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
|
9
|
+
|
10
|
+
# Precalulated AQL info
|
11
|
+
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
|
12
|
+
EMPTY_SIGNAL = hsa.hsa_signal_t()
|
13
|
+
|
14
|
+
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
|
15
|
+
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
16
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
17
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
18
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
|
19
|
+
|
20
|
+
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
21
|
+
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
22
|
+
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
23
|
+
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
24
|
+
|
25
|
+
class AQLQueue:
|
26
|
+
def __init__(self, device, sz=-1):
|
27
|
+
self.device = device
|
28
|
+
|
29
|
+
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
|
30
|
+
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
|
31
|
+
|
32
|
+
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
|
33
|
+
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
|
34
|
+
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
35
|
+
|
36
|
+
self.next_doorbell_index = 0
|
37
|
+
self.queue_base = self.hw_queue.contents.base_address
|
38
|
+
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
|
39
|
+
self.write_addr = self.queue_base
|
40
|
+
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
|
41
|
+
self.available_packet_slots = self.hw_queue.contents.size
|
42
|
+
|
43
|
+
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
44
|
+
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
45
|
+
|
46
|
+
def __del__(self):
|
47
|
+
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
48
|
+
|
49
|
+
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
|
50
|
+
if self.available_packet_slots == 0: self._wait_queue()
|
51
|
+
|
52
|
+
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
53
|
+
packet.workgroup_size_x = local_size[0]
|
54
|
+
packet.workgroup_size_y = local_size[1]
|
55
|
+
packet.workgroup_size_z = local_size[2]
|
56
|
+
packet.reserved0 = 0
|
57
|
+
packet.grid_size_x = global_size[0] * local_size[0]
|
58
|
+
packet.grid_size_y = global_size[1] * local_size[1]
|
59
|
+
packet.grid_size_z = global_size[2] * local_size[2]
|
60
|
+
packet.private_segment_size = prg.private_segment_size
|
61
|
+
packet.group_segment_size = prg.group_segment_size
|
62
|
+
packet.kernel_object = prg.handle
|
63
|
+
packet.kernarg_address = kernargs
|
64
|
+
packet.reserved2 = 0
|
65
|
+
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
66
|
+
packet.setup = DISPATCH_KERNEL_SETUP
|
67
|
+
packet.header = DISPATCH_KERNEL_HEADER
|
68
|
+
self._submit_packet()
|
69
|
+
|
70
|
+
def submit_barrier(self, wait_signals=None, completion_signal=None):
|
71
|
+
assert wait_signals is None or len(wait_signals) <= 5
|
72
|
+
if self.available_packet_slots == 0: self._wait_queue()
|
73
|
+
|
74
|
+
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
75
|
+
packet.reserved0 = 0
|
76
|
+
packet.reserved1 = 0
|
77
|
+
for i in range(5):
|
78
|
+
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
79
|
+
packet.reserved2 = 0
|
80
|
+
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
81
|
+
packet.header = BARRIER_HEADER
|
82
|
+
self._submit_packet()
|
83
|
+
|
84
|
+
def blit_packets(self, packet_addr, packet_cnt):
|
85
|
+
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
86
|
+
|
87
|
+
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
88
|
+
rem_packet_cnt = packet_cnt - tail_blit_packets
|
89
|
+
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
90
|
+
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
91
|
+
|
92
|
+
self._submit_packet(packet_cnt)
|
93
|
+
|
94
|
+
def wait(self):
|
95
|
+
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
|
96
|
+
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
97
|
+
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
98
|
+
|
99
|
+
def _wait_queue(self, need_packets=1):
|
100
|
+
while self.available_packet_slots < need_packets:
|
101
|
+
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
102
|
+
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
|
103
|
+
|
104
|
+
def _submit_packet(self, cnt=1):
|
105
|
+
self.available_packet_slots -= cnt
|
106
|
+
self.next_doorbell_index += cnt
|
107
|
+
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
|
108
|
+
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
|
109
|
+
|
110
|
+
self.write_addr += AQL_PACKET_SIZE * cnt
|
111
|
+
if self.write_addr > self.write_addr_end:
|
112
|
+
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
113
|
+
|
114
|
+
def scan_agents():
|
115
|
+
agents = collections.defaultdict(list)
|
116
|
+
|
117
|
+
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
118
|
+
def __scan_agents(agent, data):
|
119
|
+
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
|
120
|
+
if status == 0: agents[device_type.value].append(agent)
|
121
|
+
return hsa.HSA_STATUS_SUCCESS
|
122
|
+
|
123
|
+
hsa.hsa_iterate_agents(__scan_agents, None)
|
124
|
+
return agents
|
125
|
+
|
126
|
+
def find_memory_pool(agent, segtyp=-1, location=-1):
|
127
|
+
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
|
128
|
+
def __filter_amd_memory_pools(mem_pool, data):
|
129
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
|
130
|
+
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
|
131
|
+
|
132
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
|
133
|
+
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
|
134
|
+
|
135
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
|
136
|
+
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
|
137
|
+
|
138
|
+
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
|
139
|
+
ret[0] = mem_pool
|
140
|
+
return hsa.HSA_STATUS_INFO_BREAK
|
141
|
+
|
142
|
+
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
|
143
|
+
return region
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import List, Dict, cast
|
2
|
+
import ctypes
|
3
|
+
from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
|
4
|
+
from tinygrad.engine.jit import GraphRunner
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
+
from tinygrad.shape.symbolic import Variable
|
8
|
+
from tinygrad.runtime.ops_clang import ClangProgram
|
9
|
+
from tinygrad.renderer.cstyle import ClangRenderer
|
10
|
+
render_dtype = ClangRenderer().render_dtype
|
11
|
+
|
12
|
+
class ClangGraph(GraphRunner):
|
13
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
14
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
15
|
+
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
16
|
+
|
17
|
+
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
18
|
+
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
19
|
+
args += [f"int {v.expr}" for v in var_vals]
|
20
|
+
code = ["void batched("+','.join(args)+") {"]
|
21
|
+
for ji in jit_cache:
|
22
|
+
args = []
|
23
|
+
for buf in ji.bufs:
|
24
|
+
assert buf is not None
|
25
|
+
if buf in input_rawbuffers:
|
26
|
+
args.append(f"arg{input_rawbuffers.index(buf)}")
|
27
|
+
else:
|
28
|
+
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
|
29
|
+
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
30
|
+
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
31
|
+
code.append("}")
|
32
|
+
if DEBUG >= 4: print("\n".join(code))
|
33
|
+
compiler = Device["CLANG"].compiler
|
34
|
+
assert compiler is not None
|
35
|
+
self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
36
|
+
|
37
|
+
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
|
38
|
+
return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
|
tinygrad/runtime/graph/cuda.py
CHANGED
@@ -1,76 +1,81 @@
|
|
1
1
|
import ctypes
|
2
2
|
from typing import Any, Optional, Tuple, Dict, List, cast
|
3
|
-
import
|
4
|
-
from tinygrad.helpers import init_c_var,
|
5
|
-
from tinygrad.device import
|
6
|
-
from tinygrad.runtime.ops_cuda import check, cu_time_execution
|
3
|
+
import tinygrad.runtime.autogen.cuda as cuda
|
4
|
+
from tinygrad.helpers import init_c_var, GraphException
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
7
|
from tinygrad.shape.symbolic import Variable
|
8
|
-
from tinygrad.
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
9
10
|
|
10
|
-
class CUDAGraph:
|
11
|
-
def __init__(self, jit_cache: List[
|
12
|
-
|
11
|
+
class CUDAGraph(MultiGraphRunner):
|
12
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
13
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
13
14
|
|
14
|
-
|
15
|
-
|
16
|
-
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
|
17
|
-
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
18
|
-
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
19
|
-
self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
20
|
-
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params)
|
15
|
+
# Check all jit items are compatible.
|
16
|
+
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
21
17
|
|
22
|
-
self.
|
23
|
-
|
18
|
+
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
19
|
+
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
20
|
+
|
21
|
+
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
24
22
|
|
25
|
-
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
26
23
|
for j,ji in enumerate(self.jit_cache):
|
27
|
-
|
24
|
+
if isinstance(ji.prg, CompiledRunner):
|
25
|
+
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
26
|
+
|
27
|
+
new_node = cuda.CUgraphNode()
|
28
|
+
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
|
29
|
+
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
|
30
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
28
31
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
|
32
|
+
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
33
|
+
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
34
|
+
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
33
35
|
|
34
|
-
|
35
|
-
|
36
|
+
if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
|
37
|
+
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
38
|
+
elif isinstance(ji.prg, BufferXfer):
|
39
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
40
|
+
src_dev = cast(CUDADevice, Device[src.device])
|
41
|
+
node_from = cuda.CUgraphNode()
|
42
|
+
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
|
43
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
44
|
+
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
45
|
+
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
46
|
+
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
47
|
+
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
48
|
+
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
36
49
|
|
37
|
-
self.instance =
|
50
|
+
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
38
51
|
|
39
|
-
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False
|
40
|
-
# Update rawbuffers in the
|
52
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
53
|
+
# Update rawbuffers in the c_args struct.
|
41
54
|
for (j,i),input_idx in self.input_replace.items():
|
42
|
-
setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
55
|
+
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
56
|
+
else:
|
57
|
+
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
|
58
|
+
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
43
59
|
|
44
|
-
# Update var_vals in the
|
45
|
-
for j in self.
|
46
|
-
for i,v in enumerate(cast(
|
47
|
-
setattr(self.updatable_nodes[j][2], f'
|
60
|
+
# Update var_vals in the c_args struct.
|
61
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
62
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
63
|
+
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
48
64
|
|
49
|
-
# Update launch dims in the
|
50
|
-
for j in self.
|
51
|
-
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(
|
65
|
+
# Update launch dims in the kern_params struct.
|
66
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
67
|
+
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
52
68
|
|
53
69
|
# Update graph nodes with the updated structs.
|
54
|
-
for node, c_node_params,
|
55
|
-
|
70
|
+
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
71
|
+
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
|
72
|
+
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
|
56
73
|
|
57
|
-
|
58
|
-
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
|
59
|
-
return et
|
74
|
+
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
60
75
|
|
61
76
|
def __del__(self):
|
62
|
-
check(cuda.cuGraphDestroy(self.graph))
|
63
|
-
check(cuda.cuGraphExecDestroy(self.instance))
|
64
|
-
|
65
|
-
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
66
|
-
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
67
|
-
def graph_instantiate(self, graph):
|
68
|
-
return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
|
69
|
-
def graph_add_kernel_node(self, graph, c_deps, c_node_params):
|
70
|
-
return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501
|
71
|
-
def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
|
72
|
-
def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args))
|
73
|
-
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config):
|
74
|
-
return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
|
77
|
+
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
78
|
+
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
79
|
+
|
75
80
|
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
76
81
|
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import ctypes, collections, array, time
|
2
|
+
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
3
|
+
from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
|
4
|
+
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
5
|
+
from tinygrad.shape.symbolic import Variable
|
6
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
8
|
+
|
9
|
+
class HCQGraph(MultiGraphRunner):
|
10
|
+
def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
11
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
12
|
+
self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
|
13
|
+
|
14
|
+
# Check all jit items are compatible.
|
15
|
+
self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
|
16
|
+
if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
|
17
|
+
|
18
|
+
# Allocate kernel args.
|
19
|
+
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
20
|
+
for ji in self.jit_cache:
|
21
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
22
|
+
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
|
23
|
+
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
|
24
|
+
|
25
|
+
# Fill initial arguments.
|
26
|
+
self.kargs_addrs: Dict[int, int] = {}
|
27
|
+
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
28
|
+
for j,ji in enumerate(self.jit_cache):
|
29
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
30
|
+
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
|
31
|
+
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
|
32
|
+
|
33
|
+
args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
|
34
|
+
[(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
|
35
|
+
self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
|
36
|
+
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
|
37
|
+
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
38
|
+
|
39
|
+
# NV needs constbuffer to be set
|
40
|
+
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
41
|
+
|
42
|
+
# Build queues.
|
43
|
+
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
|
44
|
+
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
45
|
+
self.comp_signal_val = {dev: 0 for dev in self.devices}
|
46
|
+
|
47
|
+
self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
|
48
|
+
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
49
|
+
self.copy_signal_val = {dev: 0 for dev in self.devices}
|
50
|
+
|
51
|
+
self.kickoff_signal = self.devices[0]._get_signal(value=0)
|
52
|
+
self.kickoff_value = 0
|
53
|
+
self.graph_timeline = {dev: 0 for dev in self.devices}
|
54
|
+
|
55
|
+
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
56
|
+
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
57
|
+
|
58
|
+
for j,ji in enumerate(self.jit_cache):
|
59
|
+
if isinstance(ji.prg, CompiledRunner):
|
60
|
+
exec_params = {}
|
61
|
+
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
62
|
+
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
|
63
|
+
|
64
|
+
# On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
|
65
|
+
# Chaining executions is preferred when possible, as it is faster.
|
66
|
+
if ji.prg.device.dname.startswith("NV"):
|
67
|
+
if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
|
68
|
+
exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
|
69
|
+
else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
|
70
|
+
|
71
|
+
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
72
|
+
|
73
|
+
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
|
74
|
+
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
75
|
+
signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
|
76
|
+
self.comp_signal_val[ji.prg.device] = sig_val
|
77
|
+
elif isinstance(ji.prg, BufferXfer):
|
78
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
79
|
+
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
80
|
+
|
81
|
+
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
|
82
|
+
deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
|
83
|
+
self.copy_signal_val[Device[src.device]] = sig_val
|
84
|
+
|
85
|
+
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
|
86
|
+
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
|
87
|
+
.signal(self.copy_signal[Device[src.device]], sig_val)
|
88
|
+
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
89
|
+
|
90
|
+
for dev in self.devices:
|
91
|
+
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
92
|
+
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
93
|
+
|
94
|
+
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
|
95
|
+
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
|
96
|
+
|
97
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
98
|
+
# Wait and restore signals
|
99
|
+
self.kickoff_value += 1
|
100
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
101
|
+
for dev in self.devices:
|
102
|
+
dev._set_signal(self.comp_signal[dev], 0)
|
103
|
+
dev._set_signal(self.copy_signal[dev], 0)
|
104
|
+
dev._set_signal(self.kickoff_signal, self.kickoff_value)
|
105
|
+
|
106
|
+
# Update rawbuffers
|
107
|
+
for (j,i),input_idx in self.input_replace.items():
|
108
|
+
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
|
109
|
+
|
110
|
+
# Update var_vals
|
111
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
112
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
113
|
+
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
114
|
+
|
115
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
116
|
+
queue, cmd_ptr = self.exec_ptrs[j]
|
117
|
+
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
118
|
+
|
119
|
+
for dev in self.devices:
|
120
|
+
# Submit sync with world and queues.
|
121
|
+
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
122
|
+
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
123
|
+
self.comp_queues[dev].submit(dev)
|
124
|
+
|
125
|
+
if self.copy_signal_val[dev] > 0:
|
126
|
+
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
127
|
+
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
128
|
+
self.copy_queues[dev].submit(dev)
|
129
|
+
|
130
|
+
# Signal the final value
|
131
|
+
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
132
|
+
self.graph_timeline[dev] = dev.timeline_value
|
133
|
+
dev.timeline_value += 1
|
134
|
+
|
135
|
+
if wait:
|
136
|
+
st = time.perf_counter()
|
137
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
138
|
+
return time.perf_counter() - st
|
139
|
+
return None
|
140
|
+
|
141
|
+
def access_resources(self, read, write, new_dependency):
|
142
|
+
deps = self._access_resources(read, write, new_dependency)
|
143
|
+
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
|