tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/runtime/ops_gpu.py
    CHANGED
    
    | 
         @@ -1,106 +1,103 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
            import  
     | 
| 
       3 
     | 
    
         
            -
            import  
     | 
| 
       4 
     | 
    
         
            -
            import  
     | 
| 
       5 
     | 
    
         
            -
            from  
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       7 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       8 
     | 
    
         
            -
            from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
         
     | 
| 
       9 
     | 
    
         
            -
            from tinygrad.codegen.linearizer import LinearizerOptions
         
     | 
| 
       10 
     | 
    
         
            -
            from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Tuple, Optional, List, cast
         
     | 
| 
      
 3 
     | 
    
         
            +
            import ctypes, functools, hashlib
         
     | 
| 
      
 4 
     | 
    
         
            +
            import tinygrad.runtime.autogen.opencl as cl
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.renderer.cstyle import OpenCLRenderer
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompileError
         
     | 
| 
       11 
8 
     | 
    
         | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
      
 9 
     | 
    
         
            +
            # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
         
     | 
| 
      
 10 
     | 
    
         
            +
            OSX_TIMING_RATIO = (125/3) if OSX else 1.0
         
     | 
| 
       13 
11 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            if DEBUG >= 5:
         
     | 
| 
       18 
     | 
    
         
            -
              early_exec = fromimport("extra.helpers", "enable_early_exec")()
         
     | 
| 
      
 12 
     | 
    
         
            +
            def check(status):
         
     | 
| 
      
 13 
     | 
    
         
            +
              if status != 0: raise RuntimeError(f"OpenCL Error {status}")
         
     | 
| 
      
 14 
     | 
    
         
            +
            def checked(ret, status): return (check(status.value), ret)[1]
         
     | 
| 
       19 
15 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
            class  
     | 
| 
       21 
     | 
    
         
            -
              def  
     | 
| 
       22 
     | 
    
         
            -
                 
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
                 
     | 
| 
       28 
     | 
    
         
            -
                   
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
             
     | 
| 
       31 
     | 
    
         
            -
             
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
                 
     | 
| 
       35 
     | 
    
         
            -
                platform_devices: List[List[cl.Device]] = [y for y in ([x.get_devices(device_type=cl.device_type.GPU) for x in cl_platforms] + [x.get_devices(device_type=cl.device_type.CPU) for x in cl_platforms]) if y]
         
     | 
| 
       36 
     | 
    
         
            -
                self.devices = [device for device in platform_devices[getenv('CL_PLATFORM', 0)] if device.name not in getenv('CL_EXCLUDE', "").split(",")]
         
     | 
| 
       37 
     | 
    
         
            -
                self.cl_platform = self.devices[0].platform
         
     | 
| 
       38 
     | 
    
         
            -
              def post_init(self, device=None):
         
     | 
| 
       39 
     | 
    
         
            -
                self.cl_ctxs: List[cl.Context] = [cl.Context(devices=[x]) for x in self.devices] if device is None else [cl.Context(devices=[self.devices[device]])]
         
     | 
| 
       40 
     | 
    
         
            -
                if DEBUG >= 1: print(f"using devices: {[ctx.devices[0].hashable_model_and_version_identifier for ctx in self.cl_ctxs]}")
         
     | 
| 
       41 
     | 
    
         
            -
                self.cl_queue: List[cl.CommandQueue] = [cl.CommandQueue(ctx, device=ctx.devices[0], properties=cl.command_queue_properties.PROFILING_ENABLE) for ctx in self.cl_ctxs]
         
     | 
| 
       42 
     | 
    
         
            -
                self.cl_allocator = CLAllocator(CL.cl_ctxs[0].devices[0].get_info(cl.device_info.GLOBAL_MEM_SIZE))
         
     | 
| 
       43 
     | 
    
         
            -
              def synchronize(self):
         
     | 
| 
       44 
     | 
    
         
            -
                for q in self.cl_queue: q.finish()
         
     | 
| 
       45 
     | 
    
         
            -
            CL = _CL()
         
     | 
| 
       46 
     | 
    
         
            -
            if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
         
     | 
| 
       47 
     | 
    
         
            -
             
     | 
| 
       48 
     | 
    
         
            -
            class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
         
     | 
| 
       49 
     | 
    
         
            -
              def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
         
     | 
| 
       50 
     | 
    
         
            -
              def _copyin(self, x:np.ndarray):
         
     | 
| 
       51 
     | 
    
         
            -
                assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
         
     | 
| 
       52 
     | 
    
         
            -
                self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements='C'), is_blocking=False)
         
     | 
| 
       53 
     | 
    
         
            -
              def _copyout(self, x:np.ndarray):
         
     | 
| 
       54 
     | 
    
         
            -
                assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
         
     | 
| 
       55 
     | 
    
         
            -
                buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
         
     | 
| 
       56 
     | 
    
         
            -
                mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
         
     | 
| 
       57 
     | 
    
         
            -
                with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
         
     | 
| 
       58 
     | 
    
         
            -
              def _transfer(self, x):
         
     | 
| 
       59 
     | 
    
         
            -
                if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
         
     | 
| 
       60 
     | 
    
         
            -
                  cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
         
     | 
| 
       61 
     | 
    
         
            -
                else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd")
         
     | 
| 
      
 16 
     | 
    
         
            +
            class CLCompiler(Compiler):
         
     | 
| 
      
 17 
     | 
    
         
            +
              def __init__(self, device:CLDevice, compile_key:str):
         
     | 
| 
      
 18 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 19 
     | 
    
         
            +
                super().__init__(f"compile_cl_{compile_key}")
         
     | 
| 
      
 20 
     | 
    
         
            +
              def compile(self, src:str) -> bytes:
         
     | 
| 
      
 21 
     | 
    
         
            +
                program = checked(cl.clCreateProgramWithSource(self.device.context, 1, to_char_p_p([src.encode()]), None, status := ctypes.c_int32()), status)
         
     | 
