tinygrad 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. tinygrad/__init__.py +11 -6
  2. tinygrad/codegen/kernel.py +308 -175
  3. tinygrad/codegen/linearize.py +95 -0
  4. tinygrad/codegen/lowerer.py +143 -0
  5. tinygrad/codegen/transcendental.py +257 -0
  6. tinygrad/codegen/uopgraph.py +506 -0
  7. tinygrad/device.py +72 -171
  8. tinygrad/dtype.py +122 -47
  9. tinygrad/engine/jit.py +184 -87
  10. tinygrad/{lazy.py → engine/lazy.py} +74 -66
  11. tinygrad/engine/memory.py +51 -0
  12. tinygrad/engine/realize.py +86 -61
  13. tinygrad/engine/schedule.py +366 -317
  14. tinygrad/engine/search.py +58 -47
  15. tinygrad/function.py +59 -58
  16. tinygrad/helpers.py +120 -102
  17. tinygrad/multi.py +82 -78
  18. tinygrad/nn/__init__.py +116 -67
  19. tinygrad/nn/datasets.py +12 -5
  20. tinygrad/nn/optim.py +1 -1
  21. tinygrad/nn/state.py +91 -6
  22. tinygrad/ops.py +1126 -143
  23. tinygrad/renderer/__init__.py +47 -23
  24. tinygrad/renderer/cstyle.py +338 -265
  25. tinygrad/renderer/llvmir.py +125 -143
  26. tinygrad/renderer/ptx.py +225 -0
  27. tinygrad/runtime/autogen/adreno.py +17904 -0
  28. tinygrad/runtime/autogen/amd_gpu.py +46974 -11993
  29. tinygrad/runtime/autogen/cuda.py +6 -162
  30. tinygrad/runtime/autogen/io_uring.py +97 -63
  31. tinygrad/runtime/autogen/kfd.py +60 -47
  32. tinygrad/runtime/autogen/kgsl.py +1386 -0
  33. tinygrad/runtime/autogen/libc.py +5462 -0
  34. tinygrad/runtime/autogen/nv_gpu.py +1976 -1957
  35. tinygrad/runtime/autogen/nvrtc.py +579 -0
  36. tinygrad/runtime/autogen/opencl.py +11 -11
  37. tinygrad/runtime/autogen/qcom_dsp.py +1739 -0
  38. tinygrad/runtime/graph/clang.py +3 -3
  39. tinygrad/runtime/graph/cuda.py +11 -15
  40. tinygrad/runtime/graph/hcq.py +120 -107
  41. tinygrad/runtime/graph/metal.py +71 -43
  42. tinygrad/runtime/ops_amd.py +244 -323
  43. tinygrad/runtime/ops_clang.py +12 -5
  44. tinygrad/runtime/ops_cloud.py +220 -0
  45. tinygrad/runtime/ops_cuda.py +42 -99
  46. tinygrad/runtime/ops_disk.py +25 -26
  47. tinygrad/runtime/ops_dsp.py +181 -0
  48. tinygrad/runtime/ops_gpu.py +29 -16
  49. tinygrad/runtime/ops_hip.py +68 -0
  50. tinygrad/runtime/ops_llvm.py +15 -10
  51. tinygrad/runtime/ops_metal.py +147 -64
  52. tinygrad/runtime/ops_nv.py +356 -397
  53. tinygrad/runtime/ops_python.py +78 -79
  54. tinygrad/runtime/ops_qcom.py +405 -0
  55. tinygrad/runtime/support/__init__.py +0 -0
  56. tinygrad/runtime/support/compiler_cuda.py +77 -0
  57. tinygrad/runtime/{driver/hip_comgr.py → support/compiler_hip.py} +13 -1
  58. tinygrad/runtime/support/elf.py +38 -0
  59. tinygrad/runtime/support/hcq.py +539 -0
  60. tinygrad/shape/shapetracker.py +40 -50
  61. tinygrad/shape/view.py +102 -63
  62. tinygrad/tensor.py +1109 -365
  63. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/METADATA +54 -50
  64. tinygrad-0.10.0.dist-info/RECORD +77 -0
  65. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/WHEEL +1 -1
  66. tinygrad/codegen/linearizer.py +0 -528
  67. tinygrad/codegen/uops.py +0 -451
  68. tinygrad/engine/graph.py +0 -100
  69. tinygrad/renderer/assembly.py +0 -269
  70. tinygrad/shape/symbolic.py +0 -327
  71. tinygrad-0.9.1.dist-info/RECORD +0 -63
  72. /tinygrad/{runtime/driver/__init__.py → py.typed} +0 -0
  73. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/LICENSE +0 -0
  74. {tinygrad-0.9.1.dist-info → tinygrad-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,26 @@
1
+ from typing import Optional, List
1
2
  import ctypes, subprocess, pathlib, tempfile
2
3
  from tinygrad.device import Compiled, Compiler, MallocAllocator
3
- from tinygrad.helpers import cpu_time_execution, DEBUG, cpu_objdump
4
+ from tinygrad.helpers import cpu_time_execution, cpu_objdump
4
5
  from tinygrad.renderer.cstyle import ClangRenderer
5
6
 
6
7
  class ClangCompiler(Compiler):
8
+ def __init__(self, cachekey="compile_clang", args:Optional[List[str]]=None, objdump_tool='objdump'):
9
+ self.args = ['-march=native'] if args is None else args
10
+ self.objdump_tool = objdump_tool
11
+ super().__init__(cachekey)
12
+
7
13
  def compile(self, src:str) -> bytes:
8
14
  # TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
9
15
  with tempfile.NamedTemporaryFile(delete=True) as output_file:
10
- subprocess.check_output(['clang', '-include', 'tgmath.h', '-shared', '-march=native', '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-',
11
- '-o', str(output_file.name)], input=src.encode('utf-8'))
16
+ subprocess.check_output(['clang', '-shared', *self.args, '-O2', '-Wall', '-Werror', '-x', 'c', '-fPIC', '-ffreestanding', '-nostdlib',
17
+ '-', '-o', str(output_file.name)], input=src.encode('utf-8'))
12
18
  return pathlib.Path(output_file.name).read_bytes()
13
19
 
20
+ def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool)
21
+
14
22
  class ClangProgram:
