tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
| 
         @@ -0,0 +1,47 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import ctypes
         
     | 
| 
      
 2 
     | 
    
         
            +
            import tinygrad.runtime.autogen.comgr as comgr
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            def check(status):
         
     | 
| 
      
 5 
     | 
    
         
            +
              if status != 0:
         
     | 
| 
      
 6 
     | 
    
         
            +
                comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
         
     | 
| 
      
 7 
     | 
    
         
            +
                raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            def _get_comgr_data(data_set, data_type):
         
     | 
| 
      
 10 
     | 
    
         
            +
              check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
         
     | 
| 
      
 11 
     | 
    
         
            +
              check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
         
     | 
| 
      
 12 
     | 
    
         
            +
              check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
         
     | 
| 
      
 13 
     | 
    
         
            +
              check(comgr.amd_comgr_release_data(data_exec))
         
     | 
| 
      
 14 
     | 
    
         
            +
              return bytes(dat)
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            # AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
         
     | 
| 
      
 17 
     | 
    
         
            +
            def compile_hip(prg:str, arch="gfx1100") -> bytes:
         
     | 
| 
      
 18 
     | 
    
         
            +
              check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
         
     | 
| 
      
 19 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
         
     | 
| 
      
 20 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
         
     | 
| 
      
 21 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_logging(action_info, True))
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
              check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
         
     | 
| 
      
 24 
     | 
    
         
            +
              check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
         
     | 
| 
      
 25 
     | 
    
         
            +
              check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
         
     | 
| 
      
 26 
     | 
    
         
            +
              check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
              check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
         
     | 
| 
      
 29 
     | 
    
         
            +
              check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
         
     | 
| 
      
 30 
     | 
    
         
            +
              check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
              check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
         
     | 
| 
      
 33 
     | 
    
         
            +
              # -include hiprtc_runtime.h was removed
         
     | 
| 
      
 34 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
         
     | 
| 
      
 35 
     | 
    
         
            +
              status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
         
     | 
| 
      
 36 
     | 
    
         
            +
              if status != 0:
         
     | 
| 
      
 37 
     | 
    
         
            +
                print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
         
     | 
| 
      
 38 
     | 
    
         
            +
                raise RuntimeError("compile failed")
         
     | 
| 
      
 39 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
         
     | 
| 
      
 40 
     | 
    
         
            +
              check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
         
     | 
| 
      
 41 
     | 
    
         
            +
              check(comgr.amd_comgr_action_info_set_options(action_info, b""))
         
     | 
| 
      
 42 
     | 
    
         
            +
              check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
         
     | 
| 
      
 43 
     | 
    
         
            +
              ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
         
     | 
| 
      
 44 
     | 
    
         
            +
              check(comgr.amd_comgr_release_data(data_src))
         
     | 
| 
      
 45 
     | 
    
         
            +
              for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
         
     | 
| 
      
 46 
     | 
    
         
            +
              check(comgr.amd_comgr_destroy_action_info(action_info))
         
     | 
| 
      
 47 
     | 
    
         
            +
              return ret
         
     | 
| 
         @@ -0,0 +1,143 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import ctypes, collections
         
     | 
| 
      
 2 
     | 
    
         
            +
            import tinygrad.runtime.autogen.hsa as hsa
         
     | 
| 
      
 3 
     | 
    
         
            +
            from tinygrad.helpers import init_c_var
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            def check(status):
         
     | 
| 
      
 6 
     | 
    
         
            +
              if status != 0:
         
     | 
| 
      
 7 
     | 
    
         
            +
                hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
         
     | 
| 
      
 8 
     | 
    
         
            +
                raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            # Precalulated AQL info
         
     | 
| 
      
 11 
     | 
    
         
            +
            AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
         
     | 
| 
      
 12 
     | 
    
         
            +
            EMPTY_SIGNAL = hsa.hsa_signal_t()
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
         
     | 
| 
      
 15 
     | 
    
         
            +
            DISPATCH_KERNEL_HEADER  = 1 << hsa.HSA_PACKET_HEADER_BARRIER
         
     | 
| 
      
 16 
     | 
    
         
            +
            DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
         
     | 
| 
      
 17 
     | 
    
         
            +
            DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
         
     | 
| 
      
 18 
     | 
    
         
            +
            DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            BARRIER_HEADER  = 1 << hsa.HSA_PACKET_HEADER_BARRIER
         
     | 
| 
      
 21 
     | 
    
         
            +
            BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
         
     | 
| 
      
 22 
     | 
    
         
            +
            BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
         
     | 
| 
      
 23 
     | 
    
         
            +
            BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            class AQLQueue:
         
     | 
| 
      
 26 
     | 
    
         
            +
              def __init__(self, device, sz=-1):
         
     | 
| 
      
 27 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
         
     | 
| 
      
 30 
     | 
    
         
            +
                queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
         
     | 
| 
      
 33 
     | 
    
         
            +
                self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
         
     | 