| 
      
 22 
     | 
    
         
            +
                build_status: int = cl.clBuildProgram(program, 1, self.device.device_id, None, cl.clBuildProgram.argtypes[4](), None)
         
     | 
| 
      
 23 
     | 
    
         
            +
                if build_status != 0:
         
     | 
| 
      
 24 
     | 
    
         
            +
                  cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, log_size := ctypes.c_size_t())
         
     | 
| 
      
 25 
     | 
    
         
            +
                  cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None)  # noqa: E501
         
     | 
| 
      
 26 
     | 
    
         
            +
                  raise CompileError(f"OpenCL Compile Error\n\n{mstr.value.decode()}")
         
     | 
| 
      
 27 
     | 
    
         
            +
                check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(ctypes.c_size_t), binary_sizes := (ctypes.c_size_t * 1)(), None))
         
     | 
| 
      
 28 
     | 
    
         
            +
                check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), (ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None))  # noqa: E501
         
     | 
| 
      
 29 
     | 
    
         
            +
                check(cl.clReleaseProgram(program))
         
     | 
| 
      
 30 
     | 
    
         
            +
                return bytes(binary)
         
     | 
| 
       62 
31 
     | 
    
         | 
| 
       63 
32 
     | 
    
         
             
            class CLProgram:
         
     | 
| 
       64 
     | 
    
         
            -
              def __init__(self,  
     | 
| 
       65 
     | 
    
         
            -
                self. 
     | 
| 
       66 
     | 
    
         
            -
                 
     | 
| 
       67 
     | 
    
         
            -
             
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
                self. 
     | 
| 
       72 
     | 
    
         
            -
                if DEBUG >= 5 and not OSX:
         
     | 
| 
       73 
     | 
    
         
            -
                  if 'Adreno' in CL.cl_ctxs[0].devices[0].name:
         
     | 
| 
       74 
     | 
    
         
            -
                    fromimport('disassemblers.adreno', 'disasm')(self.binary())
         
     | 
| 
       75 
     | 
    
         
            -
                  elif CL.cl_ctxs[0].devices[0].name.startswith('gfx'):
         
     | 
| 
       76 
     | 
    
         
            -
                    asm = early_exec(([ROCM_LLVM_PATH / "llvm-objdump", '-d', '-'], self.binary()))
         
     | 
| 
       77 
     | 
    
         
            -
                    print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
         
     | 
| 
       78 
     | 
    
         
            -
                  else:
         
     | 
| 
       79 
     | 
    
         
            -
                    # print the PTX for NVIDIA. TODO: probably broken for everything else
         
     | 
| 
       80 
     | 
    
         
            -
                    print(self.binary().decode('utf-8'))
         
     | 
| 
       81 
     | 
    
         
            -
                if argdtypes is not None: self.set_argdtypes(argdtypes)
         
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
              def binary(self): return self.clprograms[0].get_info(cl.program_info.BINARIES)[0]
         
     | 
| 
       84 
     | 
    
         
            -
              def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_scalar_arg_dtypes(argdtypes) for clprg in self.clprgs]
         
     | 
| 
      
 33 
     | 
    
         
            +
              def __init__(self, device:CLDevice, name:str, lib:bytes):
         
     | 
| 
      
 34 
     | 
    
         
            +
                self.device, self.name, self.lib = device, name, lib
         
     | 
| 
      
 35 
     | 
    
         
            +
                self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)),
         
     | 
| 
      
 36 
     | 
    
         
            +
                                                                    to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(),
         
     | 
| 
      
 37 
     | 
    
         
            +
                                                                    errcode_ret := ctypes.c_int32()), errcode_ret)
         
     | 
| 
      
 38 
     | 
    
         
            +
                check(binary_status.value)
         
     | 
| 
      
 39 
     | 
    
         
            +
                check(cl.clBuildProgram(self.program, 1, device.device_id, None, cl.clBuildProgram.argtypes[4](), None)) # NOTE: OSX requires this
         
     | 
| 
      
 40 
     | 
    
         
            +
                self.kernel = checked(cl.clCreateKernel(self.program, name.encode(), status := ctypes.c_int32()), status)
         
     | 
| 
       85 
41 
     | 
    
         | 
| 
       86 
     | 
    
         
            -
               
     | 
| 
       87 
     | 
    
         
            -
             
     | 
| 
      
 42 
     | 
    
         
            +
              def __del__(self):
         
     | 
| 
      
 43 
     | 
    
         
            +
                if hasattr(self, 'kernel'): check(cl.clReleaseKernel(self.kernel))
         
     | 
| 
      
 44 
     | 
    
         
            +
                if hasattr(self, 'program'): check(cl.clReleaseProgram(self.program))
         
     | 
| 
       88 
45 
     | 
    
         | 
| 
       89 
     | 
    
         
            -
              def __call__(self, global_size, local_size,  
     | 
| 
       90 
     | 
    
         
            -
                 
     | 
| 
       91 
     | 
    
         
            -
                 
     | 
| 
       92 
     | 
    
         
            -
                 
     | 
| 
      
 46 
     | 
    
         
            +
              def __call__(self, *bufs:ctypes._CData, global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]:  # noqa: E501
         
     | 
| 
      
 47 
     | 
    
         
            +
                for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
         
     | 
| 
      
 48 
     | 
    
         
            +
                for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))
         
     | 
| 
      
 49 
     | 
    
         
            +
                if local_size is not None: global_size = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
         
     | 
| 
      
 50 
     | 
    
         
            +
                event = cl.cl_event() if wait else None
         
     | 
| 
      
 51 
     | 
    
         
            +
                check(cl.clEnqueueNDRangeKernel(self.device.queue, self.kernel, len(global_size), None, (ctypes.c_size_t * len(global_size))(*global_size), (ctypes.c_size_t * len(local_size))(*local_size) if local_size else None, 0, None, event))  # noqa: E501
         
     | 
| 
       93 
52 
     | 
    
         
             
                if wait:
         
     | 
| 
       94 
     | 
    
         
            -
                   
     | 
| 
       95 
     | 
    
         
            -
                   
     | 
| 
       96 
     | 
    
         
            -
             
     | 