15
23
  def __init__(self, name:str, lib:bytes):
16
- if DEBUG >= 6: cpu_objdump(lib)
17
24
  self.name, self.lib = name, lib
18
25
  # write to disk so we can load it
19
26
  with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
@@ -25,4 +32,4 @@ class ClangProgram:
25
32
  class ClangDevice(Compiled):
26
33
  def __init__(self, device:str):
27
34
  from tinygrad.runtime.graph.clang import ClangGraph
28
- super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler("compile_clang"), ClangProgram, ClangGraph)
35
+ super().__init__(device, MallocAllocator, ClangRenderer(), ClangCompiler(), ClangProgram, ClangGraph)
@@ -0,0 +1,220 @@
1
+ # the CLOUD=1 device is a process boundary between the frontend/runtime
2
+ # normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
3
+ # with CLOUD tinygrad is frontend <-> middleware <-> CloudDevice ///HTTP/// cloud_server <-> runtime <-> hardware
4
+ # this client and server can be on the same machine, same network, or just same internet
5
+ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
6
+
7
+ from __future__ import annotations
8
+ from typing import Tuple, Optional, Dict, Any, DefaultDict, List
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass, field
11
+ import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
12
+ from http.server import HTTPServer, BaseHTTPRequestHandler
13
+ from tinygrad.renderer import Renderer
14
+ from tinygrad.dtype import dtypes
15
+ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
16
+ from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
17
+
18
+ # ***** API *****
19
+
20
+ class CloudRequest: pass
21
+
22
+ @dataclass(frozen=True)
23
+ class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
24
+
25
+ @dataclass(frozen=True)
26
+ class BufferFree(CloudRequest): buffer_num: int # noqa: E702
27
+
28
+ @dataclass(frozen=True)
29
+ class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
30
+
31
+ @dataclass(frozen=True)
32
+ class CopyOut(CloudRequest): buffer_num: int
33
+
34
+ @dataclass(frozen=True)
35
+ class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
36
+
37
+ @dataclass(frozen=True)
38
+ class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
39
+
40
+ @dataclass(frozen=True)
41
+ class ProgramExec(CloudRequest):
42
+ name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702
43
+ global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
44
+
45
+ # for safe deserialization
46
+ whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
47
+ eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
48
+ ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
49
+ ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
50
+ def safe_eval(node): return eval_fxns[node.__class__](node)
51
+
52
+ class BatchRequest:
53
+ def __init__(self):
54
+ self._q: List[CloudRequest] = []
55
+ self._h: Dict[str, bytes] = {}
56
+ def h(self, d:bytes) -> str:
57
+ binhash = hashlib.sha256(d).digest()
58
+ self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
59
+ return datahash
60
+ def q(self, x:CloudRequest): self._q.append(x)
61
+ def serialize(self) -> bytes:
62
+ self.h(repr(self._q).encode())
63
+ return b''.join(self._h.values())
64
+ def deserialize(self, dat:bytes) -> BatchRequest:
65
+ ptr = 0
66
+ while ptr < len(dat):
67
+ datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
68
+ self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
69
+ ptr += 0x28+datalen
70
+ self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
71
+ return self
72
+
73
+ # ***** backend *****
74
+
75
+ @dataclass
76
+ class CloudSession:
77
+ programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
78
+ # TODO: the buffer should track this internally
79
+ buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
80
+
81
+ class CloudHandler(BaseHTTPRequestHandler):
82
+ protocol_version = 'HTTP/1.1'
83
+ dname: str
84
+ sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession)
85
+
86
+ def setup(self):
87
+ super().setup()
88
+ print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
89
+
90
+ def _do(self, method):
91
+ session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
92
+ ret, status_code = b"", 200
93
+ if self.path == "/batch" and method == "POST":
94
+ # TODO: streaming deserialize?
95
+ req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
96
+ # the cmds are always last (currently in datahash)
97
+ for c in req._q:
98
+ if DEBUG >= 1: print(c)
99
+ match c:
100
+ case BufferAlloc():
101
+ assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
102
+ session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
103
+ case BufferFree():
104
+ buf,sz,buffer_options = session.buffers[c.buffer_num]
105
+ Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
106
+ del session.buffers[c.buffer_num]
107
+ case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
108
+ case CopyOut():
109
+ buf,sz,_ = session.buffers[c.buffer_num]
110
+ Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
111
+ case ProgramAlloc():
112
+ lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
113
+ session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
114
+ case ProgramFree(): del session.programs[(c.name, c.datahash)]
115
+ case ProgramExec():
116
+ bufs = [session.buffers[x][0] for x in c.bufs]
117
+ extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
118
+ r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
119
+ if r is not None: ret = str(r).encode()
120
+ elif self.path == "/renderer" and method == "GET":
121
+ cls, args = Device[CloudHandler.dname].renderer.__reduce__()
122
+ ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
123
+ else: status_code = 404
124
+ self.send_response(status_code)
125
+ self.send_header('Content-Length', str(len(ret)))
126
+ self.end_headers()
127
+ return self.wfile.write(ret)
128
+
129
+ def do_GET(self): return self._do("GET")
130
+ def do_POST(self): return self._do("POST")
131
+
132
+ def cloud_server(port:int):
133
+ multiprocessing.current_process().name = "MainProcess"
134
+ CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
135
+ print(f"start cloud server on {port} with device {CloudHandler.dname}")
136
+ server = HTTPServer(('', port), CloudHandler)
137
+ server.serve_forever()
138
+
139
+ # ***** frontend *****
140
+
141
+ class CloudAllocator(Allocator):
142
+ def __init__(self, device:CloudDevice):
143
+ self.device = device
144
+ super().__init__()
145
+ # TODO: ideally we shouldn't have to deal with images here
146
+ def _alloc(self, size:int, options:BufferOptions) -> int:
147
+ self.device.buffer_num += 1
148
+ self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
149
+ return self.device.buffer_num
150
+ # TODO: options should not be here in any Allocator
151
+ def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
152
+ def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
153
+ def copyout(self, dest:memoryview, src:int):
154
+ self.device.req.q(CopyOut(src))
155
+ resp = self.device.batch_submit()
156
+ assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
157
+ dest[:] = resp
158
+
159
+ class CloudProgram:
160
+ def __init__(self, device:CloudDevice, name:str, lib:bytes):
161
+ self.device, self.name = device, name
162
+ self.datahash = self.device.req.h(lib)
163
+ self.device.req.q(ProgramAlloc(self.name, self.datahash))
164
+ super().__init__()
165
+ def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash))
166
+
167
+ def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
168
+ self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
169
+ if wait: return float(self.device.batch_submit())
170
+
171
+ class CloudDevice(Compiled):
172
+ def __init__(self, device:str):
173
+ if (host:=getenv("HOST", "")) != "": self.host = host
174
+ else:
175
+ p = multiprocessing.Process(target=cloud_server, args=(6667,))
176
+ p.daemon = True
177
+ p.start()
178
+ self.host = "127.0.0.1:6667"
179
+
180
+ # state for the connection
181
+ self.session = binascii.hexlify(os.urandom(0x10)).decode()
182
+ self.buffer_num = 0
183
+ self.req: BatchRequest = BatchRequest()
184
+
185
+ if DEBUG >= 1: print(f"cloud with host {self.host}")
186
+ while 1:
187
+ try:
188
+ self.conn = http.client.HTTPConnection(self.host, timeout=60.0)
189
+ clouddev = json.loads(self.send("GET", "renderer").decode())
190
+ break
191
+ except Exception as e:
192
+ print(e)
193
+ time.sleep(0.1)
194
+ if DEBUG >= 1: print(f"remote has device {clouddev}")
195
+ # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
196
+ if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
197
+ renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
198
+ if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
199
+ super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
200
+
201
+ def __del__(self):
202
+ # TODO: this is never being called
203
+ # TODO: should close the whole session
204
+ with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
205
+
206
+ def batch_submit(self):
207
+ data = self.req.serialize()
208
+ with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
209
+ ret = self.send("POST", "batch", data)
210
+ self.req = BatchRequest()
211
+ return ret
212
+
213
+ def send(self, method, path, data:Optional[bytes]=None) -> bytes:
214
+ # TODO: retry logic
215
+ self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
216
+ response = self.conn.getresponse()
217
+ assert response.status == 200, f"failed on {method} {path}"
218
+ return response.read()
219
+
220
+ if __name__ == "__main__": cloud_server(getenv("PORT", 6667))
@@ -1,30 +1,14 @@
1
1
  from __future__ import annotations