| 
      
 34 
     | 
    
         
            +
                  hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                self.next_doorbell_index = 0
         
     | 
| 
      
 37 
     | 
    
         
            +
                self.queue_base = self.hw_queue.contents.base_address
         
     | 
| 
      
 38 
     | 
    
         
            +
                self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
         
     | 
| 
      
 39 
     | 
    
         
            +
                self.write_addr = self.queue_base
         
     | 
| 
      
 40 
     | 
    
         
            +
                self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
         
     | 
| 
      
 41 
     | 
    
         
            +
                self.available_packet_slots = self.hw_queue.contents.size
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
         
     | 
| 
      
 44 
     | 
    
         
            +
                check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
              def __del__(self):
         
     | 
| 
      
 47 
     | 
    
         
            +
                if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
              def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
         
     | 
| 
      
 50 
     | 
    
         
            +
                if self.available_packet_slots == 0: self._wait_queue()
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
         
     | 
| 
      
 53 
     | 
    
         
            +
                packet.workgroup_size_x = local_size[0]
         
     | 
| 
      
 54 
     | 
    
         
            +
                packet.workgroup_size_y = local_size[1]
         
     | 
| 
      
 55 
     | 
    
         
            +
                packet.workgroup_size_z = local_size[2]
         
     | 
| 
      
 56 
     | 
    
         
            +
                packet.reserved0 = 0
         
     | 
| 
      
 57 
     | 
    
         
            +
                packet.grid_size_x = global_size[0] * local_size[0]
         
     | 
| 
      
 58 
     | 
    
         
            +
                packet.grid_size_y = global_size[1] * local_size[1]
         
     | 
| 
      
 59 
     | 
    
         
            +
                packet.grid_size_z = global_size[2] * local_size[2]
         
     | 
| 
      
 60 
     | 
    
         
            +
                packet.private_segment_size = prg.private_segment_size
         
     | 
| 
      
 61 
     | 
    
         
            +
                packet.group_segment_size = prg.group_segment_size
         
     | 
| 
      
 62 
     | 
    
         
            +
                packet.kernel_object = prg.handle
         
     | 
| 
      
 63 
     | 
    
         
            +
                packet.kernarg_address = kernargs
         
     | 
| 
      
 64 
     | 
    
         
            +
                packet.reserved2 = 0
         
     | 
| 
      
 65 
     | 
    
         
            +
                packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
         
     | 
| 
      
 66 
     | 
    
         
            +
                packet.setup = DISPATCH_KERNEL_SETUP
         
     | 
| 
      
 67 
     | 
    
         
            +
                packet.header = DISPATCH_KERNEL_HEADER
         
     | 
| 
      
 68 
     | 
    
         
            +
                self._submit_packet()
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
              def submit_barrier(self, wait_signals=None, completion_signal=None):
         
     | 
| 
      
 71 
     | 
    
         
            +
                assert wait_signals is None or len(wait_signals) <= 5
         
     | 
| 
      
 72 
     | 
    
         
            +
                if self.available_packet_slots == 0: self._wait_queue()
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
         
     | 
| 
      
 75 
     | 
    
         
            +
                packet.reserved0 = 0
         
     | 
| 
      
 76 
     | 
    
         
            +
                packet.reserved1 = 0
         
     | 
| 
      
 77 
     | 
    
         
            +
                for i in range(5):
         
     | 
| 
      
 78 
     | 
    
         
            +
                  packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
         
     | 
| 
      
 79 
     | 
    
         
            +
                packet.reserved2 = 0
         
     | 
| 
      
 80 
     | 
    
         
            +
                packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
         
     | 
| 
      
 81 
     | 
    
         
            +
                packet.header = BARRIER_HEADER
         
     | 
| 
      
 82 
     | 
    
         
            +
                self._submit_packet()
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
              def blit_packets(self, packet_addr, packet_cnt):
         
     | 
| 
      
 85 
     | 
    
         
            +
                if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
         
     | 
| 
      
 88 
     | 
    
         
            +
                rem_packet_cnt = packet_cnt - tail_blit_packets
         
     | 
| 
      
 89 
     | 
    
         
            +
                ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
         
     | 
| 
      
 90 
     | 
    
         
            +
                if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                self._submit_packet(packet_cnt)
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
              def wait(self):
         
     | 
| 
      
 95 
     | 
    
         
            +
                self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
         
     | 
| 
      
 96 
     | 
    
         
            +
                hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 97 
     | 
    
         
            +
                self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
              def _wait_queue(self, need_packets=1):
         
     | 
| 
      
 100 
     | 
    
         
            +
                while self.available_packet_slots < need_packets:
         
     | 
| 
      
 101 
     | 
    
         
            +
                  rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
         
     | 
| 
      
 102 
     | 
    
         
            +
                  self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
              def _submit_packet(self, cnt=1):
         
     | 
| 
      
 105 
     | 
    
         
            +
                self.available_packet_slots -= cnt
         
     | 