| 
       97 
     | 
    
         
            -
                   
     | 
| 
       98 
     | 
    
         
            -
             
     | 
| 
      
 53 
     | 
    
         
            +
                  assert event is not None
         
     | 
| 
      
 54 
     | 
    
         
            +
                  check(cl.clWaitForEvents(1, event))
         
     | 
| 
      
 55 
     | 
    
         
            +
                  check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_START, 8, ctypes.byref(start := ctypes.c_uint64()), None))
         
     | 
| 
      
 56 
     | 
    
         
            +
                  check(cl.clGetEventProfilingInfo(event, cl.CL_PROFILING_COMMAND_END, 8, ctypes.byref(end := ctypes.c_uint64()), None))
         
     | 
| 
      
 57 
     | 
    
         
            +
                  return float(end.value-start.value) * OSX_TIMING_RATIO * 1e-9
         
     | 
| 
       99 
58 
     | 
    
         
             
                return None
         
     | 
| 
       100 
59 
     | 
    
         | 
| 
       101 
     | 
    
         
            -
             
     | 
| 
       102 
     | 
    
         
            -
               
     | 
| 
       103 
     | 
    
         
            -
             
     | 
| 
       104 
     | 
    
         
            -
             
     | 
| 
       105 
     | 
    
         
            -
               
     | 
| 
       106 
     | 
    
         
            -
             
     | 
| 
      
 60 
     | 
    
         
            +
            class CLAllocator(LRUAllocator):
         
     | 
| 
      
 61 
     | 
    
         
            +
              def __init__(self, device:CLDevice):
         
     | 
| 
      
 62 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 63 
     | 
    
         
            +
                super().__init__()
         
     | 
| 
      
 64 
     | 
    
         
            +
              def _alloc(self, size:int, options:BufferOptions) -> ctypes._CData:
         
     | 
| 
      
 65 
     | 
    
         
            +
                if options.image is not None:
         
     | 
| 
      
 66 
     | 
    
         
            +
                  return checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
         
     | 
| 
      
 67 
     | 
    
         
            +
                                                    cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
         
     | 
| 
      
 68 
     | 
    
         
            +
                                                    options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status)
         
     | 
| 
      
 69 
     | 
    
         
            +
                else: return checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status)
         
     | 
| 
      
 70 
     | 
    
         
            +
              def _free(self, buf:ctypes._CData, options:BufferOptions): check(cl.clReleaseMemObject(buf))
         
     | 
| 
      
 71 
     | 
    
         
            +
              def copyin(self, dest:ctypes._CData, src:memoryview):
         
     | 
| 
      
 72 
     | 
    
         
            +
                check(cl.clEnqueueWriteBuffer(self.device.queue, dest, False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
         
     | 
| 
      
 73 
     | 
    
         
            +
                self.device.pending_copyin.append(src)    # NOTE: these can't be freed until the GPU actually executes this command
         
     | 
| 
      
 74 
     | 
    
         
            +
              def copyout(self, dest:memoryview, src:ctypes._CData):
         
     | 
| 
      
 75 
     | 
    
         
            +
                check(cl.clEnqueueReadBuffer(self.device.queue, src, False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
         
     | 
| 
      
 76 
     | 
    
         
            +
                self.device.synchronize()
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
            class CLDevice(Compiled):
         
     | 
| 
      
 79 
     | 
    
         
            +
              device_ids = None                 # this is global and only initted once
         
     | 
| 
      
 80 
     | 
    
         
            +
              def __init__(self, device:str=""):
         
     | 
| 
      
 81 
     | 
    
         
            +
                if CLDevice.device_ids is None:
         
     | 
| 
      
 82 
     | 
    
         
            +
                  check(cl.clGetPlatformIDs(0, None, num_platforms := ctypes.c_uint32()))
         
     | 
| 
      
 83 
     | 
    
         
            +
                  check(cl.clGetPlatformIDs(num_platforms.value, platform_ids := (cl.cl_platform_id * num_platforms.value)(), None))
         
     | 
| 
      
 84 
     | 
    
         
            +
                  for device_type in [cl.CL_DEVICE_TYPE_GPU, cl.CL_DEVICE_TYPE_DEFAULT]:
         
     | 
| 
      
 85 
     | 
    
         
            +
                    err = cl.clGetDeviceIDs(platform_ids[0], device_type, 0, None, num_devices := ctypes.c_uint32())
         
     | 
| 
      
 86 
     | 
    
         
            +
                    if err == 0 and num_devices.value != 0: break
         
     | 
| 
      
 87 
     | 
    
         
            +
                  if DEBUG >= 1: print(f"CLDevice: got {num_platforms.value} platforms and {num_devices.value} devices")
         
     | 
| 
      
 88 
     | 
    
         
            +
                  CLDevice.device_ids = init_c_var((cl.cl_device_id * num_devices.value)(), lambda x: check(cl.clGetDeviceIDs(platform_ids[0], device_type, num_devices, x, None)))  # noqa: E501
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
         
     | 
| 
      
 91 
     | 
    
         
            +
                self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1]  # noqa: E501
         
     | 
| 
      
 92 
     | 
    
         
            +
                self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1]  # noqa: E501
         
     | 
| 
      
 93 
     | 
    
         
            +
                self.context = checked(cl.clCreateContext(None, 1, self.device_id, cl.clCreateContext.argtypes[3](), None, status := ctypes.c_int32()), status)
         
     | 
| 
      
 94 
     | 
    
         
            +
                self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, status), status)
         
     | 
| 
      
 95 
     | 
    
         
            +
                self.pending_copyin: List[memoryview] = []
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
         
     | 
