tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
- tinygrad/codegen/kernel.py +572 -83
- tinygrad/codegen/linearizer.py +415 -395
- tinygrad/codegen/uops.py +415 -0
- tinygrad/device.py +183 -0
- tinygrad/dtype.py +113 -0
- tinygrad/engine/__init__.py +0 -0
- tinygrad/engine/graph.py +100 -0
- tinygrad/engine/jit.py +195 -0
- tinygrad/engine/realize.py +191 -0
- tinygrad/engine/schedule.py +362 -0
- tinygrad/engine/search.py +196 -0
- tinygrad/{mlops.py → function.py} +76 -55
- tinygrad/helpers.py +196 -89
- tinygrad/lazy.py +210 -371
- tinygrad/multi.py +169 -0
- tinygrad/nn/__init__.py +202 -22
- tinygrad/nn/datasets.py +7 -0
- tinygrad/nn/optim.py +112 -32
- tinygrad/nn/state.py +136 -39
- tinygrad/ops.py +119 -202
- tinygrad/renderer/__init__.py +61 -0
- tinygrad/renderer/assembly.py +276 -0
- tinygrad/renderer/cstyle.py +353 -166
- tinygrad/renderer/llvmir.py +150 -138
- tinygrad/runtime/autogen/amd_gpu.py +1900 -0
- tinygrad/runtime/autogen/comgr.py +865 -0
- tinygrad/runtime/autogen/cuda.py +5923 -0
- tinygrad/runtime/autogen/hip.py +5909 -0
- tinygrad/runtime/autogen/hsa.py +5761 -0
- tinygrad/runtime/autogen/kfd.py +812 -0
- tinygrad/runtime/autogen/nv_gpu.py +33328 -0
- tinygrad/runtime/autogen/opencl.py +1795 -0
- tinygrad/runtime/driver/hip_comgr.py +47 -0
- tinygrad/runtime/driver/hsa.py +143 -0
- tinygrad/runtime/graph/clang.py +38 -0
- tinygrad/runtime/graph/cuda.py +81 -0
- tinygrad/runtime/graph/hcq.py +143 -0
- tinygrad/runtime/graph/hsa.py +171 -0
- tinygrad/runtime/graph/metal.py +75 -0
- tinygrad/runtime/ops_amd.py +564 -0
- tinygrad/runtime/ops_clang.py +24 -77
- tinygrad/runtime/ops_cuda.py +175 -89
- tinygrad/runtime/ops_disk.py +56 -33
- tinygrad/runtime/ops_gpu.py +92 -95
- tinygrad/runtime/ops_hsa.py +278 -0
- tinygrad/runtime/ops_llvm.py +39 -60
- tinygrad/runtime/ops_metal.py +92 -74
- tinygrad/runtime/ops_npy.py +9 -0
- tinygrad/runtime/ops_nv.py +630 -0
- tinygrad/runtime/ops_python.py +204 -0
- tinygrad/shape/shapetracker.py +86 -254
- tinygrad/shape/symbolic.py +166 -141
- tinygrad/shape/view.py +296 -0
- tinygrad/tensor.py +2619 -448
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
- tinygrad-0.9.0.dist-info/METADATA +227 -0
- tinygrad-0.9.0.dist-info/RECORD +60 -0
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/assembly.py +0 -190
- tinygrad/codegen/optimizer.py +0 -379
- tinygrad/codegen/search.py +0 -72
- tinygrad/graph.py +0 -83
- tinygrad/jit.py +0 -57
- tinygrad/nn/image.py +0 -100
- tinygrad/renderer/assembly_arm64.py +0 -169
- tinygrad/renderer/assembly_ptx.py +0 -98
- tinygrad/renderer/wgsl.py +0 -53
- tinygrad/runtime/lib.py +0 -113
- tinygrad/runtime/ops_cpu.py +0 -51
- tinygrad/runtime/ops_hip.py +0 -82
- tinygrad/runtime/ops_shm.py +0 -29
- tinygrad/runtime/ops_torch.py +0 -30
- tinygrad/runtime/ops_webgpu.py +0 -45
- tinygrad-0.7.0.dist-info/METADATA +0 -212
- tinygrad-0.7.0.dist-info/RECORD +0 -40
- {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
import ctypes
|
2
|
+
import tinygrad.runtime.autogen.comgr as comgr
|
3
|
+
|
4
|
+
def check(status):
|
5
|
+
if status != 0:
|
6
|
+
comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
7
|
+
raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
|
8
|
+
|
9
|
+
def _get_comgr_data(data_set, data_type):
|
10
|
+
check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
|
11
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
|
12
|
+
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
|
13
|
+
check(comgr.amd_comgr_release_data(data_exec))
|
14
|
+
return bytes(dat)
|
15
|
+
|
16
|
+
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
|
17
|
+
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
18
|
+
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
|
19
|
+
check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
|
20
|
+
check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
|
21
|
+
check(comgr.amd_comgr_action_info_set_logging(action_info, True))
|
22
|
+
|
23
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
|
24
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
|
25
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
|
26
|
+
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
|
27
|
+
|
28
|
+
check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
|
29
|
+
check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
|
30
|
+
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
31
|
+
|
32
|
+
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
33
|
+
# -include hiprtc_runtime.h was removed
|
34
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
|
35
|
+
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
36
|
+
if status != 0:
|
37
|
+
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
38
|
+
raise RuntimeError("compile failed")
|
39
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
|
40
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
|
41
|
+
check(comgr.amd_comgr_action_info_set_options(action_info, b""))
|
42
|
+
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
|
43
|
+
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
|
44
|
+
check(comgr.amd_comgr_release_data(data_src))
|
45
|
+
for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
|
46
|
+
check(comgr.amd_comgr_destroy_action_info(action_info))
|
47
|
+
return ret
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import ctypes, collections
|
2
|
+
import tinygrad.runtime.autogen.hsa as hsa
|
3
|
+
from tinygrad.helpers import init_c_var
|
4
|
+
|
5
|
+
def check(status):
|
6
|
+
if status != 0:
|
7
|
+
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
8
|
+
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
|
9
|
+
|
10
|
+
# Precalulated AQL info
|
11
|
+
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
|
12
|
+
EMPTY_SIGNAL = hsa.hsa_signal_t()
|
13
|
+
|
14
|
+
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
|
15
|
+
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
16
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
17
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
18
|
+
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
|
19
|
+
|
20
|
+
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
21
|
+
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
22
|
+
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
23
|
+
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
24
|
+
|
25
|
+
class AQLQueue:
|
26
|
+
def __init__(self, device, sz=-1):
|
27
|
+
self.device = device
|
28
|
+
|
29
|
+
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
|
30
|
+
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
|
31
|
+
|
32
|
+
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
|
33
|
+
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
|
34
|
+
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
35
|
+
|
36
|
+
self.next_doorbell_index = 0
|
37
|
+
self.queue_base = self.hw_queue.contents.base_address
|
38
|
+
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
|
39
|
+
self.write_addr = self.queue_base
|
40
|
+
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
|
41
|
+
self.available_packet_slots = self.hw_queue.contents.size
|
42
|
+
|
43
|
+
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
44
|
+
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
45
|
+
|
46
|
+
def __del__(self):
|
47
|
+
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
48
|
+
|
49
|
+
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
|
50
|
+
if self.available_packet_slots == 0: self._wait_queue()
|
51
|
+
|
52
|
+
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
53
|
+
packet.workgroup_size_x = local_size[0]
|
54
|
+
packet.workgroup_size_y = local_size[1]
|
55
|
+
packet.workgroup_size_z = local_size[2]
|
56
|
+
packet.reserved0 = 0
|
57
|
+
packet.grid_size_x = global_size[0] * local_size[0]
|
58
|
+
packet.grid_size_y = global_size[1] * local_size[1]
|
59
|
+
packet.grid_size_z = global_size[2] * local_size[2]
|
60
|
+
packet.private_segment_size = prg.private_segment_size
|
61
|
+
packet.group_segment_size = prg.group_segment_size
|
62
|
+
packet.kernel_object = prg.handle
|
63
|
+
packet.kernarg_address = kernargs
|
64
|
+
packet.reserved2 = 0
|
65
|
+
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
66
|
+
packet.setup = DISPATCH_KERNEL_SETUP
|
67
|
+
packet.header = DISPATCH_KERNEL_HEADER
|
68
|
+
self._submit_packet()
|
69
|
+
|
70
|
+
def submit_barrier(self, wait_signals=None, completion_signal=None):
|
71
|
+
assert wait_signals is None or len(wait_signals) <= 5
|
72
|
+
if self.available_packet_slots == 0: self._wait_queue()
|
73
|
+
|
74
|
+
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
75
|
+
packet.reserved0 = 0
|
76
|
+
packet.reserved1 = 0
|
77
|
+
for i in range(5):
|
78
|
+
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
79
|
+
packet.reserved2 = 0
|
80
|
+
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
81
|
+
packet.header = BARRIER_HEADER
|
82
|
+
self._submit_packet()
|
83
|
+
|
84
|
+
def blit_packets(self, packet_addr, packet_cnt):
|
85
|
+
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
86
|
+
|
87
|
+
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
88
|
+
rem_packet_cnt = packet_cnt - tail_blit_packets
|
89
|
+
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
90
|
+
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
91
|
+
|
92
|
+
self._submit_packet(packet_cnt)
|
93
|
+
|
94
|
+
def wait(self):
|
95
|
+
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
|
96
|
+
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
97
|
+
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
98
|
+
|
99
|
+
def _wait_queue(self, need_packets=1):
|
100
|
+
while self.available_packet_slots < need_packets:
|
101
|
+
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
102
|
+
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
|
103
|
+
|
104
|
+
def _submit_packet(self, cnt=1):
|
105
|
+
self.available_packet_slots -= cnt
|
106
|
+
self.next_doorbell_index += cnt
|
107
|
+
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
|
108
|
+
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
|
109
|
+
|
110
|
+
self.write_addr += AQL_PACKET_SIZE * cnt
|
111
|
+
if self.write_addr > self.write_addr_end:
|
112
|
+
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
113
|
+
|
114
|
+
def scan_agents():
|
115
|
+
agents = collections.defaultdict(list)
|
116
|
+
|
117
|
+
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
118
|
+
def __scan_agents(agent, data):
|
119
|
+
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
|
120
|
+
if status == 0: agents[device_type.value].append(agent)
|
121
|
+
return hsa.HSA_STATUS_SUCCESS
|
122
|
+
|
123
|
+
hsa.hsa_iterate_agents(__scan_agents, None)
|
124
|
+
return agents
|
125
|
+
|
126
|
+
def find_memory_pool(agent, segtyp=-1, location=-1):
|
127
|
+
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
|
128
|
+
def __filter_amd_memory_pools(mem_pool, data):
|
129
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
|
130
|
+
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
|
131
|
+
|
132
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
|
133
|
+
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
|
134
|
+
|
135
|
+
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
|
136
|
+
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
|
137
|
+
|
138
|
+
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
|
139
|
+
ret[0] = mem_pool
|
140
|
+
return hsa.HSA_STATUS_INFO_BREAK
|
141
|
+
|
142
|
+
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
|
143
|
+
return region
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import List, Dict, cast
|
2
|
+
import ctypes
|
3
|
+
from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
|
4
|
+
from tinygrad.engine.jit import GraphRunner
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
7
|
+
from tinygrad.shape.symbolic import Variable
|
8
|
+
from tinygrad.runtime.ops_clang import ClangProgram
|
9
|
+
from tinygrad.renderer.cstyle import ClangRenderer
|
10
|
+
render_dtype = ClangRenderer().render_dtype
|
11
|
+
|
12
|
+
class ClangGraph(GraphRunner):
|
13
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
14
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
15
|
+
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
16
|
+
|
17
|
+
prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
|
18
|
+
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
|
19
|
+
args += [f"int {v.expr}" for v in var_vals]
|
20
|
+
code = ["void batched("+','.join(args)+") {"]
|
21
|
+
for ji in jit_cache:
|
22
|
+
args = []
|
23
|
+
for buf in ji.bufs:
|
24
|
+
assert buf is not None
|
25
|
+
if buf in input_rawbuffers:
|
26
|
+
args.append(f"arg{input_rawbuffers.index(buf)}")
|
27
|
+
else:
|
28
|
+
args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
|
29
|
+
args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
|
30
|
+
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
31
|
+
code.append("}")
|
32
|
+
if DEBUG >= 4: print("\n".join(code))
|
33
|
+
compiler = Device["CLANG"].compiler
|
34
|
+
assert compiler is not None
|
35
|
+
self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
36
|
+
|
37
|
+
def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
|
38
|
+
return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import ctypes
|
2
|
+
from typing import Any, Optional, Tuple, Dict, List, cast
|
3
|
+
import tinygrad.runtime.autogen.cuda as cuda
|
4
|
+
from tinygrad.helpers import init_c_var, GraphException
|
5
|
+
from tinygrad.device import Buffer, Device
|
6
|
+
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
7
|
+
from tinygrad.shape.symbolic import Variable
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
10
|
+
|
11
|
+
class CUDAGraph(MultiGraphRunner):
|
12
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
13
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
14
|
+
|
15
|
+
# Check all jit items are compatible.
|
16
|
+
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
17
|
+
|
18
|
+
self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
|
19
|
+
self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
|
20
|
+
|
21
|
+
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
22
|
+
|
23
|
+
for j,ji in enumerate(self.jit_cache):
|
24
|
+
if isinstance(ji.prg, CompiledRunner):
|
25
|
+
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
26
|
+
|
27
|
+
new_node = cuda.CUgraphNode()
|
28
|
+
deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
|
29
|
+
[x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
|
30
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
31
|
+
|
32
|
+
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
|
33
|
+
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
34
|
+
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
35
|
+
|
36
|
+
if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
|
37
|
+
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
38
|
+
elif isinstance(ji.prg, BufferXfer):
|
39
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
40
|
+
src_dev = cast(CUDADevice, Device[src.device])
|
41
|
+
node_from = cuda.CUgraphNode()
|
42
|
+
deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
|
43
|
+
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
44
|
+
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
45
|
+
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
46
|
+
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
47
|
+
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
|
48
|
+
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
|
49
|
+
|
50
|
+
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
51
|
+
|
52
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
53
|
+
# Update rawbuffers in the c_args struct.
|
54
|
+
for (j,i),input_idx in self.input_replace.items():
|
55
|
+
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
|
56
|
+
else:
|
57
|
+
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
|
58
|
+
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
|
59
|
+
|
60
|
+
# Update var_vals in the c_args struct.
|
61
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
62
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
63
|
+
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
64
|
+
|
65
|
+
# Update launch dims in the kern_params struct.
|
66
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
67
|
+
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
68
|
+
|
69
|
+
# Update graph nodes with the updated structs.
|
70
|
+
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
71
|
+
if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
|
72
|
+
else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
|
73
|
+
|
74
|
+
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
|
75
|
+
|
76
|
+
def __del__(self):
|
77
|
+
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
78
|
+
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
79
|
+
|
80
|
+
def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
|
81
|
+
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import ctypes, collections, array, time
|
2
|
+
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
3
|
+
from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
|
4
|
+
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
5
|
+
from tinygrad.shape.symbolic import Variable
|
6
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
7
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
8
|
+
|
9
|
+
class HCQGraph(MultiGraphRunner):
|
10
|
+
def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
11
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
12
|
+
self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
|
13
|
+
|
14
|
+
# Check all jit items are compatible.
|
15
|
+
self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
|
16
|
+
if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
|
17
|
+
|
18
|
+
# Allocate kernel args.
|
19
|
+
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
20
|
+
for ji in self.jit_cache:
|
21
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
22
|
+
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
|
23
|
+
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
|
24
|
+
|
25
|
+
# Fill initial arguments.
|
26
|
+
self.kargs_addrs: Dict[int, int] = {}
|
27
|
+
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
28
|
+
for j,ji in enumerate(self.jit_cache):
|
29
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
30
|
+
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
|
31
|
+
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
|
32
|
+
|
33
|
+
args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
|
34
|
+
[(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
|
35
|
+
self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
|
36
|
+
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
|
37
|
+
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
38
|
+
|
39
|
+
# NV needs constbuffer to be set
|
40
|
+
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
41
|
+
|
42
|
+
# Build queues.
|
43
|
+
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
|
44
|
+
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
45
|
+
self.comp_signal_val = {dev: 0 for dev in self.devices}
|
46
|
+
|
47
|
+
self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
|
48
|
+
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
49
|
+
self.copy_signal_val = {dev: 0 for dev in self.devices}
|
50
|
+
|
51
|
+
self.kickoff_signal = self.devices[0]._get_signal(value=0)
|
52
|
+
self.kickoff_value = 0
|
53
|
+
self.graph_timeline = {dev: 0 for dev in self.devices}
|
54
|
+
|
55
|
+
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
56
|
+
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
57
|
+
|
58
|
+
for j,ji in enumerate(self.jit_cache):
|
59
|
+
if isinstance(ji.prg, CompiledRunner):
|
60
|
+
exec_params = {}
|
61
|
+
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
62
|
+
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
|
63
|
+
|
64
|
+
# On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
|
65
|
+
# Chaining executions is preferred when possible, as it is faster.
|
66
|
+
if ji.prg.device.dname.startswith("NV"):
|
67
|
+
if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
|
68
|
+
exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
|
69
|
+
else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
|
70
|
+
|
71
|
+
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
72
|
+
|
73
|
+
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
|
74
|
+
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
75
|
+
signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
|
76
|
+
self.comp_signal_val[ji.prg.device] = sig_val
|
77
|
+
elif isinstance(ji.prg, BufferXfer):
|
78
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
79
|
+
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
80
|
+
|
81
|
+
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
|
82
|
+
deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
|
83
|
+
self.copy_signal_val[Device[src.device]] = sig_val
|
84
|
+
|
85
|
+
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
|
86
|
+
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
|
87
|
+
.signal(self.copy_signal[Device[src.device]], sig_val)
|
88
|
+
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
89
|
+
|
90
|
+
for dev in self.devices:
|
91
|
+
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
92
|
+
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
93
|
+
|
94
|
+
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
|
95
|
+
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
|
96
|
+
|
97
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
98
|
+
# Wait and restore signals
|
99
|
+
self.kickoff_value += 1
|
100
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
101
|
+
for dev in self.devices:
|
102
|
+
dev._set_signal(self.comp_signal[dev], 0)
|
103
|
+
dev._set_signal(self.copy_signal[dev], 0)
|
104
|
+
dev._set_signal(self.kickoff_signal, self.kickoff_value)
|
105
|
+
|
106
|
+
# Update rawbuffers
|
107
|
+
for (j,i),input_idx in self.input_replace.items():
|
108
|
+
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
|
109
|
+
|
110
|
+
# Update var_vals
|
111
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
112
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
113
|
+
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
114
|
+
|
115
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
116
|
+
queue, cmd_ptr = self.exec_ptrs[j]
|
117
|
+
queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
|
118
|
+
|
119
|
+
for dev in self.devices:
|
120
|
+
# Submit sync with world and queues.
|
121
|
+
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
122
|
+
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
123
|
+
self.comp_queues[dev].submit(dev)
|
124
|
+
|
125
|
+
if self.copy_signal_val[dev] > 0:
|
126
|
+
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
127
|
+
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
128
|
+
self.copy_queues[dev].submit(dev)
|
129
|
+
|
130
|
+
# Signal the final value
|
131
|
+
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
132
|
+
self.graph_timeline[dev] = dev.timeline_value
|
133
|
+
dev.timeline_value += 1
|
134
|
+
|
135
|
+
if wait:
|
136
|
+
st = time.perf_counter()
|
137
|
+
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
138
|
+
return time.perf_counter() - st
|
139
|
+
return None
|
140
|
+
|
141
|
+
def access_resources(self, read, write, new_dependency):
|
142
|
+
deps = self._access_resources(read, write, new_dependency)
|
143
|
+
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
|
@@ -0,0 +1,171 @@
|
|
1
|
+
import ctypes, collections, time, itertools
|
2
|
+
from typing import List, Any, Dict, cast, Optional, Tuple
|
3
|
+
from tinygrad.helpers import GraphException, init_c_var, round_up
|
4
|
+
from tinygrad.device import Buffer, BufferOptions
|
5
|
+
from tinygrad.device import Compiled, Device
|
6
|
+
from tinygrad.shape.symbolic import Variable
|
7
|
+
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
8
|
+
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
9
|
+
from tinygrad.engine.jit import MultiGraphRunner
|
10
|
+
import tinygrad.runtime.autogen.hsa as hsa
|
11
|
+
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
12
|
+
|
13
|
+
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
14
|
+
|
15
|
+
class VirtAQLQueue(AQLQueue):
|
16
|
+
def __init__(self, device, sz):
|
17
|
+
self.device = device
|
18
|
+
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
19
|
+
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
20
|
+
self.packets_count = 0
|
21
|
+
self.available_packet_slots = sz
|
22
|
+
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
23
|
+
def _submit_packet(self):
|
24
|
+
self.write_addr += AQL_PACKET_SIZE
|
25
|
+
self.packets_count += 1
|
26
|
+
self.available_packet_slots -= 1
|
27
|
+
|
28
|
+
class HSAGraph(MultiGraphRunner):
|
29
|
+
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
30
|
+
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
31
|
+
|
32
|
+
# Check all jit items are compatible.
|
33
|
+
compiled_devices = set()
|
34
|
+
for ji in self.jit_cache:
|
35
|
+
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
36
|
+
elif isinstance(ji.prg, BufferXfer):
|
37
|
+
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
38
|
+
else: raise GraphException
|
39
|
+
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
40
|
+
|
41
|
+
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
42
|
+
|
43
|
+
# Allocate kernel args.
|
44
|
+
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
45
|
+
for ji in self.jit_cache:
|
46
|
+
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
47
|
+
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
48
|
+
|
49
|
+
# Fill initial arguments.
|
50
|
+
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
51
|
+
for j,ji in enumerate(self.jit_cache):
|
52
|
+
if not isinstance(ji.prg, CompiledRunner): continue
|
53
|
+
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
54
|
+
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
55
|
+
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
56
|
+
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
57
|
+
|
58
|
+
# Build queues.
|
59
|
+
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
60
|
+
self.packets = {}
|
61
|
+
self.transfers = []
|
62
|
+
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
63
|
+
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
64
|
+
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
65
|
+
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
66
|
+
|
67
|
+
# Special packet to wait for the world.
|
68
|
+
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
69
|
+
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
70
|
+
|
71
|
+
for j,ji in enumerate(self.jit_cache):
|
72
|
+
if isinstance(ji.prg, CompiledRunner):
|
73
|
+
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
74
|
+
for i in range(0, len(wait_signals), 5):
|
75
|
+
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
76
|
+
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
77
|
+
|
78
|
+
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
79
|
+
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
80
|
+
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
81
|
+
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
82
|
+
elif isinstance(ji.prg, BufferXfer):
|
83
|
+
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
84
|
+
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
85
|
+
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
86
|
+
|
87
|
+
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
88
|
+
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
89
|
+
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
90
|
+
self.ji_to_transfer[j] = len(self.transfers) - 1
|
91
|
+
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
92
|
+
|
93
|
+
# Wait for all active signals to finish the graph
|
94
|
+
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
95
|
+
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
96
|
+
for dev in self.signals_to_devices[v.handle]:
|
97
|
+
wait_signals_to_finish[dev].append(v)
|
98
|
+
|
99
|
+
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
100
|
+
for dev in self.devices:
|
101
|
+
wait_signals = wait_signals_to_finish[dev]
|
102
|
+
for i in range(0, max(1, len(wait_signals)), 5):
|
103
|
+
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
104
|
+
|
105
|
+
# Zero signals to allow graph to start and execute.
|
106
|
+
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
107
|
+
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
108
|
+
|
109
|
+
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
110
|
+
# Wait and restore signals
|
111
|
+
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
112
|
+
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
113
|
+
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
114
|
+
|
115
|
+
# Update rawbuffers
|
116
|
+
for (j,i),input_idx in self.input_replace.items():
|
117
|
+
if j in self.ji_kargs_structs:
|
118
|
+
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
119
|
+
else:
|
120
|
+
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
121
|
+
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
122
|
+
|
123
|
+
# Update var_vals
|
124
|
+
for j in self.jc_idx_with_updatable_var_vals:
|
125
|
+
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
126
|
+
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
127
|
+
|
128
|
+
# Update launch dims
|
129
|
+
for j in self.jc_idx_with_updatable_launch_dims:
|
130
|
+
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
131
|
+
self.packets[j].workgroup_size_x = lc[0]
|
132
|
+
self.packets[j].workgroup_size_y = lc[1]
|
133
|
+
self.packets[j].workgroup_size_z = lc[2]
|
134
|
+
self.packets[j].grid_size_x = gl[0] * lc[0]
|
135
|
+
self.packets[j].grid_size_y = gl[1] * lc[1]
|
136
|
+
self.packets[j].grid_size_z = gl[2] * lc[2]
|
137
|
+
|
138
|
+
for dev in self.devices:
|
139
|
+
dev.flush_hdp()
|
140
|
+
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
141
|
+
|
142
|
+
for transfer_data in self.transfers:
|
143
|
+
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
144
|
+
|
145
|
+
et = None
|
146
|
+
if wait:
|
147
|
+
st = time.perf_counter()
|
148
|
+
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
149
|
+
et = time.perf_counter() - st
|
150
|
+
|
151
|
+
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
152
|
+
return et
|
153
|
+
|
154
|
+
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
155
|
+
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
156
|
+
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
157
|
+
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
158
|
+
return sync_signal
|
159
|
+
|
160
|
+
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
161
|
+
if isinstance(dep, hsa.hsa_signal_t): return dep
|
162
|
+
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
163
|
+
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
164
|
+
return packet.completion_signal
|
165
|
+
return None
|
166
|
+
|
167
|
+
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
|
168
|
+
rdeps = self._access_resources(read, write, new_dependency)
|
169
|
+
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
170
|
+
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
|
171
|
+
return dedup_signals(wait_signals)
|