| 
      
 106 
     | 
    
         
            +
                self.next_doorbell_index += cnt
         
     | 
| 
      
 107 
     | 
    
         
            +
                hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
         
     | 
| 
      
 108 
     | 
    
         
            +
                hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                self.write_addr += AQL_PACKET_SIZE * cnt
         
     | 
| 
      
 111 
     | 
    
         
            +
                if self.write_addr > self.write_addr_end:
         
     | 
| 
      
 112 
     | 
    
         
            +
                  self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
            def scan_agents():
         
     | 
| 
      
 115 
     | 
    
         
            +
              agents = collections.defaultdict(list)
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
              @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
         
     | 
| 
      
 118 
     | 
    
         
            +
              def __scan_agents(agent, data):
         
     | 
| 
      
 119 
     | 
    
         
            +
                status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
         
     | 
| 
      
 120 
     | 
    
         
            +
                if status == 0: agents[device_type.value].append(agent)
         
     | 
| 
      
 121 
     | 
    
         
            +
                return hsa.HSA_STATUS_SUCCESS
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
              hsa.hsa_iterate_agents(__scan_agents, None)
         
     | 
| 
      
 124 
     | 
    
         
            +
              return agents
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
            def find_memory_pool(agent, segtyp=-1, location=-1):
         
     | 
| 
      
 127 
     | 
    
         
            +
              @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
         
     | 
| 
      
 128 
     | 
    
         
            +
              def __filter_amd_memory_pools(mem_pool, data):
         
     | 
| 
      
 129 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
         
     | 
| 
      
 130 
     | 
    
         
            +
                if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
         
     | 
| 
      
 133 
     | 
    
         
            +
                if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
         
     | 
| 
      
 136 
     | 
    
         
            +
                if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
         
     | 
| 
      
 139 
     | 
    
         
            +
                ret[0] = mem_pool
         
     | 
| 
      
 140 
     | 
    
         
            +
                return hsa.HSA_STATUS_INFO_BREAK
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
              hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
         
     | 
| 
      
 143 
     | 
    
         
            +
              return region
         
     | 
| 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from typing import List, Dict, cast
         
     | 
| 
      
 2 
     | 
    
         
            +
            import ctypes
         
     | 
| 
      
 3 
     | 
    
         
            +
            from tinygrad.helpers import dedup, cpu_time_execution, GraphException, DEBUG
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.engine.jit import GraphRunner
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.device import Buffer, Device
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.engine.realize import ExecItem, CompiledRunner
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Variable
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.runtime.ops_clang import ClangProgram
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.renderer.cstyle import ClangRenderer
         
     | 
| 
      
 10 
     | 
    
         
            +
            render_dtype = ClangRenderer().render_dtype
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            class ClangGraph(GraphRunner):
         
     | 
| 
      
 13 
     | 
    
         
            +
              def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
         
     | 
| 
      
 14 
     | 
    
         
            +
                super().__init__(jit_cache, input_rawbuffers, var_vals)
         
     | 
| 
      
 15 
     | 
    
         
            +
                if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
         
     | 
| 
      
 18 
     | 
    
         
            +
                args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
         
     | 
| 
      
 19 
     | 
    
         
            +
                args += [f"int {v.expr}" for v in var_vals]
         
     | 
| 
      
 20 
     | 
    
         
            +
                code = ["void batched("+','.join(args)+") {"]
         
     | 
| 
      
 21 
     | 
    
         
            +
                for ji in jit_cache:
         
     | 
| 
      
 22 
     | 
    
         
            +
                  args = []
         
     | 
| 
      
 23 
     | 
    
         
            +
                  for buf in ji.bufs:
         
     | 
| 
      
 24 
     | 
    
         
            +
                    assert buf is not None
         
     | 
| 
      
 25 
     | 
    
         
            +
                    if buf in input_rawbuffers:
         
     | 
| 
      
 26 
     | 
    
         
            +
                      args.append(f"arg{input_rawbuffers.index(buf)}")
         
     | 
| 
      
 27 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 28 
     | 
    
         
            +
                      args.append(f"({render_dtype(buf.dtype)}*)0x{ctypes.addressof(buf._buf):X}")
         
     | 