| 
      
 98 
     | 
    
         
            +
                super().__init__(device, CLAllocator(self), OpenCLRenderer(), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
         
     | 
| 
      
 99 
     | 
    
         
            +
              def synchronize(self):
         
     | 
| 
      
 100 
     | 
    
         
            +
                check(cl.clFinish(self.queue))
         
     | 
| 
      
 101 
     | 
    
         
            +
                self.pending_copyin.clear()
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
            GPUDevice = CLDevice # for legacy reasons
         
     | 
| 
         @@ -0,0 +1,278 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import ctypes, functools, subprocess, io, atexit, collections, json
         
     | 
| 
      
 3 
     | 
    
         
            +
            from typing import Tuple, TypeVar, List, Dict, Any
         
     | 
| 
      
 4 
     | 
    
         
            +
            import tinygrad.runtime.autogen.hsa as hsa
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
         
     | 
| 
      
 7 
     | 
    
         
            +
            from tinygrad.renderer.cstyle import HIPRenderer
         
     | 
| 
      
 8 
     | 
    
         
            +
            from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tinygrad.runtime.driver.hip_comgr import compile_hip
         
     | 
| 
      
 10 
     | 
    
         
            +
            if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl  # noqa: F401
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            PROFILE = getenv("PROFILE", 0)
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            class HSAProfiler:
         
     | 
| 
      
 15 
     | 
    
         
            +
              def __init__(self):
         
     | 
| 
      
 16 
     | 
    
         
            +
                self.tracked_signals = collections.defaultdict(list)
         
     | 
| 
      
 17 
     | 
    
         
            +
                self.collected_events: List[Tuple[Any, ...]] = []
         
     | 
| 
      
 18 
     | 
    
         
            +
                self.copy_timings = hsa.hsa_amd_profiling_async_copy_time_t()
         
     | 
| 
      
 19 
     | 
    
         
            +
                self.disp_timings = hsa.hsa_amd_profiling_dispatch_time_t()
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
              def track(self, signal, device, name, is_copy=False): self.tracked_signals[device].append((signal, name, is_copy))
         
     | 
| 
      
 22 
     | 
    
         
            +
              def process(self, device):
         
     | 
| 
      
 23 
     | 
    
         
            +
                # Process all tracked signals, should be called before any of tracked signals are reused.
         
     | 
| 
      
 24 
     | 
    
         
            +
                for sig,name,is_copy in self.tracked_signals[device]:
         
     | 
| 
      
 25 
     | 
    
         
            +
                  if is_copy: check(hsa.hsa_amd_profiling_get_async_copy_time(sig, ctypes.byref(timings :=  self.copy_timings)))
         
     | 
| 
      
 26 
     | 
    
         
            +
                  else: check(hsa.hsa_amd_profiling_get_dispatch_time(device.agent, sig, ctypes.byref(timings := self.disp_timings))) #type:ignore
         
     | 
| 
      
 27 
     | 
    
         
            +
                  self.collected_events.append((device.device_id, 1 if is_copy else 0, name, timings.start, timings.end))
         
     | 
| 
      
 28 
     | 
    
         
            +
                self.tracked_signals.pop(device)
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
              def save(self, path):
         
     | 
| 
      
 31 
     | 
    
         
            +
                mjson = []
         
     | 
| 
      
 32 
     | 
    
         
            +
                for i in range(len(HSADevice.devices)):
         
     | 
| 
      
 33 
     | 
    
         
            +
                  mjson.append({"name": "process_name", "ph": "M", "pid": i, "args": {"name": "HSA"}})
         
     | 
| 
      
 34 
     | 
    
         
            +
                  mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 0, "args": {"name": "AQL"}})
         
     | 
| 
      
 35 
     | 
    
         
            +
                  mjson.append({"name": "thread_name", "ph": "M", "pid": i, "tid": 1, "args": {"name": "SDMA"}})
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                for dev_id,queue_id,name,st,et in self.collected_events:
         
     | 
| 
      
 38 
     | 
    
         
            +
                  mjson.append({"name": name, "ph": "B", "pid": dev_id, "tid": queue_id, "ts": st*1e-3})
         
     | 
| 
      
 39 
     | 
    
         
            +
                  mjson.append({"name": name, "ph": "E", "pid": dev_id, "tid": queue_id, "ts": et*1e-3})
         
     | 
| 
      
 40 
     | 
    
         
            +
                with open(path, "w") as f: f.write(json.dumps({"traceEvents": mjson}))
         
     | 
| 
      
 41 
     | 
    
         
            +
                print(f"Saved HSA profile to {path}")
         
     | 
| 
      
 42 
     | 
    
         
            +
            Profiler = HSAProfiler()
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            class HSACompiler(Compiler):
         
     | 
| 
      
 45 
     | 
    
         
            +
              def __init__(self, arch:str):
         
     | 
| 
      
 46 
     | 
    
         
            +
                self.arch = arch
         
     | 
| 
      
 47 
     | 
    
         
            +
                super().__init__(f"compile_hip_{self.arch}")
         
     | 
| 
      
 48 
     | 
    
         
            +
              def compile(self, src:str) -> bytes:
         
     | 
| 
      
 49 
     | 
    
         
            +
                try: return compile_hip(src, self.arch)
         
     | 
| 
      
 50 
     | 
    
         
            +
                except RuntimeError as e: raise CompileError(e)
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
            class HSAProgram:
         
     | 
| 
      
 53 
     | 
    
         
            +
              def __init__(self, device:HSADevice, name:str, lib:bytes):
         
     | 
| 
      
 54 
     | 
    
         
            +
                self.device, self.name, self.lib = device, name, lib
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                if DEBUG >= 6:
         
     | 
| 
      
 57 
     | 
    
         
            +
                  asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
         
     | 
| 
      
 58 
     | 
    
         
            +
                  print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
         
     | 
| 
      
 61 
     | 
    
         
            +
                self.code_reader = init_c_var(hsa.hsa_code_object_reader_t(),
         
     | 
| 
      
 62 
     | 
    
         
            +
                                              lambda x: check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(x))))
         
     | 
| 
      
 63 
     | 
    
         
            +
                check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, self.code_reader, None, None))
         
     | 
| 
      
 64 
     | 
    
         
            +
                check(hsa.hsa_executable_freeze(self.exec, None))
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
         
     | 
| 
      
 67 
     | 
    
         
            +
                self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
         
     | 
| 
      
 68 
     | 
    
         
            +
                self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
         
     | 