2
- import subprocess, hashlib, tempfile, ctypes, ctypes.util, functools, re
3
- from pathlib import Path
2
+ import ctypes, ctypes.util, functools
4
3
  from typing import Tuple, Optional, List
5
- import tinygrad.runtime.autogen.cuda as cuda
6
- from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution
7
- from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator, MallocAllocator
4
+ from tinygrad.helpers import DEBUG, getenv, from_mv, init_c_var, init_c_struct_t
5
+ from tinygrad.device import Compiled, BufferOptions, LRUAllocator
8
6
  from tinygrad.renderer.cstyle import CUDARenderer
9
- from tinygrad.renderer.assembly import PTXRenderer
7
+ from tinygrad.renderer.ptx import PTXRenderer
8
+ from tinygrad.runtime.autogen import cuda
9
+ from tinygrad.runtime.support.compiler_cuda import cuda_disassemble, pretty_ptx, CUDACompiler, PTXCompiler, PTX
10
10
  if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 # pylint: disable=unused-import
11
11
 
12
- def pretty_ptx(s):
13
- # all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
14
- s = re.sub(r'([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])', lambda m:m[1]+colored(m[2], "blue")+m[3], s, flags=re.M) # identifiers # noqa: E501
15
- s = re.sub(r'(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])', lambda m:m[1]+colored(m[2], "green")+m[3], s, flags=re.M) # types
16
- s = re.sub(r'^(\s*)([\w]+)(.*?;$)', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # instructions
17
- s = re.sub(r'([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])', lambda m:m[1]+colored(m[2], "yellow")+m[3], s, flags=re.M) # numbers # noqa: E501
18
- s = re.sub(r'(\.)(param|reg|global)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # space
19
- s = re.sub(r'(\.)(version|target|address_size|visible|entry)', lambda m:m[1]+colored(m[2], "magenta"), s, flags=re.M) # derivatives
20
- return s
21
-
22
- CUDACPU = getenv("CUDACPU") == 1
23
- if CUDACPU:
24
- gpuocelot_lib = ctypes.CDLL(ctypes.util.find_library("gpuocelot"))
25
- gpuocelot_lib.ptx_run.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int] # noqa: E501
26
- cuda.cuLaunchKernel = lambda src, gx, gy, gz, lx, ly, lz, shared, stream, unused_extra, args: gpuocelot_lib.ptx_run(src, len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args]), lx, ly, lz, gx, gy, gz, shared) # type: ignore # noqa: E501
27
-
28
12
  def check(status):