| 
      
 29 
     | 
    
         
            +
                  args += [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
         
     | 
| 
      
 30 
     | 
    
         
            +
                  code.append(f"  {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
         
     | 
| 
      
 31 
     | 
    
         
            +
                code.append("}")
         
     | 
| 
      
 32 
     | 
    
         
            +
                if DEBUG >= 4: print("\n".join(code))
         
     | 
| 
      
 33 
     | 
    
         
            +
                compiler = Device["CLANG"].compiler
         
     | 
| 
      
 34 
     | 
    
         
            +
                assert compiler is not None
         
     | 
| 
      
 35 
     | 
    
         
            +
                self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
              def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
         
     | 
| 
      
 38 
     | 
    
         
            +
                return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
         
     | 
| 
         @@ -0,0 +1,81 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import ctypes
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Any, Optional, Tuple, Dict, List, cast
         
     | 
| 
      
 3 
     | 
    
         
            +
            import tinygrad.runtime.autogen.cuda as cuda
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.helpers import init_c_var, GraphException
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.device import Buffer, Device
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Variable
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.engine.jit import MultiGraphRunner
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            class CUDAGraph(MultiGraphRunner):
         
     | 
| 
      
 12 
     | 
    
         
            +
              def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
         
     | 
| 
      
 13 
     | 
    
         
            +
                super().__init__(jit_cache, input_rawbuffers, var_vals)
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                # Check all jit items are compatible.
         
     | 
| 
      
 16 
     | 
    
         
            +
                if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
         
     | 
| 
      
 19 
     | 
    
         
            +
                self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
                self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                for j,ji in enumerate(self.jit_cache):
         
     | 
| 
      
 24 
     | 
    
         
            +
                  if isinstance(ji.prg, CompiledRunner):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    global_size, local_size = ji.prg.p.launch_dims(var_vals)
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                    new_node = cuda.CUgraphNode()
         
     | 
| 
      
 28 
     | 
    
         
            +
                    deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None],
         
     | 
| 
      
 29 
     | 
    
         
            +
                                                  [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node)
         
     | 
| 
      
 30 
     | 
    
         
            +
                    c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                    c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars])
         
     | 
| 
      
 33 
     | 
    
         
            +
                    kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
         
     | 
| 
      
 34 
     | 
    
         
            +
                    check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                    if j in self.jc_idx_with_updatable_launch_dims or j in self.jc_idx_with_updatable_var_vals or j in self.jc_idx_with_updatable_rawbufs:
         
     | 
| 
      
 37 
     | 
    
         
            +
                      self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
         
     | 
| 
      
 38 
     | 
    
         
            +
                  elif isinstance(ji.prg, BufferXfer):
         
     | 
| 
      
 39 
     | 
    
         
            +
                    dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
         
     | 
| 
      
 40 
     | 
    
         
            +
                    src_dev = cast(CUDADevice, Device[src.device])
         
     | 
| 
      
 41 
     | 
    
         
            +
                    node_from = cuda.CUgraphNode()
         
     | 
| 
      
 42 
     | 
    
         
            +
                    deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from)
         
     | 
| 
      
 43 
     | 
    
         
            +
                    c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
         
     | 
| 
      
 44 
     | 
    
         
            +
                    cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
         
     | 
| 
      
 45 
     | 
    
         
            +
                                                      dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
         
     | 
| 
      
 46 
     | 
    
         
            +
                                                      WidthInBytes=dest.nbytes, Height=1, Depth=1)
         
     | 
| 
      
 47 
     | 
    
         
            +
                    check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
         
     | 
| 
      
 48 
     | 
    
         
            +
                    if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
              def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
         
     | 
| 
      
 53 
     | 
    
         
            +
                # Update rawbuffers in the c_args struct.
         
     | 
| 
      
 54 
     | 
    
         
            +
                for (j,i),input_idx in self.input_replace.items():
         
     | 
| 
      
 55 
     | 
    
         
            +
                  if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
         
     | 
| 
      
 56 
     | 
    
         
            +
                  else:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
         
     | 
| 
      
 58 
     | 
    
         
            +
                    elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                # Update var_vals in the c_args struct.
         
     | 
| 
      
 61 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_var_vals:
         
     | 
| 
      
 62 
     | 
    
         
            +
                  for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
         
     | 
| 
      
 63 
     | 
    
         
            +
                    setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                # Update launch dims in the kern_params struct.
         
     | 
| 
      
 66 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_launch_dims:
         
     | 
| 
      
 67 
     | 
    
         
            +
                  self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                # Update graph nodes with the updated structs.
         
     | 
| 
      
 70 
     | 
    
         
            +
                for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
         
     | 
| 
      
 71 
     | 
    
         
            +
                  if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params)))
         
     | 
| 
      
 72 
     | 
    
         
            +
                  else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args))
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait)
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
              def __del__(self):
         
     | 
| 
      
 77 
     | 
    
         
            +
                if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
         
     | 
| 
      
 78 
     | 
    
         
            +
                if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
              def set_kernel_node_launch_dims(self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]):
         
     | 
| 
      
 81 
     | 
    
         
            +
                node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size
         
     | 
| 
         @@ -0,0 +1,143 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import ctypes, collections, array, time
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import List, Any, Dict, cast, Optional, Tuple, Set
         
     | 
| 
      
 3 
     | 
    
         
            +
            from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.device import Buffer, BufferOptions, Compiled, Device
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Variable
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.engine.jit import MultiGraphRunner
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            class HCQGraph(MultiGraphRunner):
         
     | 
| 
      
 10 
     | 
    
         
            +
              def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
         
     | 
