tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+ from typing import Tuple, Any
3
+ import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys
4
+ assert sys.platform != 'win32'
5
+ from tinygrad.device import BufferOptions, Compiled, Allocator
6
+ from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv
7
+ from tinygrad.runtime.ops_clang import ClangCompiler
8
+ from tinygrad.renderer.cstyle import DSPRenderer
9
+ from tinygrad.runtime.autogen import libc, qcom_dsp
10
+ if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
11
+
12
+ def rpc_sc(method=0, ins=0, outs=0, fds=0): return (method << 24) | (ins << 16) | (outs << 8) | fds
13
+ def rpc_prep_args(ins=None, outs=None, in_fds=None):
14
+ ins, outs, in_fds = ins or list(), outs or list(), in_fds or list()
15
+
16
+ pra = (qcom_dsp.union_remote_arg * (len(ins) + len(outs) + len(in_fds)))()
17
+ fds = (ctypes.c_int32 * (len(ins) + len(outs) + len(in_fds)))(*([-1] * (len(ins) + len(outs))), *in_fds)
18
+ attrs = (ctypes.c_uint32 * (len(ins) + len(outs) + len(in_fds)))(*([0] * (len(ins) + len(outs))), *([1] * (len(in_fds))))
19
+
20
+ for i, mv in enumerate(ins + outs): pra[i].buf.pv, pra[i].buf.len = mv_address(mv) if mv.nbytes > 0 else 0, mv.nbytes
21
+ return pra, fds, attrs, (ins, outs)
22
+
23
+ class DSPProgram:
24
+ def __init__(self, device:DSPDevice, name:str, lib:bytes):
25
+ self.device, self.lib = device, lib
26
+
27
+ def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
28
+ if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
29
+
30
+ pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
31
+ outs=[timer:=memoryview(bytearray(8)).cast('Q')], in_fds=[b.share_info.fd for b in bufs])
32
+ var_vals_mv.cast('i')[:] = array.array('i', tuple(b.size for b in bufs) + vals)
33
+ off_mv.cast('I')[:] = array.array('I', tuple(b.offset for b in bufs))
34
+ self.device.exec_lib(self.lib, rpc_sc(method=2, ins=2, outs=1, fds=len(bufs)), pra, fds, attrs)
35
+ return timer[0] / 1e6
36
+
37
+ class DSPBuffer:
38
+ def __init__(self, va_addr:int, size:int, share_info:Any, offset:int=0):
39
+ self.va_addr, self.size, self.share_info, self.offset = va_addr, size, share_info, offset
40
+
41
+ class DSPAllocator(Allocator):
42
+ def __init__(self, device:DSPDevice):
43
+ self.device = device
44
+ super().__init__()
45
+
46
+ def _alloc(self, size:int, options:BufferOptions):
47
+ b = qcom_dsp.ION_IOC_ALLOC(self.device.ion_fd, len=size, align=0x200, heap_id_mask=1<<qcom_dsp.ION_SYSTEM_HEAP_ID, flags=qcom_dsp.ION_FLAG_CACHED)
48
+ share_info = qcom_dsp.ION_IOC_SHARE(self.device.ion_fd, handle=b.handle)
49
+ va_addr = libc.mmap(0, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, share_info.fd, 0)
50
+ return DSPBuffer(va_addr, size, share_info, offset=0)
51
+
52
+ def _free(self, opaque:DSPBuffer, options:BufferOptions):
53
+ libc.munmap(opaque.va_addr, opaque.size)
54
+ os.close(opaque.share_info.fd)
55
+ qcom_dsp.ION_IOC_FREE(self.device.ion_fd, handle=opaque.share_info.handle)
56
+
57
+ def as_buffer(self, src:DSPBuffer) -> memoryview: return to_mv(src.va_addr, src.size)
58
+ def copyin(self, dest:DSPBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
59
+ def copyout(self, dest:memoryview, src:DSPBuffer): ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
60
+ def offset(self, buf, size:int, offset:int): return DSPBuffer(buf.va_addr+offset, size, buf.share_info, buf.offset+offset)
61
+
62
+ class DSPDevice(Compiled):
63
+ def __init__(self, device:str=""):
64
+ self.ion_fd = os.open('/dev/ion', os.O_RDONLY)
65
+
66
+ # Generate link script to pass into clang. Aligning all used sections to 4k fixes invoke problem.
67
+ sections = ['hash', 'text', 'rela.plt', 'got', 'got.plt', 'dynamic', 'dynsym', 'dynstr', 'plt', 'data', 'bss']
68
+ sections_link = '\n'.join([f'.{n} : ALIGN(4096) {{ *(.{n}) }}' for n in sections])
69
+ with tempfile.NamedTemporaryFile(delete=False) as self.link_ld:
70
+ self.link_ld.write(f"SECTIONS {{ . = 0x0; {sections_link}\n /DISCARD/ : {{ *(.note .note.* .gnu.hash .comment) }} }}".encode())
71
+ self.link_ld.flush()
72
+
73
+ compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"]
74
+ super().__init__(device, DSPAllocator(self), DSPRenderer(),
75
+ ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self))
76
+
77
+ fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes()))
78
+ self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferOptions(nolru=True))
79
+ ctypes.memmove(self.shell_buf.va_addr, mv_address(fastrpc_shell), fastrpc_shell.nbytes)
80
+
81
+ self.init_dsp()
82
+ RPCListner(self).start()
83
+
84
+ def open_lib(self, lib):
85
+ self.binded_lib, self.binded_lib_off = lib, 0
86
+ fp = "file:///tinylib?entry&_modver=1.0&_dom=cdsp\0"
87
+ pra, _, _, _ = rpc_prep_args(ins=[memoryview(array.array('I', [len(fp), 0xff])), memoryview(bytearray(fp.encode()))],
88
+ outs=[o1:=memoryview(bytearray(0x8)), o2:=memoryview(bytearray(0xff))])
89
+ qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=0, sc=rpc_sc(method=0, ins=2, outs=2), pra=pra)
90
+ if o1.cast('i')[1] < 0: raise RuntimeError(f"Cannot open lib: {o2.tobytes().decode()}")
91
+ return o1.cast('I')[0]
92
+
93
+ def close_lib(self, handle):
94
+ pra, _, _, _ = rpc_prep_args(ins=[memoryview(array.array('I', [handle, 0xff]))], outs=[memoryview(bytearray(0x8)), memoryview(bytearray(0xff))])
95
+ qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=0, sc=rpc_sc(method=1, ins=1, outs=2), pra=pra)
96
+
97
+ def exec_lib(self, lib, sc, args, fds, attrs):
98
+ def _exec_lib():
99
+ handle = self.open_lib(lib)
100
+ qcom_dsp.FASTRPC_IOCTL_INVOKE_ATTRS(self.rpc_fd, fds=fds, attrs=attrs, inv=qcom_dsp.struct_fastrpc_ioctl_invoke(handle=handle, sc=sc, pra=args))
101
+ self.close_lib(handle)
102
+ try: _exec_lib()
103
+ except (OSError, PermissionError):
104
+ # DSP might ask for a connection reset or just fail with operation not permitted, try to reset connection.
105
+ self.init_dsp()
106
+ _exec_lib()
107
+
108
+ def init_dsp(self):
109
+ if hasattr(self, 'rpc_fd'):
110
+ with contextlib.suppress(OSError):
111
+ qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=4, sc=rpc_sc(method=2, ins=0, outs=0)) # pylint: disable=access-member-before-definition
112
+ os.close(self.rpc_fd) # pylint: disable=access-member-before-definition
113
+
114
+ self.rpc_fd: int = os.open('/dev/adsprpc-smd', os.O_RDONLY | os.O_NONBLOCK)
115
+ qcom_dsp.FASTRPC_IOCTL_GETINFO(self.rpc_fd, 3)
116
+ qcom_dsp.FASTRPC_IOCTL_CONTROL(self.rpc_fd, req=0x3)
117
+ qcom_dsp.FASTRPC_IOCTL_INIT(self.rpc_fd, flags=0x1, file=self.shell_buf.va_addr, filelen=self.shell_buf.size, filefd=self.shell_buf.share_info.fd)
118
+ qcom_dsp.FASTRPC_IOCTL_INVOKE(self.rpc_fd, handle=3, sc=rpc_sc(method=3, ins=0, outs=0))
119
+
120
+ class RPCListner(threading.Thread):
121
+ def __init__(self, device:DSPDevice):
122
+ super().__init__()
123
+ self.device, self.daemon = device, True
124
+
125
+ def run(self):
126
+ # Setup initial request arguments.
127
+ context, status, TINYFD = 0, 0xffffffff, 0xffff
128
+ req_args, _, _, _ = rpc_prep_args(ins=[msg_send:=memoryview(bytearray(0x10)).cast('I'), out_buf:=memoryview(bytearray(0x10000)).cast('I')],
129
+ outs=[msg_recv:=memoryview(bytearray(0x10)).cast('I'), in_buf:=memoryview(bytearray(0x10000)).cast('I')])
130
+ req_args[1].buf.len = 0
131
+
132
+ while True:
133
+ # Update message request and send it.
134
+ msg_send[:] = array.array('I', [context, status, req_args[1].buf.len, in_buf.nbytes])
135
+
136
+ try: qcom_dsp.FASTRPC_IOCTL_INVOKE(self.device.rpc_fd, handle=0x3, sc=0x04020200, pra=req_args)
137
+ except OSError: continue # retry
138
+
139
+ context, inbufs, outbufs = msg_recv[0], ((sc:=msg_recv[2]) >> 16) & 0xff, (msg_recv[2] >> 8) & 0xff
140
+
141
+ in_ptr, out_ptr, objs = mv_address(in_buf), mv_address(out_buf), []
142
+ for i in range(inbufs + outbufs):
143
+ obj_ptr = round_up(in_ptr + 4, 8) if i < inbufs else round_up(out_ptr + 4, 8)
144
+ objs.append(to_mv(obj_ptr, obj_size:=to_mv(in_ptr, 4).cast('I')[0]))
145
+ if i < inbufs: in_ptr = obj_ptr + obj_size
146
+ else:
147
+ to_mv(out_ptr, 4).cast('I')[0] = obj_size
148
+ out_ptr = obj_ptr + obj_size
149
+ in_ptr += 4
150
+
151
+ in_args, out_args = objs[:inbufs], objs[inbufs:]
152
+ req_args[1].buf.len = out_ptr - mv_address(out_buf)
153
+
154
+ status = 0 # reset status, will set if error
155
+ if sc == 0x20200: pass # greating
156
+ elif sc == 0x13050100: # open
157
+ try: out_args[0].cast('I')[0] = TINYFD if (name:=in_args[3].tobytes()[:-1].decode()) == "tinylib" else os.open(name, os.O_RDONLY)
158
+ except OSError: status = 1
159
+ elif sc == 0x3010000:
160
+ if (fd:=in_args[0].cast('I')[0]) != TINYFD: os.close(fd)
161
+ elif sc == 0x9010000: # seek
162
+ if (fd:=in_args[0].cast('I')[0]) == TINYFD:
163
+ assert in_args[0].cast('I')[2] == qcom_dsp.APPS_STD_SEEK_SET, "Supported only SEEK_SET"
164
+ res, self.device.binded_lib_off = 0, in_args[0].cast('I')[1]
165
+ else: res = os.lseek(fd, in_args[0].cast('I')[1], in_args[0].cast('I')[2])
166
+ status = 0 if res >= 0 else res
167
+ elif sc == 0x4010200: # read
168
+ if (fd:=in_args[0].cast('I')[0]) == TINYFD:
169
+ buf = self.device.binded_lib[self.device.binded_lib_off:self.device.binded_lib_off+in_args[0].cast('I')[1]]
170
+ self.device.binded_lib_off += len(buf)
171
+ else: buf = os.read(fd, in_args[0].cast('I')[1])
172
+ out_args[1][:len(buf)] = buf
173
+ out_args[0].cast('I')[0:2] = array.array('I', [len(buf), int(len(buf) == 0)])
174
+ elif sc == 0x1f020100: # stat
175
+ stat = os.stat(in_args[1].tobytes()[:-1].decode())
176
+ out_stat = qcom_dsp.struct_apps_std_STAT.from_address(mv_address(out_args[0]))
177
+ for f in out_stat._fields_: out_stat.__setattr__(f[0], int(getattr(stat, f"st_{f[0]}", 0)))
178
+ elif sc == 0x2010100: # mmap
179
+ st = qcom_dsp.FASTRPC_IOCTL_MMAP(self.device.rpc_fd, fd=-1, flags=in_args[0].cast('I')[2], vaddrin=0, size=in_args[0].cast('Q')[3])
180
+ out_args[0].cast('Q')[0:2] = array.array('Q', [0, st.vaddrout])
181
+ else: raise RuntimeError(f"Unknown op: {sc=:X}")
@@ -1,16 +1,17 @@
1
1
  from __future__ import annotations