29
13
  if status != 0: raise RuntimeError(f"CUDA Error {status}, {ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(), lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()}") # noqa: E501
30
14
 
@@ -36,7 +20,6 @@ def encode_args(args, vals) -> Tuple[ctypes.Structure, ctypes.Array]:
36
20
  return c_args, vargs
37
21
 
38
22
  def cu_time_execution(cb, enable=False) -> Optional[float]:
39
- if CUDACPU: return cpu_time_execution(cb, enable=enable)
40
23
  if not enable: return cb()
41
24
  evs = [init_c_var(cuda.CUevent(), lambda x: cuda.cuEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
42
25
  cuda.cuEventRecord(evs[0], None)
@@ -47,70 +30,34 @@ def cu_time_execution(cb, enable=False) -> Optional[float]:
47
30
  for ev in evs: cuda.cuEventDestroy_v2(ev)
48
31
  return ret.value * 1e-3
49
32
 
50
- def _get_bytes(arg, get_str, get_sz, check) -> bytes:
51
- sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
52
- return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
53
-
54
- class PTXCompiler(Compiler):
55
- def __init__(self, arch:str):
56
- self.arch = arch
57
- self.version = "7.8" if arch >= "sm_89" else "7.5"
58
- super().__init__(f"compile_ptx_{self.arch}")
59
- def compile(self, src:str) -> bytes: return src.replace("TARGET", self.arch).replace("VERSION", self.version).encode()
60
-
61
- class CUDACompiler(Compiler):
62
- def __init__(self, arch:str):
63
- self.arch = arch
64
- check(cuda.nvrtcVersion((nvrtcMajor := ctypes.c_int()), (nvrtcMinor := ctypes.c_int())))
65
- self.compile_options = [f'--gpu-architecture={arch}', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"]
66
- if (nvrtcMajor.value, nvrtcMinor.value) >= (12, 4): self.compile_options.append("--minimal")
67
- super().__init__(f"compile_cuda_{self.arch}")
68
- def compile(self, src:str) -> bytes:
69
- check(cuda.nvrtcCreateProgram(ctypes.byref(prog := cuda.nvrtcProgram()), src.encode(), "<null>".encode(), 0, None, None))
70
- status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options]))
71
-
72
- if status != 0: raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}")
73
- return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check)
74
-
75
- def cuda_disassemble(lib, arch):
76
- try:
77
- fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
78
- with open(fn + ".ptx", "wb") as f: f.write(lib)
79
- subprocess.run(["ptxas", f"-arch={arch}", "-o", fn, fn+".ptx"], check=True)
80
- print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
81
- except Exception as e: print("failed to generate SASS", str(e))
82
-
83
33
  class CUDAProgram:
84
- def __init__(self, device:CUDADevice, name:str, lib:bytes):
85
- self.device, self.name, self.lib = device, name, lib
34
+ def __init__(self, device:CUDADevice, name:str, lib:bytes, smem:int=0):
35
+ self.device, self.name, self.lib, self.smem = device, name, lib, smem
86
36
  if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
87
37
  if DEBUG >= 6: cuda_disassemble(lib, device.arch)
88
38
 
89
- if CUDACPU: self.prg = lib
90
- else:
91
- check(cuda.cuCtxSetCurrent(self.device.context))
92
- self.module = cuda.CUmodule()
93
- status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
94
- if status != 0:
95
- del self.module
96
- cuda_disassemble(lib, device.arch)
97
- raise RuntimeError(f"module load failed with status code {status}: {cuda.cudaError_enum__enumvalues[status]}")
98
- check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
99
- self.prg = prg #type: ignore
39
+ check(cuda.cuCtxSetCurrent(self.device.context))
40
+ self.module = cuda.CUmodule()
41
+ status = cuda.cuModuleLoadData(ctypes.byref(self.module), lib)
42
+ if status != 0:
43
+ del self.module
44
+ cuda_disassemble(lib, device.arch)
45
+ raise RuntimeError(f"module load failed with status code {status}: {cuda.cudaError_enum__enumvalues[status]}")
46
+ check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))
47
+ self.prg = prg
48
+ if self.smem > 0: check(cuda.cuFuncSetAttribute(self.prg, cuda.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, self.smem))
100
49
 