| 
      
 11 
     | 
    
         
            +
                super().__init__(jit_cache, input_rawbuffers, var_vals)
         
     | 
| 
      
 12 
     | 
    
         
            +
                self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                # Check all jit items are compatible.
         
     | 
| 
      
 15 
     | 
    
         
            +
                self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
         
     | 
| 
      
 16 
     | 
    
         
            +
                if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                # Allocate kernel args.
         
     | 
| 
      
 19 
     | 
    
         
            +
                kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
         
     | 
| 
      
 20 
     | 
    
         
            +
                for ji in self.jit_cache:
         
     | 
| 
      
 21 
     | 
    
         
            +
                  if not isinstance(ji.prg, CompiledRunner): continue
         
     | 
| 
      
 22 
     | 
    
         
            +
                  kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
         
     | 
| 
      
 23 
     | 
    
         
            +
                kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                # Fill initial arguments.
         
     | 
| 
      
 26 
     | 
    
         
            +
                self.kargs_addrs: Dict[int, int] = {}
         
     | 
| 
      
 27 
     | 
    
         
            +
                self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
         
     | 
| 
      
 28 
     | 
    
         
            +
                for j,ji in enumerate(self.jit_cache):
         
     | 
| 
      
 29 
     | 
    
         
            +
                  if not isinstance(ji.prg, CompiledRunner): continue
         
     | 
| 
      
 30 
     | 
    
         
            +
                  self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
         
     | 
| 
      
 31 
     | 
    
         
            +
                  kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                  args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
         
     | 
