tinygrad 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tinygrad/__init__.py +6 -6
  2. tinygrad/codegen/kernel.py +230 -190
  3. tinygrad/codegen/linearizer.py +278 -384
  4. tinygrad/codegen/uops.py +415 -0
  5. tinygrad/device.py +132 -275
  6. tinygrad/dtype.py +53 -37
  7. tinygrad/engine/__init__.py +0 -0
  8. tinygrad/engine/graph.py +100 -0
  9. tinygrad/engine/jit.py +195 -0
  10. tinygrad/engine/realize.py +191 -0
  11. tinygrad/engine/schedule.py +362 -0
  12. tinygrad/engine/search.py +196 -0
  13. tinygrad/{mlops.py → function.py} +28 -14
  14. tinygrad/helpers.py +72 -43
  15. tinygrad/lazy.py +141 -240
  16. tinygrad/multi.py +169 -0
  17. tinygrad/nn/__init__.py +179 -8
  18. tinygrad/nn/datasets.py +7 -0
  19. tinygrad/nn/optim.py +106 -28
  20. tinygrad/nn/state.py +86 -17
  21. tinygrad/ops.py +70 -44
  22. tinygrad/renderer/__init__.py +61 -0
  23. tinygrad/renderer/assembly.py +276 -0
  24. tinygrad/renderer/cstyle.py +299 -206
  25. tinygrad/renderer/llvmir.py +118 -123
  26. tinygrad/runtime/autogen/amd_gpu.py +1900 -0
  27. tinygrad/runtime/autogen/comgr.py +865 -0
  28. tinygrad/runtime/autogen/cuda.py +5923 -0
  29. tinygrad/runtime/autogen/hip.py +5909 -0
  30. tinygrad/runtime/autogen/hsa.py +5761 -0
  31. tinygrad/runtime/autogen/kfd.py +812 -0
  32. tinygrad/runtime/autogen/nv_gpu.py +33328 -0
  33. tinygrad/runtime/autogen/opencl.py +1795 -0
  34. tinygrad/runtime/driver/hip_comgr.py +47 -0
  35. tinygrad/runtime/driver/hsa.py +143 -0
  36. tinygrad/runtime/graph/clang.py +38 -0
  37. tinygrad/runtime/graph/cuda.py +59 -54
  38. tinygrad/runtime/graph/hcq.py +143 -0
  39. tinygrad/runtime/graph/hsa.py +171 -0
  40. tinygrad/runtime/graph/metal.py +37 -41
  41. tinygrad/runtime/ops_amd.py +564 -0
  42. tinygrad/runtime/ops_clang.py +16 -14
  43. tinygrad/runtime/ops_cuda.py +130 -38
  44. tinygrad/runtime/ops_disk.py +45 -42
  45. tinygrad/runtime/ops_gpu.py +52 -50
  46. tinygrad/runtime/ops_hsa.py +278 -0
  47. tinygrad/runtime/ops_llvm.py +36 -56
  48. tinygrad/runtime/ops_metal.py +42 -24
  49. tinygrad/runtime/ops_npy.py +9 -0
  50. tinygrad/runtime/ops_nv.py +630 -0
  51. tinygrad/runtime/ops_python.py +204 -0
  52. tinygrad/shape/shapetracker.py +41 -105
  53. tinygrad/shape/symbolic.py +98 -95
  54. tinygrad/shape/view.py +137 -35
  55. tinygrad/tensor.py +2367 -442
  56. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
  57. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/METADATA +19 -9
  58. tinygrad-0.9.0.dist-info/RECORD +60 -0
  59. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
  60. tinygrad/features/image.py +0 -93
  61. tinygrad/features/multi.py +0 -103
  62. tinygrad/features/search.py +0 -160
  63. tinygrad/graph.py +0 -106
  64. tinygrad/jit.py +0 -152
  65. tinygrad/realize.py +0 -50
  66. tinygrad/runtime/graph/hip.py +0 -24
  67. tinygrad/runtime/ops_cpu.py +0 -45
  68. tinygrad/runtime/ops_hip.py +0 -97
  69. tinygrad/runtime/ops_torch.py +0 -49
  70. tinygrad-0.8.0.dist-info/RECORD +0 -41
  71. {tinygrad-0.8.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ import ctypes
2
+ import tinygrad.runtime.autogen.comgr as comgr
3
+
4
+ def check(status):
5
+ if status != 0:
6
+ comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
7
+ raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
8
+
9
+ def _get_comgr_data(data_set, data_type):
10
+ check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
11
+ check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
12
+ check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
13
+ check(comgr.amd_comgr_release_data(data_exec))
14
+ return bytes(dat)
15
+
16
+ # AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
17
+ def compile_hip(prg:str, arch="gfx1100") -> bytes:
18
+ check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
19
+ check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
20
+ check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
21
+ check(comgr.amd_comgr_action_info_set_logging(action_info, True))
22
+
23
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
24
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
25
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
26
+ check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
27
+
28
+ check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
29
+ check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
30
+ check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
31
+
32
+ check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
33
+ # -include hiprtc_runtime.h was removed
34
+ check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
35
+ status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
36
+ if status != 0:
37
+ print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
38
+ raise RuntimeError("compile failed")
39
+ check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
40
+ check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
41
+ check(comgr.amd_comgr_action_info_set_options(action_info, b""))
42
+ check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
43
+ ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
44
+ check(comgr.amd_comgr_release_data(data_src))
45
+ for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
46
+ check(comgr.amd_comgr_destroy_action_info(action_info))
47
+ return ret
@@ -0,0 +1,143 @@
1
+ import ctypes, collections
2
+ import tinygrad.runtime.autogen.hsa as hsa
3
+ from tinygrad.helpers import init_c_var
4
+
5
+ def check(status):
6
+ if status != 0:
7
+ hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
8
+ raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
9
+
10
+ # Precalulated AQL info
11
+ AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
12
+ EMPTY_SIGNAL = hsa.hsa_signal_t()
13
+
14
+ DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
15
+ DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
16
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
17
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
18
+ DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
19
+
20
+ BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
21
+ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
22
+ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
23
+ BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
24
+
25
+ class AQLQueue:
26
+ def __init__(self, device, sz=-1):
27
+ self.device = device
28
+
29
+ check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
30
+ queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
31
+
32
+ null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
33
+ self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
34
+ hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
35
+
36
+ self.next_doorbell_index = 0
37
+ self.queue_base = self.hw_queue.contents.base_address
38
+ self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
39
+ self.write_addr = self.queue_base
40
+ self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
41
+ self.available_packet_slots = self.hw_queue.contents.size
42
+
43
+ check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
44
+ check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
45
+
46
+ def __del__(self):
47
+ if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
48
+
49
+ def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
50
+ if self.available_packet_slots == 0: self._wait_queue()
51
+
52
+ packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
53
+ packet.workgroup_size_x = local_size[0]
54
+ packet.workgroup_size_y = local_size[1]
55
+ packet.workgroup_size_z = local_size[2]
56
+ packet.reserved0 = 0
57
+ packet.grid_size_x = global_size[0] * local_size[0]
58
+ packet.grid_size_y = global_size[1] * local_size[1]
59
+ packet.grid_size_z = global_size[2] * local_size[2]
60
+ packet.private_segment_size = prg.private_segment_size
61
+ packet.group_segment_size = prg.group_segment_size
62
+ packet.kernel_object = prg.handle
63
+ packet.kernarg_address = kernargs
64
+ packet.reserved2 = 0
65
+ packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
66
+ packet.setup = DISPATCH_KERNEL_SETUP
67
+ packet.header = DISPATCH_KERNEL_HEADER
68
+ self._submit_packet()
69
+
70
+ def submit_barrier(self, wait_signals=None, completion_signal=None):
71
+ assert wait_signals is None or len(wait_signals) <= 5
72
+ if self.available_packet_slots == 0: self._wait_queue()
73
+
74
+ packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
75
+ packet.reserved0 = 0
76
+ packet.reserved1 = 0
77
+ for i in range(5):
78
+ packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
79
+ packet.reserved2 = 0
80
+ packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
81
+ packet.header = BARRIER_HEADER
82
+ self._submit_packet()
83
+
84
+ def blit_packets(self, packet_addr, packet_cnt):
85
+ if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
86
+
87
+ tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
88
+ rem_packet_cnt = packet_cnt - tail_blit_packets
89
+ ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
90
+ if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
91
+
92
+ self._submit_packet(packet_cnt)
93
+
94
+ def wait(self):
95
+ self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
96
+ hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
97
+ self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
98
+
99
+ def _wait_queue(self, need_packets=1):
100
+ while self.available_packet_slots < need_packets:
101
+ rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
102
+ self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
103
+
104
+ def _submit_packet(self, cnt=1):
105
+ self.available_packet_slots -= cnt
106
+ self.next_doorbell_index += cnt
107
+ hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
108
+ hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
109
+
110
+ self.write_addr += AQL_PACKET_SIZE * cnt
111
+ if self.write_addr > self.write_addr_end:
112
+ self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
113
+
114
+ def scan_agents():
115
+ agents = collections.defaultdict(list)
116
+
117
+ @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
118
+ def __scan_agents(agent, data):
119
+ status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
120
+ if status == 0: agents[device_type.value].append(agent)
121
+ return hsa.HSA_STATUS_SUCCESS
122
+
123
+ hsa.hsa_iterate_agents(__scan_agents, None)
124
+ return agents
125
+
126
+ def find_memory_pool(agent, segtyp=-1, location=-1):
127
+ @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
128
+ def __filter_amd_memory_pools(mem_pool, data):
129
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
130
+ if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
131
+
132
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
133
+ if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
134
+
135
+ check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
136
+ if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
137
+
138
+ ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
139
+ ret[0] = mem_pool
140
+ return hsa.HSA_STATUS_INFO_BREAK
141
+
142
+ hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
143
+ return region
@@ -0,0 +1,38 @@
1
+ from typing import List, Dict, cast
2
+ import ctypes
3
+ from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
4
+ from tinygrad.engine.jit import GraphRunner
5
+ from tinygrad.device import Buffer, Device
6
+ from tinygrad.engine.realize import ExecItem, CompiledRunner
7
+ from tinygrad.shape.symbolic import Variable
8
+ from tinygrad.runtime.ops_clang import ClangProgram
9
+ from tinygrad.renderer.cstyle import ClangRenderer
10
+ render_dtype = ClangRenderer().render_dtype
11
+
12
+ class ClangGraph(GraphRunner):
13
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
14
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
15
+ if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
16
+
17
+ prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
18
+ args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
19
+ args += [f"int {v.expr}" for v in var_vals]
20
+ code = ["void batched("+','.join(args)+") {"]
21
+ for ji in jit_cache:
22
+ args = []
23
+ for buf in ji.bufs:
24
+ assert buf is not None
25
+ if buf in input_rawbuffers:
26
+ args.append(f"arg{input_rawbuffers.index(buf)}")
27
+ else:
28
+ args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
29
+ args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
30
+ code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
31
+ code.append("}")
32
+ if DEBUG >= 4: print("\n".join(code))
33
+ compiler = Device["CLANG"].compiler
34
+ assert compiler is not None
35
+ self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
36
+
37
+ def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
38
+ return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
@@ -1,76 +1,81 @@
1
1
  import ctypes
2
2
  from typing import Any, Optional, Tuple, Dict, List, cast
3
- import gpuctypes.cuda as cuda
4
- from tinygrad.helpers import init_c_var, encode_args_cuda_style
5
- from tinygrad.device import CompiledASTRunner, update_stats, Buffer
6
- from tinygrad.runtime.ops_cuda import check, cu_time_execution
3
+ import tinygrad.runtime.autogen.cuda as cuda
4
+ from tinygrad.helpers import init_c_var, GraphException
5
+ from tinygrad.device import Buffer, Device
6
+ from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
7
  from tinygrad.shape.symbolic import Variable
8
- from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException # noqa: E501
8
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
9
+ from tinygrad.engine.jit import MultiGraphRunner
9
10
 
10
- class CUDAGraph:
11
- def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
12
- if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
11
+ class CUDAGraph(MultiGraphRunner):
12
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
13
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
13
14
 
14
- self.jit_cache = jit_cache
15
- self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
16
- self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
17
- self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
18
- self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
19
- self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
20
- self.updatable_nodes: Dict[int, Tuple[Any, Any, Any]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params)
15
+ # Check all jit items are compatible.
16
+ if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
21
17
 
22
- self.graph = self.graph_create()
23
- graph_node: Optional[ctypes._CData] = None
18
+ self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
19
+ self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
20
+
21
+ self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
24
22
 
25
- for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
26
23
  for j,ji in enumerate(self.jit_cache):
27
- prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
24
+ if isinstance(ji.prg, CompiledRunner):
25
+ global_size, local_size = ji.prg.p.launch_dims(var_vals)
26
+
27
+ new_node = cuda.CUgraphNode()
28
+ deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
29
+ [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
30
+ c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
28
31
 
29
- c_deps = (type(graph_node)*1)(*(graph_node,)) if graph_node is not None else None
30
- c_kernel_input_config, c_input_params = encode_args_cuda_style([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info()) # noqa: E501
31
- c_node_params = self.build_kernel_node_params(prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config)
32
- graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
32
+ c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
33
+ kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
34
+ check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
33
35
 
34
- if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
35
- self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params)
36
+ if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
37
+ self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
38
+ elif isinstance(ji.prg, BufferXfer):
39
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
40
+ src_dev = cast(CUDADevice, Device[src.device])
41
+ node_from = cuda.CUgraphNode()
42
+ deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
43
+ c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
44
+ cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
45
+ dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
46
+ WidthInBytes=dest.nbytes, Height=1, Depth=1)
47
+ check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
48
+ if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
36
49
 
37
- self.instance = self.graph_instantiate(self.graph)
50
+ self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
38
51
 
39
- def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
40
- # Update rawbuffers in the c_input_params struct.
52
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
53
+ # Update rawbuffers in the c_args struct.
41
54
  for (j,i),input_idx in self.input_replace.items():
42
- setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
55
+ if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
56
+ else:
57
+ if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
58
+ elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
43
59
 
44
- # Update var_vals in the c_input_params struct.
45
- for j in self.jc_idxs_with_updatable_var_vals:
46
- for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
47
- setattr(self.updatable_nodes[j][2], f'f{len(self.jit_cache[j].rawbufs) + i}', var_vals[v])
60
+ # Update var_vals in the c_args struct.
61
+ for j in self.jc_idx_with_updatable_var_vals:
62
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
63
+ setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
48
64
 
49
- # Update launch dims in the c_node_params struct.
50
- for j in self.jc_idxs_with_updatable_launch_dims:
51
- self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
65
+ # Update launch dims in the kern_params struct.
66
+ for j in self.jc_idx_with_updatable_launch_dims:
67
+ self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
52
68
 
53
69
  # Update graph nodes with the updated structs.
54
- for node, c_node_params, _ in self.updatable_nodes.values():
55
- self.graph_exec_kernel_node_set_params(self.instance, node, ctypes.byref(c_node_params))
70
+ for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
71
+ if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
72
+ else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
56
73
 
57
- et = self.graph_launch(self.instance, None, wait=wait)
58
- update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501
59
- return et
74
+ return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
60
75
 
61
76
  def __del__(self):
62
- check(cuda.cuGraphDestroy(self.graph))
63
- check(cuda.cuGraphExecDestroy(self.instance))
64
-
65
- def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
66
- def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
67
- def graph_instantiate(self, graph):
68
- return init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)))
69
- def graph_add_kernel_node(self, graph, c_deps, c_node_params):
70
- return init_c_var(cuda.CUgraphNode(), lambda x: check(cuda.cuGraphAddKernelNode(ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps)//8 if c_deps else 0, ctypes.byref(c_node_params)))) # noqa: E501
71
- def graph_launch(self, *args, wait=False): return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
72
- def graph_exec_kernel_node_set_params(self, *args): return check(cuda.cuGraphExecKernelNodeSetParams(*args))
73
- def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config):
74
- return cuda.CUDA_KERNEL_NODE_PARAMS(prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config)
77
+ if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
78
+ if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
79
+
75
80
  def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
76
81
  node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
@@ -0,0 +1,143 @@
1
+ import ctypes, collections, array, time
2
+ from typing import List, Any, Dict, cast, Optional, Tuple, Set
3
+ from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
4
+ from tinygrad.device import Buffer, BufferOptions, Compiled, Device
5
+ from tinygrad.shape.symbolic import Variable
6
+ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
7
+ from tinygrad.engine.jit import MultiGraphRunner
8
+
9
+ class HCQGraph(MultiGraphRunner):
10
+ def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
11
+ super().__init__(jit_cache, input_rawbuffers, var_vals)
12
+ self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
13
+
14
+ # Check all jit items are compatible.
15
+ self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
16
+ if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
17
+
18
+ # Allocate kernel args.
19
+ kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
20
+ for ji in self.jit_cache:
21
+ if not isinstance(ji.prg, CompiledRunner): continue
22
+ kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
23
+ kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
24
+
25
+ # Fill initial arguments.
26
+ self.kargs_addrs: Dict[int, int] = {}
27
+ self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
28
+ for j,ji in enumerate(self.jit_cache):
29
+ if not isinstance(ji.prg, CompiledRunner): continue
30
+ self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
31
+ kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
32
+
33
+ args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
34
+ [(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
35
+ self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
36
+ for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
37
+ for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
38
+
39
+ # NV needs constbuffer to be set
40
+ if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
41
+
42
+ # Build queues.
43
+ self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
44
+ self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
45
+ self.comp_signal_val = {dev: 0 for dev in self.devices}
46
+
47
+ self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
48
+ self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
49
+ self.copy_signal_val = {dev: 0 for dev in self.devices}
50
+
51
+ self.kickoff_signal = self.devices[0]._get_signal(value=0)
52
+ self.kickoff_value = 0
53
+ self.graph_timeline = {dev: 0 for dev in self.devices}
54
+
55
+ self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
56
+ self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
57
+
58
+ for j,ji in enumerate(self.jit_cache):
59
+ if isinstance(ji.prg, CompiledRunner):
60
+ exec_params = {}
61
+ deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
62
+ deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
63
+
64
+ # On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
65
+ # Chaining executions is preferred when possible, as it is faster.
66
+ if ji.prg.device.dname.startswith("NV"):
67
+ if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
68
+ exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
69
+ else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
70
+
71
+ for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
72
+
73
+ self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
74
+ self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
75
+ signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
76
+ self.comp_signal_val[ji.prg.device] = sig_val
77
+ elif isinstance(ji.prg, BufferXfer):
78
+ dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
79
+ Device[src.device]._gpu_map(dest._buf) #type: ignore
80
+
81
+ deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
82
+ deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
83
+ self.copy_signal_val[Device[src.device]] = sig_val
84
+
85
+ for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
86
+ self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
87
+ .signal(self.copy_signal[Device[src.device]], sig_val)
88
+ self.copy_to_devs[Device[dest.device]].add(Device[src.device])
89
+
90
+ for dev in self.devices:
91
+ if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
92
+ for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
93
+
94
+ if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
95
+ if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
96
+
97
+ def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
98
+ # Wait and restore signals
99
+ self.kickoff_value += 1
100
+ for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
101
+ for dev in self.devices:
102
+ dev._set_signal(self.comp_signal[dev], 0)
103
+ dev._set_signal(self.copy_signal[dev], 0)
104
+ dev._set_signal(self.kickoff_signal, self.kickoff_value)
105
+
106
+ # Update rawbuffers
107
+ for (j,i),input_idx in self.input_replace.items():
108
+ self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
109
+
110
+ # Update var_vals
111
+ for j in self.jc_idx_with_updatable_var_vals:
112
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
113
+ self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
114
+
115
+ for j in self.jc_idx_with_updatable_launch_dims:
116
+ queue, cmd_ptr = self.exec_ptrs[j]
117
+ queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
118
+
119
+ for dev in self.devices:
120
+ # Submit sync with world and queues.
121
+ self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
122
+ .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
123
+ self.comp_queues[dev].submit(dev)
124
+
125
+ if self.copy_signal_val[dev] > 0:
126
+ self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
127
+ .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
128
+ self.copy_queues[dev].submit(dev)
129
+
130
+ # Signal the final value
131
+ self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
132
+ self.graph_timeline[dev] = dev.timeline_value
133
+ dev.timeline_value += 1
134
+
135
+ if wait:
136
+ st = time.perf_counter()
137
+ for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
138
+ return time.perf_counter() - st
139
+ return None
140
+
141
+ def access_resources(self, read, write, new_dependency):
142
+ deps = self._access_resources(read, write, new_dependency)
143
+ return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]