| 
      
 69 
     | 
    
         
            +
                self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
         
     | 
| 
      
 70 
     | 
    
         
            +
                self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
              def __del__(self):
         
     | 
| 
      
 73 
     | 
    
         
            +
                self.device.synchronize()
         
     | 
| 
      
 74 
     | 
    
         
            +
                if hasattr(self, 'code_reader'): check(hsa.hsa_code_object_reader_destroy(self.code_reader))
         
     | 
| 
      
 75 
     | 
    
         
            +
                if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
              def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
         
     | 
| 
      
 78 
     | 
    
         
            +
                if not hasattr(self, "args_struct_t"):
         
     | 
| 
      
 79 
     | 
    
         
            +
                  self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
         
     | 
| 
      
 80 
     | 
    
         
            +
                                                             [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
         
     | 
| 
      
 81 
     | 
    
         
            +
                  if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
         
     | 
| 
      
 82 
     | 
    
         
            +
                    raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                kernargs = None
         
     | 
| 
      
 85 
     | 
    
         
            +
                if self.kernargs_segment_size > 0:
         
     | 
| 
      
 86 
     | 
    
         
            +
                  kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
         
     | 
| 
      
 87 
     | 
    
         
            +
                  args_st = self.args_struct_t.from_address(kernargs)
         
     | 
| 
      
 88 
     | 
    
         
            +
                  for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
         
     | 
| 
      
 89 
     | 
    
         
            +
                  for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
         
     | 
| 
      
 90 
     | 
    
         
            +
                  self.device.flush_hdp()
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
         
     | 
| 
      
 93 
     | 
    
         
            +
                self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
         
     | 
| 
      
 94 
     | 
    
         
            +
                if PROFILE: Profiler.track(signal, self.device, self.name)
         
     | 
| 
      
 95 
     | 
    
         
            +
                if wait:
         
     | 
| 
      
 96 
     | 
    
         
            +
                  hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 97 
     | 
    
         
            +
                  check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
         
     | 
| 
      
 98 
     | 
    
         
            +
                  return (timings.end - timings.start) * self.device.clocks_to_time
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
            T = TypeVar("T")
         
     | 
| 
      
 101 
     | 
    
         
            +
            CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
         
     | 
| 
      
 102 
     | 
    
         
            +
            class HSAAllocator(LRUAllocator):
         
     | 
| 
      
 103 
     | 
    
         
            +
              def __init__(self, device:HSADevice):
         
     | 
| 
      
 104 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 105 
     | 
    
         
            +
                super().__init__()
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
              def _alloc(self, size:int, options:BufferOptions):
         
     | 
| 
      
 108 
     | 
    
         
            +
                if options.host:
         
     | 
| 
      
 109 
     | 
    
         
            +
                  check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, size, 0, ctypes.byref(mem := ctypes.c_void_p())))
         
     | 
| 
      
 110 
     | 
    
         
            +
                  check(hsa.hsa_amd_agents_allow_access(2, (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent), None, mem))
         
     | 
| 
      
 111 
     | 
    
         
            +
                  return mem.value
         
     | 
| 
      
 112 
     | 
    
         
            +
                else:
         
     | 
| 
      
 113 
     | 
    
         
            +
                  c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU])
         
     | 
| 
      
 114 
     | 
    
         
            +
                  check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
         
     | 
| 
      
 115 
     | 
    
         
            +
                  check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf))
         
     | 
| 
      
 116 
     | 
    
         
            +
                  return buf.value
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
              def _free(self, opaque:T, options:BufferOptions):
         
     | 
| 
      
 119 
     | 
    
         
            +
                HSADevice.synchronize_system()
         
     | 
| 
      
 120 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_pool_free(opaque))
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
              def copyin(self, dest:T, src: memoryview):
         
     | 
| 
      
 123 
     | 
    
         
            +
                # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
         
     | 
| 
      
 124 
     | 
    
         
            +
                self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
         
     | 
| 
      
 125 
     | 
    
         
            +
                mem = self._alloc(src.nbytes, BufferOptions(host=True))
         
     | 
| 
      
 126 
     | 
    
         
            +
                ctypes.memmove(mem, from_mv(src), src.nbytes)
         
     | 
| 
      
 127 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
         
     | 
| 
      
 128 
     | 
    
         
            +
                                                              copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
         
     | 
| 
      
 129 
     | 
    
         
            +
                self.device.hw_queue.submit_barrier([copy_signal])
         
     | 
| 
      
 130 
     | 
    
         
            +
                self.device.delayed_free.append(mem)
         
     | 
| 
      
 131 
     | 
    
         
            +
                if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
              def copy_from_fd(self, dest, fd, offset, size):
         
     | 
| 
      
 134 
     | 
    
         
            +
                self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
                if not hasattr(self, 'hb'):
         
     | 
| 
      
 137 
     | 
    
         
            +
                  self.hb = [self._alloc(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
         
     | 
| 
      
 138 
     | 
    
         
            +
                  self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
         
     | 
| 
      
 139 
     | 
    
         
            +
                  self.hb_polarity = 0
         
     | 
| 
      
 140 
     | 
    
         
            +
                  self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
         
     | 
| 
      
 141 
     | 
    
         
            +
                  for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                fo = io.FileIO(fd, "a+b", closefd=False)
         
     | 
| 
      
 144 
     | 
    
         
            +
                fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
         
     | 
| 
      
 145 
     | 
    
         
            +
             
     | 
| 
      
 146 
     | 
    
         
            +
                copies_called = 0
         
     | 
| 
      
 147 
     | 
    
         
            +
                copied_in = 0
         
     | 
| 
      
 148 
     | 
    
         
            +
                for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
         
     | 
| 
      
 149 
     | 
    
         
            +
                  local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
         
     | 
| 
      
 150 
     | 
    
         
            +
                  copy_size = min(local_size-minor_offset, size-copied_in)
         
     | 
| 
      
 151 
     | 
    
         
            +
                  if copy_size == 0: break
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                  hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 154 
     | 
    
         
            +
                  self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
         
     | 
| 
      
 155 
     | 
    
         
            +
                  self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                  fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
         
     | 
| 
      
 158 
     | 
    
         
            +
                  check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
         
     | 
| 
      
 159 
     | 
    
         
            +
                                                                copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
         
     | 
| 
      
 160 
     | 
    
         
            +
                                                                self.sdma[self.hb_polarity], True))
         
     | 