2
2
  from typing import Tuple, Optional, List, cast
3
3
  import ctypes, functools, hashlib
4
- import tinygrad.runtime.autogen.opencl as cl
5
- from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG
6
- from tinygrad.renderer.cstyle import OpenCLRenderer
4
+ from tinygrad.runtime.autogen import opencl as cl
5
+ from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, getenv, mv_address
6
+ from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
7
7
  from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompileError
8
8
 
9
9
  # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
10
10
  OSX_TIMING_RATIO = (125/3) if OSX else 1.0
11
11
 
12
+ cl_errors = {attr: k for k in dir(cl) if k.startswith("CL_") and isinstance(attr:=getattr(cl, k), int) and attr <= 0}
12
13
  def check(status):
13
- if status != 0: raise RuntimeError(f"OpenCL Error {status}")
14
+ if status != 0: raise RuntimeError(f"OpenCL Error {status}: {cl_errors.get(status, 'Unknown error')}")
14
15
  def checked(ret, status): return (check(status.value), ret)[1]
15
16
 
16
17
  class CLCompiler(Compiler):
@@ -43,8 +44,8 @@ class CLProgram:
43
44
  if hasattr(self, 'kernel'): check(cl.clReleaseKernel(self.kernel))
44
45
  if hasattr(self, 'program'): check(cl.clReleaseProgram(self.program))