101
50
  def __del__(self):
102
51
  if hasattr(self, 'module'): check(cuda.cuModuleUnload(self.module))
103
52
 
104
53
  def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
105
- if CUDACPU: self.vargs = args+tuple(vals)
54
+ check(cuda.cuCtxSetCurrent(self.device.context))
55
+ if not hasattr(self, "vargs"):
56
+ self.c_args, self.vargs = encode_args(args, vals)
106
57
  else:
107
- check(cuda.cuCtxSetCurrent(self.device.context))
108
- if not hasattr(self, "vargs"):
109
- self.c_args, self.vargs = encode_args(args, vals) #type: ignore
110
- else:
111
- for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
112
- for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
113
- return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)), enable=wait)
58
+ for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i])
59
+ for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i])
60
+ return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, self.smem, None, None, self.vargs)), enable=wait)
114
61
 
115
62
  class CUDAAllocator(LRUAllocator):
116
63
  def __init__(self, device:CUDADevice):
@@ -140,7 +87,7 @@ class CUDAAllocator(LRUAllocator):
140
87
  check(cuda.cuEventRecord(sync_event, None))
141
88
  check(cuda.cuCtxSetCurrent(dest_dev.context))
142
89
  check(cuda.cuStreamWaitEvent(None, sync_event, 0)) # sync the default stream on the dest dev
143
- def offset(self, buf, size:int, offset:int): return ctypes.c_ulong(buf.value + offset)
90
+ def offset(self, buf, size:int, offset:int): return cuda.CUdeviceptr_v2(buf.value + offset)
144
91
 
145
92
  class CUDADevice(Compiled):