| 
      
 161 
     | 
    
         
            +
                  copied_in += copy_size
         
     | 
| 
      
 162 
     | 
    
         
            +
                  self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
         
     | 
| 
      
 163 
     | 
    
         
            +
                  minor_offset = 0 # only on the first
         
     | 
| 
      
 164 
     | 
    
         
            +
                  copies_called += 1
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
                wait_signals = [self.hb_signals[self.hb_polarity - 1]]
         
     | 
| 
      
 167 
     | 
    
         
            +
                if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
         
     | 
| 
      
 168 
     | 
    
         
            +
                self.device.hw_queue.submit_barrier(wait_signals)
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
              def copyout(self, dest:memoryview, src:T):
         
     | 
| 
      
 171 
     | 
    
         
            +
                HSADevice.synchronize_system()
         
     | 
| 
      
 172 
     | 
    
         
            +
                copy_signal = self.device.alloc_signal(reusable=True)
         
     | 
| 
      
 173 
     | 
    
         
            +
                c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent)
         
     | 
| 
      
 174 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
         
     | 
| 
      
 175 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
         
     | 
| 
      
 176 
     | 
    
         
            +
                hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
         
     | 
| 
      
 177 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
         
     | 
| 
      
 178 
     | 
    
         
            +
                if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
              def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
         
     | 
| 
      
 181 
     | 
    
         
            +
                src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
         
     | 
| 
      
 182 
     | 
    
         
            +
                dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
         
     | 
| 
      
 183 
     | 
    
         
            +
                c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
         
     | 
| 
      
 184 
     | 
    
         
            +
                check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
         
     | 
| 
      
 185 
     | 
    
         
            +
                                                              copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
         
     | 
| 
      
 186 
     | 
    
         
            +
                src_dev.hw_queue.submit_barrier([copy_signal])
         
     | 
| 
      
 187 
     | 
    
         
            +
                dest_dev.hw_queue.submit_barrier([copy_signal])
         
     | 
| 
      
 188 
     | 
    
         
            +
                if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
         
     | 
| 
      
 189 
     | 
    
         
            +
             
     | 
| 
      
 190 
     | 
    
         
            +
            class HSADevice(Compiled):
         
     | 
| 
      
 191 
     | 
    
         
            +
              devices: List[HSADevice] = []
         
     | 
| 
      
 192 
     | 
    
         
            +
              agents: Dict[int, List[hsa.hsa_agent_t]] = {}
         
     | 
| 
      
 193 
     | 
    
         
            +
              cpu_agent: hsa.hsa_agent_t
         
     | 
| 
      
 194 
     | 
    
         
            +
              cpu_mempool: hsa.hsa_amd_memory_pool_t
         
     | 
| 
      
 195 
     | 
    
         
            +
              def __init__(self, device:str=""):
         
     | 
| 
      
 196 
     | 
    
         
            +
                if not HSADevice.agents:
         
     | 
| 
      
 197 
     | 
    
         
            +
                  check(hsa.hsa_init())
         
     | 
| 
      
 198 
     | 
    
         
            +
                  atexit.register(hsa_terminate)
         
     | 
| 
      
 199 
     | 
    
         
            +
                  HSADevice.agents = scan_agents()
         
     | 
| 
      
 200 
     | 
    
         
            +
                  HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0]
         
     | 
| 
      
 201 
     | 
    
         
            +
                  HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
         
     | 
| 
      
 202 
     | 
    
         
            +
                  if PROFILE: check(hsa.hsa_amd_profiling_async_copy_enable(1))
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                self.device_id = int(device.split(":")[1]) if ":" in device else 0
         
     | 
| 
      
 205 
     | 
    
         
            +
                self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id]
         
     | 
| 
      
 206 
     | 
    
         
            +
                self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
         
     | 
| 
      
 207 
     | 
    
         
            +
                self.hw_queue = AQLQueue(self)
         
     | 
| 
      
 208 
     | 
    
         
            +
                HSADevice.devices.append(self)
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
                check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
         
     | 
| 
      
 211 
     | 
    
         
            +
                self.arch = ctypes.string_at(agent_name_buf).decode()
         
     | 
| 
      
 212 
     | 
    
         
            +
             
     | 
| 
      
 213 
     | 
    
         
            +
                check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
         
     | 
| 
      
 214 
     | 
    
         
            +
                self.clocks_to_time: float = 1 / gpu_freq.value
         
     | 
| 
      
 215 
     | 
    
         
            +
             
     | 
| 
      
 216 
     | 
    
         
            +
                check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AMD_AGENT_INFO_HDP_FLUSH, ctypes.byref(hdp_flush := hsa.hsa_amd_hdp_flush_t())))
         
     | 
| 
      
 217 
     | 
    
         
            +
                self.hdp_flush = hdp_flush
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                self.delayed_free: List[int] = []
         
     | 
| 
      
 220 
     | 
    
         
            +
                self.reusable_signals: List[hsa.hsa_signal_t] = []
         
     | 
| 
      
 221 
     | 
    
         
            +
             
     | 
| 
      
 222 
     | 
    
         
            +
                from tinygrad.runtime.graph.hsa import HSAGraph
         
     | 
| 
      
 223 
     | 
    
         
            +
                super().__init__(device, HSAAllocator(self), HIPRenderer(), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
         
     | 
| 
      
 224 
     | 
    
         
            +
             
     | 
| 
      
 225 
     | 
    
         
            +
                # Finish init: preallocate some signals + space for kernargs
         
     | 
| 
      
 226 
     | 
    
         
            +
                self.signal_pool = [init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) for _ in range(4096)]
         
     | 