| 
      
 34 
     | 
    
         
            +
                                                 [(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
         
     | 
| 
      
 35 
     | 
    
         
            +
                  self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
         
     | 
| 
      
 36 
     | 
    
         
            +
                  for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
         
     | 
| 
      
 37 
     | 
    
         
            +
                  for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                  # NV needs constbuffer to be set
         
     | 
| 
      
 40 
     | 
    
         
            +
                  if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                # Build queues.
         
     | 
| 
      
 43 
     | 
    
         
            +
                self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
         
     | 
| 
      
 44 
     | 
    
         
            +
                self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
         
     | 
| 
      
 45 
     | 
    
         
            +
                self.comp_signal_val = {dev: 0 for dev in self.devices}
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
         
     | 
| 
      
 48 
     | 
    
         
            +
                self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
         
     | 
| 
      
 49 
     | 
    
         
            +
                self.copy_signal_val = {dev: 0 for dev in self.devices}
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                self.kickoff_signal = self.devices[0]._get_signal(value=0)
         
     | 
| 
      
 52 
     | 
    
         
            +
                self.kickoff_value = 0
         
     | 
| 
      
 53 
     | 
    
         
            +
                self.graph_timeline = {dev: 0 for dev in self.devices}
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
         
     | 
| 
      
 56 
     | 
    
         
            +
                self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                for j,ji in enumerate(self.jit_cache):
         
     | 
| 
      
 59 
     | 
    
         
            +
                  if isinstance(ji.prg, CompiledRunner):
         
     | 
| 
      
 60 
     | 
    
         
            +
                    exec_params = {}
         
     | 
| 
      
 61 
     | 
    
         
            +
                    deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
         
     | 
| 
      
 62 
     | 
    
         
            +
                    deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                    # On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
         
     | 
| 
      
 65 
     | 
    
         
            +
                    # Chaining executions is preferred when possible, as it is faster.
         
     | 
| 
      
 66 
     | 
    
         
            +
                    if ji.prg.device.dname.startswith("NV"):
         
     | 
| 
      
 67 
     | 
    
         
            +
                      if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
         
     | 
| 
      
 68 
     | 
    
         
            +
                        exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
         
     | 
| 
      
 69 
     | 
    
         
            +
                      else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                    for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                    self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
         
     | 
| 
      
 74 
     | 
    
         
            +
                    self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
         
     | 
| 
      
 75 
     | 
    
         
            +
                                                         signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
         
     | 
| 
      
 76 
     | 
    
         
            +
                    self.comp_signal_val[ji.prg.device] = sig_val
         
     | 
| 
      
 77 
     | 
    
         
            +
                  elif isinstance(ji.prg, BufferXfer):
         
     | 
| 
      
 78 
     | 
    
         
            +
                    dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
         
     | 
| 
      
 79 
     | 
    
         
            +
                    Device[src.device]._gpu_map(dest._buf) #type: ignore
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                    deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
         
     | 
| 
      
 82 
     | 
    
         
            +
                    deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
         
     | 
| 
      
 83 
     | 
    
         
            +
                    self.copy_signal_val[Device[src.device]] = sig_val
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                    for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
         
     | 
| 
      
 86 
     | 
    
         
            +
                    self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
         
     | 
| 
      
 87 
     | 
    
         
            +
                                                        .signal(self.copy_signal[Device[src.device]], sig_val)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.copy_to_devs[Device[dest.device]].add(Device[src.device])
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                for dev in self.devices:
         
     | 
| 
      
 91 
     | 
    
         
            +
                  if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
         
     | 
| 
      
 92 
     | 
    
         
            +
                  for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
                  if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
         
     | 
| 
      
 95 
     | 
    
         
            +
                  if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
              def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
         
     | 
| 
      
 98 
     | 
    
         
            +
                # Wait and restore signals
         
     | 
| 
      
 99 
     | 
    
         
            +
                self.kickoff_value += 1
         
     | 
| 
      
 100 
     | 
    
         
            +
                for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
         
     | 
| 
      
 101 
     | 
    
         
            +
                for dev in self.devices:
         
     | 
| 
      
 102 
     | 
    
         
            +
                  dev._set_signal(self.comp_signal[dev], 0)
         
     | 
| 
      
 103 
     | 
    
         
            +
                  dev._set_signal(self.copy_signal[dev], 0)
         
     | 
| 
      
 104 
     | 
    
         
            +
                dev._set_signal(self.kickoff_signal, self.kickoff_value)
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                # Update rawbuffers
         
     | 
| 
      
 107 
     | 
    
         
            +
                for (j,i),input_idx in self.input_replace.items():
         
     | 
| 
      
 108 
     | 
    
         
            +
                  self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                # Update var_vals
         
     | 
| 
      
 111 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_var_vals:
         
     | 
| 
      
 112 
     | 
    
         
            +
                  for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
         
     | 
| 
      
 113 
     | 
    
         
            +
                    self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_launch_dims:
         
     | 
| 
      
 116 
     | 
    
         
            +
                  queue, cmd_ptr = self.exec_ptrs[j]
         
     | 
| 
      
 117 
     | 
    
         
            +
                  queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                for dev in self.devices:
         
     | 
| 
      
 120 
     | 
    
         
            +
                  # Submit sync with world and queues.
         
     | 
| 
      
 121 
     | 
    
         
            +
                  self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
         
     | 
| 
      
 122 
     | 
    
         
            +
                                   .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
         
     | 
| 
      
 123 
     | 
    
         
            +
                  self.comp_queues[dev].submit(dev)
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                  if self.copy_signal_val[dev] > 0:
         
     | 
| 
      
 126 
     | 
    
         
            +
                    self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
         
     | 
| 
      
 127 
     | 
    
         
            +
                                     .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
         
     | 
| 
      
 128 
     | 
    
         
            +
                    self.copy_queues[dev].submit(dev)
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
                  # Signal the final value
         
     | 
| 
      
 131 
     | 
    
         
            +
                  self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
         
     | 
| 
      
 132 
     | 
    
         
            +
                  self.graph_timeline[dev] = dev.timeline_value
         
     | 
| 
      
 133 
     | 
    
         
            +
                  dev.timeline_value += 1
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                if wait:
         
     | 
| 
      
 136 
     | 
    
         
            +
                  st = time.perf_counter()
         
     | 
| 
      
 137 
     | 
    
         
            +
                  for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
         
     | 
| 
      
 138 
     | 
    
         
            +
                  return time.perf_counter() - st
         
     | 
| 
      
 139 
     | 
    
         
            +
                return None
         
     | 
| 
      
 140 
     | 
    
         
            +
             
     | 
| 
      
 141 
     | 
    
         
            +
              def access_resources(self, read, write, new_dependency):
         
     | 
| 
      
 142 
     | 
    
         
            +
                deps = self._access_resources(read, write, new_dependency)
         
     | 
| 
      
 143 
     | 
    
         
            +
                return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
         
     | 
| 
         @@ -0,0 +1,171 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import ctypes, collections, time, itertools
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import List, Any, Dict, cast, Optional, Tuple
         
     | 
| 
      
 3 
     | 
    
         
            +
            from tinygrad.helpers import GraphException, init_c_var, round_up
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.device import Buffer, BufferOptions
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.device import Compiled, Device
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.shape.symbolic import Variable
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.engine.jit import MultiGraphRunner
         
     | 
| 
      
 10 
     | 
    
         
            +
            import tinygrad.runtime.autogen.hsa as hsa
         
     | 
| 
      
 11 
     | 
    
         
            +
            from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            class VirtAQLQueue(AQLQueue):
         
     | 
| 
      
 16 
     | 
    
         
            +
              def __init__(self, device, sz):
         
     | 
| 
      
 17 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 18 
     | 
    
         
            +
                self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
         
     | 
| 
      
 19 
     | 
    
         
            +
                self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
         
     | 
| 
      
 20 
     | 
    
         
            +
                self.packets_count = 0
         
     | 
| 
      
 21 
     | 
    
         
            +
                self.available_packet_slots = sz
         
     | 
| 
      
 22 
     | 
    
         
            +
              def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
         
     | 
| 
      
 23 
     | 
    
         
            +
              def _submit_packet(self):
         
     | 
| 
      
 24 
     | 
    
         
            +
                self.write_addr += AQL_PACKET_SIZE
         
     | 
| 
      
 25 
     | 
    
         
            +
                self.packets_count += 1
         
     | 
| 
      
 26 
     | 
    
         
            +
                self.available_packet_slots -= 1
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
            class HSAGraph(MultiGraphRunner):
         
     | 
| 
      
 29 
     | 
    
         
            +
              def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
         
     | 
| 
      
 30 
     | 
    
         
            +
                super().__init__(jit_cache, input_rawbuffers, var_vals)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                # Check all jit items are compatible.
         
     | 
| 
      
 33 
     | 
    
         
            +
                compiled_devices = set()
         
     | 
| 
      
 34 
     | 
    
         
            +
                for ji in self.jit_cache:
         
     | 
| 
      
 35 
     | 
    
         
            +
                  if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
         
     | 
| 
      
 36 
     | 
    
         
            +
                  elif isinstance(ji.prg, BufferXfer):
         
     | 
| 
      
 37 
     | 
    
         
            +
                    for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
         
     | 
| 
      
 38 
     | 
    
         
            +
                  else: raise GraphException
         
     | 
| 
      
 39 
     | 
    
         
            +
                if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                # Allocate kernel args.
         
     | 
| 
      
 44 
     | 
    
         
            +
                kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
         
     | 
| 
      
 45 
     | 
    
         
            +
                for ji in self.jit_cache:
         
     | 
| 
      
 46 
     | 
    
         
            +
                  if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
         
     | 
| 
      
 47 
     | 
    
         
            +
                kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                # Fill initial arguments.
         
     | 
| 
      
 50 
     | 
    
         
            +
                self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
         
     | 
| 
      
 51 
     | 
    
         
            +
                for j,ji in enumerate(self.jit_cache):
         
     | 
| 
      
 52 
     | 
    
         
            +
                  if not isinstance(ji.prg, CompiledRunner): continue
         
     | 
| 
      
 53 
     | 
    
         
            +
                  self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
         
     | 
| 
      
 54 
     | 
    
         
            +
                  kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
         
     | 
| 
      
 55 
     | 
    
         
            +
                  for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
         
     | 
| 
      
 56 
     | 
    
         
            +
                  for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                # Build queues.
         
     | 
| 
      
 59 
     | 
    
         
            +
                self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
         
     | 
| 
      
 60 
     | 
    
         
            +
                self.packets = {}
         
     | 
| 
      
 61 
     | 
    
         
            +
                self.transfers = []
         
     | 
| 
      
 62 
     | 
    
         
            +
                self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
         
     | 
| 
      
 63 
     | 
    
         
            +
                self.signals_to_reset: List[hsa.hsa_signal_t] = []
         
     | 
| 
      
 64 
     | 
    
         
            +
                self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
         
     | 
| 
      
 65 
     | 
    
         
            +
                self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                # Special packet to wait for the world.
         
     | 
| 
      
 68 
     | 
    
         
            +
                self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
         
     | 
| 
      
 69 
     | 
    
         
            +
                for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                for j,ji in enumerate(self.jit_cache):
         
     | 
| 
      
 72 
     | 
    
         
            +
                  if isinstance(ji.prg, CompiledRunner):
         
     | 
| 
      
 73 
     | 
    
         
            +
                    wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
         
     | 
| 
      
 74 
     | 
    
         
            +
                    for i in range(0, len(wait_signals), 5):
         
     | 
| 
      
 75 
     | 
    
         
            +
                      self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
         
     | 
| 
      
 76 
     | 
    
         
            +
                    self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                    sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
         
     | 
| 
      
 79 
     | 
    
         
            +
                    self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
         
     | 
| 
      
 80 
     | 
    
         
            +
                                                                      ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
         
     | 
| 
      
 81 
     | 
    
         
            +
                    if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
         
     | 
| 
      
 82 
     | 
    
         
            +
                  elif isinstance(ji.prg, BufferXfer):
         
     | 
| 
      
 83 
     | 
    
         
            +
                    dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
         
     | 
| 
      
 84 
     | 
    
         
            +
                    dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
         
     | 
| 
      
 85 
     | 
    
         
            +
                    sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                    wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
         
     | 
| 
      
 89 
     | 
    
         
            +
                                          (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
         
     | 
| 
      
 90 
     | 
    
         
            +
                    self.ji_to_transfer[j] = len(self.transfers) - 1
         
     | 
| 
      
 91 
     | 
    
         
            +
                    if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                # Wait for all active signals to finish the graph
         
     | 
| 
      
 94 
     | 
    
         
            +
                wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
         
     | 
| 
      
 95 
     | 
    
         
            +
                for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
         
     | 
| 
      
 96 
     | 
    
         
            +
                  for dev in self.signals_to_devices[v.handle]:
         
     | 
| 
      
 97 
     | 
    
         
            +
                    wait_signals_to_finish[dev].append(v)
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
         
     | 
| 
      
 100 
     | 
    
         
            +
                for dev in self.devices:
         
     | 
| 
      
 101 
     | 
    
         
            +
                  wait_signals = wait_signals_to_finish[dev]
         
     | 
| 
      
 102 
     | 
    
         
            +
                  for i in range(0, max(1, len(wait_signals)), 5):
         
     | 
| 
      
 103 
     | 
    
         
            +
                    self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                # Zero signals to allow graph to start and execute.
         
     | 
| 
      
 106 
     | 
    
         
            +
                for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
         
     | 
| 
      
 107 
     | 
    
         
            +
                hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
              def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
         
     | 
| 
      
 110 
     | 
    
         
            +
                # Wait and restore signals
         
     | 
| 
      
 111 
     | 
    
         
            +
                hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 112 
     | 
    
         
            +
                for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
         
     | 
| 
      
 113 
     | 
    
         
            +
                hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                # Update rawbuffers
         
     | 
| 
      
 116 
     | 
    
         
            +
                for (j,i),input_idx in self.input_replace.items():
         
     | 
| 
      
 117 
     | 
    
         
            +
                  if j in self.ji_kargs_structs:
         
     | 
| 
      
 118 
     | 
    
         
            +
                    self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
         
     | 
| 
      
 119 
     | 
    
         
            +
                  else:
         
     | 
| 
      
 120 
     | 
    
         
            +
                    if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
         
     | 
| 
      
 121 
     | 
    
         
            +
                    elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                # Update var_vals
         
     | 
| 
      
 124 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_var_vals:
         
     | 
| 
      
 125 
     | 
    
         
            +
                  for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
         
     | 
| 
      
 126 
     | 
    
         
            +
                    self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                # Update launch dims
         
     | 
| 
      
 129 
     | 
    
         
            +
                for j in self.jc_idx_with_updatable_launch_dims:
         
     | 
| 
      
 130 
     | 
    
         
            +
                  gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
         
     | 
| 
      
 131 
     | 
    
         
            +
                  self.packets[j].workgroup_size_x = lc[0]
         
     | 
| 
      
 132 
     | 
    
         
            +
                  self.packets[j].workgroup_size_y = lc[1]
         
     | 
| 
      
 133 
     | 
    
         
            +
                  self.packets[j].workgroup_size_z = lc[2]
         
     | 
| 
      
 134 
     | 
    
         
            +
                  self.packets[j].grid_size_x = gl[0] * lc[0]
         
     | 
| 
      
 135 
     | 
    
         
            +
                  self.packets[j].grid_size_y = gl[1] * lc[1]
         
     | 
| 
      
 136 
     | 
    
         
            +
                  self.packets[j].grid_size_z = gl[2] * lc[2]
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                for dev in self.devices:
         
     | 
| 
      
 139 
     | 
    
         
            +
                  dev.flush_hdp()
         
     | 
| 
      
 140 
     | 
    
         
            +
                  dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                for transfer_data in self.transfers:
         
     | 
| 
      
 143 
     | 
    
         
            +
                  check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                et = None
         
     | 
| 
      
 146 
     | 
    
         
            +
                if wait:
         
     | 
| 
      
 147 
     | 
    
         
            +
                  st = time.perf_counter()
         
     | 
| 
      
 148 
     | 
    
         
            +
                  hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 149 
     | 
    
         
            +
                  et = time.perf_counter() - st
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
         
     | 
| 
      
 152 
     | 
    
         
            +
                return et
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
              def alloc_signal(self, reset_on_start=False, wait_on=None):
         
     | 
| 
      
 155 
     | 
    
         
            +
                sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
         
     | 
| 
      
 156 
     | 
    
         
            +
                if reset_on_start: self.signals_to_reset.append(sync_signal)
         
     | 
| 
      
 157 
     | 
    
         
            +
                if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
         
     | 
| 
      
 158 
     | 
    
         
            +
                return sync_signal
         
     | 
| 
      
 159 
     | 
    
         
            +
             
     | 
| 
      
 160 
     | 
    
         
            +
              def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
         
     | 
| 
      
 161 
     | 
    
         
            +
                if isinstance(dep, hsa.hsa_signal_t): return dep
         
     | 
| 
      
 162 
     | 
    
         
            +
                elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
         
     | 
| 
      
 163 
     | 
    
         
            +
                  if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
         
     | 
| 
      
 164 
     | 
    
         
            +
                  return packet.completion_signal
         
     | 
| 
      
 165 
     | 
    
         
            +
                return None
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
              def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
         
     | 
| 
      
 168 
     | 
    
         
            +
                rdeps = self._access_resources(read, write, new_dependency)
         
     | 
| 
      
 169 
     | 
    
         
            +
                wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
         
     | 
| 
      
 170 
     | 
    
         
            +
                if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
         
     | 
| 
      
 171 
     | 
    
         
            +
                return dedup_signals(wait_signals)
         
     |