146
93
  devices: List[CUDADevice] = []
@@ -148,33 +95,29 @@ class CUDADevice(Compiled):
148
95
 
149
96
  def __init__(self, device:str):
150
97
  device_id = int(device.split(":")[1]) if ":" in device else 0
151
- if not CUDACPU:
152
- check(cuda.cuInit(0))
153
- self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
154
- self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
155
- check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
156
-
157
- for dev in CUDADevice.devices:
158
- check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
159
- if val.value != 1: continue
160
- check(cuda.cuCtxSetCurrent(dev.context))
161
- check(cuda.cuCtxEnablePeerAccess(self.context, 0))
162
- check(cuda.cuCtxSetCurrent(self.context))
163
- check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
164
- CUDADevice.peer_access = True
165
-
166
- self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
98
+ check(cuda.cuInit(0))
99
+ self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
100
+ self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
101
+ check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
102
+
103
+ for dev in CUDADevice.devices:
104
+ check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
105
+ if val.value != 1: continue
106
+ check(cuda.cuCtxSetCurrent(dev.context))
107
+ check(cuda.cuCtxEnablePeerAccess(self.context, 0))
108
+ check(cuda.cuCtxSetCurrent(self.context))
109
+ check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
110
+ CUDADevice.peer_access = True
111
+
112
+ self.arch = f"sm_{major.value}{minor.value}"
167
113
  self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
168
114
  CUDADevice.devices.append(self)
169
115
 
170
116
  from tinygrad.runtime.graph.cuda import CUDAGraph
171
- super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
172
- PTXRenderer(self.arch) if getenv("PTX") else CUDARenderer(self.arch),
173
- PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
174
- functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
117
+ super().__init__(device, CUDAAllocator(self), PTXRenderer(self.arch) if PTX else CUDARenderer(self.arch),
118
+ PTXCompiler(self.arch) if PTX else CUDACompiler(self.arch), functools.partial(CUDAProgram, self), graph=CUDAGraph)
175
119
 
176
120
  def synchronize(self):
177
- if CUDACPU: return
178
121
  check(cuda.cuCtxSetCurrent(self.context))
179
122
  check(cuda.cuCtxSynchronize())
180
123
  for opaque,sz,options in self.pending_copyin: self.allocator.free(opaque, sz, options)
@@ -1,13 +1,11 @@
1
1
  from __future__ import annotations
2
- import os, mmap, _posixshmem, io, ctypes, ctypes.util, platform, contextlib
2
+ import os, sys, mmap, io, ctypes, ctypes.util, contextlib
3
3
  from typing import Optional, Generator, Tuple, Callable, List
4
4
  from tinygrad.helpers import OSX, round_up
5
5
  from tinygrad.device import Compiled, Allocator
6
- import tinygrad.runtime.autogen.io_uring as io_uring
7
-
8
- libc = ctypes.CDLL(ctypes.util.find_library("c"))
9
- libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
10
- libc.mmap.restype = ctypes.c_void_p
6
+ with contextlib.suppress(ImportError):
7
+ import _posixshmem
8
+ from tinygrad.runtime.autogen import io_uring, libc
11
9
 
12
10
  class DiskBuffer:
13
11
  def __init__(self, device:DiskDevice, size:int, offset=0):
@@ -27,7 +25,7 @@ class DiskAllocator(Allocator):
27
25
  def as_buffer(self, src:DiskBuffer): return src._buf()
28
26
  def copyin(self, dest:DiskBuffer, src:memoryview): dest._buf()[:] = src
29
27
  def copyout(self, dest:memoryview, src:DiskBuffer):
30
- if OSX and hasattr(self.device, 'fd'):
28
+ if OSX and self.device.fd is not None:
31
29
  # OSX doesn't seem great at mmap, this is faster
32
30
  with io.FileIO(self.device.fd, "a+b", closefd=False) as fo:
33
31
  fo.seek(src.offset)