| 
      
 227 
     | 
    
         
            +
                self._new_kernargs_region(16 << 20) # initial region size is 16mb
         
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
              def synchronize(self):
         
     | 
| 
      
 230 
     | 
    
         
            +
                self.hw_queue.wait()
         
     | 
| 
      
 231 
     | 
    
         
            +
             
     | 
| 
      
 232 
     | 
    
         
            +
                for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
         
     | 
| 
      
 233 
     | 
    
         
            +
                self.signal_pool.extend(self.reusable_signals)
         
     | 
| 
      
 234 
     | 
    
         
            +
                self.reusable_signals.clear()
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
         
     | 
| 
      
 237 
     | 
    
         
            +
                self.delayed_free.clear()
         
     | 
| 
      
 238 
     | 
    
         
            +
             
     | 
| 
      
 239 
     | 
    
         
            +
                self.kernarg_next_addr = self.kernarg_start_addr
         
     | 
| 
      
 240 
     | 
    
         
            +
                Profiler.process(self)
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
              @staticmethod
         
     | 
| 
      
 243 
     | 
    
         
            +
              def synchronize_system():
         
     | 
| 
      
 244 
     | 
    
         
            +
                for d in HSADevice.devices: d.synchronize()
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
              def alloc_signal(self, reusable=False):
         
     | 
| 
      
 247 
     | 
    
         
            +
                if len(self.signal_pool): signal = self.signal_pool.pop()
         
     | 
| 
      
 248 
     | 
    
         
            +
                else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                # reusable means a signal could be reused after synchronize for the device it's allocated from is called.
         
     | 
| 
      
 251 
     | 
    
         
            +
                if reusable: self.reusable_signals.append(signal)
         
     | 
| 
      
 252 
     | 
    
         
            +
                return signal
         
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
              def alloc_kernargs(self, sz):
         
     | 
| 
      
 255 
     | 
    
         
            +
                if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2))
         
     | 
| 
      
 256 
     | 
    
         
            +
                result = self.kernarg_next_addr
         
     | 
| 
      
 257 
     | 
    
         
            +
                self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16)
         
     | 
| 
      
 258 
     | 
    
         
            +
                return result
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
              def _new_kernargs_region(self, sz:int):
         
     | 
| 
      
 261 
     | 
    
         
            +
                if hasattr(self, 'kernarg_start_addr'): self.delayed_free.append(self.kernarg_start_addr)
         
     | 
| 
      
 262 
     | 
    
         
            +
                self.kernarg_start_addr: int = self.allocator._alloc(sz, BufferOptions())
         
     | 
| 
      
 263 
     | 
    
         
            +
                self.kernarg_next_addr = self.kernarg_start_addr
         
     | 
| 
      
 264 
     | 
    
         
            +
                self.kernarg_pool_sz: int = sz
         
     | 
| 
      
 265 
     | 
    
         
            +
             
     | 
| 
      
 266 
     | 
    
         
            +
              def flush_hdp(self): self.hdp_flush.HDP_MEM_FLUSH_CNTL[0] = 1
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
            def hsa_terminate():
         
     | 
| 
      
 269 
     | 
    
         
            +
              # Need to stop/delete aql queue before hsa shut down, this leads to gpu hangs.
         
     | 
| 
      
 270 
     | 
    
         
            +
              for dev in HSADevice.devices:
         
     | 
| 
      
 271 
     | 
    
         
            +
                Profiler.process(dev)
         
     | 
| 
      
 272 
     | 
    
         
            +
                del dev.hw_queue
         
     | 
| 
      
 273 
     | 
    
         
            +
             
     | 
| 
      
 274 
     | 
    
         
            +
              # hsa_shut_down cleans up all hsa-related resources.
         
     | 
| 
      
 275 
     | 
    
         
            +
              hsa.hsa_shut_down()
         
     | 
| 
      
 276 
     | 
    
         
            +
              HSADevice.synchronize = lambda: None #type:ignore
         
     | 
| 
      
 277 
     | 
    
         
            +
              HSAProgram.__del__ = lambda _: None #type:ignore
         
     | 
| 
      
 278 
     | 
    
         
            +
              if Profiler.collected_events: Profiler.save("/tmp/profile.json")
         
     | 
    
        tinygrad/runtime/ops_llvm.py
    CHANGED
    
    | 
         @@ -1,67 +1,46 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
             
     | 
| 
       2 
     | 
    
         
            -
             
     | 
| 
       3 
     | 
    
         
            -
            from  
     | 
| 
       4 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       5 
     | 
    
         
            -
            from  
     | 
| 
       6 
     | 
    
         
            -
            from tinygrad. 
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
             
     | 
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
            import ctypes, functools
         
     | 
| 
      
 3 
     | 
    
         
            +
            from typing import Tuple
         
     | 
| 
      
 4 
     | 
    
         
            +
            from tinygrad.device import Compiled, Compiler, MallocAllocator
         
     | 
| 
      
 5 
     | 
    
         
            +
            from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tinygrad.renderer.llvmir import LLVMRenderer
         
     | 
| 
      
 7 
     | 
    
         
            +
            import llvmlite.binding as llvm
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            class LLVMCompiler(Compiler):
         
     | 
| 
      
 10 
     | 
    
         
            +
              def __init__(self, device:LLVMDevice):
         
     | 
| 
      
 11 
     | 
    
         
            +
                self.device = device
         
     | 
| 
      
 12 
     | 
    
         
            +
                super().__init__("compile_llvm")
         
     | 
| 
      
 13 
     | 
    
         
            +
              def compile(self, src:str) -> bytes:
         
     | 
| 
      
 14 
     | 
    
         
            +
                mod = llvm.parse_assembly(src)
         
     | 
| 
      
 15 
     | 
    
         
            +
                mod.verify()
         
     | 
| 
      
 16 
     | 
    
         
            +
                self.device.optimizer.run(mod)
         
     | 
| 
      
 17 
     | 
    
         
            +
                if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod))
         
     | 
| 
      
 18 
     | 
    
         
            +
                return self.device.target_machine.emit_object(mod)
         
     | 
