tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. tinygrad/__init__.py +6 -0
  2. tinygrad/codegen/kernel.py +572 -83
  3. tinygrad/codegen/linearizer.py +415 -395
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +183 -0
  6. tinygrad/dtype.py +113 -0
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +76 -55
  14. tinygrad/helpers.py +196 -89
  15. tinygrad/lazy.py +210 -371
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +202 -22
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +112 -32
  20. tinygrad/nn/state.py +136 -39
  21. tinygrad/ops.py +119 -202
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +353 -166
  25. tinygrad/renderer/llvmir.py +150 -138
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +81 -0
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +75 -0
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +24 -77
  43. tinygrad/runtime/ops_cuda.py +175 -89
  44. tinygrad/runtime/ops_disk.py +56 -33
  45. tinygrad/runtime/ops_gpu.py +92 -95
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +39 -60
  48. tinygrad/runtime/ops_metal.py +92 -74
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +86 -254
  53. tinygrad/shape/symbolic.py +166 -141
  54. tinygrad/shape/view.py +296 -0
  55. tinygrad/tensor.py +2619 -448
  56. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. tinygrad-0.9.0.dist-info/METADATA +227 -0
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/codegen/assembly.py +0 -190
  61. tinygrad/codegen/optimizer.py +0 -379
  62. tinygrad/codegen/search.py +0 -72
  63. tinygrad/graph.py +0 -83
  64. tinygrad/jit.py +0 -57
  65. tinygrad/nn/image.py +0 -100
  66. tinygrad/renderer/assembly_arm64.py +0 -169
  67. tinygrad/renderer/assembly_ptx.py +0 -98
  68. tinygrad/renderer/wgsl.py +0 -53
  69. tinygrad/runtime/lib.py +0 -113
  70. tinygrad/runtime/ops_cpu.py +0 -51
  71. tinygrad/runtime/ops_hip.py +0 -82
  72. tinygrad/runtime/ops_shm.py +0 -29
  73. tinygrad/runtime/ops_torch.py +0 -30
  74. tinygrad/runtime/ops_webgpu.py +0 -45
  75. tinygrad-0.7.0.dist-info/METADATA +0 -212
  76. tinygrad-0.7.0.dist-info/RECORD +0 -40
  77. {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ import ctypes
2
+ import tinygrad.runtime.autogen.comgr as comgr
3
+
4
+ def check(status):
5
+ if status != 0:
6
+ comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
7
+ raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
8
+
9
+ def _get_comgr_data(data_set, data_type):
10
+ check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
11
+ check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
12
+ check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
13
+ check(comgr.amd_comgr_release_data(data_exec))
14
+ return bytes(dat)
15
+
16
+ # AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
17
+ def compile_hip(prg:str, arch="gfx1100") -> bytes:
18
+ check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
19
+ check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
20
+ check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
21
+ check(comgr.amd_comgr_action_info_set_logging(action_info, True))
22
+
23
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
24
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
25
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
26
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
27
+
28
+ check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
29
+ check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
30
+ check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
31
+
32
+ check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
33
+ # -include hiprtc_runtime.h was removed
34
+ check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
35
+ status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
36
+ if status != 0:
37
+ print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
38
+ raise RuntimeError("compile failed")
39
+ check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
40
+ check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
41
+ check(comgr.amd_comgr_action_info_set_options(action_info, b""))
42
+ check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
43
+ ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
44
+ check(comgr.amd_comgr_release_data(data_src))
45
+ for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
46
+ check(comgr.amd_comgr_destroy_action_info(action_info))
47
+ return ret
@@ -0,0 +1,143 @@
1
+ import ctypes, collections
2
+ import tinygrad.runtime.autogen.hsa as hsa
3
+ from tinygrad.helpers import init_c_var
4
+
5
+ def check(status):
6
+ if status != 0:
7
+ hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
8
+ raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
9
+
10
+ # Precalulated AQL info
11
+ AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
12
+ EMPTY_SIGNAL = hsa.hsa_signal_t()
13
+
14
+ DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
15
+ DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
16
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
17
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
18
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
19
+
20
+ BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
21
+ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
22
+ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
23
+ BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
24
+
25
+ class AQLQueue:
26
+ def __init__(self, device, sz=-1):
27
+ self.device = device
28
+
29
+ check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
30
+ queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
31
+
32
+ null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
33
+ self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
34
+ hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
35
+
36
+ self.next_doorbell_index = 0
37
+ self.queue_base = self.hw_queue.contents.base_address
38
+ self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
39
+ self.write_addr = self.queue_base
40
+ self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
41
+ self.available_packet_slots = self.hw_queue.contents.size
42
+
43
+ check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
44
+ check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
45
+
46
+ def __del__(self):
47
+ if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
48
+
49
+ def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
50
+ if self.available_packet_slots == 0: self._wait_queue()
51
+
52
+ packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
53
+ packet.workgroup_size_x = local_size[0]
54
+ packet.workgroup_size_y = local_size[1]
55
+ packet.workgroup_size_z = local_size[2]
56
+ packet.reserved0 = 0
57
+ packet.grid_size_x = global_size[0] * local_size[0]
58
+ packet.grid_size_y = global_size[1] * local_size[1]
59
+ packet.grid_size_z = global_size[2] * local_size[2]
60
+ packet.private_segment_size = prg.private_segment_size
61
+ packet.group_segment_size = prg.group_segment_size
62
+ packet.kernel_object = prg.handle
63
+ packet.kernarg_address = kernargs
64
+ packet.reserved2 = 0
65
+ packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
66
+ packet.setup = DISPATCH_KERNEL_SETUP
67
+ packet.header = DISPATCH_KERNEL_HEADER
68
+ self._submit_packet()
69
+
70
+ def submit_barrier(self, wait_signals=None, completion_signal=None):
71
+ assert wait_signals is None or len(wait_signals) <= 5
72
+ if self.available_packet_slots == 0: self._wait_queue()
73
+
74
+ packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
75
+ packet.reserved0 = 0
76
+ packet.reserved1 = 0
77
+ for i in range(5):
78
+ packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
79
+ packet.reserved2 = 0
80
+ packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
81
+ packet.header = BARRIER_HEADER
82
+ self._submit_packet()
83
+
84
+ def blit_packets(self, packet_addr, packet_cnt):
85
+ if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
86
+
87
+ tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
88
+ rem_packet_cnt = packet_cnt - tail_blit_packets
89
+ ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
90
+ if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
91
+
92
+ self._submit_packet(packet_cnt)
93
+
94
+ def wait(self):
95
+ self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
96
+ hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
97
+ self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
98
+
99
+ def _wait_queue(self, need_packets=1):
100
+ while self.available_packet_slots < need_packets:
101
+ rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
102
+ self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
103
+
104
+ def _submit_packet(self, cnt=1):
105
+ self.available_packet_slots -= cnt
106
+ self.next_doorbell_index += cnt
107
+ hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
108
+ hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
109
+
110
+ self.write_addr += AQL_PACKET_SIZE * cnt
111
+ if self.write_addr > self.write_addr_end:
112
+ self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
113
+
114
+ def scan_agents():
115
+ agents = collections.defaultdict(list)
116
+
117
+ @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
118
+ def __scan_agents(agent, data):
119
+ status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
120
+ if status == 0: agents[device_type.value].append(agent)
121
+ return hsa.HSA_STATUS_SUCCESS
122
+
123
+ hsa.hsa_iterate_agents(__scan_agents, None)
124
+ return agents
125
+
126
+ def find_memory_pool(agent, segtyp=-1, location=-1):
127
+ @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
128
+ def __filter_amd_memory_pools(mem_pool, data):
129
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
130
+ if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
131
+
132
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
133
+ if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
134
+
135
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
136
+ if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
137
+
138
+ ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
139
+ ret[0] = mem_pool
140
+ return hsa.HSA_STATUS_INFO_BREAK
141
+
142
+ hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
143
+ return region
@@ -0,0 +1,38 @@
1
+ from typing import List, Dict, cast
2
+ import ctypes
3
+ from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
4
+ from tinygrad.engine.jit import GraphRunner
5
+ from tinygrad.device import Buffer, Device
6
+ from tinygrad.engine.realize import ExecItem, CompiledRunner
7
+ from tinygrad.shape.symbolic import Variable
8
+ from tinygrad.runtime.ops_clang import ClangProgram
9
+ from tinygrad.renderer.cstyle import ClangRenderer
10
+ render_dtype = ClangRenderer().render_dtype
11
+
12
+ class ClangGraph(GraphRunner):
13
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
14
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
15
+ if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
16
+
17
+ prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
18
+ args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
19
+ args += [f"int {v.expr}" for v in var_vals]
20
+ code = ["void batched("+','.join(args)+") {"]
21
+ for ji in jit_cache:
22
+ args = []
23
+ for buf in ji.bufs:
24
+ assert buf is not None
25
+ if buf in input_rawbuffers:
26
+ args.append(f"arg{input_rawbuffers.index(buf)}")
27
+ else:
28
+ args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
29
+ args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
30
+ code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
31
+ code.append("}")
32
+ if DEBUG >= 4: print("\n".join(code))
33
+ compiler = Device["CLANG"].compiler
34
+ assert compiler is not None
35
+ self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
36
+
37
+ def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
38
+ return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
@@ -0,0 +1,81 @@
1
+ import ctypes
2
+ from typing import Any, Optional, Tuple, Dict, List, cast
3
+ import tinygrad.runtime.autogen.cuda as cuda
4
+ from tinygrad.helpers import init_c_var, GraphException
5
+ from tinygrad.device import Buffer, Device
6
+ from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
+ from tinygrad.shape.symbolic import Variable
8
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner
10
+
11
+ class CUDAGraph(MultiGraphRunner):
12
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
13
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
14
+
15
+ # Check all jit items are compatible.
16
+ if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
17
+
18
+ self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
19
+ self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
20
+
21
+ self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
22
+
23
+ for j,ji in enumerate(self.jit_cache):
24
+ if isinstance(ji.prg, CompiledRunner):
25
+ global_size, local_size = ji.prg.p.launch_dims(var_vals)
26
+
27
+ new_node = cuda.CUgraphNode()
28
+ deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
29
+ [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
30
+ c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
31
+
32
+ c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
33
+ kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
34
+ check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
35
+
36
+ if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
37
+ self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
38
+ elif isinstance(ji.prg, BufferXfer):
39
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
40
+ src_dev = cast(CUDADevice, Device[src.device])
41
+ node_from = cuda.CUgraphNode()
42
+ deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
43
+ c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
44
+ cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
45
+ dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
46
+ WidthInBytes=dest.nbytes, Height=1, Depth=1)
47
+ check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
48
+ if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
49
+
50
+ self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
51
+
52
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
53
+ # Update rawbuffers in the c_args struct.
54
+ for (j,i),input_idx in self.input_replace.items():
55
+ if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
56
+ else:
57
+ if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
58
+ elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
59
+
60
+ # Update var_vals in the c_args struct.
61
+ for j in self.jc_idx_with_updatable_var_vals:
62
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
63
+ setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
64
+
65
+ # Update launch dims in the kern_params struct.
66
+ for j in self.jc_idx_with_updatable_launch_dims:
67
+ self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
68
+
69
+ # Update graph nodes with the updated structs.
70
+ for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
71
+ if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
72
+ else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
73
+
74
+ return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
75
+
76
+ def __del__(self):
77
+ if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
78
+ if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
79
+
80
+ def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
81
+ node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
@@ -0,0 +1,143 @@
1
+ import ctypes, collections, array, time
2
+ from typing import List, Any, Dict, cast, Optional, Tuple, Set
3
+ from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
4
+ from tinygrad.device import Buffer, BufferOptions, Compiled, Device
5
+ from tinygrad.shape.symbolic import Variable
6
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
7
+ from tinygrad.engine.jit import MultiGraphRunner
8
+
9
+ class HCQGraph(MultiGraphRunner):
10
+ def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
11
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
12
+ self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
13
+
14
+ # Check all jit items are compatible.
15
+ self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
16
+ if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
17
+
18
+ # Allocate kernel args.
19
+ kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
20
+ for ji in self.jit_cache:
21
+ if not isinstance(ji.prg, CompiledRunner): continue
22
+ kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
23
+ kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
24
+
25
+ # Fill initial arguments.
26
+ self.kargs_addrs: Dict[int, int] = {}
27
+ self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
28
+ for j,ji in enumerate(self.jit_cache):
29
+ if not isinstance(ji.prg, CompiledRunner): continue
30
+ self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
31
+ kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
32
+
33
+ args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
34
+ [(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
35
+ self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
36
+ for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
37
+ for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
38
+
39
+ # NV needs constbuffer to be set
40
+ if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
41
+
42
+ # Build queues.
43
+ self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
44
+ self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
45
+ self.comp_signal_val = {dev: 0 for dev in self.devices}
46
+
47
+ self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
48
+ self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
49
+ self.copy_signal_val = {dev: 0 for dev in self.devices}
50
+
51
+ self.kickoff_signal = self.devices[0]._get_signal(value=0)
52
+ self.kickoff_value = 0
53
+ self.graph_timeline = {dev: 0 for dev in self.devices}
54
+
55
+ self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
56
+ self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
57
+
58
+ for j,ji in enumerate(self.jit_cache):
59
+ if isinstance(ji.prg, CompiledRunner):
60
+ exec_params = {}
61
+ deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
62
+ deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
63
+
64
+ # On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
65
+ # Chaining executions is preferred when possible, as it is faster.
66
+ if ji.prg.device.dname.startswith("NV"):
67
+ if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
68
+ exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
69
+ else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
70
+
71
+ for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
72
+
73
+ self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
74
+ self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
75
+ signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
76
+ self.comp_signal_val[ji.prg.device] = sig_val
77
+ elif isinstance(ji.prg, BufferXfer):
78
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
79
+ Device[src.device]._gpu_map(dest._buf) #type: ignore
80
+
81
+ deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
82
+ deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
83
+ self.copy_signal_val[Device[src.device]] = sig_val
84
+
85
+ for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
86
+ self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
87
+ .signal(self.copy_signal[Device[src.device]], sig_val)
88
+ self.copy_to_devs[Device[dest.device]].add(Device[src.device])
89
+
90
+ for dev in self.devices:
91
+ if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
92
+ for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
93
+
94
+ if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
95
+ if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
96
+
97
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
98
+ # Wait and restore signals
99
+ self.kickoff_value += 1
100
+ for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
101
+ for dev in self.devices:
102
+ dev._set_signal(self.comp_signal[dev], 0)
103
+ dev._set_signal(self.copy_signal[dev], 0)
104
+ dev._set_signal(self.kickoff_signal, self.kickoff_value)
105
+
106
+ # Update rawbuffers
107
+ for (j,i),input_idx in self.input_replace.items():
108
+ self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
109
+
110
+ # Update var_vals
111
+ for j in self.jc_idx_with_updatable_var_vals:
112
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
113
+ self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
114
+
115
+ for j in self.jc_idx_with_updatable_launch_dims:
116
+ queue, cmd_ptr = self.exec_ptrs[j]
117
+ queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
118
+
119
+ for dev in self.devices:
120
+ # Submit sync with world and queues.
121
+ self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
122
+ .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
123
+ self.comp_queues[dev].submit(dev)
124
+
125
+ if self.copy_signal_val[dev] > 0:
126
+ self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
127
+ .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
128
+ self.copy_queues[dev].submit(dev)
129
+
130
+ # Signal the final value
131
+ self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
132
+ self.graph_timeline[dev] = dev.timeline_value
133
+ dev.timeline_value += 1
134
+
135
+ if wait:
136
+ st = time.perf_counter()
137
+ for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
138
+ return time.perf_counter() - st
139
+ return None
140
+
141
+ def access_resources(self, read, write, new_dependency):
142
+ deps = self._access_resources(read, write, new_dependency)
143
+ return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
@@ -0,0 +1,171 @@
1
+ import ctypes, collections, time, itertools
2
+ from typing import List, Any, Dict, cast, Optional, Tuple
3
+ from tinygrad.helpers import GraphException, init_c_var, round_up
4
+ from tinygrad.device import Buffer, BufferOptions
5
+ from tinygrad.device import Compiled, Device
6
+ from tinygrad.shape.symbolic import Variable
7
+ from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
8
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner
10
+ import tinygrad.runtime.autogen.hsa as hsa
11
+ from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
12
+
13
+ def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
14
+
15
+ class VirtAQLQueue(AQLQueue):
16
+ def __init__(self, device, sz):
17
+ self.device = device
18
+ self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
19
+ self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
20
+ self.packets_count = 0
21
+ self.available_packet_slots = sz
22
+ def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
23
+ def _submit_packet(self):
24
+ self.write_addr += AQL_PACKET_SIZE
25
+ self.packets_count += 1
26
+ self.available_packet_slots -= 1
27
+
28
+ class HSAGraph(MultiGraphRunner):
29
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
30
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
31
+
32
+ # Check all jit items are compatible.
33
+ compiled_devices = set()
34
+ for ji in self.jit_cache:
35
+ if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
36
+ elif isinstance(ji.prg, BufferXfer):
37
+ for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
38
+ else: raise GraphException
39
+ if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
40
+
41
+ self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
42
+
43
+ # Allocate kernel args.
44
+ kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
45
+ for ji in self.jit_cache:
46
+ if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
47
+ kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
48
+
49
+ # Fill initial arguments.
50
+ self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
51
+ for j,ji in enumerate(self.jit_cache):
52
+ if not isinstance(ji.prg, CompiledRunner): continue
53
+ self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
54
+ kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
55
+ for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
56
+ for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
57
+
58
+ # Build queues.
59
+ self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
60
+ self.packets = {}
61
+ self.transfers = []
62
+ self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
63
+ self.signals_to_reset: List[hsa.hsa_signal_t] = []
64
+ self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
65
+ self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
66
+
67
+ # Special packet to wait for the world.
68
+ self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
69
+ for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
70
+
71
+ for j,ji in enumerate(self.jit_cache):
72
+ if isinstance(ji.prg, CompiledRunner):
73
+ wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
74
+ for i in range(0, len(wait_signals), 5):
75
+ self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
76
+ self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
77
+
78
+ sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
79
+ self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
80
+ ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
81
+ if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
82
+ elif isinstance(ji.prg, BufferXfer):
83
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
84
+ dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
85
+ sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
86
+
87
+ wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
88
+ self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
89
+ (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
90
+ self.ji_to_transfer[j] = len(self.transfers) - 1
91
+ if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
92
+
93
+ # Wait for all active signals to finish the graph
94
+ wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
95
+ for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
96
+ for dev in self.signals_to_devices[v.handle]:
97
+ wait_signals_to_finish[dev].append(v)
98
+
99
+ self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
100
+ for dev in self.devices:
101
+ wait_signals = wait_signals_to_finish[dev]
102
+ for i in range(0, max(1, len(wait_signals)), 5):
103
+ self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
104
+
105
+ # Zero signals to allow graph to start and execute.
106
+ for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
107
+ hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
108
+
109
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
110
+ # Wait and restore signals
111
+ hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
112
+ for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
113
+ hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
114
+
115
+ # Update rawbuffers
116
+ for (j,i),input_idx in self.input_replace.items():
117
+ if j in self.ji_kargs_structs:
118
+ self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
119
+ else:
120
+ if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
121
+ elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
122
+
123
+ # Update var_vals
124
+ for j in self.jc_idx_with_updatable_var_vals:
125
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
126
+ self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
127
+
128
+ # Update launch dims
129
+ for j in self.jc_idx_with_updatable_launch_dims:
130
+ gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
131
+ self.packets[j].workgroup_size_x = lc[0]
132
+ self.packets[j].workgroup_size_y = lc[1]
133
+ self.packets[j].workgroup_size_z = lc[2]
134
+ self.packets[j].grid_size_x = gl[0] * lc[0]
135
+ self.packets[j].grid_size_y = gl[1] * lc[1]
136
+ self.packets[j].grid_size_z = gl[2] * lc[2]
137
+
138
+ for dev in self.devices:
139
+ dev.flush_hdp()
140
+ dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
141
+
142
+ for transfer_data in self.transfers:
143
+ check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
144
+
145
+ et = None
146
+ if wait:
147
+ st = time.perf_counter()
148
+ hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
149
+ et = time.perf_counter() - st
150
+
151
+ for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
152
+ return et
153
+
154
+ def alloc_signal(self, reset_on_start=False, wait_on=None):
155
+ sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
156
+ if reset_on_start: self.signals_to_reset.append(sync_signal)
157
+ if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
158
+ return sync_signal
159
+
160
+ def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
161
+ if isinstance(dep, hsa.hsa_signal_t): return dep
162
+ elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
163
+ if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
164
+ return packet.completion_signal
165
+ return None
166
+
167
+ def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
168
+ rdeps = self._access_resources(read, write, new_dependency)
169
+ wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
170
+ if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
171
+ return dedup_signals(wait_signals)