@@ -76,6 +74,7 @@ class DiskDevice(Compiled):
76
74
  if not DiskDevice._tried_io_uring_init: self._iouring_setup()
77
75
 
78
76
  self.size: Optional[int] = None
77
+ self.fd: Optional[int] = None
79
78
  self.count = 0
80
79
  super().__init__(device, DiskAllocator(self), None, None, None)
81
80
  def _might_open(self, size):
@@ -85,41 +84,41 @@ class DiskDevice(Compiled):
85
84
  filename = self.dname[len("disk:"):]
86
85
  self.size = size
87
86
 
88
- if filename.startswith("shm:"):
87
+ if sys.platform != "win32" and filename.startswith("shm:"):
89
88
  fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
90
89
  self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
91
90
  os.close(fd)
92
91
  else:
93
- try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|(0 if OSX else os.O_DIRECT))
92
+ try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|getattr(os, "O_DIRECT", 0))
94
93
  except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
95
94
  if os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
96
95
  self.mem = mmap.mmap(self.fd, self.size)
97
- if (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
96
+ if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
98
97
  with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
99
98
  def _might_close(self):
100
99
  self.count -= 1
101
100
  if self.count == 0:
102
- if hasattr(self, 'fd'): os.close(self.fd)
101
+ if self.fd is not None: os.close(self.fd)
103
102
  self.size = None
104
103
  def _iouring_setup(self):
105
104
  DiskDevice._tried_io_uring_init = True
106
105
 
107
- if platform.system() != 'Linux': return
108
-
109
- fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
110
- if fd < 0: return
106
+ if sys.platform == 'linux' and not hasattr(sys, "getandroidapilevel"):
107
+ fd = libc.syscall(io_uring.NR_io_uring_setup, 4096, ctypes.byref(p:=io_uring.struct_io_uring_params()))
108
+ if fd < 0: return
111
109
 
112
- sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
113
- cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
114
- mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
115
- sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
116
- mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)
110
+ sq_ptr = libc.mmap(0, p.sq_off.array + p.sq_entries * 4, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, 0)
111
+ cq_ptr = libc.mmap(0, p.cq_off.cqes + p.cq_entries * ctypes.sizeof(io_uring.struct_io_uring_cqe),
112
+ mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_CQ_RING)
113
+ sqes = libc.mmap(0, p.sq_entries * ctypes.sizeof(io_uring.struct_io_uring_sqe),
114
+ mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_POPULATE, fd, io_uring.IORING_OFF_SQES)
117
115
 
118
- def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
119
- sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail), array=u32ptr(sq_ptr+p.sq_off.array),
120
- kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))
116
+ def u32ptr(val): return ctypes.cast(val, ctypes.POINTER(ctypes.c_uint32))
117
+ sqdesc = io_uring.struct_io_uring_sq(khead=u32ptr(sq_ptr+p.sq_off.head), ktail=u32ptr(sq_ptr+p.sq_off.tail),
118
+ array=u32ptr(sq_ptr+p.sq_off.array),
119
+ kring_mask=u32ptr(sq_ptr+p.sq_off.ring_mask), sqes=ctypes.cast(sqes, ctypes.POINTER(io_uring.struct_io_uring_sqe)))
121
120
 
122
- cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
123
- kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))
121
+ cqdesc = io_uring.struct_io_uring_cq(khead=u32ptr(cq_ptr+p.cq_off.head), ktail=u32ptr(cq_ptr+p.cq_off.tail),
122
+ kring_mask=u32ptr(sq_ptr+p.cq_off.ring_mask), cqes=ctypes.cast(cq_ptr+p.cq_off.cqes, ctypes.POINTER(io_uring.struct_io_uring_cqe)))
124
123
 
125
- DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore
124
+ DiskDevice.io_uring = io_uring.struct_io_uring(ring_fd=fd, sq=sqdesc, cq=cqdesc) # type: ignore