45
46
 
46
- def __call__(self, *bufs:ctypes._CData, global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
47
- for i,b in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
47
+ def __call__(self, *bufs:Tuple[ctypes._CData, BufferOptions], global_size:Tuple[int,int,int]=(1,1,1), local_size:Optional[Tuple[int,int,int]]=None, vals:Tuple[int, ...]=(), wait=False) -> Optional[float]: # noqa: E501
48
+ for i,(b,_) in enumerate(bufs): cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))
48
49
  for i,v in enumerate(vals,start=len(bufs)): cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))
49
50
  if local_size is not None: global_size = cast(Tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
50
51
  event = cl.cl_event() if wait else None
@@ -61,18 +62,27 @@ class CLAllocator(LRUAllocator):
61
62
  def __init__(self, device:CLDevice):
62
63
  self.device = device
63
64
  super().__init__()
64
- def _alloc(self, size:int, options:BufferOptions) -> ctypes._CData:
65
+ def _alloc(self, size:int, options:BufferOptions) -> Tuple[ctypes._CData, BufferOptions]:
65
66
  if options.image is not None:
66
- return checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
67
+ return (checked(cl.clCreateImage2D(self.device.context, cl.CL_MEM_READ_WRITE,
67
68
  cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
68
- options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status)
69
- return checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status)
70
- def _free(self, opaque:ctypes._CData, options:BufferOptions): check(cl.clReleaseMemObject(opaque))
71
- def copyin(self, dest:ctypes._CData, src:memoryview):
72
- check(cl.clEnqueueWriteBuffer(self.device.queue, dest, False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
69
+ options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
70
+ return (checked(cl.clCreateBuffer(self.device.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
71
+ def _free(self, opaque:Tuple[ctypes._CData, BufferOptions], options:BufferOptions): check(cl.clReleaseMemObject(opaque[0]))
72
+ def copyin(self, dest:Tuple[ctypes._CData, BufferOptions], src:memoryview):
73
+ if dest[1].image is not None:
74
+ check(cl.clEnqueueWriteImage(self.device.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
75
+ (ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
76
+ else:
77
+ if mv_address(src) % 16: src = memoryview(bytearray(src))
78
+ check(cl.clEnqueueWriteBuffer(self.device.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
73
79
  self.device.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
74
- def copyout(self, dest:memoryview, src:ctypes._CData):
75
- check(cl.clEnqueueReadBuffer(self.device.queue, src, False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
80
+ def copyout(self, dest:memoryview, src:Tuple[ctypes._CData, BufferOptions]):
81
+ if src[1].image is not None:
82
+ check(cl.clEnqueueReadImage(self.device.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
83
+ (ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
84
+ else:
85
+ check(cl.clEnqueueReadBuffer(self.device.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
76
86
  self.device.synchronize()
77
87
 
78
88
  class CLDevice(Compiled):
@@ -90,12 +100,15 @@ class CLDevice(Compiled):
90
100
  self.device_id = CLDevice.device_ids[0 if ":" not in device else int(device.split(":")[1])]
91
101
  self.device_name = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_NAME, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
92
102
  self.driver_version = (cl.clGetDeviceInfo(self.device_id, cl.CL_DRIVER_VERSION, 256, buf := ctypes.create_string_buffer(256), None), buf.value.decode())[1] # noqa: E501
103
+ if DEBUG >= 1: print(f"CLDevice: opening {self.device_name} with version {self.driver_version}")
93
104
  self.context = checked(cl.clCreateContext(None, 1, self.device_id, cl.clCreateContext.argtypes[3](), None, status := ctypes.c_int32()), status)
94
105
  self.queue = checked(cl.clCreateCommandQueue(self.context, self.device_id, cl.CL_QUEUE_PROFILING_ENABLE, status), status)
95
106
  self.pending_copyin: List[memoryview] = []
107
+ self.device_exts = (cl.clGetDeviceInfo(self.device_id, cl.CL_DEVICE_EXTENSIONS, 4096, ctypes.byref(buf := ctypes.create_string_buffer(4096)), ctypes.byref(total := ctypes.c_size_t())), ctypes.string_at(buf, size=total.value).decode())[1] # noqa: E501
96
108
 
97
109
  compile_key = hashlib.md5(self.device_name.encode() + self.driver_version.encode()).hexdigest()
98
- super().__init__(device, CLAllocator(self), OpenCLRenderer(), CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
110
+ renderer = IntelRenderer() if "cl_intel_subgroup_matrix_multiply_accumulate" in self.device_exts and getenv("INTEL") else OpenCLRenderer()
111
+ super().__init__(device, CLAllocator(self), renderer, CLCompiler(self, f"compile_cl_{compile_key}"), functools.partial(CLProgram, self))
99
112
  def synchronize(self):
100
113
  check(cl.clFinish(self.queue))
101
114
  self.pending_copyin.clear()
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+ import ctypes, functools
3
+ from typing import Tuple
4
+ from tinygrad.helpers import init_c_var, from_mv, init_c_struct_t, getenv
5
+ from tinygrad.device import Compiled, LRUAllocator, BufferOptions
6
+ from tinygrad.runtime.autogen import hip
7
+ from tinygrad.runtime.support.compiler_hip import AMDCompiler
8
+ from tinygrad.renderer.cstyle import HIPRenderer
9
+ if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
10
+
11
+ def check(status):
12
+ if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
13
+
14
+ class HIPProgram:
15
+ def __init__(self, device:HIPDevice, name:str, lib:bytes):
16
+ self.device, self.name, self.lib = device, name, lib
17
+ check(hip.hipSetDevice(self.device.device_id))
18
+ self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib)))
19
+ self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8"))))
20
+
21
+ def __del__(self):
22
+ if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module))
23
+
24
+ def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
25
+ check(hip.hipSetDevice(self.device.device_id))
26
+ if not hasattr(self, "vargs"):
27
+ self.c_args = init_c_struct_t(tuple([(f'f{i}', hip.hipDeviceptr_t) for i in range(len(args))] +
28
+ [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals)
29
+ self.vargs = (ctypes.c_void_p * 5)(1, ctypes.cast(ctypes.byref(self.c_args), ctypes.c_void_p), 2,
30
+ ctypes.cast(ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(self.c_args))), ctypes.c_void_p), 3)
31
+
32
+ for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
33
+ for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
34
+
35
+ if wait: check(hip.hipEventRecord(self.device.time_event_st, None))
36
+
37
+ check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs))
38
+
39
+ if wait:
40
+ check(hip.hipEventRecord(self.device.time_event_en, None))
41
+ check(hip.hipEventSynchronize(self.device.time_event_en))
42
+ check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), self.device.time_event_st, self.device.time_event_en))
43
+ return ret.value * 1e-3
44
+
45
+ class HIPAllocator(LRUAllocator):
46
+ def __init__(self, device:HIPDevice):
47
+ self.device = device
48
+ super().__init__()
49
+ def _alloc(self, size:int, options:BufferOptions):
50
+ check(hip.hipSetDevice(self.device.device_id))
51
+ return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size)))
52
+ def _free(self, opaque, options:BufferOptions): check(hip.hipFree(opaque))
53
+ def copyin(self, dest, src: memoryview):
54
+ check(hip.hipSetDevice(self.device.device_id))
55
+ check(hip.hipMemcpy(dest, from_mv(src), len(src), hip.hipMemcpyHostToDevice))
56
+ def copyout(self, dest:memoryview, src):
57
+ self.device.synchronize()
58
+ check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
59
+
60
+ class HIPDevice(Compiled):
61
+ def __init__(self, device:str=""):
62
+ self.device_id = int(device.split(":")[1]) if ":" in device else 0
63
+ self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
64
+ self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
65
+ super().__init__(device, HIPAllocator(self), HIPRenderer(), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
66
+ def synchronize(self):
67
+ check(hip.hipSetDevice(self.device_id))
68
+ check(hip.hipDeviceSynchronize())
@@ -2,27 +2,35 @@ from __future__ import annotations
2
2
  import ctypes, functools
3
3
  from typing import Tuple
4
4
  from tinygrad.device import Compiled, Compiler, MallocAllocator
5
- from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump
5
+ from tinygrad.helpers import cpu_time_execution, getenv, cpu_objdump
6
6
  from tinygrad.renderer.llvmir import LLVMRenderer
7
7
  import llvmlite.binding as llvm
8
8
 
9
9
  class LLVMCompiler(Compiler):
10
- def __init__(self, device:LLVMDevice):
10
+ def __init__(self, device:LLVMDevice, opt:bool=False):
11
11
  self.device = device
12
- super().__init__("compile_llvm")
12
+ self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
13
+ self.device.target_machine.add_analysis_passes(self.optimizer)
14
+ if opt:
15
+ with llvm.create_pass_manager_builder() as builder:
16
+ builder.opt_level = 3; builder.size_level = 0; builder.loop_vectorize = True; builder.slp_vectorize = True # noqa: E702
17
+ builder.populate(self.optimizer)
18
+ super().__init__("compile_llvm_opt" if opt else "compile_llvm")
19
+
13
20
  def compile(self, src:str) -> bytes:
14
21
  mod = llvm.parse_assembly(src)
15
22
  mod.verify()
16
- self.device.optimizer.run(mod)
17
- if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod))
23
+ self.optimizer.run(mod)
18
24
  return self.device.target_machine.emit_object(mod)
19
25
 
26
+ def disassemble(self, lib:bytes): cpu_objdump(lib)
27
+
20
28
  class LLVMProgram:
21
29
  def __init__(self, device:LLVMDevice, name:str, lib:bytes):
22
- if DEBUG >= 6: cpu_objdump(lib)
23
30
  self.name, self.lib = name, lib
24
31
  device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
25
32
  self.fxn = device.engine.get_function_address(name)
33
+ assert self.fxn != 0, "LLVM failed to get function address"
26
34
 
27
35
  def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
28
36
  if not hasattr(self, 'cfunc'):
@@ -35,12 +43,9 @@ class LLVMDevice(Compiled):
35
43
  llvm.initialize_native_target()
36
44
  llvm.initialize_native_asmprinter()
37
45
  llvm.initialize_native_asmparser()
38
- self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
39
46
  # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
40
47
  self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2)
41
- self.target_machine.add_analysis_passes(self.optimizer)
42
- self.target_machine.set_asm_verbosity(True)
43
48
  backing_mod = llvm.parse_assembly(str())
44
49
  backing_mod.triple = llvm.get_process_triple()
45
50
  self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
46
- super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
51
+ super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self, getenv("LLVMOPT")), functools.partial(LLVMProgram, self))