| 
       9 
19 
     | 
    
         | 
| 
       10 
     | 
    
         
            -
             
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
              def  
     | 
| 
       18 
     | 
    
         
            -
                if  
     | 
| 
      
 20 
     | 
    
         
            +
            class LLVMProgram:
         
     | 
| 
      
 21 
     | 
    
         
            +
              def __init__(self, device:LLVMDevice, name:str, lib:bytes):
         
     | 
| 
      
 22 
     | 
    
         
            +
                if DEBUG >= 6: cpu_objdump(lib)
         
     | 
| 
      
 23 
     | 
    
         
            +
                self.name, self.lib = name, lib
         
     | 
| 
      
 24 
     | 
    
         
            +
                device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
         
     | 
| 
      
 25 
     | 
    
         
            +
                self.fxn = device.engine.get_function_address(name)
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
              def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
         
     | 
| 
      
 28 
     | 
    
         
            +
                if not hasattr(self, 'cfunc'):
         
     | 
| 
      
 29 
     | 
    
         
            +
                  self.cfunc = ctypes.CFUNCTYPE(ctypes.c_int, *([ctypes.c_void_p]*len(bufs)), *([ctypes.c_int32]*len(vals)))(self.fxn)
         
     | 
| 
      
 30 
     | 
    
         
            +
                return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            class LLVMDevice(Compiled):
         
     | 
| 
      
 33 
     | 
    
         
            +
              def __init__(self, device:str):
         
     | 
| 
       19 
34 
     | 
    
         
             
                llvm.initialize()
         
     | 
| 
       20 
35 
     | 
    
         
             
                llvm.initialize_native_target()
         
     | 
| 
       21 
36 
     | 
    
         
             
                llvm.initialize_native_asmprinter()
         
     | 
| 
       22 
37 
     | 
    
         
             
                llvm.initialize_native_asmparser()
         
     | 
| 
       23 
     | 
    
         
            -
                 
     | 
| 
       24 
     | 
    
         
            -
                 
     | 
| 
       25 
     | 
    
         
            -
                 
     | 
| 
       26 
     | 
    
         
            -
                 
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
                # TODO: this makes compile times so much faster
         
     | 
| 
       29 
     | 
    
         
            -
                if getenv("LLVMOPT"):
         
     | 
| 
       30 
     | 
    
         
            -
                  llvm.set_option(str(), '-force-vector-interleave=4')  # this makes sum the same speed as torch, it also doubles the (slow) conv speed
         
     | 
| 
       31 
     | 
    
         
            -
                  if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize')
         
     | 
| 
       32 
     | 
    
         
            -
                  #llvm.set_option(str(), '--debug')
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
                  # does this do anything?
         
     | 
| 
       35 
     | 
    
         
            -
                  builder = llvm.create_pass_manager_builder()
         
     | 
| 
       36 
     | 
    
         
            -
                  builder.opt_level = 3
         
     | 
| 
       37 
     | 
    
         
            -
                  builder.size_level = 0
         
     | 
| 
       38 
     | 
    
         
            -
                  builder.loop_vectorize = True
         
     | 
| 
       39 
     | 
    
         
            -
                  builder.slp_vectorize = True
         
     | 
| 
       40 
     | 
    
         
            -
                  builder.populate(LLVM.optimizer)
         
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
                LLVM.target_machine.set_asm_verbosity(True)
         
     | 
| 
      
 38 
     | 
    
         
            +
                self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
         
     | 
| 
      
 39 
     | 
    
         
            +
                # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
         
     | 
| 
      
 40 
     | 
    
         
            +
                self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2)
         
     | 
| 
      
 41 
     | 
    
         
            +
                self.target_machine.add_analysis_passes(self.optimizer)
         
     | 
| 
      
 42 
     | 
    
         
            +
                self.target_machine.set_asm_verbosity(True)
         
     | 
| 
       43 
43 
     | 
    
         
             
                backing_mod = llvm.parse_assembly(str())
         
     | 
| 
       44 
44 
     | 
    
         
             
                backing_mod.triple = llvm.get_process_triple()
         
     | 
| 
       45 
     | 
    
         
            -
                 
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
            class LLVMProgram:
         
     | 
| 
       48 
     | 
    
         
            -
              def __init__(self, name:str, prg:str, binary=False):
         
     | 
| 
       49 
     | 
    
         
            -
                self.mod = llvm.parse_assembly(prg)
         
     | 
| 
       50 
     | 
    
         
            -
                self.mod.verify()
         
     | 
| 
       51 
     | 
    
         
            -
                LLVM().optimizer.run(self.mod)
         
     | 
| 
       52 
     | 
    
         
            -
                self.mod.name = hashlib.sha1(prg.encode('utf-8')).hexdigest()
         
     | 
| 
       53 
     | 
    
         
            -
                if DEBUG >= 5: print(LLVM.target_machine.emit_assembly(self.mod))
         
     | 
| 
       54 
     | 
    
         
            -
                LLVM.engine.add_module(self.mod)
         
     | 
| 
       55 
     | 
    
         
            -
                LLVM.engine.finalize_object()
         
     | 
| 
       56 
     | 
    
         
            -
                self.fxn = LLVM.engine.get_function_address(name)
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
              def __del__(self):
         
     | 
| 
       59 
     | 
    
         
            -
                if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
         
     | 
| 
       60 
     | 
    
         
            -
             
     | 
| 
       61 
     | 
    
         
            -
              def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
         
     | 
| 
       62 
     | 
    
         
            -
                cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
         
     | 
| 
       63 
     | 
    
         
            -
                if wait: st = time.monotonic()
         
     | 
| 
       64 
     | 
    
         
            -
                cfunc(*[x._buf for x in bufs])
         
     | 
| 
       65 
     | 
    
         
            -
                if wait: return time.monotonic()-st
         
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
            LLVMBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), uops_to_llvm_ir, LLVMProgram)
         
     | 
| 
      
 45 
     | 
    
         
            +
                self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
         
     | 
| 
      
 46 
     | 
    
         
            +
                super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
         
     |